Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,13 +1514,58 @@ class PatternRewritePass(BasePass):

Uses BasePass's iterative framework with GraphOptimizer.match_patterns_once()
for the actual pattern matching. Iterates until convergence (no more matches).

Supports multiple patterns, each optionally with its own rewriter.
"""

def __init__(self, pattern, rewriter, name=None, optimizer_alias=None):
def __init__(
self,
pattern: Optional[Pattern] = None,
rewriter: Optional[AnyType] = None,
name: Optional[str] = None,
optimizer_alias: Optional[str] = None,
patterns: Optional[List[Union[Pattern, Tuple[Pattern, AnyType]]]] = None,
):
"""
Initialize PatternRewritePass.

Args:
pattern: Single pattern to match (backward compatibility)
rewriter: Rewriter function for the single pattern
name: Pass name
optimizer_alias: Alias for node naming
patterns: List of Pattern objects or (Pattern, Rewriter) tuples.
If a list of Pattern objects is provided, 'rewriter' must be provided.
"""
# Use iterative mode - run until convergence
super().__init__(name, optimizer_alias, iterative=True, max_iterations=100)
self.pattern = pattern
self.rewriter = trace_transformation(rewriter)

self.rules: List[Tuple[Pattern, AnyType]] = []

if patterns:
# Wrap the common rewriter once if provided to avoid redundant logging
wrapped_common_rewriter = (
trace_transformation(rewriter) if rewriter else None
)

# Handle multiple patterns
for item in patterns:
if isinstance(item, tuple):
p, r = item
self.rules.append((p, trace_transformation(r)))
else:
if wrapped_common_rewriter is None:
raise ValueError(
"rewriter must be provided if 'patterns' contains Pattern objects"
)
self.rules.append((item, wrapped_common_rewriter))
elif pattern and rewriter:
# Handle single pattern (backward compatibility)
self.rules.append((pattern, trace_transformation(rewriter)))
else:
raise ValueError(
"Must provide either 'patterns' or both 'pattern' and 'rewriter'"
)

def transform_once(
self,
Expand All @@ -1534,9 +1579,10 @@ def transform_once(
Returns:
int: Number of changes made
"""
# Register the pattern (clear first to avoid duplicates)
# Register all patterns (clear first to avoid duplicates)
optimizer.clear_transformations()
optimizer.add_transformation(self.pattern, self.rewriter)
for pattern, rewriter in self.rules:
optimizer.add_transformation(pattern, rewriter)

# Run one pattern matching iteration
new_graph_def, changes = optimizer.match_patterns_once(
Expand Down
12 changes: 10 additions & 2 deletions transforms/scalar/algebraic_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,16 @@ class AlgebraicSimplifyPass(PatternRewritePass):

def __init__(self):
# We'll handle multiple patterns manually in _rewrite
pattern = Any(alias="op") # fallback, we check inside
super().__init__(pattern, self._rewrite, name="AlgebraicSimplify")
# By providing a list of specific ops, we enable O(1) indexed matching
# instead of O(N) wildcard matching.
supported_ops = [
"Add", "Sub", "Mul", "Div", "Neg", "LogicalNot", "Abs",
"Square", "Sqrt", "Pow", "Equal", "NotEqual", "Less",
"Greater", "LessEqual", "GreaterEqual", "LogicalAnd",
"LogicalOr", "Select", "Identity"
]
patterns = [Op(op, alias="op") for op in supported_ops]
super().__init__(patterns=patterns, rewriter=self._rewrite, name="AlgebraicSimplify")

def _rewrite(self, match, optimizer):
node = match.matched_nodes["op"]
Expand Down
15 changes: 12 additions & 3 deletions transforms/scalar/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,18 @@ class ConstantFoldPass(PatternRewritePass):
"""

def __init__(self):
# Matches any operation with all inputs as Const
pattern = Any(alias="op")
super().__init__(pattern, self._rewrite_constant_op, name="ConstantFold")
# By providing a list of specific ops, we enable O(1) indexed matching
# instead of O(N) wildcard matching.
supported_ops = [
"Add", "Mul", "Sub", "Div", "Neg", "Equal", "NotEqual", "Less",
"Greater", "LessEqual", "GreaterEqual", "LogicalAnd", "LogicalOr",
"LogicalNot", "BitwiseAnd", "BitwiseOr", "BitwiseXor", "Abs", "Exp",
"Expm1", "Log", "Log1p", "Sqrt", "Pow", "Rsqrt", "Square", "Sin",
"Cos", "Tan", "Asin", "Acos", "Atan", "Atan2", "Floor", "Ceil",
"Round", "Sign", "Reshape", "Transpose", "ConcatV2", "Select", "Cast"
]
patterns = [Op(op, alias="op") for op in supported_ops]
super().__init__(patterns=patterns, rewriter=self._rewrite_constant_op, name="ConstantFold")

def _is_all_const(self, inputs, optimizer):
"""Check if all inputs are Const nodes.
Expand Down