diff --git a/CHANGELOG.md b/CHANGELOG.md index 140d9bf..589c82b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Support links that are explicitly optional (`required=False`). +- Support input caching from optional links (`cache_if_optional=True`). + +### Changed + +- Buffer inputs from optional links that arrive before the first execution. + +### Fixed + +- Protect input cache against parallel execution of the same `InputMergeActor` instance. + ## [2.0.2] - 2026-02-15 ### Fixed diff --git a/pyproject.toml b/pyproject.toml index 01fbe80..5b6881c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,8 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ - "ewokscore >=4.0.1", - "pypushflow >=2.0.0rc1" + "ewokscore >=5.0.0rc1", + "pypushflow >=2.0.0rc1", ] [project.urls] diff --git a/src/ewoksppf/bindings.py b/src/ewoksppf/bindings.py index 5044d84..a6f454d 100644 --- a/src/ewoksppf/bindings.py +++ b/src/ewoksppf/bindings.py @@ -1,6 +1,8 @@ import os +import threading import warnings from contextlib import contextmanager +from typing import Dict from typing import Generator from typing import List from typing import Optional @@ -172,6 +174,7 @@ def __init__( name="Name mapper", trigger_on_error=False, required=False, + cache_if_optional=False, **kw, ): super().__init__(name=name, **kw) @@ -179,11 +182,12 @@ def __init__( self.map_all_data = map_all_data self.trigger_on_error = trigger_on_error self.required = required + self.cache_if_optional = cache_if_optional def connect(self, actor): super().connect(actor) if isinstance(actor, InputMergeActor): - actor.require_input_from_actor(self) + actor.register_input_actor(self) def _execute(self, inData: dict, _scope_id: Optional[str] = None) -> None: is_error = "WorkflowExceptionInstance" in inData and inData.get( @@ -215,56 +219,163 @@ def _execute(self, inData: dict, _scope_id: Optional[str] = None) -> None: class InputMergeActor(AbstractActor): - """Requires triggers from some input actors before triggering - the downstream actors. - - It remembers the last input from the required uptstream actors. - Only the last non-required input is remembered. + """Requires triggers from some input actors before triggering the downstream actors. + Optional triggers are cached or buffered (before first execution) and the last one is retained. """ def __init__(self, parent=None, name="Input merger", **kw): super().__init__(parent=parent, name=name, **kw) - self.startInData = list() - self.requiredInData = dict() - self.nonrequiredInData = dict() - def require_input_from_actor(self, actor): + # List of input dicts provided by the graph startargs (not part of the Ewoks SPEC) + self._cached_start_triggers: List[dict] = list() + + # Map actor to input dict provided by that actor + self._cached_required_triggers: Dict[AbstractActor, dict] = dict() + self._cached_optional_triggers: Dict[AbstractActor, dict] = dict() + + # List of input dicts provided by optional links without caching + # that arrived before all required triggers arrived + self._buffer_optional_triggers: List[dict] = list() + self._buffering = True + + # Retain only one input dict provided by optional links without caching + # after all required triggers arrived + self._retained_optional_trigger: Optional[dict] = None + + self._lock = threading.Lock() + + def register_input_actor(self, actor: Optional[AbstractActor]): if actor.required: - self.requiredInData[actor] = None + info = "(required): cache inputs" + self._cached_required_triggers[actor] = None + elif actor.cache_if_optional: + info = "(optional): cache inputs" + self._cached_optional_triggers[actor] = None + else: + info = "(optional): buffer inputs before first execution and then retain the last one" + # see self._buffer_optional_triggers + self.logger.info("%s %s", actor.name, info) def _execute( - self, inData: dict, _scope_id: Optional[str] = None, source=None + self, + inData: dict, + _scope_id: Optional[str] = None, + source: Optional[AbstractActor] = None, ) -> None: - self.setStarted() - self.setFinished() - if source is None: - self.startInData.append(inData) + with self._lock: + self.setStarted() + try: + self._cache_inputs(source, inData) + finally: + self.setFinished() + + if not self._has_all_required_triggers(): + return + + self._propagate_cached_inputs() + + def _propagate_cached_inputs(self) -> None: + if not self._buffering: + # Execute with the retained inputs from the last trigger + # of an optional link without caching. Might be `None` + # when there is none. + buffer = [self._retained_optional_trigger] else: - if source in self.requiredInData: - self.requiredInData[source] = inData + if self._buffer_optional_triggers: + # Execute for each retained inputs from optional links without caching. + buffer = list(self._buffer_optional_triggers) + else: + # Execute once without any retained inputs. + buffer = [None] + + for i, retained_inputs in enumerate(buffer): + try: + self._trigger_downstream(retained_inputs) + except Exception: + if self._buffering: + # Keep the inputs not successfully propagated. + self._buffer_optional_triggers = buffer[i:] + raise + + if self._buffering: + if buffer: + # Retain the last one for the next trigger. + # Might be `None` when there is none. + self._retained_optional_trigger = buffer[-1] else: - self.nonrequiredInData = inData - missing = {k: v for k, v in self.requiredInData.items() if v is None} - if missing: + self._retained_optional_trigger = None + + # No more buffering, only retain one. + self._buffering = False + + # No longer needed so do not keep references. + self._buffer_optional_triggers.clear() + + def _cache_inputs(self, source: Optional[AbstractActor], inData: dict) -> None: + if source is None: + self._cached_start_triggers.append(inData) + return + + if source in self._cached_required_triggers: + # Cache inputs from required link + self._cached_required_triggers[source] = inData + elif source in self._cached_optional_triggers: + # Cache inputs from optional link + self._cached_optional_triggers[source] = inData + elif self._buffering: + # Did not execute yet + self._buffer_optional_triggers.append(inData) + else: + # Executed at least once + self._retained_optional_trigger = inData + + def _has_all_required_triggers(self) -> bool: + missing_required = { + k: v for k, v in self._cached_required_triggers.items() if v is None + } + if missing_required: self.logger.info( "not triggering downstream actors because missing inputs from actors %s", - [actor.name for actor in missing], + [actor.name for actor in missing_required], ) - return - self.logger.info( - "triggering downstream actors (%d start inputs, %d required inputs, %d optional inputs)", - len(self.startInData), - len(self.requiredInData), - int(bool(self.nonrequiredInData)), - ) - newInData = dict() - for data in self.startInData: - newInData.update(data) - for data in self.requiredInData.values(): - newInData.update(data) - newInData.update(self.nonrequiredInData) + return False + return True + + def _trigger_downstream(self, retained_inputs: Optional[dict]): + merged_inputs = self._downstream_inputs(retained_inputs) for actor in self.listDownStreamActor: - actor.trigger(newInData) + actor.trigger(merged_inputs) + + def _downstream_inputs(self, retained_inputs: Optional[dict]) -> dict: + self.logger.debug( + "Trigger downstream actor with merged inputs from\n " + "%d graph start triggers\n " + "%d cached required links\n " + "%d cached optional links\n " + "%d retained optional links", + len(self._cached_start_triggers), + len(self._cached_required_triggers), + len(self._cached_optional_triggers), + int(retained_inputs is not None), + ) + + merged_inputs = dict() + for data in self._cached_start_triggers: + merged_inputs.update(data) + + for data in self._cached_required_triggers.values(): + merged_inputs.update(data) + + for data in self._cached_optional_triggers.values(): + if data is None: + # Optional link not triggered yet + continue + merged_inputs.update(data) + + if retained_inputs: + merged_inputs.update(retained_inputs) + + return merged_inputs class EwoksWorkflow(Workflow): @@ -444,25 +555,35 @@ def _create_name_mapper( self, taskgraph: TaskGraph, source_id: NodeIdType, target_id: NodeIdType ) -> NameMapperActor: link_attrs = taskgraph.graph[source_id][target_id] + + # Data mapping map_all_data = link_attrs.get("map_all_data", False) data_mapping = link_attrs.get("data_mapping", list()) data_mapping = { item["target_input"]: item["source_output"] for item in data_mapping } + + # Conditional link on_error = link_attrs.get("on_error", False) + cache_if_optional = link_attrs.get("cache_if_optional", False) + + # Required link required = analysis.link_is_required(taskgraph.graph, source_id, target_id) + source_label = ppfname(source_id) target_label = ppfname(target_id) if on_error: name = f"Name mapper <{source_label} -only on error- {target_label}>" else: name = f"Name mapper <{source_label} - {target_label}>" + return NameMapperActor( name=name, namemap=data_mapping, map_all_data=map_all_data, trigger_on_error=on_error, required=required, + cache_if_optional=cache_if_optional, **self._actor_arguments, ) diff --git a/src/ewoksppf/tests/test_ppf_workflow21.py b/src/ewoksppf/tests/test_ppf_workflow21.py index 1ad15d7..be010da 100644 --- a/src/ewoksppf/tests/test_ppf_workflow21.py +++ b/src/ewoksppf/tests/test_ppf_workflow21.py @@ -99,8 +99,10 @@ def submodel21_on_error(): def workflow21(on_error): if on_error: submodel21 = submodel21_on_error + out1_required = False else: submodel21 = submodel21_conditions + out1_required = None nodes = [ {"id": "in", "task_type": "method", "task_identifier": qualname(passthrough)}, @@ -132,6 +134,7 @@ def workflow21(on_error): { "source": "out1", "target": "out", + "required": out1_required, "data_mapping": [{"source_output": "return_value", "target_input": "a"}], }, { diff --git a/src/ewoksppf/tests/test_ppf_workflow25.py b/src/ewoksppf/tests/test_ppf_workflow25.py new file mode 100644 index 0000000..f4d5ceb --- /dev/null +++ b/src/ewoksppf/tests/test_ppf_workflow25.py @@ -0,0 +1,172 @@ +import itertools +import time + +import pytest +from ewokscore.task import Task +from ewoksutils.import_utils import qualname + +from ..bindings import execute_graph + + +class Required(Task, input_names=["compute_time"], output_names=["required"]): + def run(self): + time.sleep(self.inputs.compute_time) + self.outputs.required = True + + +class Optional(Task, input_names=["compute_time"], output_names=["optional"]): + def run(self): + time.sleep(self.inputs.compute_time) + self.outputs.optional = True + + +class Gather( + Task, + input_names=["required1", "required2"], + optional_input_names=["optional1", "optional2", "retained1", "retained2"], + output_names=["cached"], +): + def run(self): + global _GATHER_CACHE + cached = self.get_input_values() + _GATHER_CACHE = cached + print(f"\nDecider executed with inputs: {cached}") + self.outputs.cached = cached + + +def workflow(): + nodes = [ + { + "id": "required1", + "task_type": "class", + "task_identifier": qualname(Required), + }, + { + "id": "required2", + "task_type": "class", + "task_identifier": qualname(Required), + }, + { + "id": "optional1", + "task_type": "class", + "task_identifier": qualname(Optional), + }, + { + "id": "optional2", + "task_type": "class", + "task_identifier": qualname(Optional), + }, + { + "id": "retained1", + "task_type": "class", + "task_identifier": qualname(Optional), + }, + { + "id": "retained2", + "task_type": "class", + "task_identifier": qualname(Optional), + }, + { + "id": "gather", + "task_type": "class", + "task_identifier": qualname(Gather), + }, + ] + links = [ + { + "source": "required1", + "target": "gather", + "data_mapping": [ + {"source_output": "required", "target_input": "required1"} + ], + }, + { + "source": "required2", + "target": "gather", + "data_mapping": [ + {"source_output": "required", "target_input": "required2"} + ], + }, + { + "source": "optional1", + "target": "gather", + "required": False, + "cache_if_optional": True, + "data_mapping": [ + {"source_output": "optional", "target_input": "optional1"} + ], + }, + { + "source": "optional2", + "target": "gather", + "required": False, + "cache_if_optional": True, + "data_mapping": [ + {"source_output": "optional", "target_input": "optional2"} + ], + }, + { + "source": "retained1", + "target": "gather", + "required": False, + "cache_if_optional": False, + "data_mapping": [ + {"source_output": "optional", "target_input": "retained1"} + ], + }, + { + "source": "retained2", + "target": "gather", + "required": False, + "cache_if_optional": False, + "data_mapping": [ + {"source_output": "optional", "target_input": "retained2"} + ], + }, + ] + return {"graph": {"id": "workflow"}, "nodes": nodes, "links": links} + + +def get_inputs(required, optional, retained): + return [ + {"id": "required1", "name": "compute_time", "value": required}, + {"id": "required2", "name": "compute_time", "value": required}, + {"id": "optional1", "name": "compute_time", "value": optional}, + {"id": "optional2", "name": "compute_time", "value": optional}, + {"id": "retained1", "name": "compute_time", "value": retained}, + {"id": "retained2", "name": "compute_time", "value": retained}, + ] + + +_ORDER = list(itertools.permutations(["required", "optional", "retained"])) + + +@pytest.mark.parametrize("order", _ORDER, ids=["-".join(keys) for keys in _ORDER]) +def test_ppf_workflow25(ppf_log_config, order): + """Test input caching for different types of links executed in different orders.""" + global _GATHER_CACHE + _GATHER_CACHE = None + compute_times = [0, 0.5, 1] + inputs = get_inputs(**dict(zip(order, compute_times))) + + # result = execute_graph(workflow(), inputs=inputs) + # cached = set(result["cached"]) + # + # When + # + # order = ('retained', 'required', 'optional') + # + # the last two calls to "Gather" could be for example + # + # {'required1': True, 'required2': True, 'optional1': True, 'retained2': True} + # {'required1': True, 'required2': True, 'optional1': True, 'optional2': True, 'retained2': True} + # + # Since these calls happen in parallel and there is nothing in the workflow + # that guarantees we get one or the other as the final workflow result we + # cannot use the result to test the caching. + + _ = execute_graph(workflow(), pool_type="thread", inputs=inputs) + cached = set(_GATHER_CACHE) + cached1 = {"required1", "required2", "optional1", "optional2", "retained1"} + cached2 = {"required1", "required2", "optional1", "optional2", "retained2"} + assert cached == cached1 or cached == cached2, cached