diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 29589ed8f9..79413b6291 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -6,9 +6,10 @@ from devito.finite_differences.differentiable import diff2sympy from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.ir.support import ( - GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses, + Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses, detect_io ) +from devito.ir.support.guards import GuardFactorEq from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min @@ -221,11 +222,11 @@ def __new__(cls, *args, **kwargs): if not d.is_Conditional: continue if d.condition is None: - conditionals[d] = GuardFactor(d) + conditionals[d] = GuardFactorEq.new_from_dim(d) else: cond = diff2sympy(lower_exprs(d.condition)) if d._factor is not None: - cond = d.relation(cond, GuardFactor(d)) + cond = d.relation(cond, GuardFactorEq.new_from_dim(d)) conditionals[d] = cond # Replace dimension with index index = d.index diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index fb73d25004..fcac32658e 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -11,17 +11,17 @@ from typing import Any, Generic, TypeVar import cgen as c -from sympy import IndexedBase +from sympy import IndexedBase, Number from sympy.core.function import Application from devito.exceptions import CompilationError from devito.ir.iet.nodes import ( - BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, - Section + BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, MetaCall, + Node, Section ) from devito.ir.support.space import Backward from devito.symbolics import ( - FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace + FieldFromComposite, FieldFromPointer, IndexedPointer, ListInitializer, uxreplace ) from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import ( @@ -45,6 +45,7 @@ 'MapExprStmts', 'MapHaloSpots', 'MapNodes', + 'Specializer', 'Transformer', 'Uxreplace', 'printAST', @@ -1498,6 +1499,119 @@ def visit_KernelLaunch(self, o): arguments=arguments) +class Specializer(Uxreplace): + """ + A Transformer to "specialize" a pre-built Operator - that is to replace a given + set of (scalar) symbols with hard-coded values to free up registers. This will + yield a "specialized" version of the Operator, specific to a particular setup. + + Note that the Operator is not re-optimized in response to this replacement - this + transformation could nominally result in expressions of the form `f + 0` in the + generated code. If one wants to construct an Operator where such expressions are + considered, then use of `subs=...` at construction time is a better choice. However, + it is likely that such expressions will be optimized away by the C-level compiler. + """ + + def __init__(self, mapper, nested=False): + super().__init__(mapper, nested=nested) + + # Sanity check + for k, v in self.mapper.items(): + if not isinstance(k, (AbstractSymbol, IndexedPointer)): + raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") + + if not isinstance(v, Number): + raise ValueError("Only SymPy Numbers can used to replace values during " + f"specialization. Value {v} was supplied for symbol " + f"{k}, but is of type {type(v)}.") + + def _visit(self, o, *args, **kwargs): + retval = super()._visit(o, *args, **kwargs) + # print(f"Visiting {o.__class__}") + # print(retval) + # print("--------------------------------------------") + return retval + + # TODO: Should probably be moved to Uxreplace at least (as should some of these + # others I think?) + def visit_DifferentiableFunction(self, o): + return uxreplace(o, self.mapper) + + def visit_Definition(self, o): + try: + function = self._visit(o.function) + return o._rebuild(function=function) + except KeyError: + return o + + def visit_BlockGrid(self, o): + # TODO: Should probably be made into a uxreplace handler of some description + cargs = self._visit(o.cargs) + shape = self._visit(o.shape) + return o._rebuild(cargs=cargs, shape=shape) + + def visit_OrderedDict(self, o): + return OrderedDict((k, self._visit(v)) for k, v in o.items()) + + def visit_MetaCall(self, o): + root = self._visit(o.root) + return MetaCall(root=root, local=o.local) + + def visit_Callable(self, o): + body = self._visit(o.body) + parameters = [i for i in o.parameters if i not in self.mapper] + return o._rebuild(body=body, parameters=parameters) + + def visit_KernelLaunch(self, o): + # Remove kernel args if they are to be hardcoded + arguments = [i for i in o.arguments if i not in self.mapper] + return o._rebuild(arguments=arguments) + + def visit_Operator(self, o, **kwargs): + # Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this + # is the intended use case + body = self._visit(o.body) + + # NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in + # the Operator parameters. Perhaps this check wants relaxing. + not_params = tuple(i for i in self.mapper + if i not in o.parameters and isinstance(i, AbstractSymbol)) + if not_params: + raise ValueError(f"Attempted to specialize symbols {not_params} which are not" + " found in the Operator parameters") + + # FIXME: Should also type-check the values supplied against the symbols they are + # replacing (and cast them if needed?) -> use a try-except on the cast in + # python-land + + parameters = tuple(i for i in o.parameters if i not in self.mapper) + + # Note: the following is not dissimilar to unpickling an Operator + state = o.__getstate__() + state['parameters'] = parameters + state['body'] = body + # Modify the _func_table to ensure callbacks are specialized + state['_func_table'] = self._visit(o._func_table) + + state.pop('ccode', None) + + # The specialized operator must be compiled fresh - strip any pre-existing + # compiled binary state inherited from a previously-applied operator. + # Without this, __setstate__ reloads the old binary (which expects the full + # parameter list), while the new operator has fewer parameters after + # specialization, causing a stack corruption (SIGABRT) at call time. + state.pop('binary', None) + state.pop('soname', None) + state.pop('_soname', None) # Clear cached soname so it is recomputed + + newargs, newkwargs = o.__getnewargs_ex__() + newop = o.__class__(*newargs, **newkwargs) + + newop.__setstate__(state) + + return newop + + # Utils blankline = c.Line("") diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index b8a335b1f4..9610f274f8 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -47,35 +47,34 @@ def canonical(self): @property def negated(self): - return negations[self.__class__](*self._args_rebuild, evaluate=False) + try: + return negations[self.__class__](*self._args_rebuild, evaluate=False) + except KeyError: + raise ValueError(f"Class {self.__class__.__name__} does not have a negation") # *** GuardFactor -class GuardFactor(Guard, CondEq, Pickable): +class GuardFactor(Guard, Pickable): """ A guard for factor-based ConditionalDimensions. - Given the ConditionalDimension `d` with factor `k`, create the - symbolic relational `d.parent % k == 0`. + Introduces a constructor where, given the ConditionalDimension `d` with factor `k`, + the symbolic relational `d.parent % k == 0` is created. """ - __rargs__ = ('d',) + __rargs__ = ('lhs', 'rhs') - def __new__(cls, d, **kwargs): + @classmethod + def new_from_dim(cls, d, **kwargs): assert d.is_Conditional obj = super().__new__(cls, d.parent % d.symbolic_factor, 0) - obj.d = d return obj - @property - def _args_rebuild(self): - return (self.d,) - class GuardFactorEq(GuardFactor, CondEq): pass @@ -85,9 +84,6 @@ class GuardFactorNe(GuardFactor, CondNe): pass -GuardFactor = GuardFactorEq - - # *** GuardBound diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 779106d75a..2ddb1ff1c2 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -22,7 +22,7 @@ from devito.ir.equations import LoweredEq, concretize_subdims, lower_exprs from devito.ir.iet import ( Callable, CInterface, DeviceFunction, EntryFunction, FindSymbols, MetaCall, - derive_parameters, iet_build + Specializer, derive_parameters, iet_build ) from devito.ir.stree import stree_build from devito.ir.support import AccessMode, SymbolRegistry @@ -924,6 +924,43 @@ def _enrich_memreport(self, args): # Hook for enriching memory report with additional metadata return {} + def specialize(self, **kwargs): + """ + """ + + specialize = as_tuple(kwargs.pop('specialize', [])) + + if not specialize: + return self, kwargs + + # FIXME: Cannot cope with things like sizes/strides yet since it only + # looks at the parameters + + # Build the arguments list for specialization + with self._profiler.timer_on('specialization'): + args = self.arguments(**kwargs) + # Uses parameters here since Specializer needs {symbol: sympy value} + specialized_values = {p: sympify(args[p.name]) + for p in self.parameters + if p.name in specialize} + + op = Specializer(specialized_values).visit(self) + + with switch_log_level(comm=args.comm): + self._emit_args_profiling('specialization') + + unspecialized_kwargs = {k: v for k, v in kwargs.items() + if k not in specialize} + + return op, unspecialized_kwargs + + def apply_specialize(self, **kwargs): + """ + """ + + op, unspecialized_kwargs = self.specialize(**kwargs) + return op.apply(**unspecialized_kwargs) + def apply(self, **kwargs): """ Execute the Operator. @@ -986,6 +1023,7 @@ def apply(self, **kwargs): >>> op = Operator(Eq(u3.forward, u3 + 1)) >>> summary = op.apply(time_M=10) """ + # Compile the operator before building the arguments list # to avoid out of memory with greedy compilers cfunction = self.cfunction @@ -996,6 +1034,8 @@ def apply(self, **kwargs): with switch_log_level(comm=args.comm): self._emit_args_profiling('arguments-preprocess') + self._emit_arguments(args) + # Invoke kernel function with args arg_values = [args[p.name] for p in self.parameters] try: @@ -1030,6 +1070,28 @@ def _emit_args_profiling(self, tag=''): tagstr = ' '.join(tag.split('-')) debug(f"Operator `{self.name}` {tagstr}: {elapsed:.2f} s") + def _emit_arguments(self, args): + comm = args.comm + scalar_args = ", ".join([f"{p.name}={args[p.name]}" + for p in self.parameters + if p.is_Symbol]) + + rank = f"[rank{args.comm.Get_rank()}] " if comm is not MPI.COMM_NULL else "" + + msg = f"* {rank}{scalar_args}" + + with switch_log_level(comm=comm): + debug(f"Scalar arguments used to invoke `{self.name}`") + + if comm is not MPI.COMM_NULL: + # With MPI enabled, we add one entry per rank + allmsg = comm.allgather(msg) + if comm.Get_rank() == 0: + for m in allmsg: + debug(m) + else: + debug(msg) + def _emit_build_profiling(self): if not is_log_enabled_for('PERF'): return diff --git a/devito/operator/profiling.py b/devito/operator/profiling.py index 6a82928277..6082e16d2b 100644 --- a/devito/operator/profiling.py +++ b/devito/operator/profiling.py @@ -426,8 +426,8 @@ def add(self, name, rank, time, if not ops or any(not np.isfinite(i) for i in [ops, points, traffic]): self[k] = PerfEntry(time, 0.0, 0.0, 0.0, 0, []) else: - gflops = float(ops)/10**9 - gpoints = float(points)/10**9 + gflops = float(ops)/10e9 + gpoints = float(points)/10e9 gflopss = gflops/time gpointss = gpoints/time oi = float(ops/traffic) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 7df8f430fc..1780757fc3 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -36,7 +36,8 @@ class CondEq(sympy.Eq): """ def __new__(cls, *args, **kwargs): - return sympy.Eq.__new__(cls, *args, evaluate=False) + kwargs['evaluate'] = False + return sympy.Eq.__new__(cls, *args, **kwargs) @property def canonical(self): diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 57d9314e16..f11ffc8b3c 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -152,9 +152,9 @@ def _(mapper, rule): @singledispatch def _uxreplace_handle(expr, args, kwargs): try: - return expr.func(*args, evaluate=False) + return expr.func(*args, evaluate=False, **kwargs) except TypeError: - return expr.func(*args) + return expr.func(*args, **kwargs) @_uxreplace_handle.register(Min) diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 4010472bef..325b07ec55 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -188,6 +188,11 @@ def symbolic_max(self): """Symbol defining the maximum point of the Dimension.""" return Scalar(name=self.max_name, dtype=np.int32, is_const=True) + @property + def symbolic_extrema(self): + """Symbols for the minimum and maximum points of the Dimension""" + return (self.symbolic_min, self.symbolic_max) + @property def symbolic_incr(self): """The increment value while iterating over the Dimension.""" diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 407e0433ff..e159873e79 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -13,7 +13,8 @@ ) from devito.data import LEFT, OWNED from devito.finite_differences.tools import centered, direct, left, right, transpose -from devito.ir import Backward, Forward, GuardBound, GuardBoundNext, GuardFactor +from devito.ir import Backward, Forward, GuardBound, GuardBoundNext +from devito.ir.support.guards import GuardFactorEq from devito.mpi.halo_scheme import Halo from devito.mpi.routines import ( MPIMsgEnriched, MPIRegion, MPIRequestObject, MPIStatusObject @@ -503,7 +504,7 @@ def test_guard_factor(self, pickle): d = Dimension(name='d') cd = ConditionalDimension(name='cd', parent=d, factor=4) - gf = GuardFactor(cd) + gf = GuardFactorEq.new_from_dim(cd) pkl_gf = pickle.dumps(gf) new_gf = pickle.loads(pkl_gf) diff --git a/tests/test_specialization.py b/tests/test_specialization.py new file mode 100644 index 0000000000..0332fa9031 --- /dev/null +++ b/tests/test_specialization.py @@ -0,0 +1,258 @@ +import logging + +import numpy as np +import pytest +import sympy + +from devito import ( + ConditionalDimension, Dimension, Eq, Function, Grid, Operator, SubDomain, + TimeFunction, switchconfig +) +from devito.ir.iet.visitors import Specializer + + +class TestSpecializer: + """Tests for the Specializer transformer""" + + @pytest.mark.parametrize('pre_gen', [True, False]) + @pytest.mark.parametrize('expand', [True, False]) + def test_bounds(self, pre_gen, expand): + """Test specialization of dimension bounds""" + grid = Grid(shape=(11, 11)) + + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + time_m = grid.time_dim.symbolic_min + minima = (x_m, y_m, time_m) + maxima = (x_M, y_M) + + def check_op(mapper, operator): + for k, v in mapper.items(): + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + # Check that the loop bounds are modified correctly + if k in minima: + assert f"{k.name.split('_')[0]} = {v}" in str(operator.ccode) + elif k in maxima: + assert f"{k.name.split('_')[0]} <= {v}" in str(operator.ccode) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = TimeFunction(name='h', grid=grid) + + eq0 = Eq(f, f + 1) + eq1 = Eq(g, f.dx) + eq2 = Eq(h.forward, (g + x_m).dy) + eq3 = Eq(f, x_M) + + # Check behaviour with expansion since we have a replaced symbol inside a + # derivative + if expand: + kwargs = {'opt': ('advanced', {'expand': True})} + else: + kwargs = {'opt': ('advanced', {'expand': False})} + + op = Operator([eq0, eq1, eq2, eq3], **kwargs) + + if pre_gen: + # Generate C code for the unspecialized Operator - the result should be + # the same regardless, but it ensures that the old generated code is + # purged and replaced in the specialized Operator + _ = op.ccode + + mapper0 = {x_m: sympy.S.Zero} + mapper1 = {x_M: sympy.Integer(20), y_m: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + mapper3 = {y_M: sympy.Integer(10), time_m: sympy.Integer(5)} + + mappers = (mapper0, mapper1, mapper2, mapper3) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops, strict=True): + check_op(m, o) + + def test_subdomain(self): + """Test that SubDomain thicknesses can be specialized""" + + def check_op(mapper, operator): + for k in mapper: + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + + class SD(SubDomain): + name = 'sd' + + def define(self, dimensions): + x, y = dimensions + return {x: ('middle', 1, 1), y: ('right', 2)} + + grid = Grid(shape=(11, 11)) + sd = SD(grid=grid) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=sd) + + eqs = [Eq(f, f+1, subdomain=sd), + Eq(g, g+1, subdomain=sd)] + + op = Operator(eqs) + + subdims = [d for d in op.dimensions if d.is_Sub] + ((xltkn, xrtkn), (_, yrtkn)) = [d.thickness for d in subdims] + + mapper0 = {xltkn: sympy.S.Zero} + mapper1 = {xrtkn: sympy.Integer(2), yrtkn: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + + mappers = (mapper0, mapper1, mapper2) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops, strict=True): + check_op(m, o) + + def test_factor(self): + """Test that ConditionalDimensions can have their symbolic factors specialized""" + size = 16 + factor = 4 + i = Dimension(name='i') + ci = ConditionalDimension(name='ci', parent=i, factor=factor) + + g = Function(name='g', shape=(size,), dimensions=(i,)) + f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) + + op0 = Operator([Eq(f, g)]) + + mapper = {ci.symbolic_factor: sympy.Integer(factor)} + + op1 = Specializer(mapper).visit(op0) + + assert ci.symbolic_factor not in op1.parameters + assert ci.symbolic_factor.name not in str(op1.ccode) + assert "if ((i)%(4) == 0)" in str(op1.ccode) + + def test_spacing(self): + """Test that grid spacings can be specialized""" + grid = Grid(shape=(11,)) + f = Function(name='f', grid=grid) + + op0 = Operator(Eq(f, f.dx)) + + mapper = {grid.dimensions[0].spacing: sympy.Float(grid.spacing[0])} + op1 = Specializer(mapper).visit(op0) + + assert grid.dimensions[0].spacing not in op1.parameters + assert grid.dimensions[0].spacing.name not in str(op1.ccode) + assert "/1.0e-1F;" in str(op1.ccode) + + def test_sizes(self): + """Test that strides generated for linearization can be specialized""" + grid = Grid(shape=(11, 11)) + + f = TimeFunction(name='f', grid=grid, space_order=2) + + op0 = Operator(Eq(f.forward, f.dx2), + opt=('advanced', {'expand': True, 'linearize': True})) + + mapper = {f.symbolic_shape[1]: sympy.Integer(11), + f.symbolic_shape[2]: sympy.Integer(11)} + + op1 = Specializer(mapper).visit(op0) + + assert "const int x_fsz0 = 11;" in str(op1.ccode) + assert "const int y_fsz0 = 11;" in str(op1.ccode) + + # TODO: Should strides get specialized? If so, how? + + def test_apply_basic(self): + """ + Test that a trivial operator runs and returns the same results when + specialized. + """ + grid = Grid(shape=(11, 11)) + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + f = Function(name='f', grid=grid, dtype=np.int32) + + op0 = Operator(Eq(f, f+1)) + + mapper = {x_m: sympy.Integer(2), x_M: sympy.Integer(7), + y_m: sympy.Integer(3), y_M: sympy.Integer(8)} + + op1 = Specializer(mapper).visit(op0) + + assert op1.cfunction is not op0.cfunction + + op1.apply() + + check = np.array(f.data[:]) + f.data[:] = 0 + + op0.apply(x_m=2, x_M=7, y_m=3, y_M=8) + + assert np.all(check == f.data) + + # TODO: Need a test to check that block sizes can be specialized + # TODO: Need to test that tile sizes can be specialized + + +class TestApply: + """Tests for specialization of operators at apply time""" + + @pytest.mark.parametrize('override', [False, True]) + def test_basic(self, caplog, override): + grid = Grid(shape=(11, 11)) + f = Function(name='f', grid=grid, dtype=np.int32) + op = Operator(Eq(f, f+1)) + + specialize = ('x_m', 'x_M') + + kwargs = {} + if override: + kwargs['x_m'] = 3 + + with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG): + op.apply_specialize(specialize=specialize, **kwargs) + + # Ensure that the specialized operator was run + assert all(s not in caplog.text for s in specialize) + assert "specialization" in caplog.text + + check = np.array(f.data[:]) + f.data[:] = 0 + op.apply(**kwargs) + + assert np.all(check == f.data[:]) + + @pytest.mark.parallel(mode=[2, 4]) + @pytest.mark.parametrize('override', [False, True]) + def test_basic_mpi(self, caplog, mode, override): + self.test_basic(caplog, override) + + @pytest.mark.parametrize('specialize', + [('x_m',), + ('y_M',), + ('time_m',), + ('time_m', 'time_M'), + ('x_m', 'y_M'), + ('x_m', 'x_M', 'y_m', 'y_M')]) + def test_diffusion_like(self, specialize): + grid = Grid(shape=(11, 11)) + + dt = 2.5e-5 + + f = TimeFunction(name='f', grid=grid, space_order=4) + f.data[:, 4:-4, 4:-4] = 1 + + op = Operator(Eq(f.forward, f - grid.time_dim.spacing*f.laplace)) + + op.apply(t_M=100, dt=dt) + + check = np.array(f.data) + f.data[:] = 0 + f.data[:, 4:-4, 4:-4] = 1 + + op.apply_specialize(t_M=100, dt=dt, specialize=specialize) + + assert np.all(np.isclose(check, f.data)) + + # Diffusion-like test + # Acoustic-like test (with and without source injection) + # Elastic-like test (with and without source injection)