Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
a419b17
compiler: Start adding machinery to specialise operators with hardcod…
EdCaunt Dec 17, 2025
2fe3d38
tests: Start adding tests for operator specialization
EdCaunt Dec 18, 2025
a60a0de
tests: Introduce further tests
EdCaunt Dec 19, 2025
8b1c74d
tests: Add tests for specialising ConditionalDimension factors
EdCaunt Dec 23, 2025
d53af66
tests: Added test applying a specialized operator
EdCaunt Dec 23, 2025
c6885c1
misc: flake8
EdCaunt Dec 23, 2025
dd983b2
api: Start enabling specialization at operator apply
EdCaunt Dec 23, 2025
11fa94a
dsl: Tweak specialization at apply
EdCaunt Dec 30, 2025
af95137
tests: Add initial test for specialization at operator apply
EdCaunt Dec 30, 2025
47e88a8
compiler: Enhance logging of arguments and apply specialization test
EdCaunt Jan 6, 2026
ba37c25
compiler: Emit arguments used to invoke kernels and add test for spec…
EdCaunt Jan 6, 2026
8697a96
compiler: Add KernelLaunch handling to Specializer
EdCaunt Jan 9, 2026
b92ff5f
compiler: Make Specializer visit _func_table of an Operator
EdCaunt Jan 16, 2026
cc27a34
compiler: Refactor func table specialization to use a visitor
EdCaunt Jan 16, 2026
35dfdae
compiler: Update Specializer with handler for BlockGrid
EdCaunt Jan 19, 2026
289676d
tests: Start work on diffusion-like test
EdCaunt Jan 21, 2026
29a255e
API: Refactor operator specialization API
EdCaunt Feb 20, 2026
7afa536
tests: Expand specialization tests
EdCaunt Mar 4, 2026
14738d7
compiler: Fix stack corruption bug in specialization
EdCaunt Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from devito.finite_differences.differentiable import diff2sympy
from devito.ir.equations.algorithms import dimension_sort, lower_exprs
from devito.ir.support import (
GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses,
Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses,
detect_io
)
from devito.ir.support.guards import GuardFactorEq
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
Expand Down Expand Up @@ -221,11 +222,11 @@ def __new__(cls, *args, **kwargs):
if not d.is_Conditional:
continue
if d.condition is None:
conditionals[d] = GuardFactor(d)
conditionals[d] = GuardFactorEq.new_from_dim(d)
else:
cond = diff2sympy(lower_exprs(d.condition))
if d._factor is not None:
cond = d.relation(cond, GuardFactor(d))
cond = d.relation(cond, GuardFactorEq.new_from_dim(d))
conditionals[d] = cond
# Replace dimension with index
index = d.index
Expand Down
122 changes: 118 additions & 4 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
from typing import Any, Generic, TypeVar

import cgen as c
from sympy import IndexedBase
from sympy import IndexedBase, Number
from sympy.core.function import Application

from devito.exceptions import CompilationError
from devito.ir.iet.nodes import (
BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node,
Section
BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, MetaCall,
Node, Section
)
from devito.ir.support.space import Backward
from devito.symbolics import (
FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace
FieldFromComposite, FieldFromPointer, IndexedPointer, ListInitializer, uxreplace
)
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (
Expand All @@ -45,6 +45,7 @@
'MapExprStmts',
'MapHaloSpots',
'MapNodes',
'Specializer',
'Transformer',
'Uxreplace',
'printAST',
Expand Down Expand Up @@ -1498,6 +1499,119 @@ def visit_KernelLaunch(self, o):
arguments=arguments)


class Specializer(Uxreplace):
"""
A Transformer to "specialize" a pre-built Operator - that is to replace a given
set of (scalar) symbols with hard-coded values to free up registers. This will
yield a "specialized" version of the Operator, specific to a particular setup.

Note that the Operator is not re-optimized in response to this replacement - this
transformation could nominally result in expressions of the form `f + 0` in the
generated code. If one wants to construct an Operator where such expressions are
considered, then use of `subs=...` at construction time is a better choice. However,
it is likely that such expressions will be optimized away by the C-level compiler.
"""

def __init__(self, mapper, nested=False):
super().__init__(mapper, nested=nested)

# Sanity check
for k, v in self.mapper.items():
if not isinstance(k, (AbstractSymbol, IndexedPointer)):
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")

if not isinstance(v, Number):
raise ValueError("Only SymPy Numbers can used to replace values during "
f"specialization. Value {v} was supplied for symbol "
f"{k}, but is of type {type(v)}.")

def _visit(self, o, *args, **kwargs):
retval = super()._visit(o, *args, **kwargs)
# print(f"Visiting {o.__class__}")
# print(retval)
# print("--------------------------------------------")
return retval

# TODO: Should probably be moved to Uxreplace at least (as should some of these
# others I think?)
def visit_DifferentiableFunction(self, o):
return uxreplace(o, self.mapper)

def visit_Definition(self, o):
try:
function = self._visit(o.function)
return o._rebuild(function=function)
except KeyError:
return o

def visit_BlockGrid(self, o):
# TODO: Should probably be made into a uxreplace handler of some description
cargs = self._visit(o.cargs)
shape = self._visit(o.shape)
return o._rebuild(cargs=cargs, shape=shape)

def visit_OrderedDict(self, o):
return OrderedDict((k, self._visit(v)) for k, v in o.items())

def visit_MetaCall(self, o):
root = self._visit(o.root)
return MetaCall(root=root, local=o.local)

def visit_Callable(self, o):
body = self._visit(o.body)
parameters = [i for i in o.parameters if i not in self.mapper]
return o._rebuild(body=body, parameters=parameters)

def visit_KernelLaunch(self, o):
# Remove kernel args if they are to be hardcoded
arguments = [i for i in o.arguments if i not in self.mapper]
return o._rebuild(arguments=arguments)

def visit_Operator(self, o, **kwargs):
# Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this
# is the intended use case
body = self._visit(o.body)

# NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in
# the Operator parameters. Perhaps this check wants relaxing.
not_params = tuple(i for i in self.mapper
if i not in o.parameters and isinstance(i, AbstractSymbol))
if not_params:
raise ValueError(f"Attempted to specialize symbols {not_params} which are not"
" found in the Operator parameters")

# FIXME: Should also type-check the values supplied against the symbols they are
# replacing (and cast them if needed?) -> use a try-except on the cast in
# python-land

parameters = tuple(i for i in o.parameters if i not in self.mapper)

# Note: the following is not dissimilar to unpickling an Operator
state = o.__getstate__()
state['parameters'] = parameters
state['body'] = body
# Modify the _func_table to ensure callbacks are specialized
state['_func_table'] = self._visit(o._func_table)

state.pop('ccode', None)

# The specialized operator must be compiled fresh - strip any pre-existing
# compiled binary state inherited from a previously-applied operator.
# Without this, __setstate__ reloads the old binary (which expects the full
# parameter list), while the new operator has fewer parameters after
# specialization, causing a stack corruption (SIGABRT) at call time.
state.pop('binary', None)
state.pop('soname', None)
state.pop('_soname', None) # Clear cached soname so it is recomputed

newargs, newkwargs = o.__getnewargs_ex__()
newop = o.__class__(*newargs, **newkwargs)

newop.__setstate__(state)

return newop


# Utils

blankline = c.Line("")
Expand Down
24 changes: 10 additions & 14 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,35 +47,34 @@ def canonical(self):

@property
def negated(self):
return negations[self.__class__](*self._args_rebuild, evaluate=False)
try:
return negations[self.__class__](*self._args_rebuild, evaluate=False)
except KeyError:
raise ValueError(f"Class {self.__class__.__name__} does not have a negation")


# *** GuardFactor


class GuardFactor(Guard, CondEq, Pickable):
class GuardFactor(Guard, Pickable):

"""
A guard for factor-based ConditionalDimensions.

Given the ConditionalDimension `d` with factor `k`, create the
symbolic relational `d.parent % k == 0`.
Introduces a constructor where, given the ConditionalDimension `d` with factor `k`,
the symbolic relational `d.parent % k == 0` is created.
"""

__rargs__ = ('d',)
__rargs__ = ('lhs', 'rhs')

def __new__(cls, d, **kwargs):
@classmethod
def new_from_dim(cls, d, **kwargs):
assert d.is_Conditional

obj = super().__new__(cls, d.parent % d.symbolic_factor, 0)
obj.d = d

return obj

@property
def _args_rebuild(self):
return (self.d,)


class GuardFactorEq(GuardFactor, CondEq):
pass
Expand All @@ -85,9 +84,6 @@ class GuardFactorNe(GuardFactor, CondNe):
pass


GuardFactor = GuardFactorEq


# *** GuardBound


Expand Down
64 changes: 63 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from devito.ir.equations import LoweredEq, concretize_subdims, lower_exprs
from devito.ir.iet import (
Callable, CInterface, DeviceFunction, EntryFunction, FindSymbols, MetaCall,
derive_parameters, iet_build
Specializer, derive_parameters, iet_build
)
from devito.ir.stree import stree_build
from devito.ir.support import AccessMode, SymbolRegistry
Expand Down Expand Up @@ -924,6 +924,43 @@ def _enrich_memreport(self, args):
# Hook for enriching memory report with additional metadata
return {}

def specialize(self, **kwargs):
"""
"""

specialize = as_tuple(kwargs.pop('specialize', []))

if not specialize:
return self, kwargs

# FIXME: Cannot cope with things like sizes/strides yet since it only
# looks at the parameters

# Build the arguments list for specialization
with self._profiler.timer_on('specialization'):
args = self.arguments(**kwargs)
# Uses parameters here since Specializer needs {symbol: sympy value}
specialized_values = {p: sympify(args[p.name])
for p in self.parameters
if p.name in specialize}

op = Specializer(specialized_values).visit(self)

with switch_log_level(comm=args.comm):
self._emit_args_profiling('specialization')

unspecialized_kwargs = {k: v for k, v in kwargs.items()
if k not in specialize}

return op, unspecialized_kwargs

def apply_specialize(self, **kwargs):
"""
"""

op, unspecialized_kwargs = self.specialize(**kwargs)
return op.apply(**unspecialized_kwargs)

def apply(self, **kwargs):
"""
Execute the Operator.
Expand Down Expand Up @@ -986,6 +1023,7 @@ def apply(self, **kwargs):
>>> op = Operator(Eq(u3.forward, u3 + 1))
>>> summary = op.apply(time_M=10)
"""

# Compile the operator before building the arguments list
# to avoid out of memory with greedy compilers
cfunction = self.cfunction
Expand All @@ -996,6 +1034,8 @@ def apply(self, **kwargs):
with switch_log_level(comm=args.comm):
self._emit_args_profiling('arguments-preprocess')

self._emit_arguments(args)

# Invoke kernel function with args
arg_values = [args[p.name] for p in self.parameters]
try:
Expand Down Expand Up @@ -1030,6 +1070,28 @@ def _emit_args_profiling(self, tag=''):
tagstr = ' '.join(tag.split('-'))
debug(f"Operator `{self.name}` {tagstr}: {elapsed:.2f} s")

def _emit_arguments(self, args):
comm = args.comm
scalar_args = ", ".join([f"{p.name}={args[p.name]}"
for p in self.parameters
if p.is_Symbol])

rank = f"[rank{args.comm.Get_rank()}] " if comm is not MPI.COMM_NULL else ""

msg = f"* {rank}{scalar_args}"

with switch_log_level(comm=comm):
debug(f"Scalar arguments used to invoke `{self.name}`")

if comm is not MPI.COMM_NULL:
# With MPI enabled, we add one entry per rank
allmsg = comm.allgather(msg)
if comm.Get_rank() == 0:
for m in allmsg:
debug(m)
else:
debug(msg)

def _emit_build_profiling(self):
if not is_log_enabled_for('PERF'):
return
Expand Down
4 changes: 2 additions & 2 deletions devito/operator/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ def add(self, name, rank, time,
if not ops or any(not np.isfinite(i) for i in [ops, points, traffic]):
self[k] = PerfEntry(time, 0.0, 0.0, 0.0, 0, [])
else:
gflops = float(ops)/10**9
gpoints = float(points)/10**9
gflops = float(ops)/10e9
gpoints = float(points)/10e9
gflopss = gflops/time
gpointss = gpoints/time
oi = float(ops/traffic)
Expand Down
3 changes: 2 additions & 1 deletion devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class CondEq(sympy.Eq):
"""

def __new__(cls, *args, **kwargs):
return sympy.Eq.__new__(cls, *args, evaluate=False)
kwargs['evaluate'] = False
return sympy.Eq.__new__(cls, *args, **kwargs)

@property
def canonical(self):
Expand Down
4 changes: 2 additions & 2 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def _(mapper, rule):
@singledispatch
def _uxreplace_handle(expr, args, kwargs):
try:
return expr.func(*args, evaluate=False)
return expr.func(*args, evaluate=False, **kwargs)
except TypeError:
return expr.func(*args)
return expr.func(*args, **kwargs)


@_uxreplace_handle.register(Min)
Expand Down
5 changes: 5 additions & 0 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def symbolic_max(self):
"""Symbol defining the maximum point of the Dimension."""
return Scalar(name=self.max_name, dtype=np.int32, is_const=True)

@property
def symbolic_extrema(self):
"""Symbols for the minimum and maximum points of the Dimension"""
return (self.symbolic_min, self.symbolic_max)

@property
def symbolic_incr(self):
"""The increment value while iterating over the Dimension."""
Expand Down
5 changes: 3 additions & 2 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)
from devito.data import LEFT, OWNED
from devito.finite_differences.tools import centered, direct, left, right, transpose
from devito.ir import Backward, Forward, GuardBound, GuardBoundNext, GuardFactor
from devito.ir import Backward, Forward, GuardBound, GuardBoundNext
from devito.ir.support.guards import GuardFactorEq
from devito.mpi.halo_scheme import Halo
from devito.mpi.routines import (
MPIMsgEnriched, MPIRegion, MPIRequestObject, MPIStatusObject
Expand Down Expand Up @@ -503,7 +504,7 @@ def test_guard_factor(self, pickle):
d = Dimension(name='d')
cd = ConditionalDimension(name='cd', parent=d, factor=4)

gf = GuardFactor(cd)
gf = GuardFactorEq.new_from_dim(cd)

pkl_gf = pickle.dumps(gf)
new_gf = pickle.loads(pkl_gf)
Expand Down
Loading
Loading