From 45d73d8e887433bdcd2202744d49f835d3947617 Mon Sep 17 00:00:00 2001 From: mcremon-meta <134334895+mcremon-meta@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:52:50 -0700 Subject: [PATCH] Revert "Move optimization passes from opt_level=0 to opt_level=1 (#18206)" This reverts commit bf2243a9fa577b3b46723a533c1cd4a0883b99ba. --- backends/cadence/aot/BUCK | 1 - backends/cadence/aot/decompose_ops.py | 6 +- backends/cadence/aot/functions.yaml | 9 +- backends/cadence/aot/fuse_ops.py | 5 +- backends/cadence/aot/ops_registrations.py | 55 +------ backends/cadence/aot/passes.py | 2 - backends/cadence/aot/quantizer/patterns.py | 7 +- backends/cadence/aot/ref_implementations.py | 35 +---- backends/cadence/aot/remove_ops.py | 25 ++- backends/cadence/aot/replace_ops.py | 110 +++---------- backends/cadence/aot/simplify_ops.py | 3 +- .../aot/tests/test_replace_ops_passes.py | 54 ------- .../operators/op_quantized_max_pool2d.cpp | 12 +- .../operators/op_quantized_max_pool2d.h | 2 +- .../op_quantized_max_pool2d_nhwc.cpp | 137 ---------------- .../operators/op_quantized_max_pool2d_nhwc.h | 30 ---- .../cadence/generic/operators/targets.bzl | 12 -- .../op_quantized_max_pool2d_nhwc.cpp | 148 ------------------ backends/cadence/hifi/operators/targets.bzl | 10 -- 19 files changed, 50 insertions(+), 613 deletions(-) delete mode 100644 backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp delete mode 100644 backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h delete mode 100644 backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index e4bc833c183..c85dc23c4bd 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -300,7 +300,6 @@ fbcode_target(_kind = runtime.python_library, ], typing = True, deps = [ - ":fuse_ops", ":ops_registrations", "//caffe2:torch", "//executorch/backends/cadence/aot:pass_utils", diff --git a/backends/cadence/aot/decompose_ops.py b/backends/cadence/aot/decompose_ops.py index 0e0d15dd7fb..7ee1bb36fef 100644 --- a/backends/cadence/aot/decompose_ops.py +++ b/backends/cadence/aot/decompose_ops.py @@ -23,12 +23,10 @@ from torch.fx.node import Argument -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class DecomposeAtenApproxGeluPass(ExportPass): """ - Decompose the aten gelu op with an approximate arg to a series of simpler ops. - This is an optimization - gelu has a portable kernel fallback, but decomposing - may be more efficient on some backends. + Decompose the aten gelu op with an approximate arg to a series of simpler ops """ def call_operator( diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 1c4dd3e06f3..80de190fedf 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -309,15 +309,10 @@ - arg_meta: null kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out -- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::generic::quantized_max_pool2d_nchw_out - -- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) - kernels: - - arg_meta: null - kernel_name: impl::generic::quantized_max_pool2d_nhwc_out + kernel_name: impl::generic::quantized_max_pool2d_out - func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index e71803c03bb..34bf3b11684 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -1170,10 +1170,7 @@ def can_fuse_for_chain( return False # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions - producer_input = cast(torch.fx.Node, producer.args[0]) - if "val" not in producer_input.meta: - return False - input_shape = producer_input.meta["val"].shape + input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape ident_dims = list(range(len(input_shape))) # this mapping helps to handle both transpose and permutations f: dict[Any, Callable] = { diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 060702becec..601d54fe49b 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -214,16 +214,10 @@ def register_fake( ) lib.define( - "quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" + "quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" ) lib.define( - "quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" -) -lib.define( - "quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" -) -lib.define( - "quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( @@ -2283,8 +2277,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta( return input.new_empty(input.size(), dtype=input.dtype) -@register_fake("cadence::quantized_max_pool2d_nchw") -def quantized_max_pool2d_nchw_meta( +@register_fake("cadence::quantized_max_pool2d") +def quantized_max_pool2d_meta( input: torch.Tensor, kernel_size: list[int], stride: list[int], @@ -2324,47 +2318,6 @@ def quantized_max_pool2d_nchw_meta( return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype) -@register_fake("cadence::quantized_max_pool2d_nhwc") -def quantized_max_pool2d_nhwc_meta( - input: torch.Tensor, - kernel_size: list[int], - stride: list[int], - padding: list[int], - dilation: list[int], - ceil_mode: bool, -) -> torch.Tensor: - assert ( - len(kernel_size) == 2 - ), f"kernel_size must have 2 elements, got {len(kernel_size)}" - assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}" - assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}" - assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}" - assert ( - len(input.size()) == 4 - ), f"input must be 4D (N, H, W, C), got {len(input.size())}D" - - batch = input.size(0) - height_in = input.size(1) - width_in = input.size(2) - channels = input.size(3) - - height_out_raw = ( - height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 - ) / stride[0] + 1 - width_out_raw = ( - width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 - ) / stride[1] + 1 - - if ceil_mode: - height_out = ceil(height_out_raw) - width_out = ceil(width_out_raw) - else: - height_out = int(height_out_raw) - width_out = int(width_out_raw) - - return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype) - - @register_fake("cadence::fully_connected") def fully_connected_meta( src: torch.Tensor, diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index 647819d91ab..bb4a8f065d5 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -25,7 +25,6 @@ from executorch.backends.cadence.aot.remove_ops import ( CadenceRemoveNops, RemoveNopSliceOrViewOpPass, - RemovePermutesAroundElementwiseOps, RemoveRedundantOps, ) from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph @@ -90,7 +89,6 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]: CadenceSimplifyOpsInGraph.passes, FinalizePipeline, FuseFullThenReshapePass, - RemovePermutesAroundElementwiseOps, FuseTransposeOrPermuteOpPairsPass, RemoveNopSliceOrViewOpPass, CompileTimeTypeDispatchPass, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 204f066ebf4..0d52c004dea 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -459,7 +459,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_max_pool2d_nchw.default + return torch.ops.cadence.quantized_max_pool2d.default class MaxPool2dWithoutIndicesPattern(QuantizationPattern): @@ -498,10 +498,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_max_pool2d_nchw.default - - -# This is a base class for ReLU + return torch.ops.cadence.quantized_max_pool2d.default # This is a base class for ReLU, since it can be used with two different aten ops diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index f985718c150..ed8b3ca60ae 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1868,8 +1868,8 @@ def rms_norm( return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) -@impl_tracked(m, "quantized_max_pool2d_nchw") -def quantized_max_pool2d_nchw( +@impl_tracked(m, "quantized_max_pool2d") +def quantized_max_pool2d( input: torch.Tensor, kernel_size: list[int], stride: list[int], @@ -1897,37 +1897,6 @@ def quantized_max_pool2d_nchw( ) -@impl_tracked(m, "quantized_max_pool2d_nhwc") -def quantized_max_pool2d_nhwc( - input: torch.Tensor, - kernel_size: list[int], - stride: list[int], - padding: list[int], - dilation: list[int], - ceil_mode: bool, -) -> torch.Tensor: - """ - Quantized max pooling in NHWC layout. - - Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC. - """ - # Convert NHWC [N, H, W, C] to NCHW [N, C, H, W] - input_nchw = input.permute(0, 3, 1, 2).contiguous() - - # Call the NCHW version - output_nchw = quantized_max_pool2d_nchw( - input_nchw, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - ) - - # Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C] - return output_nchw.permute(0, 2, 3, 1).contiguous() - - @impl_tracked(m, "where_Scalar") def where_Scalar( condition: torch.Tensor, diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index f5a47ea2603..8e1d6d1f07e 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -14,8 +14,6 @@ import torch import torch.fx - -from executorch.backends.cadence.aot.fuse_ops import FuseTransposeOrPermuteOpPairsPass from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, get_arg, @@ -23,6 +21,7 @@ RemoveOrReplacePassInterface, set_arg, ) + from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform @@ -34,7 +33,7 @@ from torch.fx.node import Node -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveCloneOpsTransformImported(ExportPass): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: finalize_passes: List[PassType] = [ @@ -45,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return result -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveDetachCopyPass(RemoveOrReplacePassInterface): @property def targets(self) -> list[EdgeOpOverload]: @@ -67,7 +66,7 @@ class RemoveRedundantOps: ] -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface): @property def targets(self) -> list[EdgeOpOverload]: @@ -121,11 +120,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool: return False -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveNopExpandOpPass(RemoveOrReplacePassInterface): """ For an expand op, if the operator shape matches the expand shape, then the - expand is a nop. This is an optimization that removes unnecessary ops. + expand is a nop. """ @property @@ -144,9 +143,9 @@ def maybe_remove_or_replace(self, node: Node) -> bool: return False -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveToOpsPass(RemoveOrReplacePassInterface): - # aten.to.* ops are no-ops in inference - this is an optimization + # aten.to.* as of now are all nops @property def targets(self) -> list[EdgeOpOverload]: return [ @@ -265,11 +264,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface): """ + alias_copy is a no-op and can be removed. - This is an optimization that removes unnecessary ops. """ @property @@ -413,9 +412,6 @@ class Subgraph: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.cadence.quantize_per_tensor.default, exir_ops.edge.cadence.dequantize_per_tensor.default, - exir_ops.edge.cadence.quantized_relu.per_tensor, - exir_ops.edge.cadence.requantize.per_tensor, - exir_ops.edge.cadence.quantized_add.per_tensor, # Ops that require special handling. exir_ops.edge.aten.cat.default, exir_ops.edge.aten.mean.dim, @@ -808,7 +804,6 @@ class CommonRemovePasses: RemoveToOpsPass, RemoveZeroSizedCatArgsPass, RemovePermutesAroundElementwiseOps, - FuseTransposeOrPermuteOpPairsPass, RemoveSqueezeViewBeforeElementwiseOps, RemoveCatFromSliceCopyPass, RemoveCloneOpsTransformImported, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index dfc7d90b4ea..14a35c01baf 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -45,12 +45,11 @@ } -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface): """ A where op with a logical_not and a boolean tensor can be replaced by a where op with flipped inputs and the initial boolean tensor. - This is an optimization that simplifies the graph. """ @property @@ -89,11 +88,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): # keep """ - Replace _safe_softmax with _softmax. - This is an optimization - both ops are functionally equivalent for inference. + Replace _safe_softmax with _softmax """ @property @@ -171,11 +169,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceSqueezeAndUnsqueezeWithViewPass(RemoveOrReplacePassInterface): """ When the shape is static, replace squeeze_copy and unsqueeze_copy ops with - view_copy op. This is an optimization that reduces op variety in the graph. + view_copy op """ @property @@ -208,12 +206,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceFunctionallyEquivalentOpTargets(RemoveOrReplacePassInterface): """ Replace an op with a functionally equivalent op by just switching the op target, but without incurring any change to the op args. - This is an optimization that normalizes the graph to use canonical op variants. """ @property @@ -278,12 +275,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return False -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceMMWithAddMMPass(RemoveOrReplacePassInterface): """ This pass replaces mm with addmm by introducing a zero bias. - This is an optimization - mm has a portable kernel fallback, but addmm - may be more efficient on some backends. + mm is not supported, so this is an opt_level=0 pass. """ @property @@ -475,12 +471,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return False -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(RemoveOrReplacePassInterface): """ Replace optional tensors with concrete tensors. Currently, we replace the optional bias tensor with a zero tensor. - This is an optimization that simplifies kernel dispatch. """ @property @@ -533,12 +528,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceRepeatWithCatPass(RemoveOrReplacePassInterface): """ Replace repeat op as successive cat ops along different dimensions. - This is an optimization - repeat has a portable kernel fallback, but - cat may be more efficient on some backends. + repeat is not supported, so this is an opt_level=0 pass. """ @property @@ -1188,67 +1182,6 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=3)) -class ReplaceMaxPool2dWithChannelLastMaxPool2dPass(RemoveOrReplacePassInterface): - """ - Replace NCHW max pooling with NHWC (channel-last) max pooling by adding - permute operations before and after the max pooling. - """ - - @property - def targets(self) -> list[EdgeOpOverload]: - return [ - exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, - ] - - def _change_nchw_to_nhwc( - self, graph: torch.fx.Graph, node: torch.fx.Node - ) -> torch.fx.Node: - """Convert NCHW format to NHWC format.""" - permute_node = graph.call_function( - exir_ops.edge.aten.permute_copy.default, (node, [0, 2, 3, 1]), {} - ) - permute_node.meta = node.meta - return permute_node - - def _change_nhwc_to_nchw( - self, graph: torch.fx.Graph, node: torch.fx.Node - ) -> torch.fx.Node: - """Convert NHWC format to NCHW format.""" - permute_node = graph.call_function( - exir_ops.edge.aten.permute_copy.default, (node, [0, 3, 1, 2]), {} - ) - permute_node.meta = node.meta - return permute_node - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - graph = node.graph - - # Get input node - input_node = cast(torch.fx.Node, node.args[0]) - - with graph.inserting_before(node): - # Convert input from NCHW to NHWC - input_nhwc = self._change_nchw_to_nhwc(graph, input_node) - - # Create the NHWC max pooling with the same args (kernel_size, stride, padding, dilation, ceil_mode) - new_args = (input_nhwc,) + tuple(node.args[1:]) - - new_pool = graph.call_function( - exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, - new_args, - node.kwargs, - ) - new_pool.meta = node.meta - - # Convert output back from NHWC to NCHW - nchw_output = self._change_nhwc_to_nchw(graph, new_pool) - - # Replace all uses with the final output - node.replace_all_uses_with(nchw_output) - return True - - @register_cadence_pass(CadencePassAttribute(opt_level=3)) class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): """ @@ -1837,12 +1770,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass) -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceScalarTensorWithFullPass(RemoveOrReplacePassInterface): """ aten.scalar_tensor can be replaced by aten.full with a shape of [1]. - This is an optimization - scalar_tensor has a portable kernel fallback, - but using full may reduce op variety in the graph. + scalar_tensor is not supported, so this is an opt_level=0 pass. """ @property @@ -1867,12 +1799,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceFullLikeWithFullPass(RemoveOrReplacePassInterface): """ aten.full_like can be replaced by aten.full with the shape of the arg tensor. - This is an optimization - full_like has a portable kernel fallback, - but using full may reduce op variety in the graph. + full_like is not supported, so this is an opt_level=0 pass. """ @property @@ -1896,12 +1827,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceInfArgInFullWithValuePass(RemoveOrReplacePassInterface): """ aten.full allows "-inf" and "inf" as inputs. The profiler cannot handle that, so replace them with the maximum value of the type. - This is an optimization for tooling compatibility, not runtime correctness. """ @property @@ -2075,7 +2005,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=0)) +@register_cadence_pass(CadencePassAttribute(opt_level=1)) class ReplaceEmptyTensorsWithFullPass(ExportPass): """Replaces nodes that produce empty tensors with full nodes.""" @@ -2288,12 +2218,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceMatmulWithTransposedMatmulPass(RemoveOrReplacePassInterface): """ For certain backends, we have efficient kernels for transposed matmul. We replace AxB with AxB' for such backends. - This is a performance optimization. """ @property @@ -2632,7 +2561,6 @@ class CadenceReplaceOpsInGraph: ReplacePadWithCatPass, ReplaceConstantPadNdWithSlicePass, ReplaceConvWithChannelLastConvPass, - ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceTrivialConvWithLinear, ReplaceConvWithIm2RowAndLinear, ReplaceTransposedConvWithLinearPass, diff --git a/backends/cadence/aot/simplify_ops.py b/backends/cadence/aot/simplify_ops.py index b8842a38045..0cd16e18721 100644 --- a/backends/cadence/aot/simplify_ops.py +++ b/backends/cadence/aot/simplify_ops.py @@ -25,11 +25,10 @@ from torch.fx import Node -@register_cadence_pass(CadencePassAttribute(opt_level=1)) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) class SimplifySliceOpPass(RemoveOrReplacePassInterface): """ Simplify the start and end indices of slice and slice_scatter ops. - This is an optimization that normalizes slice indices for easier processing. """ def adjust_slice_range( diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 5d9f8c0784b..95d470644a0 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -36,7 +36,6 @@ ReplaceLinearWithFullyConnectedOpPass, ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, - ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, ReplaceNopTransposeOrPermuteWithViewPass, @@ -2587,59 +2586,6 @@ def test_cat_insert_transpose(self) -> None: ) -class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase): - def test_replace_max_pool2d_nchw_with_nhwc(self) -> None: - # Create a graph with a single quantized_max_pool2d_nchw node. - x = torch.randint(0, 100, (1, 3, 8, 8), dtype=torch.int8) - gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, - args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False), - ) - self.assertEqual( - count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1 - ) - self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) - - # Deepcopy before the pass - original = copy.deepcopy(gm) - - # Apply replacement pass. - p = ReplaceMaxPool2dWithChannelLastMaxPool2dPass() - result = p.call(gm) - self.assertTrue(result.modified) - gm_after_replacement = result.graph_module - - # Check that replacement was made. - self.assertEqual( - count_node( - gm_after_replacement, - exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, - ), - 1, - ) - self.assertEqual( - count_node( - gm_after_replacement, - exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, - ), - 0, - ) - # Two permutes: one for input NCHW->NHWC, one for output NHWC->NCHW - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), - 2, - ) - - # Validate numerical accuracy - validate( - original, - gm_after_replacement, - (x,), - "ReplaceMaxPool2dWithChannelLastMaxPool2dPass", - ) - - class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]: builder = GraphBuilder() diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp index f843ad84080..b241b0851a9 100644 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp @@ -27,7 +27,7 @@ using ::executorch::runtime::KernelRuntimeContext; namespace { template -void quantized_max_pool2d_nchw_impl( +void quantized_max_pool2d_impl( const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -98,7 +98,7 @@ void quantized_max_pool2d_nchw_impl( } // namespace -Tensor& quantized_max_pool2d_nchw_out( +Tensor& quantized_max_pool2d_out( ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, IntArrayRef kernel_size, @@ -107,9 +107,9 @@ Tensor& quantized_max_pool2d_nchw_out( IntArrayRef dilation, bool ceil_mode, Tensor& output) { -#define typed_quantized_max_pool2d_nchw(ctype, dtype) \ +#define typed_quantized_max_pool2d(ctype, dtype) \ case ScalarType::dtype: { \ - quantized_max_pool2d_nchw_impl( \ + quantized_max_pool2d_impl( \ input, kernel_size, stride, padding, dilation, ceil_mode, output); \ break; \ } @@ -117,14 +117,14 @@ Tensor& quantized_max_pool2d_nchw_out( ScalarType dtype = input.scalar_type(); // NOLINTBEGIN(clang-diagnostic-switch-enum) switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nchw) + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d) default: ET_DCHECK_MSG( false, "Unhandled dtype %s", torch::executor::toString(dtype)); } // NOLINTEND(clang-diagnostic-switch-enum) -#undef typed_quantized_max_pool2d_nchw +#undef typed_quantized_max_pool2d return output; } diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.h b/backends/cadence/generic/operators/op_quantized_max_pool2d.h index 453dd5a2582..07f406a37a7 100644 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d.h +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.h @@ -15,7 +15,7 @@ namespace impl { namespace generic { namespace native { -::executorch::aten::Tensor& quantized_max_pool2d_nchw_out( +::executorch::aten::Tensor& quantized_max_pool2d_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& input, ::executorch::aten::IntArrayRef kernel_size, diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp deleted file mode 100644 index cb4a9616394..00000000000 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include -#include -#include - -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -namespace { - -template -void quantized_max_pool2d_nhwc_impl( - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - ET_UNUSED bool ceil_mode, - Tensor& output) { - const T* __restrict__ in_data = input.const_data_ptr(); - T* __restrict__ out_data = output.mutable_data_ptr(); - - // Input dimensions: [N, H, W, C] - const int64_t batch_size = input.size(0); - const int64_t in_height = input.size(1); - const int64_t in_width = input.size(2); - const int64_t channels = input.size(3); - - // Output dimensions: [N, H_out, W_out, C] - const int64_t out_height = output.size(1); - const int64_t out_width = output.size(2); - - // Pooling parameters - const int64_t kernel_h = kernel_size[0]; - const int64_t kernel_w = kernel_size[1]; - const int64_t stride_h = stride[0]; - const int64_t stride_w = stride[1]; - const int64_t pad_h = padding[0]; - const int64_t pad_w = padding[1]; - const int64_t dilation_h = dilation[0]; - const int64_t dilation_w = dilation[1]; - - for (int64_t n = 0; n < batch_size; ++n) { - for (int64_t oh = 0; oh < out_height; ++oh) { - for (int64_t ow = 0; ow < out_width; ++ow) { - const int64_t ih_start = oh * stride_h - pad_h; - const int64_t iw_start = ow * stride_w - pad_w; - - T* __restrict__ out_ptr = - out_data + ((n * out_height + oh) * out_width + ow) * channels; - - // Initialize all channels to the minimum value. - for (int64_t c = 0; c < channels; ++c) { - out_ptr[c] = std::numeric_limits::lowest(); - } - - // For each kernel position, compute element-wise max across all - // channels. The inner loop over channels is a stride-1 contiguous - // access in NHWC layout, enabling SIMD auto-vectorization. - for (int64_t kh = 0; kh < kernel_h; ++kh) { - const int64_t ih = ih_start + kh * dilation_h; - if (ih < 0 || ih >= in_height) { - continue; - } - for (int64_t kw = 0; kw < kernel_w; ++kw) { - const int64_t iw = iw_start + kw * dilation_w; - if (iw < 0 || iw >= in_width) { - continue; - } - - const T* __restrict__ in_ptr = - in_data + ((n * in_height + ih) * in_width + iw) * channels; - - for (int64_t c = 0; c < channels; ++c) { - out_ptr[c] = std::max(out_ptr[c], in_ptr[c]); - } - } - } - } - } - } -} - -} // namespace - -Tensor& quantized_max_pool2d_nhwc_out( - ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode, - Tensor& output) { -#define typed_quantized_max_pool2d_nhwc(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_max_pool2d_nhwc_impl( \ - input, kernel_size, stride, padding, dilation, ceil_mode, output); \ - break; \ - } - - ScalarType dtype = input.scalar_type(); - // NOLINTBEGIN(clang-diagnostic-switch-enum) - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nhwc) - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - // NOLINTEND(clang-diagnostic-switch-enum) - -#undef typed_quantized_max_pool2d_nhwc - return output; -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h deleted file mode 100644 index 2b0c02e4bb7..00000000000 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -namespace impl { -namespace generic { -namespace native { - -::executorch::aten::Tensor& quantized_max_pool2d_nhwc_out( - ::executorch::runtime::KernelRuntimeContext& ctx, - const ::executorch::aten::Tensor& input, - ::executorch::aten::IntArrayRef kernel_size, - ::executorch::aten::IntArrayRef stride, - ::executorch::aten::IntArrayRef padding, - ::executorch::aten::IntArrayRef dilation, - bool ceil_mode, - ::executorch::aten::Tensor& output); - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index fa6708a188e..bf1de9e009a 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -225,18 +225,6 @@ def define_common_targets(): visibility = ["PUBLIC"], ) - runtime.cxx_library( - name = "op_quantized_max_pool2d_nhwc", - srcs = ["op_quantized_max_pool2d_nhwc.cpp"], - exported_headers = ["op_quantized_max_pool2d_nhwc.h"], - platforms = CXX, - deps = [ - "//executorch/runtime/kernel:kernel_includes", - ":cadence_type_util", - ], - visibility = ["PUBLIC"], - ) - runtime.cxx_library( name = "op_quantized_matmul", srcs = ["op_quantized_matmul.cpp"], diff --git a/backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp b/backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp deleted file mode 100644 index 69c4a3fbc45..00000000000 --- a/backends/cadence/hifi/operators/op_quantized_max_pool2d_nhwc.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace impl { -namespace HiFi { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -Tensor& quantized_max_pool2d_nhwc_out( - KernelRuntimeContext& ctx, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode, - Tensor& output) { - // NHWC layout: [N, H, W, C] - const int32_t batch_size = input.size(0); - const int32_t in_height = input.size(1); - const int32_t in_width = input.size(2); - const int32_t channels = input.size(3); - - const int32_t out_height = output.size(1); - const int32_t out_width = output.size(2); - - const int32_t kernel_h = kernel_size[0]; - const int32_t kernel_w = kernel_size[1]; - const int32_t stride_h = stride[0]; - const int32_t stride_w = stride[1]; - const int32_t pad_h = padding[0]; - const int32_t pad_w = padding[1]; - - // Determine NNLIB precision constants based on dtype - ScalarType dtype = input.scalar_type(); - int32_t nnlib_precision; - switch (dtype) { - case ScalarType::Char: // int8 - nnlib_precision = PREC_SYM8S; - break; - case ScalarType::Byte: // uint8 - nnlib_precision = PREC_ASYM8U; - break; - default: - ET_DCHECK_MSG( - false, - "Unsupported dtype %s for HiFi quantized_max_pool2d_nhwc", - torch::executor::toString(dtype)); - return output; - } - - // Compute scratch buffer size for NNLIB maxpool - int32_t scratch_size = xa_nn_maxpool_getsize( - channels, - nnlib_precision, - nnlib_precision, - in_height, - in_width, - kernel_h, - kernel_w, - stride_w, // x_stride - stride_h, // y_stride - pad_w, // x_padding - pad_h, // y_padding - out_height, - out_width, - 0, // inp_data_format: 0 = NHWC - 0); // out_data_format: 0 = NHWC - ET_DCHECK_MSG(scratch_size >= 0, "xa_nn_maxpool_getsize failed"); - - // Allocate aligned scratch memory - void* p_scratch = kernels::allocate_temp_memory(ctx, scratch_size); - - // Process each batch using NNLIB optimized maxpool kernel - for (int32_t n = 0; n < batch_size; ++n) { - const int32_t spatial_size = in_height * in_width * channels; - const int32_t out_spatial_size = out_height * out_width * channels; - - int32_t ret; - if (dtype == ScalarType::Char) { - const int8_t* in_batch = - input.const_data_ptr() + n * spatial_size; - int8_t* out_batch = - output.mutable_data_ptr() + n * out_spatial_size; - - ret = xa_nn_maxpool_8( - out_batch, - in_batch, - in_height, - in_width, - channels, - kernel_h, - kernel_w, - stride_w, // x_stride - stride_h, // y_stride - pad_w, // x_padding - pad_h, // y_padding - out_height, - out_width, - 0, // inp_data_format: NHWC - 0, // out_data_format: NHWC - p_scratch); - } else { - const uint8_t* in_batch = - input.const_data_ptr() + n * spatial_size; - uint8_t* out_batch = - output.mutable_data_ptr() + n * out_spatial_size; - - ret = xa_nn_maxpool_asym8( - out_batch, - in_batch, - in_height, - in_width, - channels, - kernel_h, - kernel_w, - stride_w, // x_stride - stride_h, // y_stride - pad_w, // x_padding - pad_h, // y_padding - out_height, - out_width, - 0, // inp_data_format: NHWC - 0, // out_data_format: NHWC - p_scratch); - } - ET_DCHECK_MSG(ret == 0, "HiFi xa_nn_maxpool failed"); - } - - return output; -} - -} // namespace native -} // namespace HiFi -} // namespace impl diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index 1ea57862cf6..9753051bf72 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -632,16 +632,6 @@ def define_common_targets(): compatible_with = ["ovr_config//cpu:xtensa"], ) - runtime.cxx_library( - name = "op_quantized_max_pool2d_nhwc", - srcs = ["op_quantized_max_pool2d_nhwc.cpp"], - exported_headers = ["operators.h"], - platforms = CXX, - deps = COMMON_DEPS, - visibility = ["PUBLIC"], - compatible_with = ["ovr_config//cpu:xtensa"], - ) - runtime.cxx_library( name = "op_quantized_relu_asym8s_asym8s_per_tensor_out", srcs = ["op_quantized_relu_asym8s_asym8s_per_tensor_out.cpp"],