Follow-up to “Frontier MoE sleep/wake at TP=4 on consumer Blackwell” from May 16. That post described the production state. This one describes what changed underneath when we re-validated the upstream story end-to-end.
The earlier sleep/wake post listed vllm-project/vllm#34600 (sleep wake_up partial-map rollback) and a small handful of related PRs as the open dependencies — fixes we were carrying as patches but waiting on upstream to merge. One more dependency landed as an upstream draft after that post went up: vllm-project/vllm#43020 — Make CuMemAllocator free callback stream-aware.
The 43020 fix is a single line, but the bug it closes is the reason GLM-5.1-REAP-478B-A42B-NVFP4 got retired from our TP=4 heavyweight rotation pool in April: “needs a hard reset every N uses.” We re-tested it this week with the patch applied. It survives.
What 43020 actually fixes
CuMemAllocator (the sleep-mode allocator at vllm/device_allocator/cumem.py) is a CUDA pluggable allocator. When a pool-backed tensor’s refcount hits zero, the C extension csrc/cumem_allocator.cpp:unmap_and_release calls cuMemUnmap and cuMemRelease synchronously from the host.
PyTorch’s regular caching allocator records a CUDA event on the stream that last touched the freed block and defers reclaim until the event completes. The pluggable path does not. If a kernel is still in flight against the storage being torn down, cuMemUnmap races against it and surfaces (asynchronously) as CUDA_ERROR_ILLEGAL_ADDRESS at the next synchronous CUDA call, or CUDA Error: invalid argument at cumem_allocator.cpp:146 at the next /wake_up.
The fix: one torch.cuda.synchronize() before the C extension calls cuMemUnmap. Cost lands on cumem-allocator frees only (model load, KV cache init, sleep/wake) — the regular caching allocator is unaffected, so steady-state inference loses nothing.
def _python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool.
+ Drain pending CUDA work before the C extension cuMemUnmaps the
+ pool-backed storage. Without this, kernels still in flight against
+ the freed storage race with cuMemUnmap and surface
+ CUDA_ERROR_ILLEGAL_ADDRESS asynchronously at the next sync.
"""
+ torch.cuda.synchronize()
data = self.pointer_to_data.pop(ptr)
We caught the bug live during validation
Cleanup sequence inside the validation cycle: sleep DSv4 → boot GLM with its own (unpatched) cumem allocator on the same TP=4 GPUs → stop GLM → wake DSv4. The wake failed:
{"error":{"message":"Call to wake_up method failed: Worker failed with error
'CUDA Error: invalid argument at /build/vllm/csrc/cumem_allocator.cpp:146',
please check the stack trace above for the root cause",
"type":"InternalServerError","param":null,"code":500}}
That is the exact bug class #43020 fixes. The cross-image cumem state corruption from running an unpatched second model on the same physical GPUs left DSv4 unable to re-allocate. Hard restart cleared it. With #43020 baked in, the same sequence does not corrupt the allocator state in the first place.
Validation results
We did this without rebuilding any container images, by overlaying just the patched cumem.py via Nomad bind-mount at /usr/local/lib/python3.12/dist-packages/vllm/device_allocator/cumem.py (and the equivalent /opt/vllm/vllm/... path for GLM’s image). One-line patch, zero image build.
Test 1 — Q3.6-27B-NVFP4 sleep/wake cycle
/sleep level=1 → HTTP 200 in 39.4 s
VRAM 58101 → 32307 MiB (25.8 GiB released, engine alive at /health 200)
/wake_up → HTTP 200 in 1.9 s
VRAM 58149 MiB
post-wake chat: coherent ("Hello! How can I help you today?"), no !!!!! garbage
Test 2 — GLM-5.1-REAP-478B-A42B-NVFP4 30-request stress
The retirement bug needed N uses to manifest, so we hammered it with 30 back-to-back chat completions on a freshly resurrected GLM (TP=4 on the same 4 RTX PRO 6000 GPUs DSv4 lives on).
=== 30-request stress test starting 19:08:14 ===
[01] ok 2.2s | 17 + 25 = 42
[02] ok 11.2s | Blue
[03] ok 17.9s | The translation of "good morning" as a greeting in Spanish is "Buenos días"
[04] ok 12.9s | 1. **Understand the User's Request:** ...
[05] ok 2.8s | The capital of France is Paris.
...
[28] ok 6.2s | The largest planet in our solar system is **Jupiter**. ...
[29] ok 12.6s | 1. **Understand the Request:** ...
[30] ok 12.8s | Here are 3 popular Python web frameworks: 1. **Django** ...
PASS=30 FAIL=0 GARBAGE=0 / 30
nomad alloc logs grep CUDA_ERROR_ILLEGAL_ADDRESS: 0
Zero CUDA errors in engine stderr across 30 sequential generations. The original retirement symptom did not recur.
Plumbing it across the fleet
GLM is back into the rotation scheduler. Same pattern applies to MiMo-V2.5-Flash, DeepSeek-V4-Flash, and the AB-1 GPU 5 sleep/wake cohort (Q3.6-27B, Gemma-4-26B, Nemotron-Omni, Qwen3-VL-30B, Q3.6-35B-A3B). Implementation plan staged at .claude/plans/cumem-stream-aware-overlay-fleet-rollout-20260518.md — extract the cumem.py from each image’s vLLM install root, apply the 1-line patch, stage as a bind-mount overlay, and walk the rotation one peer at a time. The rolling restart pattern is the same one our Cumem Bundle v4 / v5 post described for the earlier wave; this is the third overlay on that timeline.
The community-contribution side: a generic Triton shmem-budget helper
While we were validating #43020 we also fixed our own outstanding upstream contribution — [Core] Add shmem-aware autotune pruner for non-H100 Triton kernels. This is a separate, complementary win.
The problem: Triton kernels in vLLM often ship @triton.autotune configurations tuned for H100/H200’s 228 KiB per-block opt-in shared-memory budget. On smaller-shmem GPUs (Turing ~64 KiB, Ampere A100 ~164 KiB, consumer Blackwell SM_120 RTX 5090 / PRO 6000 ~99 KiB), the larger configs raise triton.runtime.errors.OutOfResources at JIT time, killing the worker mid-cold-load.
The existing upstream pattern is per-kernel hand-rolled bucket switches:
# in vllm/model_executor/layers/fla/ops/chunk_o.py — current upstream
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
That works for the worst-case bucket but it’s binary, fires only once at module import, and isn’t always enough. Concrete: on SM_120 the small-bin BK=BV=64 num_stages=4 BT=64 config in chunk_fwd_kernel_o still needs ~131 KiB shmem — over the ~101 KiB opt-in even after the hand-roll’s small-bin selection.
Our helper at vllm/triton_utils/shmem_budget.py replaces hand-rolls with two primitives:
-
infer_shmem_budget(device)— read the actual per-block opt-in shmem fromtorch.cuda.get_device_properties. Cached per-device. Falls back to 48 KiB if torch.cuda isn’t available. -
make_shmem_pruner(estimate_shmem_bytes, *, safety_margin_bytes=1024, on_empty="smallest")— return a Tritonearly_config_prunecallback. Caller supplies a kernel-specific byte estimator; the pruner drops configs exceeding the device budget and falls back to the smallest config (with a one-shot warning) if everything’s too big.
Wired into the autotuner via Triton’s existing prune_configs_by={"early_config_prune": ...} hook — zero changes to Triton itself, zero changes to anyone else’s kernels, zero change to the H100/H200 fast path (every config that fits stays).
We wired the helper into two reference call sites: chunk_gated_delta_rule_fwd_kernel_h_blockdim64 in chunk_delta_h.py (the kernel we originally hit the OOM on) and chunk_fwd_kernel_o in chunk_o.py (the one that already had the hand-rolled BKV_LIST and still wasn’t sufficient). 21 unit tests cover the helper + both estimators across H100, A100, and SM_120 budgets.
The PR is being submitted to vllm-project/vllm as we publish this post. Linked at the bottom.
What’s still open
These four upstream PRs each carry a piece of the consumer-Blackwell sleep/wake story. All four are open as of publication; we’re carrying their patches as Nomad bind-mount overlays in the meantime.
| PR | What it does | Our state |
|---|---|---|
vllm#34600 | Sleep wake_up partial-map rollback | Already in our bundle:v4 / v5 |
vllm#43020 | CuMemAllocator stream-aware free | Validated this week, bind-mount overlay live |
vllm#41602 | Hybrid Mamba/DeltaNet wake_up | Already in cu129-nightly + 41564fix-v2 |
vllm#42856 | Sparse-MLA + indexer workspace bounds on SM_120 | DSv4-only path; not on our hot path yet |
Plus our own contribution:
| PR | What it adds | State |
|---|---|---|
vllm-project/vllm#43047 | [Core] shmem-aware autotune pruner for non-H100 Triton kernels | Submitted |
The one bug not fixed by any of these: Qwen3-Next-80B-A3B-Thinking’s /sleep level=2 returning HTTP 200 without actually releasing weights. That’s a hybrid DeltaNet+SWA architectural issue where the release codepath never fires, not a CuMemAllocator race. No upstream PR addresses it yet — separate work, separate post.
Reproducing
The bind-mount overlay pattern is small enough to drop into any vLLM Nomad spec:
volumes = [
# bake every vLLM patch we carry into the live container
"/path/to/cumem-43020-patched.py:/usr/local/lib/python3.12/dist-packages/vllm/device_allocator/cumem.py:ro",
# /root/.cache is needed writable for FlashInfer JIT on readonly_rootfs containers
"/tmp/<job-name>-rootcache:/root/.cache:rw",
...
]
The patched cumem.py is the one from vllm-project/vllm#43020 head with our minimal port: just the torch.cuda.synchronize() line at the top of _python_free_callback. The full-file swap from upstream main breaks FlashInfer imports on our cu129-nightly base because of vLLM-version skew — the surgical 1-line patch is the safer overlay.
Links
- Earlier post: Frontier MoE sleep/wake at TP=4 on consumer Blackwell
- vllm-project/vllm#43020 — CuMemAllocator stream-aware free
- vllm-project/vllm#34600 — Sleep wake_up partial-map rollback
- vllm-project/vllm#42856 — SM_120 sparse-MLA workspace bounds
- Our PR: vllm-project/vllm#43047 —
[Core] Add shmem-aware autotune pruner for non-H100 Triton kernels - r/LocalLLaMA discussion thread: link will be added when posted
- Hugging Face Posts thread