From a419b17308349f15cd8fe86f43834a1189f307a0 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 17 Dec 2025 18:04:06 +0000 Subject: [PATCH 01/19] compiler: Start adding machinery to specialise operators with hardcoded values --- devito/ir/iet/visitors.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index fb73d25004..3421698850 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1498,6 +1498,41 @@ def visit_KernelLaunch(self, o): arguments=arguments) +class Specializer(Uxreplace): + """ + A Transformer to "specialize" a pre-built Operator - that is to replace a given + set of (scalar) symbols with hard-coded values to free up registers. This will + yield a "specialized" version of the Operator, specific to a particular setup. + """ + + def __init__(self, mapper, nested=False): + super().__init__(mapper, nested=nested) + + # Sanity check + for k in self.mapper.keys(): + if not isinstance(k, AbstractSymbol): + raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") + + def visit_Operator(self, o, **kwargs): + # Entirely fine to apply this to an Operator + body = self._visit(o.body) + parameters = tuple(i for i in o.parameters if i not in self.mapper) + + # Note: the following is not dissimilar to unpickling an Operator + state = o.__getstate__() + state['parameters'] = parameters + state['body'] = body + state.pop('ccode') + + # FIXME: These names aren't great + newargs, newkwargs = o.__getnewargs_ex__() + newop = o.__class__(*newargs, **newkwargs) + + newop.__setstate__(state) + + return newop + + # Utils blankline = c.Line("") From 2fe3d3842b6eb888f1f61a0d06a2bb6a71c4f1b3 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 18 Dec 2025 16:39:36 +0000 Subject: [PATCH 02/19] tests: Start adding tests for operator specialization --- devito/ir/iet/visitors.py | 25 +++++- devito/types/dimension.py | 5 ++ tests/test_specialization.py | 146 +++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 tests/test_specialization.py diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 3421698850..caab738ce6 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1503,6 +1503,11 @@ class Specializer(Uxreplace): A Transformer to "specialize" a pre-built Operator - that is to replace a given set of (scalar) symbols with hard-coded values to free up registers. This will yield a "specialized" version of the Operator, specific to a particular setup. + + Note that the Operator is not re-optimized in response to this replacement - this + transformation could nominally result in expressions of the form `f + 0` in the + generated code. If one wants to construct an Operator where such expressions are + considered, then use of `subs=...` is a better choice. """ def __init__(self, mapper, nested=False): @@ -1514,15 +1519,31 @@ def __init__(self, mapper, nested=False): raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") def visit_Operator(self, o, **kwargs): - # Entirely fine to apply this to an Operator + # Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this + # is the intended use case body = self._visit(o.body) + + not_params = tuple(i for i in self.mapper if i not in o.parameters) + if not_params: + raise ValueError(f"Attempted to specialize symbols {not_params} which are not" + " found in the Operator parameters") + + # FIXME: Should also type-check the values supplied against the symbols they are + # replacing (and cast them if needed?) -> use a try-except on the cast in + # python-land + parameters = tuple(i for i in o.parameters if i not in self.mapper) # Note: the following is not dissimilar to unpickling an Operator state = o.__getstate__() state['parameters'] = parameters state['body'] = body - state.pop('ccode') + + try: + state.pop('ccode') + except KeyError: + # C code has not previously been generated for this Operator + pass # FIXME: These names aren't great newargs, newkwargs = o.__getnewargs_ex__() diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 4010472bef..325b07ec55 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -188,6 +188,11 @@ def symbolic_max(self): """Symbol defining the maximum point of the Dimension.""" return Scalar(name=self.max_name, dtype=np.int32, is_const=True) + @property + def symbolic_extrema(self): + """Symbols for the minimum and maximum points of the Dimension""" + return (self.symbolic_min, self.symbolic_max) + @property def symbolic_incr(self): """The increment value while iterating over the Dimension.""" diff --git a/tests/test_specialization.py b/tests/test_specialization.py new file mode 100644 index 0000000000..0962ce33bb --- /dev/null +++ b/tests/test_specialization.py @@ -0,0 +1,146 @@ +import sympy +import pytest + +from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension, + ConditionalDimension) +from devito.ir.iet.visitors import Specializer + +# Test that specializer replaces symbols as expected + +# Create a couple of arbitrary operators +# Reference bounds, subdomains, spacings, constants, conditionaldimensions with symbolic +# factor +# Create a couple of different substitution sets + +# Check that all the instances in the kernel are replaced +# Check that all the instances in the parameters are removed + +# Check that sanity check catches attempts to specialize non-scalar types +# Check that trying to specialize symbols not in the Operator parameters results +# in an error being thrown + +# Check that sizes and strides get specialized when using `linearize=True` + + +class TestSpecializer: + """Tests for the Specializer transformer""" + + @pytest.mark.parametrize('pre_gen', [True, False]) + @pytest.mark.parametrize('expand', [True, False]) + def test_bounds(self, pre_gen, expand): + """Test specialization of dimension bounds""" + grid = Grid(shape=(11, 11)) + + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + time_m = grid.time_dim.symbolic_min + minima = (x_m, y_m, time_m) + maxima = (x_M, y_M) + + def check_op(mapper, operator): + for k, v in mapper.items(): + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + # Check that the loop bounds are modified correctly + if k in minima: + assert f"{k.name.split('_')[0]} = {v}" in str(operator.ccode) + elif k in maxima: + assert f"{k.name.split('_')[0]} <= {v}" in str(operator.ccode) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = TimeFunction(name='h', grid=grid) + + eq0 = Eq(f, f + 1) + eq1 = Eq(g, f.dx) + eq2 = Eq(h.forward, (g + x_m).dy) + eq3 = Eq(f, x_M) + + # Check behaviour with expansion since we have a replaced symbol inside a + # derivative + if expand: + kwargs = {'opt': ('advanced', {'expand': True})} + else: + kwargs = {'opt': ('advanced', {'expand': False})} + + op = Operator([eq0, eq1, eq2, eq3], **kwargs) + + if pre_gen: + # Generate C code for the unspecialized Operator - the result should be + # the same regardless, but it ensures that the old generated code is + # purged and replaced in the specialized Operator + _ = op.ccode + + mapper0 = {x_m: sympy.S.Zero} + mapper1 = {x_M: sympy.Integer(20), y_m: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + mapper3 = {y_M: sympy.Integer(10), time_m: sympy.Integer(5)} + + mappers = (mapper0, mapper1, mapper2, mapper3) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops): + check_op(m, o) + + def test_subdomain(self): + """Test that SubDomain thicknesses can be specialized""" + + def check_op(mapper, operator): + for k in mapper.keys(): + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + + class SD(SubDomain): + name = 'sd' + + def define(self, dimensions): + x, y = dimensions + return {x: ('middle', 1, 1), y: ('right', 2)} + + grid = Grid(shape=(11, 11)) + sd = SD(grid=grid) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=sd) + + eqs = [Eq(f, f+1, subdomain=sd), + Eq(g, g+1, subdomain=sd)] + + op = Operator(eqs) + + subdims = [d for d in op.dimensions if d.is_Sub] + ((xltkn, xrtkn), (_, yrtkn)) = [d.thickness for d in subdims] + + mapper0 = {xltkn: sympy.S.Zero} + mapper1 = {xrtkn: sympy.Integer(2), yrtkn: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + + mappers = (mapper0, mapper1, mapper2) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops): + check_op(m, o) + + # FIXME: Currently throws an error + # def test_factor(self): + # """Test that ConditionalDimensions can have their symbolic factors specialized""" + # size = 16 + # factor = 4 + # i = Dimension(name='i') + # ci = ConditionalDimension(name='ci', parent=i, factor=factor) + + # g = Function(name='g', shape=(size,), dimensions=(i,)) + # f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) + + # op0 = Operator([Eq(f, g)]) + + # mapper = {ci.symbolic_factor: sympy.Integer(factor)} + + # op1 = Specializer(mapper).visit(op0) + + # assert ci.symbolic_factor not in op1.parameters + # assert ci.symbolic_factor.name not in str(op1.ccode) + # assert "if ((i)%(4) == 0)" in str(op1.ccode) + + # Spacings + + # Strides/sizes From a60a0de3f95da1d0ac678fe0e46b466880077113 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 19 Dec 2025 11:36:38 +0000 Subject: [PATCH 03/19] tests: Introduce further tests --- devito/ir/iet/visitors.py | 11 +++++++++-- tests/test_specialization.py | 27 +++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index caab738ce6..32a7ba911c 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -11,7 +11,7 @@ from typing import Any, Generic, TypeVar import cgen as c -from sympy import IndexedBase +from sympy import IndexedBase, Number from sympy.core.function import Application from devito.exceptions import CompilationError @@ -1514,10 +1514,17 @@ def __init__(self, mapper, nested=False): super().__init__(mapper, nested=nested) # Sanity check - for k in self.mapper.keys(): + for k, v in self.mapper.items(): + # FIXME: Erronously blocks f_vec->size[1] + # Apparently this is an IndexedPointer if not isinstance(k, AbstractSymbol): raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") + if not isinstance(v, Number): + raise ValueError("Only SymPy Numbers can used to replace values during " + f"specialization. Value {v} was supplied for symbol " + f"{k}, but is of type {type(v)}.") + def visit_Operator(self, o, **kwargs): # Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this # is the intended use case diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 0962ce33bb..c8e8f54c59 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -120,7 +120,8 @@ def define(self, dimensions): for m, o in zip(mappers, ops): check_op(m, o) - # FIXME: Currently throws an error + # FIXME: Currently throws an error - probably a missing handler for GuardFactor + # in Uxreplace # def test_factor(self): # """Test that ConditionalDimensions can have their symbolic factors specialized""" # size = 16 @@ -141,6 +142,28 @@ def define(self, dimensions): # assert ci.symbolic_factor.name not in str(op1.ccode) # assert "if ((i)%(4) == 0)" in str(op1.ccode) - # Spacings + def test_spacing(self): + """Test that grid spacings can be specialized""" + grid = Grid(shape=(11,)) + f = Function(name='f', grid=grid) + + op0 = Operator(Eq(f, f.dx)) + + mapper = {grid.dimensions[0].spacing: sympy.Float(grid.spacing[0])} + op1 = Specializer(mapper).visit(op0) + + assert grid.dimensions[0].spacing not in op1.parameters + assert grid.dimensions[0].spacing.name not in str(op1.ccode) + assert "/1.0e-1F;" in str(op1.ccode) # Strides/sizes + def test_strides(self): + """Test that strides and sizes generated for linearization can be specialized""" + grid = Grid(shape=(11, 11)) + + f = TimeFunction(name='f', grid=grid, space_order=2) + + op0 = Operator(Eq(f.forward, f.dx2), + opt=('advanced', {'expand': True, 'linearize': True})) + + from IPython import embed; embed() \ No newline at end of file From 8b1c74d41d6fc4bff8c0bffb7836290cd80de5fb Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Dec 2025 12:28:47 +0000 Subject: [PATCH 04/19] tests: Add tests for specialising ConditionalDimension factors --- devito/ir/equations/equation.py | 7 +++-- devito/ir/iet/visitors.py | 11 ++++---- devito/ir/support/guards.py | 24 +++++++--------- devito/symbolics/extended_sympy.py | 3 +- devito/symbolics/manipulation.py | 4 +-- tests/test_pickle.py | 5 ++-- tests/test_specialization.py | 45 +++++++++++++++++------------- 7 files changed, 52 insertions(+), 47 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 29589ed8f9..79413b6291 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -6,9 +6,10 @@ from devito.finite_differences.differentiable import diff2sympy from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.ir.support import ( - GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses, + Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses, detect_io ) +from devito.ir.support.guards import GuardFactorEq from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min @@ -221,11 +222,11 @@ def __new__(cls, *args, **kwargs): if not d.is_Conditional: continue if d.condition is None: - conditionals[d] = GuardFactor(d) + conditionals[d] = GuardFactorEq.new_from_dim(d) else: cond = diff2sympy(lower_exprs(d.condition)) if d._factor is not None: - cond = d.relation(cond, GuardFactor(d)) + cond = d.relation(cond, GuardFactorEq.new_from_dim(d)) conditionals[d] = cond # Replace dimension with index index = d.index diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 32a7ba911c..2afe85ef42 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -21,7 +21,7 @@ ) from devito.ir.support.space import Backward from devito.symbolics import ( - FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace + FieldFromComposite, FieldFromPointer, IndexedPointer, ListInitializer, uxreplace ) from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import ( @@ -1515,9 +1515,7 @@ def __init__(self, mapper, nested=False): # Sanity check for k, v in self.mapper.items(): - # FIXME: Erronously blocks f_vec->size[1] - # Apparently this is an IndexedPointer - if not isinstance(k, AbstractSymbol): + if not isinstance(k, (AbstractSymbol, IndexedPointer)): raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") if not isinstance(v, Number): @@ -1530,7 +1528,10 @@ def visit_Operator(self, o, **kwargs): # is the intended use case body = self._visit(o.body) - not_params = tuple(i for i in self.mapper if i not in o.parameters) + # NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in + # the Operator parameters. Perhaps this check wants relaxing. + not_params = tuple(i for i in self.mapper + if i not in o.parameters and isinstance(i, AbstractSymbol)) if not_params: raise ValueError(f"Attempted to specialize symbols {not_params} which are not" " found in the Operator parameters") diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index b8a335b1f4..9610f274f8 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -47,35 +47,34 @@ def canonical(self): @property def negated(self): - return negations[self.__class__](*self._args_rebuild, evaluate=False) + try: + return negations[self.__class__](*self._args_rebuild, evaluate=False) + except KeyError: + raise ValueError(f"Class {self.__class__.__name__} does not have a negation") # *** GuardFactor -class GuardFactor(Guard, CondEq, Pickable): +class GuardFactor(Guard, Pickable): """ A guard for factor-based ConditionalDimensions. - Given the ConditionalDimension `d` with factor `k`, create the - symbolic relational `d.parent % k == 0`. + Introduces a constructor where, given the ConditionalDimension `d` with factor `k`, + the symbolic relational `d.parent % k == 0` is created. """ - __rargs__ = ('d',) + __rargs__ = ('lhs', 'rhs') - def __new__(cls, d, **kwargs): + @classmethod + def new_from_dim(cls, d, **kwargs): assert d.is_Conditional obj = super().__new__(cls, d.parent % d.symbolic_factor, 0) - obj.d = d return obj - @property - def _args_rebuild(self): - return (self.d,) - class GuardFactorEq(GuardFactor, CondEq): pass @@ -85,9 +84,6 @@ class GuardFactorNe(GuardFactor, CondNe): pass -GuardFactor = GuardFactorEq - - # *** GuardBound diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 7df8f430fc..1780757fc3 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -36,7 +36,8 @@ class CondEq(sympy.Eq): """ def __new__(cls, *args, **kwargs): - return sympy.Eq.__new__(cls, *args, evaluate=False) + kwargs['evaluate'] = False + return sympy.Eq.__new__(cls, *args, **kwargs) @property def canonical(self): diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 57d9314e16..f11ffc8b3c 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -152,9 +152,9 @@ def _(mapper, rule): @singledispatch def _uxreplace_handle(expr, args, kwargs): try: - return expr.func(*args, evaluate=False) + return expr.func(*args, evaluate=False, **kwargs) except TypeError: - return expr.func(*args) + return expr.func(*args, **kwargs) @_uxreplace_handle.register(Min) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 407e0433ff..e159873e79 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -13,7 +13,8 @@ ) from devito.data import LEFT, OWNED from devito.finite_differences.tools import centered, direct, left, right, transpose -from devito.ir import Backward, Forward, GuardBound, GuardBoundNext, GuardFactor +from devito.ir import Backward, Forward, GuardBound, GuardBoundNext +from devito.ir.support.guards import GuardFactorEq from devito.mpi.halo_scheme import Halo from devito.mpi.routines import ( MPIMsgEnriched, MPIRegion, MPIRequestObject, MPIStatusObject @@ -503,7 +504,7 @@ def test_guard_factor(self, pickle): d = Dimension(name='d') cd = ConditionalDimension(name='cd', parent=d, factor=4) - gf = GuardFactor(cd) + gf = GuardFactorEq.new_from_dim(cd) pkl_gf = pickle.dumps(gf) new_gf = pickle.loads(pkl_gf) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index c8e8f54c59..5350d7675f 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -120,27 +120,25 @@ def define(self, dimensions): for m, o in zip(mappers, ops): check_op(m, o) - # FIXME: Currently throws an error - probably a missing handler for GuardFactor - # in Uxreplace - # def test_factor(self): - # """Test that ConditionalDimensions can have their symbolic factors specialized""" - # size = 16 - # factor = 4 - # i = Dimension(name='i') - # ci = ConditionalDimension(name='ci', parent=i, factor=factor) + def test_factor(self): + """Test that ConditionalDimensions can have their symbolic factors specialized""" + size = 16 + factor = 4 + i = Dimension(name='i') + ci = ConditionalDimension(name='ci', parent=i, factor=factor) - # g = Function(name='g', shape=(size,), dimensions=(i,)) - # f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) + g = Function(name='g', shape=(size,), dimensions=(i,)) + f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) - # op0 = Operator([Eq(f, g)]) + op0 = Operator([Eq(f, g)]) - # mapper = {ci.symbolic_factor: sympy.Integer(factor)} + mapper = {ci.symbolic_factor: sympy.Integer(factor)} - # op1 = Specializer(mapper).visit(op0) + op1 = Specializer(mapper).visit(op0) - # assert ci.symbolic_factor not in op1.parameters - # assert ci.symbolic_factor.name not in str(op1.ccode) - # assert "if ((i)%(4) == 0)" in str(op1.ccode) + assert ci.symbolic_factor not in op1.parameters + assert ci.symbolic_factor.name not in str(op1.ccode) + assert "if ((i)%(4) == 0)" in str(op1.ccode) def test_spacing(self): """Test that grid spacings can be specialized""" @@ -156,9 +154,8 @@ def test_spacing(self): assert grid.dimensions[0].spacing.name not in str(op1.ccode) assert "/1.0e-1F;" in str(op1.ccode) - # Strides/sizes - def test_strides(self): - """Test that strides and sizes generated for linearization can be specialized""" + def test_sizes(self): + """Test that strides generated for linearization can be specialized""" grid = Grid(shape=(11, 11)) f = TimeFunction(name='f', grid=grid, space_order=2) @@ -166,4 +163,12 @@ def test_strides(self): op0 = Operator(Eq(f.forward, f.dx2), opt=('advanced', {'expand': True, 'linearize': True})) - from IPython import embed; embed() \ No newline at end of file + mapper = {f.symbolic_shape[1]: sympy.Integer(11), + f.symbolic_shape[2]: sympy.Integer(11)} + + op1 = Specializer(mapper).visit(op0) + + assert "const int x_fsz0 = 11;" in str(op1.ccode) + assert "const int y_fsz0 = 11;" in str(op1.ccode) + + # TODO: Should strides get linearized? If so, how? From d53af668339e6b49975e3f1e1f11e2ef105d6d83 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Dec 2025 16:57:03 +0000 Subject: [PATCH 05/19] tests: Added test applying a specialized operator --- tests/test_specialization.py | 44 +++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 5350d7675f..40d3248cea 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -1,6 +1,8 @@ import sympy import pytest +import numpy as np + from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension, ConditionalDimension) from devito.ir.iet.visitors import Specializer @@ -171,4 +173,44 @@ def test_sizes(self): assert "const int x_fsz0 = 11;" in str(op1.ccode) assert "const int y_fsz0 = 11;" in str(op1.ccode) - # TODO: Should strides get linearized? If so, how? + # TODO: Should strides get specialized? If so, how? + + def test_apply_basic(self): + """ + Test that a trivial operator runs and returns the same results when + specialized. + """ + grid = Grid(shape=(11, 11)) + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + f = Function(name='f', grid=grid, dtype=np.int32) + + op0 = Operator(Eq(f, f+1)) + + mapper = {x_m: sympy.Integer(2), x_M: sympy.Integer(7), + y_m: sympy.Integer(3), y_M: sympy.Integer(8)} + + op1 = Specializer(mapper).visit(op0) + + assert op1.cfunction is not op0.cfunction + + op1.apply() + + check = np.array(f.data[:]) + f.data[:] = 0 + + op0.apply(x_m=2, x_M=7, y_m=3, y_M=8) + + assert np.all(check == f.data) + + +# class TestApply: +# """Tests for specialization of operators at apply time""" + +# def test_basic(self): +# grid = Grid(shape=(11, 11)) + +# f = TimeFunction(name='f', grid=grid, space_order=2) + +# op = Operator(Eq(f.forward, f + 1)) + +# op.apply(time_M=10, specialize=('x_m', 'x_M')) \ No newline at end of file From c6885c1a6743d781b9941bf0e674cdb646401fdf Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Dec 2025 16:59:20 +0000 Subject: [PATCH 06/19] misc: flake8 --- tests/test_specialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 40d3248cea..759b6cc2db 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -213,4 +213,4 @@ def test_apply_basic(self): # op = Operator(Eq(f.forward, f + 1)) -# op.apply(time_M=10, specialize=('x_m', 'x_M')) \ No newline at end of file +# op.apply(time_M=10, specialize=('x_m', 'x_M')) From dd983b231aead9e647014ab7be85902fcb321360 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Dec 2025 16:32:23 +0000 Subject: [PATCH 07/19] api: Start enabling specialization at operator apply --- devito/ir/iet/visitors.py | 1 + devito/operator/operator.py | 26 ++++++++++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 2afe85ef42..450846e53b 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -45,6 +45,7 @@ 'MapExprStmts', 'MapHaloSpots', 'MapNodes', + 'Specializer', 'Transformer', 'Uxreplace', 'printAST', diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 779106d75a..d401a7aba6 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -22,7 +22,7 @@ from devito.ir.equations import LoweredEq, concretize_subdims, lower_exprs from devito.ir.iet import ( Callable, CInterface, DeviceFunction, EntryFunction, FindSymbols, MetaCall, - derive_parameters, iet_build + Specializer, derive_parameters, iet_build ) from devito.ir.stree import stree_build from devito.ir.support import AccessMode, SymbolRegistry @@ -986,9 +986,13 @@ def apply(self, **kwargs): >>> op = Operator(Eq(u3.forward, u3 + 1)) >>> summary = op.apply(time_M=10) """ - # Compile the operator before building the arguments list - # to avoid out of memory with greedy compilers - cfunction = self.cfunction + # Get items expected to be specialized + specialize = as_tuple(kwargs.pop('specialize', [])) + + if not specialize: + # Compile the operator before building the arguments list + # to avoid out of memory with greedy compilers + cfunction = self.cfunction # Build the arguments list to invoke the kernel function with self._profiler.timer_on('arguments-preprocess'): @@ -996,6 +1000,20 @@ def apply(self, **kwargs): with switch_log_level(comm=args.comm): self._emit_args_profiling('arguments-preprocess') + # In the case of specialization, arguments must be processed before + # the operator can be compiled + if specialize: + specialized_args = {p: sympify(args.pop(p.name)) + for p in self.parameters if p.name in specialize} + + op = Specializer(specialized_args).visit(self) + else: + op = self + + from IPython import embed; embed() + + # TODO: Whose profiler should get used here? + # Invoke kernel function with args arg_values = [args[p.name] for p in self.parameters] try: From 11fa94a5b77a1eb485ca7a2260215879dd05a4ca Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 30 Dec 2025 16:24:57 +0000 Subject: [PATCH 08/19] dsl: Tweak specialization at apply --- devito/operator/operator.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index d401a7aba6..d78274bbb5 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1003,16 +1003,18 @@ def apply(self, **kwargs): # In the case of specialization, arguments must be processed before # the operator can be compiled if specialize: + # FIXME: Cannot cope with things like sizes/strides yet since it only + # looks at the parameters specialized_args = {p: sympify(args.pop(p.name)) for p in self.parameters if p.name in specialize} op = Specializer(specialized_args).visit(self) - else: - op = self - from IPython import embed; embed() + specialized_kwargs = {k: v for k, v in kwargs.items() + if k not in specialize} - # TODO: Whose profiler should get used here? + # TODO: Does this cause problems for profilers? + return op.apply(**specialized_kwargs) # Invoke kernel function with args arg_values = [args[p.name] for p in self.parameters] From af95137bc22a8d34e63f32459a6f038037c43263 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 30 Dec 2025 16:51:38 +0000 Subject: [PATCH 09/19] tests: Add initial test for specialization at operator apply --- devito/operator/operator.py | 2 ++ tests/test_specialization.py | 23 ++++++++++++++++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index d78274bbb5..db2675553c 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1014,6 +1014,8 @@ def apply(self, **kwargs): if k not in specialize} # TODO: Does this cause problems for profilers? + # FIXME: Need some way to inspect this Operator for testing + # FIXME: Perhaps this should use some separate method return op.apply(**specialized_kwargs) # Invoke kernel function with args diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 759b6cc2db..9c48010d1a 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -203,14 +203,23 @@ def test_apply_basic(self): assert np.all(check == f.data) -# class TestApply: -# """Tests for specialization of operators at apply time""" +class TestApply: + """Tests for specialization of operators at apply time""" -# def test_basic(self): -# grid = Grid(shape=(11, 11)) + def test_basic(self): + grid = Grid(shape=(11, 11)) + f = Function(name='f', grid=grid, dtype=np.int32) + op = Operator(Eq(f, f+1)) + + # TODO: Need to verify that specialized operator is actually the one + # being run. How can I achieve this? + op.apply(specialize=('x_m', 'x_M')) -# f = TimeFunction(name='f', grid=grid, space_order=2) + check = np.array(f.data[:]) + f.data[:] = 0 + op.apply() -# op = Operator(Eq(f.forward, f + 1)) + assert np.all(check == f.data) -# op.apply(time_M=10, specialize=('x_m', 'x_M')) +# Need to test combining specialization and overrides (a range of them) +# Need to test specialization with MPI (both at) From 47e88a8ed8d0fa84b77f32e3f473e024dfa14c97 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 6 Jan 2026 14:23:14 +0000 Subject: [PATCH 10/19] compiler: Enhance logging of arguments and apply specialization test --- devito/operator/operator.py | 45 ++++++++++++++++++++++-------------- devito/operator/profiling.py | 4 ++-- tests/test_specialization.py | 28 ++++++++++++++++------ 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index db2675553c..feccfb08bc 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -989,34 +989,45 @@ def apply(self, **kwargs): # Get items expected to be specialized specialize = as_tuple(kwargs.pop('specialize', [])) - if not specialize: - # Compile the operator before building the arguments list - # to avoid out of memory with greedy compilers - cfunction = self.cfunction - - # Build the arguments list to invoke the kernel function - with self._profiler.timer_on('arguments-preprocess'): - args = self.arguments(**kwargs) - with switch_log_level(comm=args.comm): - self._emit_args_profiling('arguments-preprocess') - # In the case of specialization, arguments must be processed before # the operator can be compiled if specialize: # FIXME: Cannot cope with things like sizes/strides yet since it only # looks at the parameters - specialized_args = {p: sympify(args.pop(p.name)) - for p in self.parameters if p.name in specialize} - op = Specializer(specialized_args).visit(self) + # Build the arguments list for specialization + with self._profiler.timer_on('specialized-arguments-preprocess'): + args = self.arguments(**kwargs) + with switch_log_level(comm=args.comm): + self._emit_args_profiling('specialized-arguments-preprocess') - specialized_kwargs = {k: v for k, v in kwargs.items() - if k not in specialize} + # Uses parameters here since Specializer needs {symbol: sympy value} mapper + specialized_values = {p: sympify(args[p.name]) + for p in self.parameters if p.name in specialize} + + op = Specializer(specialized_values).visit(self) # TODO: Does this cause problems for profilers? # FIXME: Need some way to inspect this Operator for testing # FIXME: Perhaps this should use some separate method - return op.apply(**specialized_kwargs) + unspecialized_kwargs = {k: v for k, v in kwargs.items() + if k not in specialize} + + return op.apply(**unspecialized_kwargs) + + # Compile the operator before building the arguments list + # to avoid out of memory with greedy compilers + cfunction = self.cfunction + + # Build the arguments list to invoke the kernel function + with self._profiler.timer_on('arguments-preprocess'): + args = self.arguments(**kwargs) + with switch_log_level(comm=args.comm): + self._emit_args_profiling('arguments-preprocess') + + args_string = ", ".join([f"{p.name}={args[p.name]}" + for p in self.parameters if p.is_Symbol]) + debug(f"Invoking `{self.name}` with scalar arguments: {args_string}") # Invoke kernel function with args arg_values = [args[p.name] for p in self.parameters] diff --git a/devito/operator/profiling.py b/devito/operator/profiling.py index 6a82928277..6082e16d2b 100644 --- a/devito/operator/profiling.py +++ b/devito/operator/profiling.py @@ -426,8 +426,8 @@ def add(self, name, rank, time, if not ops or any(not np.isfinite(i) for i in [ops, points, traffic]): self[k] = PerfEntry(time, 0.0, 0.0, 0.0, 0, []) else: - gflops = float(ops)/10**9 - gpoints = float(points)/10**9 + gflops = float(ops)/10e9 + gpoints = float(points)/10e9 gflopss = gflops/time gpointss = gpoints/time oi = float(ops/traffic) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 9c48010d1a..cc4f0c900f 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -1,10 +1,12 @@ +import logging + import sympy import pytest import numpy as np from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension, - ConditionalDimension) + ConditionalDimension, switchconfig) from devito.ir.iet.visitors import Specializer # Test that specializer replaces symbols as expected @@ -202,24 +204,36 @@ def test_apply_basic(self): assert np.all(check == f.data) + # TODO: Need a test to check that block sizes can be specialized + # TODO: Need to test that tile sizes can be specialized + class TestApply: """Tests for specialization of operators at apply time""" - def test_basic(self): + @pytest.mark.parametrize('override', [False, True]) + def test_basic(self, caplog, override): grid = Grid(shape=(11, 11)) f = Function(name='f', grid=grid, dtype=np.int32) op = Operator(Eq(f, f+1)) - # TODO: Need to verify that specialized operator is actually the one - # being run. How can I achieve this? - op.apply(specialize=('x_m', 'x_M')) + specialize = ('x_m', 'x_M') + + kwargs = {} + if override: + kwargs['x_m'] = 3 + + with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG): + op.apply(specialize=specialize, **kwargs) + + # Ensure that the specialized operator was run + assert all(s not in caplog.text for s in specialize) + assert "specialized arguments preprocess" in caplog.text check = np.array(f.data[:]) f.data[:] = 0 - op.apply() + op.apply(**kwargs) assert np.all(check == f.data) -# Need to test combining specialization and overrides (a range of them) # Need to test specialization with MPI (both at) From ba37c2592a7b321be57efecb45c8b7bad355ee7f Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 6 Jan 2026 16:26:31 +0000 Subject: [PATCH 11/19] compiler: Emit arguments used to invoke kernels and add test for specialization with MPI --- devito/operator/operator.py | 45 ++++++++++++++++++++++++------------ tests/test_specialization.py | 9 +++++--- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index feccfb08bc..d64ee9f9d9 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -989,27 +989,22 @@ def apply(self, **kwargs): # Get items expected to be specialized specialize = as_tuple(kwargs.pop('specialize', [])) - # In the case of specialization, arguments must be processed before - # the operator can be compiled if specialize: # FIXME: Cannot cope with things like sizes/strides yet since it only # looks at the parameters # Build the arguments list for specialization - with self._profiler.timer_on('specialized-arguments-preprocess'): + with self._profiler.timer_on('specialization'): args = self.arguments(**kwargs) - with switch_log_level(comm=args.comm): - self._emit_args_profiling('specialized-arguments-preprocess') + # Uses parameters here since Specializer needs {symbol: sympy value} + specialized_values = {p: sympify(args[p.name]) + for p in self.parameters if p.name in specialize} - # Uses parameters here since Specializer needs {symbol: sympy value} mapper - specialized_values = {p: sympify(args[p.name]) - for p in self.parameters if p.name in specialize} + op = Specializer(specialized_values).visit(self) - op = Specializer(specialized_values).visit(self) + with switch_log_level(comm=args.comm): + self._emit_args_profiling('specialization') - # TODO: Does this cause problems for profilers? - # FIXME: Need some way to inspect this Operator for testing - # FIXME: Perhaps this should use some separate method unspecialized_kwargs = {k: v for k, v in kwargs.items() if k not in specialize} @@ -1025,9 +1020,7 @@ def apply(self, **kwargs): with switch_log_level(comm=args.comm): self._emit_args_profiling('arguments-preprocess') - args_string = ", ".join([f"{p.name}={args[p.name]}" - for p in self.parameters if p.is_Symbol]) - debug(f"Invoking `{self.name}` with scalar arguments: {args_string}") + self._emit_arguments(args) # Invoke kernel function with args arg_values = [args[p.name] for p in self.parameters] @@ -1063,6 +1056,28 @@ def _emit_args_profiling(self, tag=''): tagstr = ' '.join(tag.split('-')) debug(f"Operator `{self.name}` {tagstr}: {elapsed:.2f} s") + def _emit_arguments(self, args): + comm = args.comm + scalar_args = ", ".join([f"{p.name}={args[p.name]}" + for p in self.parameters + if p.is_Symbol]) + + rank = f"[rank{args.comm.Get_rank()}] " if comm is not MPI.COMM_NULL else "" + + msg = f"* {rank}{scalar_args}" + + with switch_log_level(comm=comm): + debug(f"Scalar arguments used to invoke `{self.name}`") + + if comm is not MPI.COMM_NULL: + # With MPI enabled, we add one entry per rank + allmsg = comm.allgather(msg) + if comm.Get_rank() == 0: + for m in allmsg: + debug(m) + else: + debug(msg) + def _emit_build_profiling(self): if not is_log_enabled_for('PERF'): return diff --git a/tests/test_specialization.py b/tests/test_specialization.py index cc4f0c900f..6684093c79 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -228,12 +228,15 @@ def test_basic(self, caplog, override): # Ensure that the specialized operator was run assert all(s not in caplog.text for s in specialize) - assert "specialized arguments preprocess" in caplog.text + assert "specialization" in caplog.text check = np.array(f.data[:]) f.data[:] = 0 op.apply(**kwargs) - assert np.all(check == f.data) + assert np.all(check == f.data[:]) -# Need to test specialization with MPI (both at) + @pytest.mark.parallel(mode=[2, 4]) + @pytest.mark.parametrize('override', [False, True]) + def test_basic_mpi(self, caplog, mode, override): + self.test_basic(caplog, override) From 8697a962f411be28df819e3e088959361ecfd525 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 9 Jan 2026 16:17:07 +0000 Subject: [PATCH 12/19] compiler: Add KernelLaunch handling to Specializer --- devito/ir/iet/visitors.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 450846e53b..7b11705a3b 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1523,6 +1523,11 @@ def __init__(self, mapper, nested=False): raise ValueError("Only SymPy Numbers can used to replace values during " f"specialization. Value {v} was supplied for symbol " f"{k}, but is of type {type(v)}.") + + def visit_KernelLaunch(self, o): + # Remove kernel args if they are to be hardcoded + arguments = [i for i in o.arguments if i not in self.mapper] + return o._rebuild(arguments=arguments) def visit_Operator(self, o, **kwargs): # Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this From b92ff5f522a4d75b6e0de20a09f1da6e8e4f0b3f Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 16 Jan 2026 11:35:59 +0000 Subject: [PATCH 13/19] compiler: Make Specializer visit _func_table of an Operator --- devito/ir/iet/visitors.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 7b11705a3b..4670679a3a 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -17,7 +17,7 @@ from devito.exceptions import CompilationError from devito.ir.iet.nodes import ( BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, - Section + MetaCall, Section ) from devito.ir.support.space import Backward from devito.symbolics import ( @@ -1508,7 +1508,8 @@ class Specializer(Uxreplace): Note that the Operator is not re-optimized in response to this replacement - this transformation could nominally result in expressions of the form `f + 0` in the generated code. If one wants to construct an Operator where such expressions are - considered, then use of `subs=...` is a better choice. + considered, then use of `subs=...` at construction time is a better choice. However, + it is likely that such expressions will be optimized away by the C-level compiler. """ def __init__(self, mapper, nested=False): @@ -1523,7 +1524,7 @@ def __init__(self, mapper, nested=False): raise ValueError("Only SymPy Numbers can used to replace values during " f"specialization. Value {v} was supplied for symbol " f"{k}, but is of type {type(v)}.") - + def visit_KernelLaunch(self, o): # Remove kernel args if they are to be hardcoded arguments = [i for i in o.arguments if i not in self.mapper] @@ -1553,6 +1554,23 @@ def visit_Operator(self, o, **kwargs): state['parameters'] = parameters state['body'] = body + # TODO: Also rebuild the _func_table for the Operator + # TODO: This is somewhat incongruent with the visitor and should be refactored + + func_table = OrderedDict() + for k, v in o._func_table.items(): + root = v.root + local = v.local + + body = self._visit(root.body) + parameters = tuple(i for i in root.parameters if i not in self.mapper) + + new_root = root._rebuild(body=body, parameters=parameters) + + func_table[k] = MetaCall(root=new_root, local=local) + + state['_func_table'] = func_table + try: state.pop('ccode') except KeyError: From cc27a34632e4a65e7abbbad536d020845d0e21da Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 16 Jan 2026 11:54:38 +0000 Subject: [PATCH 14/19] compiler: Refactor func table specialization to use a visitor --- devito/ir/iet/visitors.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 4670679a3a..28171c57a1 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1525,6 +1525,18 @@ def __init__(self, mapper, nested=False): f"specialization. Value {v} was supplied for symbol " f"{k}, but is of type {type(v)}.") + def visit_OrderedDict(self, o): + return OrderedDict((k, self._visit(v)) for k, v in o.items()) + + def visit_MetaCall(self, o): + root = self._visit(o.root) + return MetaCall(root=root, local=o.local) + + def visit_Callable(self, o): + body = self._visit(o.body) + parameters = [i for i in o.parameters if i not in self.mapper] + return o._rebuild(body=body, parameters=parameters) + def visit_KernelLaunch(self, o): # Remove kernel args if they are to be hardcoded arguments = [i for i in o.arguments if i not in self.mapper] @@ -1553,23 +1565,8 @@ def visit_Operator(self, o, **kwargs): state = o.__getstate__() state['parameters'] = parameters state['body'] = body - - # TODO: Also rebuild the _func_table for the Operator - # TODO: This is somewhat incongruent with the visitor and should be refactored - - func_table = OrderedDict() - for k, v in o._func_table.items(): - root = v.root - local = v.local - - body = self._visit(root.body) - parameters = tuple(i for i in root.parameters if i not in self.mapper) - - new_root = root._rebuild(body=body, parameters=parameters) - - func_table[k] = MetaCall(root=new_root, local=local) - - state['_func_table'] = func_table + # Modify the _func_table to ensure callbacks are specialized + state['_func_table'] = self._visit(o._func_table) try: state.pop('ccode') From 35dfdaeefb44646d9ccaef269dea2cc4ef62047c Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 19 Jan 2026 13:46:03 +0000 Subject: [PATCH 15/19] compiler: Update Specializer with handler for BlockGrid --- devito/ir/iet/visitors.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 28171c57a1..5966a7fc05 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1525,6 +1525,31 @@ def __init__(self, mapper, nested=False): f"specialization. Value {v} was supplied for symbol " f"{k}, but is of type {type(v)}.") + def _visit(self, o, *args, **kwargs): + retval = super()._visit(o, *args, **kwargs) + # print(f"Visiting {o.__class__}") + # print(retval) + # print("--------------------------------------------") + return retval + + # TODO: Should probably be moved to Uxreplace at least (as should some of these + # others I think?) + def visit_DifferentiableFunction(self, o): + return uxreplace(o, self.mapper) + + def visit_Definition(self, o): + try: + function = self._visit(o.function) + return o._rebuild(function=function) + except KeyError: + return o + + def visit_BlockGrid(self, o): + # TODO: Should probably be made into a uxreplace handler of some description + cargs = self._visit(o.cargs) + shape = self._visit(o.shape) + return o._rebuild(cargs=cargs, shape=shape) + def visit_OrderedDict(self, o): return OrderedDict((k, self._visit(v)) for k, v in o.items()) From 289676d8aa25d0e2888fa06a4940152f05d41c41 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 21 Jan 2026 13:36:20 +0000 Subject: [PATCH 16/19] tests: Start work on diffusion-like test --- tests/test_specialization.py | 42 ++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 6684093c79..9074e18405 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -9,22 +9,6 @@ ConditionalDimension, switchconfig) from devito.ir.iet.visitors import Specializer -# Test that specializer replaces symbols as expected - -# Create a couple of arbitrary operators -# Reference bounds, subdomains, spacings, constants, conditionaldimensions with symbolic -# factor -# Create a couple of different substitution sets - -# Check that all the instances in the kernel are replaced -# Check that all the instances in the parameters are removed - -# Check that sanity check catches attempts to specialize non-scalar types -# Check that trying to specialize symbols not in the Operator parameters results -# in an error being thrown - -# Check that sizes and strides get specialized when using `linearize=True` - class TestSpecializer: """Tests for the Specializer transformer""" @@ -240,3 +224,29 @@ def test_basic(self, caplog, override): @pytest.mark.parametrize('override', [False, True]) def test_basic_mpi(self, caplog, mode, override): self.test_basic(caplog, override) + + def test_diffusion_like(self): + grid = Grid(shape=(11, 11)) + + dt = 2.5e-5 + + f = TimeFunction(name='f', grid=grid, space_order=4) + f.data[:, 4:-4, 4:-4] = 1 + + op = Operator(Eq(f.forward, f - grid.time_dim.spacing*f.laplace)) + + op.apply(t_M=100, dt=dt) + + check = np.array(f.data[0]) + f.data[:] = 0 + f.data[:, 4:-4, 4:-4] = 1 + + op.apply(t_M=100, dt=dt, specialize=tuple()) + + print(f.data[0]) + print(check) + assert False + + # Diffusion-like test + # Acoustic-like test (with and without source injection) + # Elastic-like test (with and without source injection) From 29a255ebd5e543ff71352588c43c93bec2df12b3 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 20 Feb 2026 16:27:40 +0000 Subject: [PATCH 17/19] API: Refactor operator specialization API --- devito/operator/operator.py | 60 +++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index d64ee9f9d9..2ddb1ff1c2 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -924,6 +924,43 @@ def _enrich_memreport(self, args): # Hook for enriching memory report with additional metadata return {} + def specialize(self, **kwargs): + """ + """ + + specialize = as_tuple(kwargs.pop('specialize', [])) + + if not specialize: + return self, kwargs + + # FIXME: Cannot cope with things like sizes/strides yet since it only + # looks at the parameters + + # Build the arguments list for specialization + with self._profiler.timer_on('specialization'): + args = self.arguments(**kwargs) + # Uses parameters here since Specializer needs {symbol: sympy value} + specialized_values = {p: sympify(args[p.name]) + for p in self.parameters + if p.name in specialize} + + op = Specializer(specialized_values).visit(self) + + with switch_log_level(comm=args.comm): + self._emit_args_profiling('specialization') + + unspecialized_kwargs = {k: v for k, v in kwargs.items() + if k not in specialize} + + return op, unspecialized_kwargs + + def apply_specialize(self, **kwargs): + """ + """ + + op, unspecialized_kwargs = self.specialize(**kwargs) + return op.apply(**unspecialized_kwargs) + def apply(self, **kwargs): """ Execute the Operator. @@ -986,29 +1023,6 @@ def apply(self, **kwargs): >>> op = Operator(Eq(u3.forward, u3 + 1)) >>> summary = op.apply(time_M=10) """ - # Get items expected to be specialized - specialize = as_tuple(kwargs.pop('specialize', [])) - - if specialize: - # FIXME: Cannot cope with things like sizes/strides yet since it only - # looks at the parameters - - # Build the arguments list for specialization - with self._profiler.timer_on('specialization'): - args = self.arguments(**kwargs) - # Uses parameters here since Specializer needs {symbol: sympy value} - specialized_values = {p: sympify(args[p.name]) - for p in self.parameters if p.name in specialize} - - op = Specializer(specialized_values).visit(self) - - with switch_log_level(comm=args.comm): - self._emit_args_profiling('specialization') - - unspecialized_kwargs = {k: v for k, v in kwargs.items() - if k not in specialize} - - return op.apply(**unspecialized_kwargs) # Compile the operator before building the arguments list # to avoid out of memory with greedy compilers From 7afa5367698c7d1b9c8881919e01306f992b7ac2 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 4 Mar 2026 16:08:32 +0000 Subject: [PATCH 18/19] tests: Expand specialization tests --- tests/test_specialization.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 9074e18405..0988987b60 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -1,12 +1,13 @@ import logging -import sympy -import pytest - import numpy as np +import pytest +import sympy -from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension, - ConditionalDimension, switchconfig) +from devito import ( + ConditionalDimension, Dimension, Eq, Function, Grid, Operator, SubDomain, + TimeFunction, switchconfig +) from devito.ir.iet.visitors import Specializer @@ -66,14 +67,14 @@ def check_op(mapper, operator): mappers = (mapper0, mapper1, mapper2, mapper3) ops = tuple(Specializer(m).visit(op) for m in mappers) - for m, o in zip(mappers, ops): + for m, o in zip(mappers, ops, strict=True): check_op(m, o) def test_subdomain(self): """Test that SubDomain thicknesses can be specialized""" def check_op(mapper, operator): - for k in mapper.keys(): + for k in mapper: assert k not in operator.parameters assert k.name not in str(operator.ccode) @@ -105,7 +106,7 @@ def define(self, dimensions): mappers = (mapper0, mapper1, mapper2) ops = tuple(Specializer(m).visit(op) for m in mappers) - for m, o in zip(mappers, ops): + for m, o in zip(mappers, ops, strict=True): check_op(m, o) def test_factor(self): @@ -208,7 +209,7 @@ def test_basic(self, caplog, override): kwargs['x_m'] = 3 with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG): - op.apply(specialize=specialize, **kwargs) + op.apply_specialize(specialize=specialize, **kwargs) # Ensure that the specialized operator was run assert all(s not in caplog.text for s in specialize) @@ -225,7 +226,14 @@ def test_basic(self, caplog, override): def test_basic_mpi(self, caplog, mode, override): self.test_basic(caplog, override) - def test_diffusion_like(self): + @pytest.mark.parametrize('specialize', + [('x_m',), + ('y_M',), + ('t_m',), + ('t_m', 't_M'), + ('x_m', 'y_M'), + ('x_m', 'x_M', 'y_m', 'y_M')]) + def test_diffusion_like(self, specialize): grid = Grid(shape=(11, 11)) dt = 2.5e-5 @@ -237,15 +245,13 @@ def test_diffusion_like(self): op.apply(t_M=100, dt=dt) - check = np.array(f.data[0]) + check = np.array(f.data) f.data[:] = 0 f.data[:, 4:-4, 4:-4] = 1 - op.apply(t_M=100, dt=dt, specialize=tuple()) + op.apply_specialize(t_M=100, dt=dt, specialize=specialize) - print(f.data[0]) - print(check) - assert False + assert np.all(np.isclose(check, f.data)) # Diffusion-like test # Acoustic-like test (with and without source injection) From 14738d7a5c3069008ae1a8cb28bd6eb6ef30dd39 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 6 Mar 2026 11:53:21 +0000 Subject: [PATCH 19/19] compiler: Fix stack corruption bug in specialization --- devito/ir/iet/visitors.py | 20 ++++++++++++-------- tests/test_specialization.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 5966a7fc05..fcac32658e 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -16,8 +16,8 @@ from devito.exceptions import CompilationError from devito.ir.iet.nodes import ( - BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, - MetaCall, Section + BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, MetaCall, + Node, Section ) from devito.ir.support.space import Backward from devito.symbolics import ( @@ -1593,13 +1593,17 @@ def visit_Operator(self, o, **kwargs): # Modify the _func_table to ensure callbacks are specialized state['_func_table'] = self._visit(o._func_table) - try: - state.pop('ccode') - except KeyError: - # C code has not previously been generated for this Operator - pass + state.pop('ccode', None) + + # The specialized operator must be compiled fresh - strip any pre-existing + # compiled binary state inherited from a previously-applied operator. + # Without this, __setstate__ reloads the old binary (which expects the full + # parameter list), while the new operator has fewer parameters after + # specialization, causing a stack corruption (SIGABRT) at call time. + state.pop('binary', None) + state.pop('soname', None) + state.pop('_soname', None) # Clear cached soname so it is recomputed - # FIXME: These names aren't great newargs, newkwargs = o.__getnewargs_ex__() newop = o.__class__(*newargs, **newkwargs) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 0988987b60..0332fa9031 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -229,8 +229,8 @@ def test_basic_mpi(self, caplog, mode, override): @pytest.mark.parametrize('specialize', [('x_m',), ('y_M',), - ('t_m',), - ('t_m', 't_M'), + ('time_m',), + ('time_m', 'time_M'), ('x_m', 'y_M'), ('x_m', 'x_M', 'y_m', 'y_M')]) def test_diffusion_like(self, specialize):