Skip to content

[PyTorch] torch.compile support for permutation functions#2686

Open
pggPL wants to merge 13 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile
Open

[PyTorch] torch.compile support for permutation functions#2686
pggPL wants to merge 13 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Feb 17, 2026

Description

This PR adds torch.compile(fullgraph=True) support for MoE permutation operations (moe_permute, moe_unpermute, moe_sort_chunks_by_index) by converting all torch.autograd.Function implementations to PyTorch custom operators using torch.library.custom_op.

Note that this PR does not add torch.compile support for QuantizedTensor as an input.

Related to #2590

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the moe_torch_compile branch from 41e22ef to 8159d26 Compare February 18, 2026 17:31
pre-commit-ci bot and others added 4 commits February 18, 2026 17:32
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review February 19, 2026 15:45
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR converts all MoE permutation torch.autograd.Function subclasses (_moe_permute_index_map, _moe_unpermute_index_map, _moe_permute_mask_map, _moe_unpermute_mask_map, _moe_chunk_sort) into torch.library.custom_op + register_fake + register_autograd triplets, enabling torch.compile(fullgraph=True) support for moe_permute, moe_unpermute, and moe_sort_chunks_by_index. A new _quantized_tensor_passthrough_ops hook in quantized_tensor.py prevents FP8 tensor unwrapping for the newly registered ops, and all permute/unpad test functions are parametrized with use_torch_compile=[False, True].

Key issues found:

  • Dead use_torch_compile parameter in two test helpers: _test_permutation_and_padding_mask_map (line 692) and _test_permutation_and_padding_with_merging_probs (line 984) accept use_torch_compile but never forward it to _maybe_compile. The parametrized True variants of test_permutation_and_padding_mask_map and test_permutation_and_padding_with_merging_probs therefore provide no torch.compile coverage for those code paths.
  • assert instead of raise: Three FP8 guard checks (lines 517–519, 564–566, 802–804 in permutation.py) were changed from explicit exceptions to assert statements. These are silently dropped under python -O, potentially masking type errors in production.

Confidence Score: 4/5

  • Safe to merge after fixing the dead use_torch_compile parameter in the two test helpers; the core production logic is sound.
  • The production-facing changes (custom_op migration, fake shapes, autograd wrappers, passthrough ops) are architecturally correct and preserve existing semantics. The main concerns are: (1) two test helpers silently ignoring use_torch_compile=True, meaning torch.compile coverage is incomplete for the pad/unpad paths; and (2) assert-instead-of-raise for FP8 guards, which is a correctness regression under -O. Neither issue causes failures in the default test run, but they lower confidence in test coverage and production robustness.
  • tests/pytorch/test_permutation.py — the two helper functions _test_permutation_and_padding_mask_map and _test_permutation_and_padding_with_merging_probs need the use_torch_compile parameter wired up.

Important Files Changed

Filename Overview
transformer_engine/pytorch/permutation.py Full rewrite of all MoE permutation autograd Functions to torch.library.custom_op for torch.compile support; introduces module-level workspace globals and registers new ops as QuantizedTensor passthrough ops. Style regression: three assert-instead-of-raise guards.
tests/pytorch/test_permutation.py Adds use_torch_compile parametrization to all test functions and a _maybe_compile helper; however, _test_permutation_and_padding_mask_map and _test_permutation_and_padding_with_merging_probs accept but never use the parameter, making the torch.compile=True variants no-ops for those code paths.
transformer_engine/pytorch/quantized_tensor.py Adds the _quantized_tensor_passthrough_ops set and a dispatch hook in torch_dispatch so registered custom ops receive FP8 tensors without being unwrapped first. Minimal, well-isolated change.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Compile as torch.compile
    participant Dispatch as QuantizedTensor.__torch_dispatch__
    participant CustomOp as te_moe custom_op (forward)
    participant Fake as register_fake (shape inference)
    participant AutogradCtx as register_autograd (backward)
    participant TEKernel as tex / triton kernel

    User->>Compile: compiled_fn(inp, routing_map, ...)
    Compile->>Fake: shape inference pass (abstract tensors)
    Fake-->>Compile: output shapes inferred
    Compile->>User: compiled graph ready

    User->>CustomOp: torch.ops.te_moe.permute_mask_map_fwd(inp, ...)
    alt inp is QuantizedTensor
        CustomOp->>Dispatch: __torch_dispatch__ called
        Dispatch->>Dispatch: check passthrough_ops set
        Dispatch->>CustomOp: forward without unwrapping
    end
    CustomOp->>TEKernel: triton_permutation.permute_with_mask_map(...)
    TEKernel-->>CustomOp: (output, row_id_map, permuted_probs)
    CustomOp-->>AutogradCtx: setup_context saves row_id_map, shapes
    CustomOp-->>User: (output, row_id_map, permuted_probs)

    User->>AutogradCtx: output.backward(grad)
    AutogradCtx->>CustomOp: torch.ops.te_moe.permute_mask_map_bwd(grad, ...)
    CustomOp->>TEKernel: triton_permutation.unpermute_with_mask_map(...)
    TEKernel-->>CustomOp: (act_grad, probs_grad)
    CustomOp-->>User: gradients
Loading

Last reviewed commit: aa4884a

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits February 19, 2026 15:57
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +225 to +227
import torch._functorch.config as functorch_config

functorch_config.donated_buffer = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it do and why do we need to do that? Could we add a comment here, especially since we would be using the internal function here (and so it will most probably break at some point).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is optimization of torch.compile which is not compatible with retain_graph=True used in tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where?

# ===================== _moe_permute_index_map custom ops =====================

topK = index.size(1)
# Workspace state for moe_permute_index_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it (although I realize this is not really the problem with this PR, but rather the original implementation).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can figure out how to change that however, that would be great. Maybe we could make moe_compute a functor (struct MoECompute with __call__ methods and the workspaces, then moe_compute would just be a object of that class that we would create at the very beginning).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? I mean what you don't like about it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the main thing is the fact that we implicitly rely on the fact that there is only one permutation happening at a time (and that problem would not be solved by my proposal BTW - this would need a change of this to be actual nn.Module but that has its own problems by effectively being an API break, we should still do it for TE 3.0 though). If you run 2 permutations in 2 streams then that has a chance of silent data corruption since both of those kernels would be using the same underlying workspace. This is something that the user has no way of knowing about without consulting the code. And with torch.compile the chance of this happening may be even bigger - we are at the whim of the compiler optimizations at this point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change it in TE 3.0 then? I can indeed change it to functor, but as you said this will not solve a problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reply to offline discussion:

  1. there is no support for autograd for ops which mutate args,
  2. torch.compile does not put thing in different streams

Signed-off-by: root <pgadzinski@nvidia.com>
Signed-off-by: root <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL pggPL closed this Mar 5, 2026
@pggPL pggPL reopened this Mar 5, 2026
@pggPL
Copy link
Collaborator Author

pggPL commented Mar 5, 2026

/te-ci pytorch

pggPL and others added 2 commits March 17, 2026 15:44
Conflict in transformer_engine/pytorch/permutation.py resolved by keeping
the torch.library.custom_op structure from this branch and applying
upstream's assert→raise ValueError improvements (commit 61f9594).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…sk_map_forward

num_out_tokens is typed as int in the custom_op signature and can never
be None; the check was incorrectly carried over from the class-based
upstream version during merge conflict resolution.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants