Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 196 additions & 5 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -187,12 +188,55 @@ def memory_format_differs(shape, spatial_rank):
channel_dim = shape[channel_idx]
return channel_dim > 1 and any(dim > 1 for dim in spatial_dims)

@staticmethod
def _is_nhwc_safe_reshape(input_shape, output_shape) -> bool:
"""Return ``True`` when a 4-D+ reshape can operate directly on NHWC
data.
A reshape is NHWC-safe when its shape_indices are monotonic and both the
batch dimension (index 0) and the channel dimension (last index) are
preserved alone in their output groups.
"""
rank_in = len(input_shape)
rank_out = len(output_shape)
if rank_in < 4 or rank_out < 4:
return False

indices = ToTosaMemoryFormatPass._get_shape_indices(
list(input_shape), list(output_shape)
)
if indices is None or not ToTosaMemoryFormatPass._is_monotonic(indices):
return False

# The channel dim (last axis in NHWC) and batch dim (index 0)
# must each appear alone — merging either with spatial dims
# would reorder data or change element pairing semantics.
channel_idx = rank_in - 1
batch_idx = 0
for group in indices:
if channel_idx in group and len(group) != 1:
return False
if batch_idx in group and len(group) != 1:
return False

batch_found = any(batch_idx in g for g in indices)
channel_found = any(channel_idx in g for g in indices)
return batch_found and channel_found

@staticmethod
def is_channel_reshape(
input_shape, output_shape, input_spatial_rank, output_spatial_rank
):
"""Check whether a reshape touches the logical channel or consolidated
batch dimensions, which would invalidate dim-order annotations.
batch dimensions in a way that would invalidate dim-order annotations.
Returns ``False`` (no transposes needed) when either:
- The reshape does not change the channel or batch dimensions at all, OR
- The reshape is NHWC-safe: monotonic shape_indices with both batch
(index 0) and channel (last index) preserved alone in their output
groups, meaning the view_copy can operate directly on NHWC data.
"""

valid_ranks = {4, 5, 6}
Expand Down Expand Up @@ -220,7 +264,15 @@ def get_batch_prod_dim(shape, spatial_rank):
N_old = get_batch_prod_dim(input_shape, input_spatial_rank)
N_new = get_batch_prod_dim(output_shape, output_spatial_rank)

return (N_old != N_new) or (C_old != C_new)
if (N_old == N_new) and (C_old == C_new):
return False

# The reshape touches batch/channel dims — check whether it is
# NHWC-safe (can operate directly on NHWC data without transposes).
if ToTosaMemoryFormatPass._is_nhwc_safe_reshape(input_shape, output_shape):
return False

return True

