[PyTorch] torch.compile support for permutation functions#2686
[PyTorch] torch.compile support for permutation functions#2686pggPL wants to merge 13 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
41e22ef to
8159d26
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR converts all MoE permutation Key issues found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant Compile as torch.compile
participant Dispatch as QuantizedTensor.__torch_dispatch__
participant CustomOp as te_moe custom_op (forward)
participant Fake as register_fake (shape inference)
participant AutogradCtx as register_autograd (backward)
participant TEKernel as tex / triton kernel
User->>Compile: compiled_fn(inp, routing_map, ...)
Compile->>Fake: shape inference pass (abstract tensors)
Fake-->>Compile: output shapes inferred
Compile->>User: compiled graph ready
User->>CustomOp: torch.ops.te_moe.permute_mask_map_fwd(inp, ...)
alt inp is QuantizedTensor
CustomOp->>Dispatch: __torch_dispatch__ called
Dispatch->>Dispatch: check passthrough_ops set
Dispatch->>CustomOp: forward without unwrapping
end
CustomOp->>TEKernel: triton_permutation.permute_with_mask_map(...)
TEKernel-->>CustomOp: (output, row_id_map, permuted_probs)
CustomOp-->>AutogradCtx: setup_context saves row_id_map, shapes
CustomOp-->>User: (output, row_id_map, permuted_probs)
User->>AutogradCtx: output.backward(grad)
AutogradCtx->>CustomOp: torch.ops.te_moe.permute_mask_map_bwd(grad, ...)
CustomOp->>TEKernel: triton_permutation.unpermute_with_mask_map(...)
TEKernel-->>CustomOp: (act_grad, probs_grad)
CustomOp-->>User: gradients
Last reviewed commit: aa4884a |
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
| import torch._functorch.config as functorch_config | ||
|
|
||
| functorch_config.donated_buffer = False |
There was a problem hiding this comment.
What does it do and why do we need to do that? Could we add a comment here, especially since we would be using the internal function here (and so it will most probably break at some point).
There was a problem hiding this comment.
This is optimization of torch.compile which is not compatible with retain_graph=True used in tests.
There was a problem hiding this comment.
I added some comment.
| # ===================== _moe_permute_index_map custom ops ===================== | ||
|
|
||
| topK = index.size(1) | ||
| # Workspace state for moe_permute_index_map |
There was a problem hiding this comment.
I don't like it (although I realize this is not really the problem with this PR, but rather the original implementation).
There was a problem hiding this comment.
If we can figure out how to change that however, that would be great. Maybe we could make moe_compute a functor (struct MoECompute with __call__ methods and the workspaces, then moe_compute would just be a object of that class that we would create at the very beginning).
There was a problem hiding this comment.
why? I mean what you don't like about it
There was a problem hiding this comment.
Well, the main thing is the fact that we implicitly rely on the fact that there is only one permutation happening at a time (and that problem would not be solved by my proposal BTW - this would need a change of this to be actual nn.Module but that has its own problems by effectively being an API break, we should still do it for TE 3.0 though). If you run 2 permutations in 2 streams then that has a chance of silent data corruption since both of those kernels would be using the same underlying workspace. This is something that the user has no way of knowing about without consulting the code. And with torch.compile the chance of this happening may be even bigger - we are at the whim of the compiler optimizations at this point.
There was a problem hiding this comment.
Can we change it in TE 3.0 then? I can indeed change it to functor, but as you said this will not solve a problem.
There was a problem hiding this comment.
Reply to offline discussion:
- there is no support for autograd for ops which mutate args,
- torch.compile does not put thing in different streams
Signed-off-by: root <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Conflict in transformer_engine/pytorch/permutation.py resolved by keeping the torch.library.custom_op structure from this branch and applying upstream's assert→raise ValueError improvements (commit 61f9594). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…sk_map_forward num_out_tokens is typed as int in the custom_op signature and can never be None; the check was incorrectly carried over from the class-based upstream version during merge conflict resolution. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Description
This PR adds
torch.compile(fullgraph=True)support for MoE permutation operations (moe_permute,moe_unpermute,moe_sort_chunks_by_index) by converting alltorch.autograd.Functionimplementations to PyTorch custom operators usingtorch.library.custom_op.Note that this PR does not add torch.compile support for QuantizedTensor as an input.
Related to #2590
Type of change
Checklist: