Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
18906bd
compiler: Start adding machinery to specialise operators with hardcod…
EdCaunt Dec 17, 2025
a03766c
tests: Start adding tests for operator specialization
EdCaunt Dec 18, 2025
27a81e6
tests: Introduce further tests
EdCaunt Dec 19, 2025
9430b4c
tests: Add tests for specialising ConditionalDimension factors
EdCaunt Dec 23, 2025
a48ce35
tests: Added test applying a specialized operator
EdCaunt Dec 23, 2025
5e6e10d
misc: flake8
EdCaunt Dec 23, 2025
c75d4d0
api: Start enabling specialization at operator apply
EdCaunt Dec 23, 2025
85d6288
dsl: Tweak specialization at apply
EdCaunt Dec 30, 2025
c6a9389
tests: Add initial test for specialization at operator apply
EdCaunt Dec 30, 2025
863e0e6
compiler: Enhance logging of arguments and apply specialization test
EdCaunt Jan 6, 2026
0f4f10e
compiler: Emit arguments used to invoke kernels and add test for spec…
EdCaunt Jan 6, 2026
14d03b8
compiler: Add KernelLaunch handling to Specializer
EdCaunt Jan 9, 2026
259ce71
compiler: Make Specializer visit _func_table of an Operator
EdCaunt Jan 16, 2026
fd865d3
compiler: Refactor func table specialization to use a visitor
EdCaunt Jan 16, 2026
20fcb1c
compiler: Update Specializer with handler for BlockGrid
EdCaunt Jan 19, 2026
ca1ccf2
tests: Start work on diffusion-like test
EdCaunt Jan 21, 2026
f3ded23
API: Refactor operator specialization API
EdCaunt Feb 20, 2026
4d784a6
tests: Expand specialization tests
EdCaunt Mar 4, 2026
7a00263
compiler: Fix stack corruption bug in specialization
EdCaunt Mar 6, 2026
76c46e8
tests: Add test for acoustic wave equation
EdCaunt Mar 16, 2026
b071524
tests: Add elastic test for specialization
EdCaunt Mar 17, 2026
32d75cd
compiler: Add more checks prior to specialization
EdCaunt Mar 18, 2026
9956264
tests: Add to specialization tests
EdCaunt Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from devito.finite_differences.differentiable import diff2sympy
from devito.ir.equations.algorithms import dimension_sort, lower_exprs
from devito.ir.support import (
GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses,
Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses,
detect_io
)
from devito.ir.support.guards import GuardFactorEq
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
Expand Down Expand Up @@ -221,11 +222,11 @@ def __new__(cls, *args, **kwargs):
if not d.is_Conditional:
continue
if d.condition is None:
conditionals[d] = GuardFactor(d)
conditionals[d] = GuardFactorEq.new_from_dim(d)
else:
cond = diff2sympy(lower_exprs(d.condition))
if d._factor is not None:
cond = d.relation(cond, GuardFactor(d))
cond = d.relation(cond, GuardFactorEq.new_from_dim(d))
conditionals[d] = cond
# Replace dimension with index
index = d.index
Expand Down
125 changes: 121 additions & 4 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
from typing import Any, Generic, TypeVar

import cgen as c
from sympy import IndexedBase
from sympy import IndexedBase, Number
from sympy.core.function import Application

from devito.exceptions import CompilationError
from devito.ir.iet.nodes import (
BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node,
Section
BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, MetaCall,
Node, Section
)
from devito.ir.support.space import Backward
from devito.symbolics import (
FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace
FieldFromComposite, FieldFromPointer, IndexedPointer, ListInitializer, uxreplace
)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (
Expand All @@ -45,6 +45,7 @@
'MapExprStmts',
'MapHaloSpots',
'MapNodes',
'Specializer',
'Transformer',
'Uxreplace',
'printAST',
Expand Down Expand Up @@ -1498,6 +1499,122 @@ def visit_KernelLaunch(self, o):
arguments=arguments)


class Specializer(Uxreplace):
"""
A Transformer to "specialize" a pre-built Operator - that is to replace a given
set of (scalar) symbols with hard-coded values to free up registers. This will
yield a "specialized" version of the Operator, specific to a particular setup.

Note that the Operator is not re-optimized in response to this replacement - this
transformation could nominally result in expressions of the form `f + 0` in the
generated code. If one wants to construct an Operator where such expressions are
considered, then use of `subs=...` at construction time is a better choice. However,
it is likely that such expressions will be optimized away by the C-level compiler.
"""

def __init__(self, mapper, nested=False):
super().__init__(mapper, nested=nested)

# Sanity check
for k, v in self.mapper.items():
if not isinstance(k, (AbstractSymbol, IndexedPointer)):
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")

if not isinstance(v, Number):
raise ValueError("Only SymPy Numbers can used to replace values during "
f"specialization. Value {v} was supplied for symbol "
f"{k}, but is of type {type(v)}.")

try:
_ = k.dtype(v)
except ValueError as e:
raise ValueError(f"Value {v} is incompatible with {k.dtype} dtype "
"of {k}") from e

def _visit(self, o, *args, **kwargs):
retval = super()._visit(o, *args, **kwargs)
return retval

# TODO: Should probably be moved to Uxreplace at least (as should some of these
# others I think?)
def visit_DifferentiableFunction(self, o):
return uxreplace(o, self.mapper)

def visit_Definition(self, o):
try:
function = self._visit(o.function)
return o._rebuild(function=function)
except KeyError:
return o

def visit_BlockGrid(self, o):
# TODO: Should probably be made into a uxreplace handler of some description
cargs = self._visit(o.cargs)
shape = self._visit(o.shape)
return o._rebuild(cargs=cargs, shape=shape)

def visit_OrderedDict(self, o):
return OrderedDict((k, self._visit(v)) for k, v in o.items())

def visit_MetaCall(self, o):
root = self._visit(o.root)
return MetaCall(root=root, local=o.local)

def visit_Callable(self, o):
body = self._visit(o.body)
parameters = [i for i in o.parameters if i not in self.mapper]
return o._rebuild(body=body, parameters=parameters)

def visit_KernelLaunch(self, o):
# Remove kernel args if they are to be hardcoded
arguments = [i for i in o.arguments if i not in self.mapper]
return o._rebuild(arguments=arguments)

def visit_Operator(self, o, **kwargs):
# Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this
# is the intended use case
body = self._visit(o.body)

# NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in
# the Operator parameters. Perhaps this check wants relaxing.
not_params = tuple(i for i in self.mapper
if i not in o.parameters and isinstance(i, AbstractSymbol))
if not_params:
raise ValueError(f"Attempted to specialize symbols {not_params} which are not"
" found in the Operator parameters")

parameters = tuple(i for i in o.parameters if i not in self.mapper)

# Note: the following is not dissimilar to unpickling an Operator
state = o.__getstate__()
state['parameters'] = parameters
state['body'] = body
# Modify the _func_table to ensure callbacks are specialized
state['_func_table'] = self._visit(o._func_table)

state.pop('ccode', None)

# The specialized operator must be compiled fresh - strip any pre-existing
# compiled binary state inherited from a previously-applied operator.
# Without this, __setstate__ reloads the old binary (which expects the full
# parameter list), while the new operator has fewer parameters after
# specialization, causing a stack corruption (SIGABRT) at call time.
state.pop('binary', None)
state.pop('soname', None)
state.pop('_soname', None) # Clear cached soname so it is recomputed

# Tag operator to indicate that it's specialized
name = state['name']
state['name'] = f"{name}Specialized"

newargs, newkwargs = o.__getnewargs_ex__()
newop = o.__class__(*newargs, **newkwargs)

newop.__setstate__(state)

return newop


# Utils

blankline = c.Line("")
Expand Down
24 changes: 10 additions & 14 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,35 +47,34 @@ def canonical(self):

@property
def negated(self):
return negations[self.__class__](*self._args_rebuild, evaluate=False)
try:
return negations[self.__class__](*self._args_rebuild, evaluate=False)
except KeyError:
raise ValueError(f"Class {self.__class__.__name__} does not have a negation")


# *** GuardFactor


class GuardFactor(Guard, CondEq, Pickable):
class GuardFactor(Guard, Pickable):

"""
A guard for factor-based ConditionalDimensions.

Given the ConditionalDimension `d` with factor `k`, create the
symbolic relational `d.parent % k == 0`.
Introduces a constructor where, given the ConditionalDimension `d` with factor `k`,
the symbolic relational `d.parent % k == 0` is created.
"""

__rargs__ = ('d',)
__rargs__ = ('lhs', 'rhs')

def __new__(cls, d, **kwargs):
@classmethod
def new_from_dim(cls, d, **kwargs):
assert d.is_Conditional

obj = super().__new__(cls, d.parent % d.symbolic_factor, 0)
obj.d = d

return obj

@property
def _args_rebuild(self):
return (self.d,)


class GuardFactorEq(GuardFactor, CondEq):
pass
Expand All @@ -85,9 +84,6 @@ class GuardFactorNe(GuardFactor, CondNe):
pass


GuardFactor = GuardFactorEq


# *** GuardBound


Expand Down
64 changes: 63 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from devito.ir.equations import LoweredEq, concretize_subdims, lower_exprs
from devito.ir.iet import (
Callable, CInterface, DeviceFunction, EntryFunction, FindSymbols, MetaCall,
derive_parameters, iet_build
Specializer, derive_parameters, iet_build
)
from devito.ir.stree import stree_build
from devito.ir.support import AccessMode, SymbolRegistry
Expand Down Expand Up @@ -924,6 +924,43 @@ def _enrich_memreport(self, args):
# Hook for enriching memory report with additional metadata
return {}

def specialize(self, **kwargs):
"""
"""

specialize = as_tuple(kwargs.pop('specialize', []))

if not specialize:
return self, kwargs

# FIXME: Cannot cope with things like sizes/strides yet since it only
# looks at the parameters

# Build the arguments list for specialization
with self._profiler.timer_on('specialization'):
args = self.arguments(**kwargs)
# Uses parameters here since Specializer needs {symbol: sympy value}
specialized_values = {p: sympify(args[p.name])
for p in self.parameters
if p.name in specialize}

op = Specializer(specialized_values).visit(self)

with switch_log_level(comm=args.comm):
self._emit_args_profiling('specialization')

unspecialized_kwargs = {k: v for k, v in kwargs.items()
if k not in specialize}

return op, unspecialized_kwargs

def apply_specialize(self, **kwargs):
"""
"""

op, unspecialized_kwargs = self.specialize(**kwargs)
return op.apply(**unspecialized_kwargs)

def apply(self, **kwargs):
"""
Execute the Operator.
Expand Down Expand Up @@ -986,6 +1023,7 @@ def apply(self, **kwargs):
>>> op = Operator(Eq(u3.forward, u3 + 1))
>>> summary = op.apply(time_M=10)
"""

# Compile the operator before building the arguments list
# to avoid out of memory with greedy compilers
cfunction = self.cfunction
Expand All @@ -996,6 +1034,8 @@ def apply(self, **kwargs):
with switch_log_level(comm=args.comm):
self._emit_args_profiling('arguments-preprocess')

self._emit_arguments(args)

# Invoke kernel function with args
arg_values = [args[p.name] for p in self.parameters]
try:
Expand Down Expand Up @@ -1030,6 +1070,28 @@ def _emit_args_profiling(self, tag=''):
tagstr = ' '.join(tag.split('-'))
debug(f"Operator `{self.name}` {tagstr}: {elapsed:.2f} s")

def _emit_arguments(self, args):
comm = args.comm
scalar_args = ", ".join([f"{p.name}={args[p.name]}"
for p in self.parameters
if p.is_Symbol])

rank = f"[rank{args.comm.Get_rank()}] " if comm is not MPI.COMM_NULL else ""

msg = f"* {rank}{scalar_args}"

with switch_log_level(comm=comm):
debug(f"Scalar arguments used to invoke `{self.name}`")

if comm is not MPI.COMM_NULL:
# With MPI enabled, we add one entry per rank
allmsg = comm.allgather(msg)
if comm.Get_rank() == 0:
for m in allmsg:
debug(m)
else:
debug(msg)

def _emit_build_profiling(self):
if not is_log_enabled_for('PERF'):
return
Expand Down
4 changes: 2 additions & 2 deletions devito/operator/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ def add(self, name, rank, time,
if not ops or any(not np.isfinite(i) for i in [ops, points, traffic]):
self[k] = PerfEntry(time, 0.0, 0.0, 0.0, 0, [])
else:
gflops = float(ops)/10**9
gpoints = float(points)/10**9
gflops = float(ops)/10e9
gpoints = float(points)/10e9
gflopss = gflops/time
gpointss = gpoints/time
oi = float(ops/traffic)
Expand Down
3 changes: 2 additions & 1 deletion devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class CondEq(sympy.Eq):
"""

def __new__(cls, *args, **kwargs):
return sympy.Eq.__new__(cls, *args, evaluate=False)
kwargs['evaluate'] = False
return sympy.Eq.__new__(cls, *args, **kwargs)

@property
def canonical(self):
Expand Down
4 changes: 2 additions & 2 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def _(mapper, rule):
@singledispatch
def _uxreplace_handle(expr, args, kwargs):
try:
return expr.func(*args, evaluate=False)
return expr.func(*args, evaluate=False, **kwargs)
except TypeError:
return expr.func(*args)
return expr.func(*args, **kwargs)


@_uxreplace_handle.register(Min)
Expand Down
5 changes: 5 additions & 0 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def symbolic_max(self):
"""Symbol defining the maximum point of the Dimension."""
return Scalar(name=self.max_name, dtype=np.int32, is_const=True)

@property
def symbolic_extrema(self):
"""Symbols for the minimum and maximum points of the Dimension"""
return (self.symbolic_min, self.symbolic_max)

@property
def symbolic_incr(self):
"""The increment value while iterating over the Dimension."""
Expand Down
Loading
Loading