diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 52cbf924fc4..018dc51ea19 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -16,6 +16,7 @@ ) from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -187,12 +188,55 @@ def memory_format_differs(shape, spatial_rank): channel_dim = shape[channel_idx] return channel_dim > 1 and any(dim > 1 for dim in spatial_dims) + @staticmethod + def _is_nhwc_safe_reshape(input_shape, output_shape) -> bool: + """Return ``True`` when a 4-D+ reshape can operate directly on NHWC + data. + + A reshape is NHWC-safe when its shape_indices are monotonic and both the + batch dimension (index 0) and the channel dimension (last index) are + preserved alone in their output groups. + + """ + rank_in = len(input_shape) + rank_out = len(output_shape) + if rank_in < 4 or rank_out < 4: + return False + + indices = ToTosaMemoryFormatPass._get_shape_indices( + list(input_shape), list(output_shape) + ) + if indices is None or not ToTosaMemoryFormatPass._is_monotonic(indices): + return False + + # The channel dim (last axis in NHWC) and batch dim (index 0) + # must each appear alone — merging either with spatial dims + # would reorder data or change element pairing semantics. + channel_idx = rank_in - 1 + batch_idx = 0 + for group in indices: + if channel_idx in group and len(group) != 1: + return False + if batch_idx in group and len(group) != 1: + return False + + batch_found = any(batch_idx in g for g in indices) + channel_found = any(channel_idx in g for g in indices) + return batch_found and channel_found + @staticmethod def is_channel_reshape( input_shape, output_shape, input_spatial_rank, output_spatial_rank ): """Check whether a reshape touches the logical channel or consolidated - batch dimensions, which would invalidate dim-order annotations. + batch dimensions in a way that would invalidate dim-order annotations. + + Returns ``False`` (no transposes needed) when either: + - The reshape does not change the channel or batch dimensions at all, OR + - The reshape is NHWC-safe: monotonic shape_indices with both batch + (index 0) and channel (last index) preserved alone in their output + groups, meaning the view_copy can operate directly on NHWC data. + """ valid_ranks = {4, 5, 6} @@ -220,7 +264,15 @@ def get_batch_prod_dim(shape, spatial_rank): N_old = get_batch_prod_dim(input_shape, input_spatial_rank) N_new = get_batch_prod_dim(output_shape, output_spatial_rank) - return (N_old != N_new) or (C_old != C_new) + if (N_old == N_new) and (C_old == C_new): + return False + + # The reshape touches batch/channel dims — check whether it is + # NHWC-safe (can operate directly on NHWC data without transposes). + if ToTosaMemoryFormatPass._is_nhwc_safe_reshape(input_shape, output_shape): + return False + + return True @staticmethod def insert_input_transpose(node, input_node, graph_module): @@ -271,7 +323,7 @@ def insert_output_transpose(node, graph_module): # Guard: mem_format must be a true permutation for the current rank assert sorted(mem_format) == list( range(rank) - ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" + ), f"bad perm {mem_format} for rank {rank} in insert_output_transpose" with graph_module.graph.inserting_after(node): permute_node = create_node( @@ -296,6 +348,65 @@ def insert_output_transpose(node, graph_module): for user in users: user.replace_input_with(node, permute_node) + @staticmethod + def _get_shape_indices( + src_shape: list[int], tgt_shape: list[int] + ) -> list[list[int]] | None: + """Greedy dimension matching for reshape operations. + + For each target dimension, greedily consumes contiguous source + dimensions whose product equals the target size. Size-1 target + dimensions that do not correspond to any source dimension produce + empty index lists (inserted dims). + + Returns ``None`` when no valid mapping exists. + + """ + src_idx = 0 + result: list[list[int]] = [] + + for tgt_dim in tgt_shape: + if tgt_dim <= 0: + return None + + indices: list[int] = [] + remaining = tgt_dim + + while src_idx < len(src_shape): + if src_shape[src_idx] == 0: + return None + if remaining % src_shape[src_idx] != 0: + break + indices.append(src_idx) + remaining //= src_shape[src_idx] + src_idx += 1 + if remaining == 1: + break + + if remaining != 1: + return None + + result.append(indices) + + if src_idx != len(src_shape): + return None + + return result + + @staticmethod + def _is_monotonic(indices: list[list[int]]) -> bool: + """Return ``True`` when all non-empty index groups are strictly ordered + — i.e. each group's indices follow the previous group's. + """ + last_max = -1 + for group in indices: + if not group: + continue + if group[0] <= last_max: + return False + last_max = group[-1] + return True + @staticmethod def _insert_view_transpose( input_shape, output_shape, node, input_node, graph_module @@ -329,6 +440,83 @@ def _insert_view_transpose( ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr): ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) + def _try_replace_redundant_permute( + self, node: torch.fx.Node, graph_module: torch.fx.GraphModule + ) -> bool: + """Remove a permute_copy if it duplicates tosa_dim_order. + + When a permute_copy's permutation matches the channels-last order + (or its inverse) AND the input is already in NHWC dim_order, the + permute does the same NCHW<>NHWC conversion that tosa_dim_order + already handles — keeping both would double-convert. Remove the + permute by wiring its users directly to its input. + + Returns ``True`` if the node was removed. + + """ + if node.target not in ( + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.permute.default, + ): + return False + + perm_arg = node.args[1] + assert isinstance(perm_arg, (list, tuple)) + perm = list(perm_arg) + rank = len(perm) + sr = node.meta.get("tosa_spatial_rank", 0) + + if rank < 3 or sr < 1: + return False + + cl_order = list(self._channels_last_order(rank, sr)) + cl_inv = list(self._channels_last_inverse_order(rank, sr)) + if perm != cl_order and perm != cl_inv: + return False + + # Only replace when the permute is genuinely redundant with the + # tosa_dim_order annotation. When the input is already in + # channels-last order (NHWC), any permute matching cl_order or + # cl_inv is a format conversion that tosa_dim_order already + # handles — keeping both would double-convert. + # + # When the input is in NCHW (e.g., from a placeholder or a + # non-spatial op), the permute is the model's intended + # computation and must NOT be replaced. + input_node = node.args[0] + if not isinstance(input_node, torch.fx.Node): + return False + input_dim_order = input_node.meta.get("tosa_dim_order") + if input_dim_order is not None: + if list(input_dim_order) != cl_order: + return False + + # The permute is redundant — tosa_dim_order already handles + # the format conversion. Replace with a view_copy (identity + # reshape to the permuted shape) so consumers still see the + # correct shape. The view_copy must NOT be further processed + # by _insert_view_transpose (it's a no-op reshape, not a + # channel-crossing reshape). + output_shape = list(node.meta["val"].shape) + with graph_module.graph.inserting_before(node): + const_shape_node = graph_module.graph.call_function( + exir_ops.backend.tosa.CONST_SHAPE.default, + (output_shape,), + ) + const_shape_node.meta["val"] = output_shape + const_shape_node.meta["tosa_dim_order"] = node.meta.get( + "tosa_dim_order", tuple(range(rank)) + ) + const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE + view_node = graph_module.graph.call_function( + exir_ops.edge.aten.view_copy.default, + (input_node, const_shape_node), + ) + view_node.meta = dict(node.meta) + node.replace_all_uses_with(view_node) + graph_module.graph.erase_node(node) + return True + def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): """Transposes are needed for operators transforming the input to a different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC- @@ -345,12 +533,15 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): - 1D/2D tensors """ - for node in graph_module.graph.nodes: + for node in list(graph_module.graph.nodes): if node.op != "call_function": continue + if self._try_replace_redundant_permute(node, graph_module): + continue + # Transpose views - elif node.target == exir_ops.edge.aten.view_copy.default: + if node.target == exir_ops.edge.aten.view_copy.default: input_node = node.args[0] input_shape = input_node.meta["val"].shape output_shape = node.meta["val"].shape diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index dfd57aa7e61..1eb4954e6a2 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -177,11 +177,79 @@ def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) +class NHWCSafeSpatialMerge(torch.nn.Module): + """Test-module with a 4D->4D reshape that merges spatial dims H*W while + preserving the last-dim channel. + + For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2 + sits at NCHW position 1 and the last dim (72) is the NHWC channel that gets + preserved. ``is_channel_reshape`` returns False (NHWC-safe: monotonic + shape_indices with batch and channel alone), so no transposes are inserted + around the view_copy. + + Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC). + + """ + + ops_before_pass: Dict[str, int] = {} + # Only the 2 I/O transposes for the conv, NO extra transposes from view_copy + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=2, out_channels=2, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72] + x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved + return x + x # keep result 4-D in NHWC + + def get_inputs(self) -> input_t: + return (torch.randn(1, 2, 14, 72),) + + +class NHWCUnsafeChannelChange(torch.nn.Module): + """Test-module with a 4D->4D reshape that is NOT NHWC-safe because the + target shape cannot be produced by a monotonic merge of NHWC input dims. + + The pass MUST still insert transposes around the view_copy. + + """ + + ops_before_pass: Dict[str, int] = {} + # conv I/O transposes (2) + view_copy transposes (2) = 4 + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=72, out_channels=72, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # output [1, 72, 2, 14] + x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled) + return x + x + + def get_inputs(self) -> input_t: + return (torch.randn(1, 72, 2, 14),) + + modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), "serial_clusters": SerialClusters(), "reshapes": Reshapes(), + "nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(), + "nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(), } @@ -209,3 +277,79 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No module_nn = cast(torch.nn.Module, module) pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() + + +# --- Direct unit tests for NHWC-safe reshape detection in is_channel_reshape --- + + +def test_get_shape_indices_spatial_merge(): + """[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([1, 2, 14, 72], [1, 28, 1, 72]) + assert indices == [[0], [1, 2], [], [3]] + + +def test_get_shape_indices_identity(): + """Same shape => each dim maps to itself.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4]) + assert indices == [[0], [1], [2]] + + +def test_get_shape_indices_full_merge(): + """[2, 3, 4] -> [24]: merge all dims into one.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24]) + assert indices == [[0, 1, 2]] + + +def test_get_shape_indices_incompatible(): + """Sizes that don't divide => None.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4]) + assert indices is None + + +def test_get_shape_indices_size_one_insert(): + """[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4]) + assert indices is not None + assert indices == [[0], [], [1]] + + +def test_is_monotonic_true(): + assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]]) + + +def test_is_monotonic_false(): + assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]]) + assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]]) + + +def test_channel_reshape_nhwc_safe(): + """Shapes already NHWC by the time the pass runs. + + [1,2,14,72] -> [1,28,1,72], sr=2 -> NHWC-safe (spatial merge, C=72 + preserved). is_channel_reshape should return False (no transposes needed). + + """ + assert not ToTosaMemoryFormatPass.is_channel_reshape( + [1, 2, 14, 72], [1, 28, 1, 72], input_spatial_rank=2, output_spatial_rank=2 + ) + + +def test_channel_reshape_non_4d(): + """Reshapes below rank 4 always return False from is_channel_reshape.""" + assert not ToTosaMemoryFormatPass.is_channel_reshape( + [6, 4], [24], input_spatial_rank=0, output_spatial_rank=0 + ) + + +def test_channel_reshape_batch_merge(): + """Reshapes merging batch with spatial dims are NOT NHWC-safe.""" + # [1,2,5,10] -> [2,1,5,10]: merges N(=1) with H(=2) — not safe + assert ToTosaMemoryFormatPass.is_channel_reshape( + [1, 2, 5, 10], [2, 1, 5, 10], input_spatial_rank=2, output_spatial_rank=2 + ) + # [5,10,25,20] -> [1250,20,1,1]: merges N+H+W — not safe (Linear decomp) + assert ToTosaMemoryFormatPass.is_channel_reshape( + [5, 10, 25, 20], [1250, 20, 1, 1], input_spatial_rank=2, output_spatial_rank=2 + )