Cuda-graph capturable Dispatch and combine#6031
Conversation
|
!test |
Greptile SummaryThis PR replaces TCPStore-based CPU synchronization in the Key issues to address before merge:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant A as Rank A (stream)
participant B as Rank B (stream)
Note over A,B: prepareAlltoallvMetadataGpu (counts exchange)
A->>A: cudaMemcpyD2D send_counts → sync_buf_[A]
B->>B: cudaMemcpyD2D send_counts → sync_buf_[B]
A-->>B: cuStreamWriteValue32(B.countsSem[A], kInProgress)
B-->>A: cuStreamWriteValue32(A.countsSem[B], kInProgress)
A->>A: cuStreamWaitValue32(A.countsSem[B], EQ kInProgress)
B->>B: cuStreamWaitValue32(B.countsSem[A], EQ kInProgress)
A->>B: cudaMemcpyD2D syncRemotePtr[B] → counts_matrix[B]
B->>A: cudaMemcpyD2D syncRemotePtr[A] → counts_matrix[A]
A->>A: cuStreamWriteValue32(A.countsSem[B], kIdle) [reset]
B->>B: cuStreamWriteValue32(B.countsSem[A], kIdle) [reset]
Note over A,B: alltoallvWithCudaBackend ×N (payload RDMA writes)
A-->>B: NVLink write recv_x/topk_idx/topk_weights/src_idx chunks
B-->>A: NVLink write recv_x/topk_idx/topk_weights/src_idx chunks
Note over A,B: doneBarrier
A-->>B: cuStreamWriteValue32(B.doneSem[A], kInProgress)
B-->>A: cuStreamWriteValue32(A.doneSem[B], kInProgress)
A->>A: cuStreamWaitValue32(A.doneSem[B], EQ kInProgress)
B->>B: cuStreamWaitValue32(B.doneSem[A], EQ kInProgress)
A->>A: cuStreamWriteValue32(A.doneSem[B], kIdle) [reset]
B->>B: cuStreamWriteValue32(B.doneSem[A], kIdle) [reset]
Last reviewed commit: 1f622c7 |
|
!test |
…_total to avoid kernel grid over-provisioning
| const int64_t elem_stride = | ||
| metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; |
There was a problem hiding this comment.
Divisibility guard removed — silent wrong elem_stride if mismatched
The PR removes the checks:
NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0, ...);
NVF_CHECK(
metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, ...);elem_stride is computed as send.numel() / metadata.max_send_total. If send.numel() is not divisible by max_send_total (e.g. because a caller passes mismatched metadata), integer truncation silently gives a wrong stride. Every send_offsets, send_counts, and recv_offsets is then scaled by this wrong value before being passed to the kernel, producing corrupted data without any error. The checks were cheap and provided essential diagnostic value; removing them for graph-capturability does not improve performance because they are CPU-side and never execute inside a captured region.
| const at::Tensor& src_idx, | ||
| const at::Tensor& n_tokens_to_rank, | ||
| const at::Tensor& n_tokens_from_rank, | ||
| int64_t num_tokens, | ||
| Communicator* communicator, | ||
| CommunicatorBackend backend) { | ||
| NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); | ||
| NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); | ||
| const bool has_topk_weights = topk_weights.numel() > 0; | ||
| if (has_topk_weights) { | ||
| NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); | ||
| NVF_CHECK( | ||
| topk_weights.is_floating_point(), | ||
| "Combine topk_weights must be floating point."); | ||
| NVF_CHECK( | ||
| topk_weights.dim() == 2 && topk_weights.size(0) == x.size(0) && | ||
| topk_weights.size(1) == 1, | ||
| "topk_weights must be shape [T, 1], got: ", | ||
| topk_weights.sizes()); | ||
| } | ||
| NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); | ||
| NVF_CHECK( | ||
| n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); | ||
| NVF_CHECK( | ||
| n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA."); | ||
| NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); | ||
| NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); | ||
| n_tokens_to_rank.is_cuda() && n_tokens_from_rank.is_cuda(), | ||
| "Combine count tensors must be on CUDA."); | ||
| NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D."); | ||
| NVF_CHECK_EQ( | ||
| src_idx.size(0), x.size(0), "src_idx size must match x first dimension."); | ||
| NVF_CHECK_EQ( | ||
| n_tokens_to_rank.numel(), | ||
| communicator->size(), | ||
| "n_tokens_to_rank must match world size."); | ||
| NVF_CHECK_EQ( | ||
| n_tokens_from_rank.numel(), | ||
| communicator->size(), | ||
| "n_tokens_from_rank must match world size."); | ||
|
|
||
| // Reconstruct source ranks from per-rank counts. alltoall_base concatenates | ||
| // received chunks in rank order, so this matches the receive layout. | ||
| auto src_rank = at::arange( | ||
| n_tokens_from_rank.numel(), | ||
| at::TensorOptions().dtype(at::kLong).device(x.device())) | ||
| .repeat_interleave(n_tokens_from_rank.to(at::kLong)); | ||
| NVF_CHECK_EQ( | ||
| src_rank.size(0), | ||
| x.size(0), | ||
| "Reconstructed src_rank must match x first dimension."); | ||
| // Sort by source rank so alltoall can send contiguous chunks per rank. | ||
| auto sorted_indices = at::argsort(src_rank); | ||
| auto send_x = x.index_select(0, sorted_indices); | ||
| auto send_src_idx = src_idx.index_select(0, sorted_indices); | ||
|
|
There was a problem hiding this comment.
Several input validations removed from doMoeCombine
The following checks that existed in the old code were removed and are not replaced:
NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D.")— a 2Dsrc_idxwould still pass thesize(0) == x.size(0)check and cause a silent runtime error insideindex_copy_.NVF_CHECK_EQ(n_tokens_to_rank.numel(), communicator->size(), ...)and the equivalent forn_tokens_from_rank— without these,toSplitSizes(NCCL path) silently operates on a wrong-sized tensor, andprepareAlltoallvMetadataGpu(CUDA path) reads/writesWentries from a tensor that may have fewer.
These checks are CPU-side and do not execute during graph capture, so removing them provides no graph-capturability benefit.
|
|
||
| // Scatter by original token index to restore local order. | ||
| auto combined_x = at::empty({total_recv, hidden}, x.options()); | ||
| combined_x.index_copy_(0, recv_src_idx, recv_x); | ||
| auto combined_x = at::zeros({num_tokens, hidden}, x.options()); | ||
| combined_x.index_copy_( | ||
| 0, | ||
| rs.buffer.narrow(0, 0, num_tokens), | ||
| rx.buffer.narrow(0, 0, num_tokens)); | ||
|
|
There was a problem hiding this comment.
at::zeros initialization of combined_x silently masks index errors
combined_x is initialized to all-zeros, then index_copy_ is expected to fill all num_tokens rows. If rs.buffer.narrow(0, 0, num_tokens) ever contains duplicate indices or misses some positions (due to a bug in the alltoallv or send_counts mismatch), the affected rows will silently be zero rather than triggering any error. Using at::empty plus an assertion that each index in [0, num_tokens) appears exactly once would make the contract explicit.
If zeros are intentional as a safety net, a brief comment explaining the guarantee (that all num_tokens positions are covered exactly once) would clarify the intent.
| ctx, | ||
| n_tokens_to_rank, | ||
| /*max_send_total=*/num_tokens, | ||
| /*max_send_bytes=*/num_tokens, | ||
| /*max_recv=*/capacity, | ||
| stream); | ||
| auto n_tokens_from_rank = metadata.recv_counts; | ||
| const int64_t total_recv = metadata.total_recv; | ||
| const int64_t max_recv = metadata.max_recv; | ||
|
|
||
| // Allocate symmetric buffers for send/recv payloads. | ||
| auto send_x_sym = SymmetricTensor::allocate( | ||
| {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); | ||
| send_x_sym.narrow(0, 0, num_tokens).copy_(send_x); | ||
| auto send_topk_idx_sym = SymmetricTensor::allocate( | ||
| {metadata.max_send_total, topk_idx.size(1)}, | ||
|
|
||
| auto& rx = ctx.recv("x", capacity, {hidden}, x.scalar_type(), x.device()); | ||
| auto& ri = ctx.recv( | ||
| "topk_idx", | ||
| capacity, | ||
| {topk_idx.size(1)}, | ||
| topk_idx.scalar_type(), | ||
| x.device()); |
There was a problem hiding this comment.
DispatchResult from CUDA backend returns full capacity buffers — callers must not narrow by recv_counts
recv_x, recv_topk_idx, recv_topk_weights, and recv_src_idx all have shape [C, ...] = [T*R, ...] with only the first V = sum(n_tokens_from_rank) rows valid. Any downstream operation that passes these tensors to kernel code treating the first dimension as the actual token count (e.g., a linear layer, softmax, or gradient accumulation) will silently process T*R - V garbage rows.
The header now documents this contract (DispatchResult comments), but the HostIrEvaluator that schedules the fusion kernel between dispatch and combine receives these over-allocated tensors with no enforcement that the kernel only operates on valid rows. Consider adding a runtime check in the evaluator path or document that any kernel scheduled between dispatch and combine must accept and correctly ignore padding rows.
|
!test |
| auto& entry = recv_entries_[name]; | ||
| if (entry.sym && entry.cached_first_dim >= first_dim) { | ||
| return entry.handle; | ||
| } |
There was a problem hiding this comment.
Cache hit doesn't validate dtype or extra_sizes
The cache returns the existing RecvHandle whenever entry.cached_first_dim >= first_dim, but does not check that dtype or extra_sizes match the originally-allocated buffer. If the same logical name (e.g., "x") were ever called with a different scalar type or a different extra dimension (e.g., hidden size changes), the kernel would silently receive a buffer with the wrong element type or stride, leading to data corruption.
While the current callers always pass consistent types for a given name, this implicit contract is not enforced. A defensive check would make it explicit:
if (entry.sym && entry.cached_first_dim >= first_dim) {
NVF_CHECK(
entry.cached_dtype == dtype,
"SymMemForAlltoallv::recv: buffer '", name,
"' dtype mismatch (cached ", entry.cached_dtype, " vs requested ", dtype, ").");
// similarly for extra_sizes
return entry.handle;
}At minimum, storing and asserting on dtype on every cache hit would catch mismatches early.
| SymMemForAlltoallv::SymMemForAlltoallv( | ||
| at::Device device, | ||
| const std::string& tag) | ||
| : tag_(tag) { | ||
| Communicator& comm = Communicator::getInstance(); | ||
| world_size_ = comm.size(); | ||
| my_rank_ = comm.deviceId(); | ||
|
|
||
| sync_buf_ = SymmetricTensor::allocate({world_size_ + 2}, at::kLong, device); | ||
| sync_buf_.zero_(); | ||
|
|
||
| sync_sym_ = std::make_unique<SymmetricTensor>(sync_buf_); | ||
| sync_sym_->setupRemoteHandles(tag + "_sync"); | ||
|
|
||
| sync_ptrs_.resize(world_size_); | ||
| for (int64_t r = 0; r < world_size_; r++) { | ||
| sync_ptrs_[r] = | ||
| reinterpret_cast<CUdeviceptr>(sync_sym_->remoteTensor(r).data_ptr()); | ||
| } |
There was a problem hiding this comment.
SymMemForAlltoallv uses global communicator singleton, ignoring the caller's communicator
The constructor captures world size and rank from Communicator::getInstance(), but both doMoeDispatch and doMoeCombine accept an arbitrary Communicator* communicator parameter. If a caller passes a non-singleton communicator (e.g., a sub-communicator for a subset of ranks), the SymMemForAlltoallv context will be initialised with the wrong world size and rank. This leads to incorrect sync-buffer layout (wrong number of semaphore/count slots), wrong send_count reads in prepareAlltoallvMetadataGpu (which iterates ctx.worldSize()), and potential out-of-bounds access into sync_ptrs_.
The getOrCreateAlltoallv helper does not accept a communicator, so there is currently no way for the caller's communicator to propagate down to the context. The communicator (or at least its size() and deviceId()) should be threaded through to the constructor.
| auto n_tokens_to_rank = | ||
| at::zeros({world_size}, gpu_long_opts) | ||
| .scatter_add( | ||
| 0, rank_for_token_long, at::ones({num_tokens}, gpu_long_opts)); |
There was a problem hiding this comment.
scatter_add silently accepts out-of-range indices — regression from bincount
rank_for_token_long = floor_divide(topk_idx_long, experts_per_rank). If any token's expert ID is invalid (e.g., >= num_experts), the resulting rank index will be >= world_size. PyTorch's scatter_add with an output of size world_size will silently write to a position outside the allocated tensor, corrupting adjacent memory and producing wrong n_tokens_to_rank counts without any diagnostic.
The old at::bincount(rank_for_token_cpu, {}, world_size) would have produced a tensor longer than world_size if values exceeded world_size - 1, which at least makes the anomaly detectable downstream (e.g., the earlier NVF_CHECK_EQ for tensor sizes would have fired). Adding an explicit bounds check before scatter_add preserves the error-detection property without a CPU-GPU sync:
NVF_CHECK(
rank_for_token_long.min().item<int64_t>() >= 0 &&
rank_for_token_long.max().item<int64_t>() < world_size,
"topk_idx maps to an out-of-range rank.");Note: this check would be a CPU-GPU sync and would need to be performed outside the captured graph region (e.g., as a one-time warmup assertion).
|
!test |
csrc/multidevice/ipc_handle.cpp
Outdated
| sync_buf_ = SymmetricTensor::allocate({3 * world_size_}, at::kLong, device); | ||
| sync_buf_.zero_(); | ||
|
|
||
| sync_sym_ = std::make_unique<SymmetricTensor>(sync_buf_); | ||
| sync_sym_->setupRemoteHandles(tag + "_sync"); | ||
|
|
||
| sync_ptrs_.resize(world_size_); | ||
| for (int64_t r = 0; r < world_size_; r++) { | ||
| sync_ptrs_[r] = | ||
| reinterpret_cast<CUdeviceptr>(sync_sym_->remoteTensor(r).data_ptr()); | ||
| } |
There was a problem hiding this comment.
sync_buf_ allocated as int64 but semaphore slots use 32-bit stream ops
sync_buf_ is allocated with at::kLong (8 bytes/slot). The semaphore addresses are computed as sync_ptrs_[rank] + N * sizeof(int64_t), placing each semaphore 8 bytes apart. However, CU_STREAM_MEM_OP_WRITE_VALUE_32 and CU_STREAM_MEM_OP_WAIT_VALUE_32 operate on 4-byte quantities.
On little-endian NVIDIA hardware this happens to work — the 32-bit write goes to the lower 4 bytes of the 8-byte slot and the upper 4 bytes remain zero — but it is a type mismatch that relies on:
- The GPU being little-endian.
- The upper 4 bytes of each semaphore slot never being touched by any 64-bit op.
A more explicit design would allocate the semaphore region as at::kInt (or in a separate tensor), so that WRITE/WAIT_VALUE_32 addresses map directly to element boundaries. Alternatively, computing semaphore addresses with sizeof(int32_t) strides within a dedicated 32-bit allocation would make the intent and layout unambiguous.
| void SymMemForAlltoallv::doneBarrier(CUstream stream) { | ||
| batchSignal( | ||
| stream, | ||
| static_cast<cuuint32_t>(IpcSemaphore::kInProgress), | ||
| &SymMemForAlltoallv::doneSemAddr); | ||
| batchWait( | ||
| stream, | ||
| static_cast<cuuint32_t>(IpcSemaphore::kInProgress), | ||
| &SymMemForAlltoallv::doneSemAddr); | ||
| batchReset( | ||
| stream, | ||
| static_cast<cuuint32_t>(IpcSemaphore::kIdle), | ||
| &SymMemForAlltoallv::doneSemAddr); | ||
| } |
There was a problem hiding this comment.
doneBarrier resets own slots but write ordering across ranks is unspecified
The barrier sequence is:
batchSignal(kInProgress)— write to peers'done_semslotsbatchWait(kInProgress)— wait for peers to write to MYdone_semslotsbatchReset(kIdle)— reset MYdone_semslots
Step 3 resets doneSemAddr(my_rank, r) (my own memory). However, the batchReset uses CU_STREAM_WRITE_VALUE_DEFAULT. Between step 3 on rank A and step 1 of the NEXT replay on rank B, there is no formal guarantee that the reset write at A is visible before rank B's next signal arrives. Because CU_STREAM_WAIT_VALUE_EQ is used (not GEQ), correctness relies on the reset reaching the memory before the next iteration's batchSignal on the peer, which the doneBarrier's completion semantics guarantee (peers can only start the next replay after their own barrier completes). This is correct by the protocol ordering, but the dependency is non-obvious and would benefit from a short comment explaining why reset on one rank cannot race with the next signal from the peer.
| NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); | ||
| NVF_CHECK( | ||
| (int64_t)recv_ptrs.size() == metadata.world_size, | ||
| "recv_ptrs size must match world size."); | ||
|
|
||
| auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); | ||
| auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); | ||
| auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>(); | ||
| for (int64_t rank = 0; rank < metadata.world_size; ++rank) { | ||
| ptrs[rank] = | ||
| static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank])); | ||
| } | ||
| auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); | ||
|
|
||
| const int64_t elem_stride = | ||
| metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; | ||
| NVF_CHECK( | ||
| metadata.max_send_total == 0 || | ||
| send.numel() % metadata.max_send_total == 0, |
There was a problem hiding this comment.
Missing validation for
recv_ptrs_gpu — no size or device check
The old call-site accepted const std::vector<void*>& recv_ptrs and explicitly verified:
NVF_CHECK(
(int64_t)recv_ptrs.size() == metadata.world_size,
"recv_ptrs size must match world size.");It also coerced the pointer table to the send device via .to(send.device()).
The new at::Tensor recv_ptrs_gpu has neither check: if it has fewer than world_size entries the kernel silently reads garbage pointers; if it lives on the wrong device the launch will fault. remotePointersTensor() always produces a [world_size] tensor on the right device by construction, but the API contract is now implicit and fragile for any future caller. Consider adding:
| NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); | |
| NVF_CHECK( | |
| (int64_t)recv_ptrs.size() == metadata.world_size, | |
| "recv_ptrs size must match world size."); | |
| auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); | |
| auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); | |
| auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>(); | |
| for (int64_t rank = 0; rank < metadata.world_size; ++rank) { | |
| ptrs[rank] = | |
| static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank])); | |
| } | |
| auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); | |
| const int64_t elem_stride = | |
| metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; | |
| NVF_CHECK( | |
| metadata.max_send_total == 0 || | |
| send.numel() % metadata.max_send_total == 0, | |
| NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); | |
| NVF_CHECK( | |
| recv_ptrs_gpu.is_cuda() && recv_ptrs_gpu.device() == send.device(), | |
| "recv_ptrs_gpu must be a CUDA tensor on the same device as send."); | |
| NVF_CHECK( | |
| recv_ptrs_gpu.dim() == 1 && | |
| recv_ptrs_gpu.size(0) == metadata.world_size, | |
| "recv_ptrs_gpu must have shape [world_size]."); |
|
|
||
| auto counts_matrix = at::empty({W, W}, gpu_opts); | ||
| for (int64_t r = 0; r < W; r++) { | ||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( | ||
| counts_matrix[r].data_ptr<int64_t>(), | ||
| reinterpret_cast<void*>(a2av.syncRemotePtr(r)), | ||
| W * sizeof(int64_t), | ||
| cudaMemcpyDeviceToDevice, | ||
| reinterpret_cast<cudaStream_t>(stream))); | ||
| } | ||
|
|
||
| a2av.resetCountsSem(stream); | ||
|
|
||
| auto recv_counts = counts_matrix.select(1, my_rank).contiguous(); | ||
|
|
||
| auto send_offsets = at::zeros({W}, gpu_opts); | ||
| if (W > 1) { | ||
| send_offsets.narrow(0, 1, W - 1) | ||
| .copy_(send_counts.cumsum(0).narrow(0, 0, W - 1)); | ||
| } | ||
|
|
||
| at::Tensor recv_offsets = my_rank > 0 |
There was a problem hiding this comment.
Counts read from peers before
waitCountsReady ensures NVLink visibility
The order in prepareAlltoallvMetadataGpu is correct (waitCountsReady precedes the cudaMemcpyAsync loop), so the architectural concern here is memory-model ordering rather than statement ordering.
signalCountsReady uses CU_STREAM_WRITE_VALUE_DEFAULT, which does not provide release/flush semantics for preceding device-memory writes visible over NVLink. Per CUDA documentation, CU_STREAM_WRITE_VALUE_FLUSH is required to guarantee that the cudaMemcpyAsync that copied send_counts into sync_buf_ (step 1) is visible to a remote peer that later observes the semaphore value via CU_STREAM_WAIT_VALUE_EQ.
Without the flush flag, the counts-ready signal from peer B could be seen by rank A before B's sync_buf_ write is observable through NVLink. In practice NVLink ordering has made this work, but it violates the documented memory model. The same issue applies to doneBarrier's batchSignal call (ipc_handle.cpp, the signalCountsReady and batchSignal helpers).
// Change in batchSignal:
ops[idx].writeValue.flags = CU_STREAM_WRITE_VALUE_FLUSH; // release semantics
Replace TCPStore-based synchronization and CPU barriers in the
kCudabackend of Dispatch / Combine, with a fully graph-capturable implementation:[C=T*R]to avoid data-dependent shapes.MoeCombineIR node carriesnum_tokensas an attribute to allocate the output (this could be removed when we support pre-allocated output buffers)SymMemForAlltoallv) with static buffer allocation -- buffers are allocated and "rendezvous-ed" once and reused; re-allocation isNVF_CHECK-guarded because captured CUDA graphs hold baked pointers to these buffersat::bincountreplaced withscatter_addbecause bincount has hidden CPU-GPU syncsSymmetricTensor::remotePointersTensorto pack all the remote pointers into a gpu buffer for device-initiated comms. Change signature ofalltoallvWithCudaBackendto account for that.DispatchCombineCudaGraphTestcaptures dispatch+combineinto a CUDAGraph and exercises replay
The NCCL backend path is unchanged and not graph-capturable.