Skip to content

Arm ToTosaMemoryFormatPass collapses duplicated output slots when equal placeholders are fused #18320

@Rob-Hughes-Arm

Description

@Rob-Hughes-Arm

🐛 Describe the bug

Arm ToTosaMemoryFormatPass collapses duplicated output slots when equal placeholders are fused

Summary

There is a bug in the Arm Python lowering path where duplicated logical graph outputs can be collapsed onto the same output slot before TOSA serialization.

The issue is triggered when:

  1. FuseEqualPlaceholdersPass merges two equal constant placeholders into one shared placeholder.
  2. ToTosaMemoryFormatPass inserts output transposes.
  3. The same node appears more than once in the graph output tuple.

At that point, ToTosaMemoryFormatPass.insert_input_transpose(...) rewrites the graph output node using replace_input_with(...), which replaces all matching occurrences of the shared node, not just the intended output slot.

This causes one transpose node to be reused for multiple logical output slots, and later inserted transpose nodes become orphaned.

In the VGF lowered path, this can surface as broken output materialization.


Root Cause

In backends/arm/_passes/to_tosa_memory_format_pass.py, graph outputs are handled by iterating output_node.args[0] and calling:

self.insert_input_transpose(output_node, output_node_input, graph_module)

Inside insert_input_transpose(...), the problematic rewrite is:

node.replace_input_with(input_node, permute_node)

When node is the graph output node and input_node appears multiple times in the output tuple, this rewrites every matching slot at once.

So if the output tuple is effectively:

return (shared_node, shared_node, other_output)

the first transpose insertion rewrites both shared_node outputs simultaneously, collapsing the two logical outputs onto the same transpose node.


Minimal Repro

A very small module is enough to reproduce this:

class DuplicateConstantOutputs(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
        self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))

    def forward(self, x):
        return self.grid0, self.grid1, x

Run these passes in sequence:

  • RemoveGetItemPass
  • AnnotateOutputDimOrderPass
  • FuseEqualPlaceholdersPass
  • ToTosaMemoryFormatPass

Observed bad graph after ToTosaMemoryFormatPass

return (tosa_transpose_default_1, tosa_transpose_default_1, x)

with another transpose node inserted but unused.

Expected

return (tosa_transpose_default_1, tosa_transpose_default_2, x)

The two logical output slots should remain distinct.


Why This Matters

This is not just a cosmetic graph issue.

In a real VGF lowered-path failure, this was the mechanism behind duplicated logical outputs becoming corrupted downstream. By the time TOSA/VGF serialization happens, the output tuple has already been collapsed incorrectly, so later stages cannot recover the intended slot structure.


Focused Regression Test

The following unit test reproduces the bug and fails before the fix:

def test_to_tosa_memory_format_preserves_duplicate_output_slots() -> None:
    pipeline = PassPipeline(
        DuplicateConstantOutputs(),
        (torch.rand(1, 2, 32, 32),),
        quantize=False,
        pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
        passes_with_exported_program=[
            FuseEqualPlaceholdersPass,
            ToTosaMemoryFormatPass,
        ],
    )
    pipeline.pop_stage("run_method_and_compare_outputs")
    pipeline.run()

    graph_module = pipeline.tester.get_artifact().exported_program().graph_module
    output_node = graph_module.graph.output_node()
    outputs = list(output_node.args[0])

    assert outputs[0] is not outputs[1]

Failure before fix

AssertionError: Duplicate output slots were collapsed onto the same node after FuseEqualPlaceholdersPass + ToTosaMemoryFormatPass.
assert tosa_transpose_default_1 is not tosa_transpose_default_1

Proposed Fix

Do not use replace_input_with(...) on the whole output node when inserting output transposes.

Instead, rewrite the selected graph-output slot positionally:

  1. Convert the output tuple/list to a mutable list.
  2. Replace only the target index with the new transpose node.
  3. Assign the rebuilt tuple/list back to output_node.args.

That preserves duplicated logical output slots even when they share the same source node after FuseEqualPlaceholdersPass.


Scope

This appears to be a bug in the Arm Python lowering path, specifically:

  • FuseEqualPlaceholdersPass creates the shared-node precondition.
  • ToTosaMemoryFormatPass incorrectly rewrites duplicated output slots.
  • TOSA / VGF serialization only exposes the already-corrupted graph shape.

So the primary bug is in ToTosaMemoryFormatPass output rewriting, not in downstream scenario export or runtime code.

Versions

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell

Metadata

Metadata

Assignees

No one assigned

    Labels

    partner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions