-
Notifications
You must be signed in to change notification settings - Fork 884
Description
🐛 Describe the bug
Arm ToTosaMemoryFormatPass collapses duplicated output slots when equal placeholders are fused
Summary
There is a bug in the Arm Python lowering path where duplicated logical graph outputs can be collapsed onto the same output slot before TOSA serialization.
The issue is triggered when:
FuseEqualPlaceholdersPassmerges two equal constant placeholders into one shared placeholder.ToTosaMemoryFormatPassinserts output transposes.- The same node appears more than once in the graph output tuple.
At that point, ToTosaMemoryFormatPass.insert_input_transpose(...) rewrites the graph output node using replace_input_with(...), which replaces all matching occurrences of the shared node, not just the intended output slot.
This causes one transpose node to be reused for multiple logical output slots, and later inserted transpose nodes become orphaned.
In the VGF lowered path, this can surface as broken output materialization.
Root Cause
In backends/arm/_passes/to_tosa_memory_format_pass.py, graph outputs are handled by iterating output_node.args[0] and calling:
self.insert_input_transpose(output_node, output_node_input, graph_module)Inside insert_input_transpose(...), the problematic rewrite is:
node.replace_input_with(input_node, permute_node)When node is the graph output node and input_node appears multiple times in the output tuple, this rewrites every matching slot at once.
So if the output tuple is effectively:
return (shared_node, shared_node, other_output)the first transpose insertion rewrites both shared_node outputs simultaneously, collapsing the two logical outputs onto the same transpose node.
Minimal Repro
A very small module is enough to reproduce this:
class DuplicateConstantOutputs(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))
def forward(self, x):
return self.grid0, self.grid1, xRun these passes in sequence:
RemoveGetItemPassAnnotateOutputDimOrderPassFuseEqualPlaceholdersPassToTosaMemoryFormatPass
Observed bad graph after ToTosaMemoryFormatPass
return (tosa_transpose_default_1, tosa_transpose_default_1, x)with another transpose node inserted but unused.
Expected
return (tosa_transpose_default_1, tosa_transpose_default_2, x)The two logical output slots should remain distinct.
Why This Matters
This is not just a cosmetic graph issue.
In a real VGF lowered-path failure, this was the mechanism behind duplicated logical outputs becoming corrupted downstream. By the time TOSA/VGF serialization happens, the output tuple has already been collapsed incorrectly, so later stages cannot recover the intended slot structure.
Focused Regression Test
The following unit test reproduces the bug and fails before the fix:
def test_to_tosa_memory_format_preserves_duplicate_output_slots() -> None:
pipeline = PassPipeline(
DuplicateConstantOutputs(),
(torch.rand(1, 2, 32, 32),),
quantize=False,
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
passes_with_exported_program=[
FuseEqualPlaceholdersPass,
ToTosaMemoryFormatPass,
],
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()
graph_module = pipeline.tester.get_artifact().exported_program().graph_module
output_node = graph_module.graph.output_node()
outputs = list(output_node.args[0])
assert outputs[0] is not outputs[1]Failure before fix
AssertionError: Duplicate output slots were collapsed onto the same node after FuseEqualPlaceholdersPass + ToTosaMemoryFormatPass.
assert tosa_transpose_default_1 is not tosa_transpose_default_1
Proposed Fix
Do not use replace_input_with(...) on the whole output node when inserting output transposes.
Instead, rewrite the selected graph-output slot positionally:
- Convert the output tuple/list to a mutable list.
- Replace only the target index with the new transpose node.
- Assign the rebuilt tuple/list back to
output_node.args.
That preserves duplicated logical output slots even when they share the same source node after FuseEqualPlaceholdersPass.
Scope
This appears to be a bug in the Arm Python lowering path, specifically:
FuseEqualPlaceholdersPasscreates the shared-node precondition.ToTosaMemoryFormatPassincorrectly rewrites duplicated output slots.- TOSA / VGF serialization only exposes the already-corrupted graph shape.
So the primary bug is in ToTosaMemoryFormatPass output rewriting, not in downstream scenario export or runtime code.
Versions
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell