Sleep/wake rotation works fine for small dense models on a single GPU — load weights into pinned CPU RAM, swap them back to VRAM in under a second. That’s not the interesting case. The interesting case is frontier MoE at TP=4: ~17 GiB of sharded weights per GPU, sparse routing tables, KV cache across four ranks, and a co-resident peer model on the same physical cards. For months that combination on consumer Blackwell either OOM’d, hung in cuMemMap, or produced silent garbage outputs after a few rotation cycles.
It now works. Same-peer /wake_up: 1 second. Cross-peer swap (sleep model A, wake model B on the same 4 GPUs): 3 seconds. Down from 50. Live in production on 4× RTX PRO 6000 Blackwell with DeepSeek-V4-Flash and MiMo-V2.5-Flash sharing the same TP=4 pool. This is what it took — a specific image stack, two cherry-picked upstream PRs, and one non-obvious config knob.
The recipe is at github.com/DoradusResearch/vllm-blackwell-sm12x-bundle.
The setup
4× NVIDIA RTX PRO 6000 Blackwell Workstation Edition. 95 GiB per GPU. SM_120 (consumer Blackwell, no Fabric/RDMA). Goal: run a 2-model rotation pool where one is awake serving inference and the other is /sleep’d (weights copied to pinned CPU memory) so they share the same GPUs without doubling hardware.
In vLLM terms: --enable-sleep-mode with level-1 sleep. Documented use case, well-tested on H100/H200.
On consumer Blackwell, in our experience: an obstacle course.
What broke, in order
1. The b12x community image had cumem state bugs
The closest existing image is the voipmonitor b12x fork, which patches vLLM for SM_120-specific kernel paths. Excellent baseline, but on cross-peer /wake_up we got CUDA Error: invalid argument at cumem_allocator.cpp:145 reliably.
After digging it turned out this is the bug PR #35489 addresses: vLLM’s cumem allocator queries CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED to decide whether to use Fabric handles, and on hardware without Fabric (all consumer Blackwell), that query returns CUDA_ERROR_INVALID_VALUE. The error code gets cached in a global, and the next cuMemMap returns EINVAL because it sees the stale error state.
The fix is a one-line error_code = no_error; reset at the top of create_and_map. The PR has been open since March; it isn’t in any released vLLM image yet.
2. The cumem cycle leak is real but manageable
Even with #35489 applied, repeated /sleep + /wake_up cycles leak pinned GPU memory. We measured up to ~5 GiB per cycle, eventually overlapping live weight tensors and producing silent garbage outputs around cycle 4-5.
There are about 7 open vLLM issues tracking this (#36651, #21336, #34600, #37111, etc.). PR #34600 adds proper rollback in wake_up when partial allocations fail — important for not leaking on every failed wake even if it doesn’t directly fix the slow per-cycle accumulation. We cherry-pick that too.
3. The PR that fixes everything for DeepSeek-V4 also breaks the memory budget
PR #41834 is the upstream effort to land native DeepSeek-V4 support on SM12x: Triton sparse-MLA fallback, DeepGEMM-free paths, MLA prefix-cache fix. It’s the only credible path to native DSv4 without a community fork, and it makes /sleep synchronous (no more 35-second POST_SLEEP_GAP_S workaround we used to need with the jasl-on-cu130 image).
But PR #41834 also adds about ~22 GiB of GPU state per GPU that lives outside the vLLM cumem allocator’s budget — sparse-MLA workspace, marlin scratch, cuda-graph private pools. On H200 (140 GiB per GPU) this is invisible headroom. On consumer Blackwell (95 GiB) it overflows.
vLLM’s default --gpu-memory-utilization 0.85 × 95 GiB = ~80 GiB cumem budget. Plus the 22 GiB non-cumem. Plus a few GiB for any sleeping co-tenant peer. Total: ~105 GiB on a 95 GiB GPU. OOM.
4. The 22 GiB pool ignores every config knob you’d reach for
We spent a couple of days trying to shrink the 22 GiB:
--max-cudagraph-capture-sizefrom 12 to 4 to 1: no effect.cudagraph_mode: PIECEWISEonly (vs FULL+PIECEWISE): no effect.--enforce-eager(disable cuda graphs entirely): no effect. ← the diagnostic.--max-num-batched-tokens8192 → 2048: no effect.--max-num-seqs12 → 4: no effect.--max-model-len131072 → 32768: no effect.
The 22 GiB is hardcoded into PR #41834’s compilation framework. It’s workspace tensors and pre-allocated buffers that the user can’t tune from spec.
The misleading thing: PyTorch’s OOM message labels this “22.7 GiB allocated in private pools (e.g., CUDA Graphs)”. The “(e.g., CUDA Graphs)” sent us chasing graph-capture configuration for days. The label is technically correct but misleading — PyTorch’s “private pools” catches any tensor allocated via a torch.cuda.MemPool, including allocations done outside graph capture. The 22 GiB is workspace tensors, not graphs.
5. The actual fix is operational, not source-level
Once we understood that the 22 GiB is a hardcoded non-cumem footprint, the fix becomes obvious:
--gpu-memory-utilization 0.70 # not 0.85
cuMemMap race conditions. PR #35489. PR #34600. PR #41834. Sparse-MLA workspace archaeology. PyTorch private-pool misattribution. And the final unlock was one config knob.0.70 × 95 GiB = 66.5 GiB cumem budget. Plus 22 GiB non-cumem. Plus 2.4-4.8 GiB sleeping peer residue. Total: ~92 GiB on a 95 GiB GPU. ~3 GiB margin. Fits.
The cost: cumem budget is smaller, so less room for weights + KV cache. DeepSeek-V4 weights at TP=4 are ~17 GiB per GPU; KV cache for max_model_len=131072 max_num_seqs=12 at FP8 is ~7 GiB. Total ~24 GiB, comfortably inside 66.5 GiB.
Measured outcome
| Operation | Before (jasl-on-cu130 image) | After (this stack) |
|---|---|---|
/sleep | Async, returns 200 in ~5s, actual unmap takes another 35s | Synchronous, 2-25s (returns when actually done) |
/wake_up (same peer) | 2s | 1s |
/wake_up (cross-peer, after /sleep) | Needed 35s POST_SLEEP_GAP_S workaround to avoid cuMemMap race | 1s, no workaround |
| Cross-peer swap total | ~50s | ~3s |
3 seconds is the headline number. For a rotation pool where one user request can trigger a cross-peer swap, 50s → 3s is the difference between “annoying” and “invisible.”
What we learned that wasn’t obvious
-
PyTorch’s “private pools” OOM label includes the cumem allocator’s MemPool. Not just cuda graphs. If you’re chasing graph-capture configuration based on this label, you might be looking at the wrong thing.
-
current_platform.is_device_capability_family(120)is the right pattern for SM12x-specific config. If you’re adding consumer-Blackwell adaptations, gate them behind this so non-SM12x users see no change. -
PR #41834’s perf wins come with memory costs that aren’t called out in the PR description. The 22 GiB non-cumem footprint isn’t documented. If you’re testing on H200 you’d never notice; on consumer Blackwell you’ll hit it immediately.
-
--enforce-eageras a DIAGNOSTIC is more useful than as a deployment posture. It’s the cleanest way to ask “is this OOM coming from graphs?” Even when you don’t want to deploy with it, run one experiment with it to disambiguate. -
The cumem cycle leak is real but secondary. We chased it for weeks thinking it was the primary blocker. With PR #35489 + PR #34600 + the
0.70config, we no longer have to alloc-restart for cumem accumulation under our normal rotation load.
What about other attention types
Sleep-mode behavior on consumer Blackwell is attention-type sensitive. What works for one architecture’s KV cache layout and routing table doesn’t necessarily work for another. What we’ve validated on this stack:
| Model | Attention | Status |
|---|---|---|
| DeepSeek-V4-Flash | sparse-MLA (Triton fallback via PR #41834) | Full sleep/wake at TP=4. Live. |
| MiMo-V2.5-Flash | dense MLA | Full sleep/wake at TP=4. Live. |
| Q3-Coder-Next-80B-A3B | hybrid DeltaNet + SWA | /sleep level=2 does not release cleanly on the cu129 image — VRAM stays held. We keep this one always-awake. Tracking vllm#41602. |
| Single-GPU 7B-class dense | standard MHA / GQA | Sleep/wake works trivially. Sub-second on llama-swap pools. This is the easy case. |
If you’re trying sleep-mode on a model with an unusual attention variant (hybrid linear/SWA, RWKV-style, Mamba) on consumer Blackwell, budget extra time. The cumem allocator path is uniform; the per-architecture /sleep + /wake_up release paths are not.
What’s in the repo
- Dockerfile.builder + Dockerfile.bundle (the two-stage build)
patches/cumem-fix-stack-pr35489-pr34600.patch(the cherry-pick stack, regenerable)nomad-examples/vllm-dsv4-flash.nomad(a working production spec)- Full README with the gotchas above
Apache-2.0. PRs welcome — especially benchmarks on hardware we haven’t tested.
Acknowledgements
@jasl + @aabbccddwasd for PR #41834, which is the entire SM12x DSv4 enablement. @haosdent for PR #35489, the one-line fix that took us a week to find. The vLLM cumem allocator authors for the sleep/wake_up design that makes 3-second model swap possible at all.