Skip to content

Add support for NVIDIA DGX Spark (GB10 / sm_121a, arm64)#1835

Open
boots-coder wants to merge 5 commits intoTHUDM:mainfrom
boots-coder:feat/gb10-support
Open

Add support for NVIDIA DGX Spark (GB10 / sm_121a, arm64)#1835
boots-coder wants to merge 5 commits intoTHUDM:mainfrom
boots-coder:feat/gb10-support

Conversation

@boots-coder
Copy link
Copy Markdown

Add NVIDIA DGX Spark (GB10 / sm_121a) support

Summary

This PR adds support for NVIDIA DGX Spark (Project Digits, GB10 chip — Grace
CPU + consumer Blackwell GPU at sm_121a, aarch64, 128 GB unified memory) to
slime's docker build pipeline. GB10 is an explicit gap in the current support
matrix:

  • slime's published images (slimerl/slime:*) are x86_64-only.
  • The arm64 base slimerl/sglang:v0.5.9 ships CUDA 12.9 whose ptxas has no
    sm_121a target, so triton JIT crashes on first kernel.
  • The ENABLE_CUDA_13=1 branch in the upstream Dockerfile is aimed at GB200/GB300
    (sm_100a) with an x86-only sgl-router wheel.

This PR rebases a second Dockerfile (docker/Dockerfile.gb10) on
nvcr.io/nvidia/vllm:26.03-py3 (arm64), which ships CUDA 13.2, PyTorch 2.11.0a0
with compute_120 PTX, Triton 3.6.0, and flash-attn 2.7.4.post1 — the minimal
viable baseline for GB10. Fifteen small blockers are resolved in the Dockerfile
and supporting patches.

End-to-end validation: Qwen2.5-0.5B + GRPO + dapo-math-17k, 1 GB10 GPU,
colocated actor + rollout, one full rollout → reward → policy-update cycle
completed in 2m10s (step 0 metrics logged, weight sync to SGLang succeeded,
checkpoint saved).

What's in the patch

New files

  • docker/Dockerfile.gb10 — digest-pinned, arm64-native build on NGC vllm base
  • docker/patch/gb10/patch_sgl_kernel.py — adds SGL_KERNEL_GB10_ONLY CMake option
  • docker/patch/gb10/sgl-kernel-arch.patch — unified diff of the same
  • docker/patch/gb10/cuda_profiler_api.h — 20-line shim for a CUDA 13 dropped header
  • scripts/run-qwen2.5-0.5B-gb10-smoke.sh — minimal 1-GPU smoke script
  • NOTES_GB10.md — walkthrough of the 15 blockers with root causes

Existing file edits

  • slime/utils/arguments.py — remove trailing comma that turned an argparse
    help= string into a single-element tuple (causes 'tuple' has no attribute 'strip'
    during --help; affects all platforms, not just GB10)

The 15 blockers (quick reference)

# Blocker Resolution
1 ptxas fatal: sm_121a Rebase on NGC CUDA 13.2 image
2 libnvrtc.so.12 missing from sgl_kernel wheel Build sgl_kernel from source
3 sgl_kernel wheel ABI-mismatch vs NGC libtorch Build against NGC torch
4 sgl_kernel arm64 wheels only have sm_100 variant Build sm_121 variant
5 sgl_kernel source build OOM (7 archs × cutlass) SGL_KERNEL_GB10_ONLY CMake option
6 TE: CUDNN::cudnn_engines_precompiled not found Install nvidia-cudnn-cu13==9.20.0.48 pypi wheel, symlink
7 TE: nvtx3/nvToolsExt.h missing Copy headers from NVIDIA/NVTX repo
8 TE ptx.cuh: "use smXXXf, not smXXX" NVTE_CUDA_ARCHS="120f;121f"
9 CMake 3.31 rejects f suffix Upgrade to cmake==4.3.1, override PIP_CONSTRAINT
10 cuda_profiler_api.h missing in CUDA 13 20-line shim header
11 'tuple' object has no attribute 'strip' Remove trailing comma in arguments.py
12 sglang_router x86_64 only Use upstream sglang-router==0.3.2 (arm64 wheel)
13 antlr4==4.13.2 breaks omegaconf Pin antlr4-python3-runtime==4.9.3
14 megatron.training not found post pip install -e PYTHONPATH=/root/src/Megatron-LM
15 tilelang needs libz3.so apt install libz3-dev

Testing

Run inside the image:

cd /root/slime
python train.py --help   # prints 3714 lines, exits 0
bash scripts/run-qwen2.5-0.5B-gb10-smoke.sh   # end-to-end GRPO run, 2m10s

Scope this PR does NOT cover

  • FA3 (Hopper-only) kernels: deliberately skipped; GB10 cannot run sm_90a binaries.
  • Any multi-GPU or multi-node GB10 topology (DGX Spark is a single-GPU SKU).
  • Convergence / training quality validation (separate ongoing benchmark).

Reproducibility

cd slime
docker build -f docker/Dockerfile.gb10 -t slime-gb10:latest .

Base image is pinned by digest:

nvcr.io/nvidia/vllm@sha256:13e327dad79e6e417f6687fec2ba76b0386d597082ec0ee003c1e964ec6ad0e7

@boots-coder
Copy link
Copy Markdown
Author

image image image End-to-end smoke on a real GB10 (DGX Spark). Qwen2.5-0.5B + GRPO + dapo-math-17k, 1 GPU colocated, 1 rollout × 1 policy step. Full cycle: 1m26s, Ray job succeeded, 8/8 weight-sync HTTP calls returned 200 OK.

