Skip to content

Almost-exact graph recognizes equivalence in the split-split pattern#5986

Open
wujingyue wants to merge 11 commits intomainfrom
wjy/split
Open

Almost-exact graph recognizes equivalence in the split-split pattern#5986
wujingyue wants to merge 11 commits intomainfrom
wjy/split

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Feb 19, 2026

A spin-off from #4404

For #3987

@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 19, 2026

Review updated until commit 42937d3

Description

  • Add mapSplitOfSplit function to recognize equivalence in split-split pattern

  • Modify getUses to return empty groups instead of error when no uses found

  • Add test cases for split-reshape patterns with different extent scenarios

  • Simplify test code by removing unnecessary unique_ptr usage

Changes walkthrough

Relevant files
Enhancement
id_model.cpp
Add split-split pattern mapping functionality                       

csrc/id_model/id_model.cpp

  • Remove unnecessary includes for trivial_broadcast and
    val_graph_visitor
  • Add mapSplitOfSplit function to handle split-split equivalence pattern
  • Integrate mapSplitOfSplit into buildAlmostExactGraph workflow
  • +57/-2   
    Error handling
    val_graph.cpp
    Improve getUses error handling                                                     

    csrc/val_graph.cpp

  • Modify getUses to return empty ExprGroups instead of throwing error
  • Add static empty_expr_groups for cases with no uses
  • +5/-4     
    Tests
    test_id_model.cpp
    Add split-reshape tests and simplify test code                     

    tests/cpp/test_id_model.cpp

  • Remove unused includes for fstream and graphviz
  • Simplify Fusion creation in multiple tests using direct instantiation
  • Add SplitingReshape test for basic split-reshape equivalence
  • Add SplitingReshape_DifferentExtents test for extent mismatch case
  • +55/-26 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Error Handling Regression

    The removal of NVF_ERROR in getUses() method may hide legitimate bugs. Previously, if a val_group was expected to have uses but didn't exist in unique_uses_, it would throw an error. Now it silently returns an empty ExprGroups. This could mask programming errors where a val_group is incorrectly expected to have uses.

    const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const {
      NVF_ERROR(val_group, "Nullptr not allowed");
    
      static ExprGroups empty_expr_groups;
      const auto it = unique_uses_.find(val_group);
      if (it == unique_uses_.end()) {
        return empty_expr_groups;
      }
      return it->second;
    }
    Algorithm Correctness

    The mapSplitOfSplit function implements a complex pattern matching algorithm for split-split patterns. The algorithm should be carefully reviewed to ensure it correctly handles all edge cases, particularly around the conditions for mapping outermost_grand and outer' IDs. The logic around extent comparison and the requirement that outer and inner must not be conflated needs validation.

    void mapSplitOfSplit(ValGraph& graph) {
      // The following is a subpattern of
      // https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations
      //
      // outer, _ = split(root)
      // outermost_grand, _ = split(outer)
      // outer', _ = split(root)
      //
      // If outermost_grand and outer' have the same extent, map them.
      std::vector<std::pair<Val*, Val*>> ids_to_map;
      for (const ValGroup& root : graph.disjointValSets().disjointSets()) {
        const ExprGroups& uses_of_root = graph.getUses(root);
        std::vector<ValGroup> outermost_grands;
        for (const ExprGroup& use_of_root : uses_of_root) {
          auto* split0 = dynamic_cast<Split*>(use_of_root->front());
          if (split0 == nullptr) {
            continue;
          }
          // Only follow the outer output of the first split; outer and inner
          // must not be conflated.
          const ValGroup& outer = graph.toGroup(split0->outer());
          for (const ExprGroup& use_of_outer : graph.getUses(outer)) {
            auto* split1 = dynamic_cast<Split*>(use_of_outer->front());
            if (split1 == nullptr) {
              continue;
            }
            const ValGroup& outermost_grand = graph.toGroup(split1->outer());
            outermost_grands.push_back(outermost_grand);
          }
        }
    
        for (const ValGroup& outermost_grand : outermost_grands) {
          Val* extent_of_grand =
              outermost_grand->front()->as<IterDomain>()->extent();
    
          for (const ExprGroup& use_of_root : uses_of_root) {
            auto* split = dynamic_cast<Split*>(use_of_root->front());
            if (split == nullptr) {
              continue;
            }
    
            const ValGroup& outer = graph.toGroup(split->outer());
            if (outer->front()->as<IterDomain>()->extent()->sameAs(
                    extent_of_grand)) {
              ids_to_map.emplace_back(outermost_grand->front(), outer->front());
            }
          }
        }
      }
    
      for (const auto& [id1, id2] : ids_to_map) {
        graph.mapVals(id1, id2);
      }
    }

    Test failures

    • (Medium, 34) NVFuser TMA load & inner-reduction tests hitting internal assertions (validator_utils.cpp, indexing.cpp) across multiple runners

      Test Name GB200 H100 Source
      TMASimpleLdstTest.Load/1D_128B___half Link
      TMASimpleLdstTest.Load/1D_128B_float Link
      TMASimpleLdstTest.Load/1D_32B___half Link
      TMASimpleLdstTest.Load/1D_32B_float Link
      TMASimpleLdstTest.Load/1D_64B___half Link
      TMASimpleLdstTest.Load/1D_64B_float Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_1048576 Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_131072 Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_524288 Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_65536 Link
      ... with 7 more test failures omitted. Check internal logs.

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue changed the title [IdModel] almost-exact graph recognizes (partially) equivalence in the split-split pattern [IdModel] almost-exact graph recognizes equivalence in the split-split pattern Feb 19, 2026
    @wujingyue wujingyue changed the title [IdModel] almost-exact graph recognizes equivalence in the split-split pattern Almost-exact graph recognizes equivalence in the split-split pattern Feb 19, 2026
    @wujingyue wujingyue requested a review from naoyam February 19, 2026 22:28
    @wujingyue
    Copy link
    Collaborator Author

    @naoyam while I'm cleaning things up and verifying tests, do you think it's moving to the right direction?

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from naoyam February 20, 2026 01:10
    @naoyam
    Copy link
    Collaborator

    naoyam commented Feb 20, 2026

    Looks good overall.

    @wujingyue
    Copy link
    Collaborator Author

    I'm running into some interesting test failures. One of them is an validation error:

    [ RUN      ] TMASimpleLdstTest.Load/1D_128B___half
    ...
    Validation error in output 0 on line 524 in file /opt/pytorch/nvfuser/tests/cpp/test_memory.cpp.
      Detected max abs error of: 7.34375
        absolute tolerance was set to 0.00390625
        and relative tolerance set to 0.0078125
    

    The symptom is around this TensorView

    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
      logical domain: (iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
        Split: iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS16{64}
        Split: iS16{64} by factor 64 -> iS17{1}, iS18{64}
        Split: iS18{64} by factor 8 -> iS23{8}, iS24{8}
        Split: iS17{1} by factor 8 -> iS19{1}, iS20{8}
        Split: iS20{8} by factor 1 -> iS21{8}, iS22{1}
        Xor(2D): iS21{8} , iS23{8} -> iS25{8} , iS26{8}
        Merge: iS19{1} and iS25{8} -> iS27{8}
        Merge: iS27{8} and iS22{1} -> iS28{8}
        Merge: iS28{8} and iS26{8} -> iS29{64}
        Merge: iS29{64} and iS24{8} -> ithreadIdx.x30{512}
      loop domain: (iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512})
    

    The new code maps iS19{1} and iS17{1}.

    This is mathematically correct because these two IterDomains can share the same index -- the index is 0 all the time. However, codegen doesn't seem to like the mapping.

    Before I throw more if-elses at it, what's the right contract so people can DbC? cc @naoyam

    @naoyam
    Copy link
    Collaborator

    naoyam commented Feb 20, 2026

    Can you show the diff of generated codes? I'm guessing something isn't working around predicates.

    @wujingyue
    Copy link
    Collaborator Author

    TMASimpleLdstTest.Load/1D_128B___half

    git fetch origin wjy/split
    git checkout wjy/split
    _bn && bin/test_nvfuser --gtest_filter=TMASimpleLdstTest.Load/1D_128B___half
    

    cc @naoyam

    @wujingyue
    Copy link
    Collaborator Author

    As @naoyam requested:

    [ RUN      ] TMASimpleLdstTest.Load/1D_128B___half
    Inputs:
      T0_g___half[iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )}]
    Outputs:
      T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
    
    %kernel {
    T1_s___half[iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8}]
       = CpAsyncBulkTensorTile( T0_g___half[iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )}] )
    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
       = Set( T1_s___half[iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8}], cache_op=Streaming )
    
    TransformPrinter :
    T0_g___half[iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )}]
      logical domain: (iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
      loop domain: (iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )})
    T1_s___half[iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8}]
      logical domain: (iS32{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      allocation domain: (iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8})
      contiguity: t t t t t t
        Split: iS32{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS4{64}
        Split: iS4{64} by factor 64 -> iS5{1}, iS6{64}
        Split: iS5{1} by factor 8 -> iB7{1}, iS8{8}
        Split: iS6{64} by factor 8 -> iS11{8}, iB12{8}
        Split: iS8{8} by factor 1 -> iS9{8}, iB10{1}
        Xor(2D): iS9{8} , iS11{8} -> iB13{8} , iB14{8}
      loop domain: (iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8})
    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
      logical domain: (iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
        Split: iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS16{64}
        Split: iS16{64} by factor 64 -> iS17{1}, iS18{64}
        Split: iS18{64} by factor 8 -> iS23{8}, iS24{8}
        Split: iS17{1} by factor 8 -> iS19{1}, iS20{8}
        Split: iS20{8} by factor 1 -> iS21{8}, iS22{1}
        Xor(2D): iS21{8} , iS23{8} -> iS25{8} , iS26{8}
        Merge: iS19{1} and iS25{8} -> iS27{8}
        Merge: iS27{8} and iS22{1} -> iS28{8}
        Merge: iS28{8} and iS26{8} -> iS29{64}
        Merge: iS29{64} and iS24{8} -> ithreadIdx.x30{512}
      loop domain: (iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512})
    } // %kernel
    iS19{1} <==> iS17{1}
    
    ======= Codegen output for kernel: nvfuser_none_f0_c0_r0_g0 =======
    
    // Codegen generated code
    __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 1, 1> T0, const __grid_constant__ TensorMap var0, Tensor<__half, 1, 1> T2) {
      alignas(128) extern __shared__ char array[];
      const unsigned smem_offset = 0;
      const TensorMap* ptr1;
      ptr1 = &var0;
      nvfuser_index_t i2;
      i2 = 64 * ((nvfuser_index_t)blockIdx.x);
      Array<int, 1, 1> a3;
      a3 = Array<int, 1, 1>{__to_int32(i2)};
      nvfuser_index_t i4;
      i4 = ((8 * ((((nvfuser_index_t)threadIdx.x) / 64) ^ ((((nvfuser_index_t)threadIdx.x) / 8) % 8))) + (((nvfuser_index_t)threadIdx.x) % 8)) + i2;
      __half* T1 = reinterpret_cast<__half*>(array + smem_offset + 0);
      uint64_t* T3 = reinterpret_cast<uint64_t*>(array + smem_offset + 1024);
      mbarrier::init(toSmem(T3), 1U);
      __syncthreads();
      if ((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.x) < 32ULL))) {
        uint64_t i5;
        i5 = mbarrier::arriveExpectTX(toSmem(T3), 128U);
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<1>{ ptr1, a3, toSmem(T3) }), toSmem(T1));
        mbarrier::wait(toSmem(T3), i5);
      }
      __syncthreads();
      mbarrier::inval(toSmem(T3));
      if (((i4 >= 0) && (i4 < T0.logical_size[0LL]))) {
        T2[i4]
           = T1[((nvfuser_index_t)threadIdx.x)];
      }
    }
    
    ======================================
    
    unknown file: Failure
    C++ exception with description " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/validator_utils.cpp:505, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
    Expected aten_output_in_common_dtype.allclose( fusion_output_in_common_dtype, tolerance_values.second, tolerance_values.first, true) .
    
    Validation error in output 0 on line 524 in file /opt/pytorch/nvfuser/tests/cpp/test_memory.cpp.
      Detected max abs error of: 8.17969
        absolute tolerance was set to 0.00390625
        and relative tolerance set to 0.0078125
    

    @naoyam
    Copy link
    Collaborator

    naoyam commented Feb 20, 2026

    I'd like to see the diff result comparing the generated kernels. Please run the test with NVFUSER_DUMP=cuda_to_file to save the code to a file and run them through the diff command.

    @wujingyue
    Copy link
    Collaborator Author

    NVFUSER_DUMP=cuda_kernel without this PR:

    __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 1, 1> T0, const __grid_constant__ TensorMap var0, Tensor<__half, 1, 1> T2) {
      alignas(128) extern __shared__ char array[];
      const unsigned smem_offset = 0;
      const TensorMap* ptr1;
      ptr1 = &var0;
      nvfuser_index_t i2;
      i2 = 64 * ((nvfuser_index_t)blockIdx.x);
      Array<int, 1, 1> a3;
      a3 = Array<int, 1, 1>{__to_int32(i2)};
      nvfuser_index_t i4;
      i4 = ((8 * ((((nvfuser_index_t)threadIdx.x) / 64) ^ ((((nvfuser_index_t)threadIdx.x) / 8) % 8))) + (((nvfuser_index_t)threadIdx.x) % 8)) + i2;
      __half* T1 = reinterpret_cast<__half*>(array + smem_offset + 0);
      uint64_t* T3 = reinterpret_cast<uint64_t*>(array + smem_offset + 1024);
      mbarrier::init(toSmem(T3), 1U);
      __syncthreads();
      if ((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.x) < 32ULL))) {
        uint64_t i5;
        i5 = mbarrier::arriveExpectTX(toSmem(T3), 128U);
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<1>{ ptr1, a3, toSmem(T3) }), toSmem(T1));
        mbarrier::wait(toSmem(T3), i5);
      }
      __syncthreads();
      mbarrier::inval(toSmem(T3));
      if ((((((nvfuser_index_t)threadIdx.x) < 64) && (i4 >= 0)) && (i4 < T0.logical_size[0LL]))) {
        T2[i4]
           = T1[((nvfuser_index_t)threadIdx.x)];
      }
    }
    

    I think you are right about predication:

    $ diff -ruN /tmp/old_kernel.txt /tmp/new_kernel.txt
    --- /tmp/old_kernel.txt 2026-02-20 14:49:44.624896719 -0800
    +++ /tmp/new_kernel.txt 2026-02-20 14:50:04.110394445 -0800
    @@ -21,7 +21,7 @@
       }
       __syncthreads();
       mbarrier::inval(toSmem(T3));
    -  if ((((((nvfuser_index_t)threadIdx.x) < 64) && (i4 >= 0)) && (i4 < T0.logical_size[0LL]))) {
    +  if (((i4 >= 0) && (i4 < T0.logical_size[0LL]))) {
         T2[i4]
            = T1[((nvfuser_index_t)threadIdx.x)];
       }
    

    cc @naoyam

    @wujingyue
    Copy link
    Collaborator Author

    wujingyue commented Feb 24, 2026

    Copying messages from @naoyam for https://abseil.io/resources/swe-book/html/ch03.html


    I looked into the issue. The issue happens due to the predication for non-divisible splits.

    https://github.com/NVIDIA/Fuser/blob/main/csrc/id_model/indexing.cpp#L778

    IIRC, Xiang had some writeup.

    https://github.com/NVIDIA/Fuser/blob/main/doc/reading/divisibility-of-split.md

    In this case, T2 has a non-divisible split with is17:

    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
      logical domain: (iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
        Split: iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS16{64}
        Split: iS16{64} by factor 64 -> iS17{1}, iS18{64}
        Split: iS18{64} by factor 8 -> iS23{8}, iS24{8}
        Split: iS17{1} by factor 8 -> iS19{1}, iS20{8}
        Split: iS20{8} by factor 1 -> iS21{8}, iS22{1}
        Xor(2D): iS21{8} , iS23{8} -> iS25{8} , iS26{8}
        Merge: iS19{1} and iS25{8} -> iS27{8}
        Merge: iS27{8} and iS22{1} -> iS28{8}
        Merge: iS28{8} and iS26{8} -> iS29{64}
        Merge: iS29{64} and iS24{8} -> ithreadIdx.x30{512}
    

    iS17 is split by 8, which effectively expands the domain by a factor of 8, and so we would need to make sure indexing would not go beyond the original extent of iS17, which is just 1.

    getNonDivisibleIdsToPredicate used here returns iS17 in this case.

    In main, this line creates this predicate: ( ( ( threadIdx.x / 8 ) / 8 ) < 1 )

    Now, the PR adds another mapping: iS17 and iS19 . When we do the traversal, iS19 simply gets assigned with index value of zero. That is because of Merge: iS19{1} and iS25{8} -> iS27{8}. Here, iS25 and iS27 are mapped as part of the almost-exact mappings, so we simply forward the assigned index of iS27 to iS25, and for iS19, I think we simply assign zero (I need to confirm this). Since iS19 gets zero, so does iS17.

    This results in the non-divisible split predicate of 0 < 1, instead of ( ( ( threadIdx.x / 8 ) / 8 ) < 1 ) . As a result, since 0 < 1 is always true, the resulting code doesn't get any predicate for the non-divisible split.

    The almost-exact mapping is used for indexing traversal, so its mapping needs to take indexing equality into consideration. Even if two iter domains have the same extent, it doesn't automatically mean they should use the same index. In this case, for the purpose of indexing, I'd question if iS17 and iS19 should be mapped.

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue marked this pull request as ready for review March 15, 2026 19:33
    @wujingyue
    Copy link
    Collaborator Author

    @naoyam this is ready for review

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Mar 15, 2026

    Greptile Summary

    This PR extends the almost-exact graph construction in IdModel to recognise equivalence in the split-split pattern: when a root dimension is split twice (producing outermost_grand) and the same root is also split once (producing outer') with the same resulting extent, the two outputs are now mapped together, provided all splits are divisible. The change is motivated by reshape fusions where the scheduler independently splits the same logical dimension via different paths.

    Key changes:

    • csrc/id_model/id_model.cpp — New mapDivisibleSplits pass added at the end of buildAlmostExactGraph. It iterates over all val groups, collects depth-2 outer outputs (outermost_grand) via split-split chains, then maps them with any depth-1 outer output of the same root whose extent matches.
    • csrc/val_graph.cppgetUses now returns an empty ExprGroups instead of throwing when a val group has no registered uses. This is necessary because mapDivisibleSplits calls getUses on leaf val groups.
    • tests/cpp/test_id_model.cpp — Three new focused unit tests verify: correct mapping when extents agree (SplittingReshape_Mapped), no mapping when extents differ, and no mapping when splits are non-divisible. Several tests also refactored from heap-allocated to stack-allocated Fusion.
    • tests/cpp/test_indexing.cpp — Hard-coded IR SSA names updated from i114-116 to i126-128 to account for additional IR values allocated by the new simplifyExpr calls inside mapDivisibleSplits.

    Confidence Score: 4/5

    • Safe to merge with minor caveats; the core algorithm is correct and validated by three new unit tests, with low-risk style concerns remaining.
    • The algorithm is mathematically sound for the described split-split pattern and is guarded by divisibility checks. The deferred-mapping design prevents iterator invalidation. Post-pass validateConsistency() and assertNoSelfMapping() calls provide runtime safety nets. The main concerns are: (1) a degenerate factor-1 split edge case where the parent and grandchild of a split could share the same extent and be queued for mapping — harmless today because mapVals would detect the self-mapping, but worth an explicit guard; (2) the getUses behavioral change silently relaxes an invariant without documentation; (3) fragile hard-coded SSA names in test_indexing.cpp. None of these are blocking issues.
    • csrc/id_model/id_model.cpp (mapDivisibleSplits second loop — consider factor-1 split guard and deduplication of outermost_grands); csrc/val_graph.cpp (asymmetry between getUses and getDefinitions error behaviour).

    Important Files Changed

    Filename Overview
    csrc/id_model/id_model.cpp Adds mapDivisibleSplits, a new pass run after trivial-expression forwarding in buildAlmostExactGraph. It recognises the split-split pattern and maps the depth-2 outer output with any depth-1 outer output of the same root that has the same extent, provided both splits are divisible. Logic is correct in the common case; minor concerns around degenerate factor-1 splits and possible duplicate entries in outermost_grands.
    csrc/val_graph.cpp Changes getUses to return an empty static ExprGroups instead of throwing when a val group has no registered uses (leaf nodes). This is required by the new mapDivisibleSplits pass that iterates over all val groups. The change relaxes a previous invariant and creates a slight asymmetry with getDefinitions (which still throws); otherwise all call sites handle the empty result correctly.
    tests/cpp/test_id_model.cpp Adds three new unit tests covering the split-split mapping (happy path, mismatched extents, non-divisible splits). Also cleans up several tests to use stack-allocated Fusion objects instead of std::make_unique<Fusion>, and removes unused #include headers. Minor: the test name SplitingReshape_DifferentExtents_NotMapped contains a typo ("Spliting").
    tests/cpp/test_indexing.cpp Updates hard-coded IR variable names in Reshape test from i114/i115/i116 to i126/i127/i128 to reflect that 12 additional IR values are now allocated by mapDivisibleSplits's calls to simplifyExpr. The test logic is unchanged but remains fragile due to reliance on global IR counter state.

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A[buildAlmostExactGraph] --> B[Copy EXACT graph]
        B --> C[setUnmappable on root/logical/loop domains]
        C --> D[Collect trivially-mapped ID pairs\nvia getTriviallyMappedIds]
        D --> E[Apply trivial mappings\nalmost_exact_graph.mapVals]
        E --> F[mapDivisibleSplits]
    
        subgraph mapDivisibleSplits
            F1[For each ValGroup root] --> F2[Find divisible splits of root]
            F2 --> F3[Follow outer output → find\ndivisible splits of outer\nCollect outermost_grand groups]
            F3 --> F4[For each outermost_grand\ncompare extent to outer outputs\nof all splits of root]
            F4 --> F5{Same extent\nAND both splits\ndivisible?}
            F5 -- Yes --> F6[Queue ids_to_map pair]
            F5 -- No --> F4
            F6 --> F7[Apply all deferred mapVals]
        end
    
        F --> G[validateConsistency]
        G --> H[assertNoSelfMapping]
    
    Loading

    Last reviewed commit: 40a2cc6

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Comment on lines +521 to +537
    for (const ValGroup& outermost_grand : outermost_grands) {
    Val* extent_of_grand =
    outermost_grand->front()->as<IterDomain>()->extent();

    for (const ExprGroup& use_of_root : uses_of_root) {
    auto* split = dynamic_cast<Split*>(use_of_root->front());
    if (split == nullptr || !is_divisible(split)) {
    continue;
    }

    const ValGroup& outer = graph.toGroup(split->outer());
    if (outer->front()->as<IterDomain>()->extent()->sameAs(
    extent_of_grand)) {
    ids_to_map.emplace_back(outermost_grand->front(), outer->front());
    }
    }
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Duplicate mapping entries possible

    When multiple split-split paths from the same root share the same outermost_grand, that group gets added to outermost_grands more than once. The second loop will then emit duplicate (id1, id2) pairs into ids_to_map for every repeated grand. mapVals is idempotent, so correctness is preserved, but the deduplication of outermost_grands (e.g. using an UnorderedSetOfValGroup rather than a std::vector) would prevent the redundant work and keep ids_to_map minimal.

    Comment on lines +521 to +534
    for (const ValGroup& outermost_grand : outermost_grands) {
    Val* extent_of_grand =
    outermost_grand->front()->as<IterDomain>()->extent();

    for (const ExprGroup& use_of_root : uses_of_root) {
    auto* split = dynamic_cast<Split*>(use_of_root->front());
    if (split == nullptr || !is_divisible(split)) {
    continue;
    }

    const ValGroup& outer = graph.toGroup(split->outer());
    if (outer->front()->as<IterDomain>()->extent()->sameAs(
    extent_of_grand)) {
    ids_to_map.emplace_back(outermost_grand->front(), outer->front());
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Second loop can self-map outermost_grand through its own ancestor

    The second loop re-scans every use of root and maps outer->front() to outermost_grand->front() whenever their extents agree. Nothing prevents that outer from being the very outer that feeds into split1 to produce outermost_grand — i.e. the direct parent of outermost_grand in the split chain.

    In practice this case cannot occur when split1.factor > 1 (the parent has a strictly larger extent), but when split1.factor == 1 (a degenerate, length-1 inner split), outermost_grand.extent == outer.extent. Then ids_to_map receives a pair (outermost_grand->front(), outer->front()) that maps the outer output of split1 to its own input. mapVals will not fire because is_divisible checks split1->isDivisible(), and a factor-1 split is always divisible, so the pair would be queued.

    A guard comparing the two ValGroup pointers before emitting the pair would make the intent explicit and protect against future degenerate fusions:

    if (outer != outermost_grand &&
        outer->front()->as<IterDomain>()->extent()->sameAs(extent_of_grand)) {
      ids_to_map.emplace_back(outermost_grand->front(), outer->front());
    }

    Comment on lines 398 to 407
    const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const {
    NVF_ERROR(val_group, "Nullptr not allowed");

    static const ExprGroups empty_expr_groups;
    const auto it = unique_uses_.find(val_group);
    NVF_ERROR(
    it != unique_uses_.end(),
    "Use group not found for ",
    nvfuser::toString(val_group));
    if (it == unique_uses_.end()) {
    return empty_expr_groups;
    }
    return it->second;
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Behavioral change silently relaxes a previous invariant

    The old implementation treated a missing val_group entry in unique_uses_ as a hard error (via NVF_ERROR). The new implementation silently returns an empty set. While this is functionally required by mapDivisibleSplits (which calls getUses on leaf nodes that have no entries), it also removes the diagnostic for callers that previously relied on the error to detect graphs built incorrectly. By contrast, getDefinitions still throws on a missing entry.

    Consider whether a comment, or a separate hasUses()/tryGetUses() accessor, would make the relaxed contract explicit without silently hiding misuse.

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from mdavis36 March 17, 2026 17:42
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants