From e1f809d059dc5820ec1c70d71ce2285818e712a8 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:11:18 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20Bolt:=20optimize=20AlgebraicSimplif?= =?UTF-8?q?y=20and=20ConstantFold=20with=20indexed=20matching?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Modified PatternRewritePass to support multiple patterns and updated AlgebraicSimplifyPass and ConstantFoldPass to use specific Op patterns instead of a wildcard Any() pattern. This enables O(1) indexed matching in the GraphOptimizer instead of O(N) wildcard matching for every node, significantly improving optimization performance on large graphs. Co-authored-by: Iorest <16451699+Iorest@users.noreply.github.com> --- core.py | 56 ++++++++++++++++++++++--- transforms/scalar/algebraic_simplify.py | 12 +++++- transforms/scalar/constant_fold.py | 15 +++++-- 3 files changed, 73 insertions(+), 10 deletions(-) diff --git a/core.py b/core.py index 25d8232..c079c23 100644 --- a/core.py +++ b/core.py @@ -1514,13 +1514,58 @@ class PatternRewritePass(BasePass): Uses BasePass's iterative framework with GraphOptimizer.match_patterns_once() for the actual pattern matching. Iterates until convergence (no more matches). + + Supports multiple patterns, each optionally with its own rewriter. """ - def __init__(self, pattern, rewriter, name=None, optimizer_alias=None): + def __init__( + self, + pattern: Optional[Pattern] = None, + rewriter: Optional[AnyType] = None, + name: Optional[str] = None, + optimizer_alias: Optional[str] = None, + patterns: Optional[List[Union[Pattern, Tuple[Pattern, AnyType]]]] = None, + ): + """ + Initialize PatternRewritePass. + + Args: + pattern: Single pattern to match (backward compatibility) + rewriter: Rewriter function for the single pattern + name: Pass name + optimizer_alias: Alias for node naming + patterns: List of Pattern objects or (Pattern, Rewriter) tuples. + If a list of Pattern objects is provided, 'rewriter' must be provided. + """ # Use iterative mode - run until convergence super().__init__(name, optimizer_alias, iterative=True, max_iterations=100) - self.pattern = pattern - self.rewriter = trace_transformation(rewriter) + + self.rules: List[Tuple[Pattern, AnyType]] = [] + + if patterns: + # Wrap the common rewriter once if provided to avoid redundant logging + wrapped_common_rewriter = ( + trace_transformation(rewriter) if rewriter else None + ) + + # Handle multiple patterns + for item in patterns: + if isinstance(item, tuple): + p, r = item + self.rules.append((p, trace_transformation(r))) + else: + if wrapped_common_rewriter is None: + raise ValueError( + "rewriter must be provided if 'patterns' contains Pattern objects" + ) + self.rules.append((item, wrapped_common_rewriter)) + elif pattern and rewriter: + # Handle single pattern (backward compatibility) + self.rules.append((pattern, trace_transformation(rewriter))) + else: + raise ValueError( + "Must provide either 'patterns' or both 'pattern' and 'rewriter'" + ) def transform_once( self, @@ -1534,9 +1579,10 @@ def transform_once( Returns: int: Number of changes made """ - # Register the pattern (clear first to avoid duplicates) + # Register all patterns (clear first to avoid duplicates) optimizer.clear_transformations() - optimizer.add_transformation(self.pattern, self.rewriter) + for pattern, rewriter in self.rules: + optimizer.add_transformation(pattern, rewriter) # Run one pattern matching iteration new_graph_def, changes = optimizer.match_patterns_once( diff --git a/transforms/scalar/algebraic_simplify.py b/transforms/scalar/algebraic_simplify.py index 76c94af..59b87af 100644 --- a/transforms/scalar/algebraic_simplify.py +++ b/transforms/scalar/algebraic_simplify.py @@ -88,8 +88,16 @@ class AlgebraicSimplifyPass(PatternRewritePass): def __init__(self): # We'll handle multiple patterns manually in _rewrite - pattern = Any(alias="op") # fallback, we check inside - super().__init__(pattern, self._rewrite, name="AlgebraicSimplify") + # By providing a list of specific ops, we enable O(1) indexed matching + # instead of O(N) wildcard matching. + supported_ops = [ + "Add", "Sub", "Mul", "Div", "Neg", "LogicalNot", "Abs", + "Square", "Sqrt", "Pow", "Equal", "NotEqual", "Less", + "Greater", "LessEqual", "GreaterEqual", "LogicalAnd", + "LogicalOr", "Select", "Identity" + ] + patterns = [Op(op, alias="op") for op in supported_ops] + super().__init__(patterns=patterns, rewriter=self._rewrite, name="AlgebraicSimplify") def _rewrite(self, match, optimizer): node = match.matched_nodes["op"] diff --git a/transforms/scalar/constant_fold.py b/transforms/scalar/constant_fold.py index 187cff4..d5803b5 100644 --- a/transforms/scalar/constant_fold.py +++ b/transforms/scalar/constant_fold.py @@ -58,9 +58,18 @@ class ConstantFoldPass(PatternRewritePass): """ def __init__(self): - # Matches any operation with all inputs as Const - pattern = Any(alias="op") - super().__init__(pattern, self._rewrite_constant_op, name="ConstantFold") + # By providing a list of specific ops, we enable O(1) indexed matching + # instead of O(N) wildcard matching. + supported_ops = [ + "Add", "Mul", "Sub", "Div", "Neg", "Equal", "NotEqual", "Less", + "Greater", "LessEqual", "GreaterEqual", "LogicalAnd", "LogicalOr", + "LogicalNot", "BitwiseAnd", "BitwiseOr", "BitwiseXor", "Abs", "Exp", + "Expm1", "Log", "Log1p", "Sqrt", "Pow", "Rsqrt", "Square", "Sin", + "Cos", "Tan", "Asin", "Acos", "Atan", "Atan2", "Floor", "Ceil", + "Round", "Sign", "Reshape", "Transpose", "ConcatV2", "Select", "Cast" + ] + patterns = [Op(op, alias="op") for op in supported_ops] + super().__init__(patterns=patterns, rewriter=self._rewrite_constant_op, name="ConstantFold") def _is_all_const(self, inputs, optimizer): """Check if all inputs are Const nodes.