@boots-coder boots-coder force-pushed the feat/gb10-support branch 2 times, most recently from 37561d6 to 9447a85 Compare April 15, 2026 08:10
The help= kwarg was written as a single-element tuple (trailing comma),
which makes argparse format_help() raise:
  AttributeError: 'tuple' object has no attribute 'strip'
on any `python train.py --help` invocation.

Remove the comma so the help text is a plain string.
slime's published images are x86_64-only. The arm64 slimerl/sglang base
ships CUDA 12.9 whose ptxas lacks an sm_121a target, so Triton crashes on
first JIT. The ENABLE_CUDA_13=1 branch in the upstream Dockerfile is
aimed at GB200/GB300 (sm_100a) with an x86-only sgl-router wheel and does
not work on GB10.

This change adds a second Dockerfile targeting GB10 specifically:

  - Rebased on nvcr.io/nvidia/vllm:26.03-py3 (arm64), pinned by digest.
    Ships CUDA 13.2, PyTorch 2.11.0a0 with compute_120 PTX,
    Triton 3.6.0, flash-attn 2.7.4.post1.

  - sgl-kernel is rebuilt from source with a new CMake option
    SGL_KERNEL_GB10_ONLY=ON (patch_sgl_kernel.py / sgl-kernel-arch.patch)
    that restricts gencode to sm_120a + sm_121a. The stock 7-arch
    emission OOM-kills cicc on 128 GB Spark hosts (each extra cutlass
    FP8 gemm arch costs ~10-15 GB RAM per TU).

  - TransformerEngine 2.10 is built with NVTE_CUDA_ARCHS=120f;121f.
    The 'f' (family-specific) arch suffix is required by TE's ptx.cuh
    static_assert and is only parsed by CMake >= 4.0, so cmake is
    upgraded to 4.3.1 (NGC's PIP_CONSTRAINT must be cleared).

  - Small CUDA 13 gaps that NGC does not ship are filled:
    * cuda_profiler_api.h shim (symbols remain in libcudart.so)
    * NVTX3 headers (copied from github.com/NVIDIA/NVTX)
    * libcudnn_engines_precompiled.so.9 (from nvidia-cudnn-cu13 wheel)

  - The x86-only zhuzilin/sgl-router wheel is swapped for upstream
    sglang-router==0.3.2 (has arm64 wheel). slime's version-compare
    code still works; only the 'slime' in version wandb branch differs.

Also adds scripts/run-qwen2.5-0.5B-gb10-smoke.sh: a 1-GPU colocated
smoke that exercises the full rollout -> reward -> policy-update ->
weight-sync cycle. Validated end to end with Qwen2.5-0.5B on
dapo-math-17k; a single step completes in ~2m10s on GB10.
Captures the root causes and resolutions for every issue encountered
while porting the full stack (sgl-kernel, TransformerEngine, apex,
Megatron-LM, sglang, slime) to GB10 / sm_121a. Placed next to
Dockerfile.gb10 so reviewers see the reference material when reading
the Dockerfile.
…nc, math_verify, pyarrow, accelerate)

These packages were installed manually during the interactive GB10 build
session but were missing from the Dockerfile, breaking clean reproduction.
- numpy<2: Megatron requires numpy 1.x (NGC ships 2.x)
- pylatexenc, math_verify, word2number: reward function runtime deps
- pyarrow: GSM8K parquet data preprocessing
- accelerate: HuggingFace transformers device_map support
@boots-coder
Copy link
Copy Markdown
Author

boots-coder commented Apr 16, 2026

100-rollout convergence run + held-out test set evaluation on GB10

Following the smoke test posted earlier, I ran a full 100-rollout GRPO convergence experiment on the GB10 (single GPU, colocated mode), then evaluated on the held-out GSM8K test set (1,319 questions, greedy decoding).

Test Set Results (Pass@1, greedy, 1319 questions)

Model Correct Accuracy No \boxed{} output Avg response length
Baseline (Qwen2.5-0.5B-Instruct) 575 / 1319 43.59% 33 (2.5%) 1071 chars
+ 100-step GRPO (iter_99) 662 / 1319 50.19% 3 (0.2%) 907 chars
Improvement +87 +6.60 pp −30 (format compliance ↑) −15.3% (more concise)

Three dimensions improved simultaneously:

  1. Accuracy +6.60 pp on never-seen test questions
  2. Format compliance 97.5% → 99.8% — model learned to reliably output \boxed{}
  3. Responses 15% shorter — RL optimized for quality, not length

Training Configuration

  • Model: Qwen2.5-0.5B-Instruct (494M params)
  • Data: GSM8K train split (7,473 questions), --rm-type math (grade_answer_verl)
  • Scale: 20 prompts × 8 samples × 100 rollouts = 16,000 total samples (< 1 epoch)
  • Hyperparams: lr=1e-6 constant, GRPO with eps-clip=0.2/0.28, colocated single GPU
  • Training time: 1h 11min (41.9s/rollout avg)
  • Eval time: 28min (2 × 1319 greedy, HF transformers backend)

Training Stability

Metric Value
grad_norm 0.93 mean (range 0.73–1.18, no explosion)
entropy_loss 0.295 → 0.172 (monotonic decrease)
kl_loss vs ref 0.0 → 0.044 (smooth, controlled)
Rollout throughput 5,495 tok/gpu/s
Actor train TFLOPS 16.83 (mean)

Reward Curve (training rollout, temp=1.0)

fig1_reward_curve fig2_stability fig3_length

Rollout raw_reward trajectory: 0.375 → 0.656 (peak 0.763 at rollout 58).
Note: rollout metrics use temp=1.0 sampling on training data (8 samples/prompt), so they're higher than the greedy test-set Pass@1 reported above.

Reproduction

All scripts, configs, and the updated Dockerfile.gb10 (with 6 missing runtime deps fixed in commit 8f55ecb1) are in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant