diff --git a/core.py b/core.py index 25d8232..c079c23 100644 --- a/core.py +++ b/core.py @@ -1514,13 +1514,58 @@ class PatternRewritePass(BasePass): Uses BasePass's iterative framework with GraphOptimizer.match_patterns_once() for the actual pattern matching. Iterates until convergence (no more matches). + + Supports multiple patterns, each optionally with its own rewriter. """ - def __init__(self, pattern, rewriter, name=None, optimizer_alias=None): + def __init__( + self, + pattern: Optional[Pattern] = None, + rewriter: Optional[AnyType] = None, + name: Optional[str] = None, + optimizer_alias: Optional[str] = None, + patterns: Optional[List[Union[Pattern, Tuple[Pattern, AnyType]]]] = None, + ): + """ + Initialize PatternRewritePass. + + Args: + pattern: Single pattern to match (backward compatibility) + rewriter: Rewriter function for the single pattern + name: Pass name + optimizer_alias: Alias for node naming + patterns: List of Pattern objects or (Pattern, Rewriter) tuples. + If a list of Pattern objects is provided, 'rewriter' must be provided. + """ # Use iterative mode - run until convergence super().__init__(name, optimizer_alias, iterative=True, max_iterations=100) - self.pattern = pattern - self.rewriter = trace_transformation(rewriter) + + self.rules: List[Tuple[Pattern, AnyType]] = [] + + if patterns: + # Wrap the common rewriter once if provided to avoid redundant logging + wrapped_common_rewriter = ( + trace_transformation(rewriter) if rewriter else None + ) + + # Handle multiple patterns + for item in patterns: + if isinstance(item, tuple): + p, r = item + self.rules.append((p, trace_transformation(r))) + else: + if wrapped_common_rewriter is None: + raise ValueError( + "rewriter must be provided if 'patterns' contains Pattern objects" + ) + self.rules.append((item, wrapped_common_rewriter)) + elif pattern and rewriter: + # Handle single pattern (backward compatibility) + self.rules.append((pattern, trace_transformation(rewriter))) + else: + raise ValueError( + "Must provide either 'patterns' or both 'pattern' and 'rewriter'" + ) def transform_once( self, @@ -1534,9 +1579,10 @@ def transform_once( Returns: int: Number of changes made """ - # Register the pattern (clear first to avoid duplicates) + # Register all patterns (clear first to avoid duplicates) optimizer.clear_transformations() - optimizer.add_transformation(self.pattern, self.rewriter) + for pattern, rewriter in self.rules: + optimizer.add_transformation(pattern, rewriter) # Run one pattern matching iteration new_graph_def, changes = optimizer.match_patterns_once( diff --git a/transforms/scalar/algebraic_simplify.py b/transforms/scalar/algebraic_simplify.py index 76c94af..59b87af 100644 --- a/transforms/scalar/algebraic_simplify.py +++ b/transforms/scalar/algebraic_simplify.py @@ -88,8 +88,16 @@ class AlgebraicSimplifyPass(PatternRewritePass): def __init__(self): # We'll handle multiple patterns manually in _rewrite - pattern = Any(alias="op") # fallback, we check inside - super().__init__(pattern, self._rewrite, name="AlgebraicSimplify") + # By providing a list of specific ops, we enable O(1) indexed matching + # instead of O(N) wildcard matching. + supported_ops = [ + "Add", "Sub", "Mul", "Div", "Neg", "LogicalNot", "Abs", + "Square", "Sqrt", "Pow", "Equal", "NotEqual", "Less", + "Greater", "LessEqual", "GreaterEqual", "LogicalAnd", + "LogicalOr", "Select", "Identity" + ] + patterns = [Op(op, alias="op") for op in supported_ops] + super().__init__(patterns=patterns, rewriter=self._rewrite, name="AlgebraicSimplify") def _rewrite(self, match, optimizer): node = match.matched_nodes["op"] diff --git a/transforms/scalar/constant_fold.py b/transforms/scalar/constant_fold.py index 187cff4..d5803b5 100644 --- a/transforms/scalar/constant_fold.py +++ b/transforms/scalar/constant_fold.py @@ -58,9 +58,18 @@ class ConstantFoldPass(PatternRewritePass): """ def __init__(self): - # Matches any operation with all inputs as Const - pattern = Any(alias="op") - super().__init__(pattern, self._rewrite_constant_op, name="ConstantFold") + # By providing a list of specific ops, we enable O(1) indexed matching + # instead of O(N) wildcard matching. + supported_ops = [ + "Add", "Mul", "Sub", "Div", "Neg", "Equal", "NotEqual", "Less", + "Greater", "LessEqual", "GreaterEqual", "LogicalAnd", "LogicalOr", + "LogicalNot", "BitwiseAnd", "BitwiseOr", "BitwiseXor", "Abs", "Exp", + "Expm1", "Log", "Log1p", "Sqrt", "Pow", "Rsqrt", "Square", "Sin", + "Cos", "Tan", "Asin", "Acos", "Atan", "Atan2", "Floor", "Ceil", + "Round", "Sign", "Reshape", "Transpose", "ConcatV2", "Select", "Cast" + ] + patterns = [Op(op, alias="op") for op in supported_ops] + super().__init__(patterns=patterns, rewriter=self._rewrite_constant_op, name="ConstantFold") def _is_all_const(self, inputs, optimizer): """Check if all inputs are Const nodes.