@staticmethod
def insert_input_transpose(node, input_node, graph_module):
Expand Down Expand Up @@ -271,7 +323,7 @@ def insert_output_transpose(node, graph_module):
# Guard: mem_format must be a true permutation for the current rank
assert sorted(mem_format) == list(
range(rank)
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"

with graph_module.graph.inserting_after(node):
permute_node = create_node(
Expand All @@ -296,6 +348,65 @@ def insert_output_transpose(node, graph_module):
for user in users:
user.replace_input_with(node, permute_node)

@staticmethod
def _get_shape_indices(
src_shape: list[int], tgt_shape: list[int]
) -> list[list[int]] | None:
"""Greedy dimension matching for reshape operations.
For each target dimension, greedily consumes contiguous source
dimensions whose product equals the target size. Size-1 target
dimensions that do not correspond to any source dimension produce
empty index lists (inserted dims).
Returns ``None`` when no valid mapping exists.
"""
src_idx = 0
result: list[list[int]] = []

for tgt_dim in tgt_shape:
if tgt_dim <= 0:
return None

indices: list[int] = []
remaining = tgt_dim

while src_idx < len(src_shape):
if src_shape[src_idx] == 0:
return None
if remaining % src_shape[src_idx] != 0:
break
indices.append(src_idx)
remaining //= src_shape[src_idx]
src_idx += 1
if remaining == 1:
break

if remaining != 1:
return None

result.append(indices)

if src_idx != len(src_shape):
return None

return result

@staticmethod
def _is_monotonic(indices: list[list[int]]) -> bool:
"""Return ``True`` when all non-empty index groups are strictly ordered
— i.e. each group's indices follow the previous group's.
"""
last_max = -1
for group in indices:
if not group:
continue
if group[0] <= last_max:
return False
last_max = group[-1]
return True

@staticmethod
def _insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
Expand Down Expand Up @@ -329,6 +440,83 @@ def _insert_view_transpose(
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr):
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)

def _try_replace_redundant_permute(
self, node: torch.fx.Node, graph_module: torch.fx.GraphModule
) -> bool:
"""Remove a permute_copy if it duplicates tosa_dim_order.
When a permute_copy's permutation matches the channels-last order
(or its inverse) AND the input is already in NHWC dim_order, the
permute does the same NCHW<>NHWC conversion that tosa_dim_order
already handles — keeping both would double-convert. Remove the
permute by wiring its users directly to its input.
Returns ``True`` if the node was removed.
"""
if node.target not in (
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.permute.default,
):
return False

perm_arg = node.args[1]
assert isinstance(perm_arg, (list, tuple))
perm = list(perm_arg)
rank = len(perm)
sr = node.meta.get("tosa_spatial_rank", 0)

if rank < 3 or sr < 1:
return False

cl_order = list(self._channels_last_order(rank, sr))
cl_inv = list(self._channels_last_inverse_order(rank, sr))
if perm != cl_order and perm != cl_inv:
return False

# Only replace when the permute is genuinely redundant with the
# tosa_dim_order annotation. When the input is already in
# channels-last order (NHWC), any permute matching cl_order or
# cl_inv is a format conversion that tosa_dim_order already
# handles — keeping both would double-convert.
#
# When the input is in NCHW (e.g., from a placeholder or a
# non-spatial op), the permute is the model's intended
# computation and must NOT be replaced.
input_node = node.args[0]
if not isinstance(input_node, torch.fx.Node):
return False
input_dim_order = input_node.meta.get("tosa_dim_order")
if input_dim_order is not None:
if list(input_dim_order) != cl_order:
return False

# The permute is redundant — tosa_dim_order already handles
# the format conversion. Replace with a view_copy (identity
# reshape to the permuted shape) so consumers still see the
# correct shape. The view_copy must NOT be further processed
# by _insert_view_transpose (it's a no-op reshape, not a
# channel-crossing reshape).
output_shape = list(node.meta["val"].shape)
with graph_module.graph.inserting_before(node):
const_shape_node = graph_module.graph.call_function(
exir_ops.backend.tosa.CONST_SHAPE.default,
(output_shape,),
)
const_shape_node.meta["val"] = output_shape
const_shape_node.meta["tosa_dim_order"] = node.meta.get(
"tosa_dim_order", tuple(range(rank))
)
const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
view_node = graph_module.graph.call_function(
exir_ops.edge.aten.view_copy.default,
(input_node, const_shape_node),
)
view_node.meta = dict(node.meta)
node.replace_all_uses_with(view_node)
graph_module.graph.erase_node(node)
return True

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
"""Transposes are needed for operators transforming the input to a
different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
Expand All @@ -345,12 +533,15 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
- 1D/2D tensors
"""
for node in graph_module.graph.nodes:
for node in list(graph_module.graph.nodes):
if node.op != "call_function":
continue

if self._try_replace_redundant_permute(node, graph_module):
continue

# Transpose views
elif node.target == exir_ops.edge.aten.view_copy.default:
if node.target == exir_ops.edge.aten.view_copy.default:
input_node = node.args[0]
input_shape = input_node.meta["val"].shape
output_shape = node.meta["val"].shape
Expand Down
Loading
Loading