Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions swirlc/compiler/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,16 @@
import uuid

from io import BytesIO
from pathlib import Path
from threading import Condition, Event, Thread
from typing import Any, MutableMapping, MutableSequence
"""

global_vars = """

BUF_SIZE = 8192
available_port_data = {}

condition: Condition = Condition()
connections: MutableMapping[str, MutableMapping[str, socket]] = {}
connections: MutableMapping[str, MutableMapping[str, socket.socket]] = {}
ports: MutableMapping[str, Any] = {}
stopping: bool = False

Expand All @@ -64,7 +63,7 @@
logger.propagate = False
"""

accept_function = """def _accept(sock: socket):
accept_function = """def _accept(sock: socket.socket):
while not stopping:
try:
conn, _ = sock.accept()
Expand Down Expand Up @@ -121,7 +120,6 @@

init_dataset_function = """def _init_dataset(port_name: str, data: str):
ports[port_name] = data
available_port_data[port_name] = Event()
available_port_data[port_name].set()
"""

Expand Down Expand Up @@ -173,7 +171,7 @@
logger.debug(f"Received data for port {port} from location {src}")
buf.seek(0)
ports[port] = buf.read().decode("utf-8")
available_port_data.setdefault(port, Event()).set()
available_port_data[port].set()
elif data_type == "file":
filename = connections[src][port].recv(1024).decode()
connections[src][port].send("ack".encode("utf-8"))
Expand All @@ -186,7 +184,7 @@
fd.write(data)
fd.close()
ports[port] = filepath
available_port_data.setdefault(port, Event()).set()
available_port_data[port].set()
logger.debug(f"Received file '{ports[port]}' on port {port}")
elif data_type == "directory":
raise NotImplementedError(f"Recv directories not implemented yet")
Expand Down Expand Up @@ -250,6 +248,7 @@ def __init__(self, outdir: str) -> None:
self.current_location: Location | None = None
self.functions = []
self.function_counter = 0
self.location_ports = set()
self.parallel_step_counter = 0
# If `parathetized` attribute is to True it means that an open bracket has been encountered
# but not yet its corresponding closed bracket
Expand All @@ -270,6 +269,7 @@ def begin_dataset(
):
for port_name, data in dataset:
self.current_location.data[data.name] = data
self.location_ports.add(port_name)
self.programs[self.current_location.name].write(f"""
_init_dataset("{port_name}", "{data.value}")""")

Expand Down Expand Up @@ -328,10 +328,19 @@ def end_location(self) -> None:
for name, location in self.workflow.locations.items()
]
)
ports = ",\n".join(
[
f"'{self.location_ports.pop()}' : Event()"
for _ in range(len(self.location_ports))
]
)
self.programs[self.current_location.name].write(f"""
locations = {{
{locations}
}}
available_port_data = {{
{ports}
}}

OUT_DIR = {out_dir}
SCRATCH_DIR = {scratch_dir}
Expand Down Expand Up @@ -443,15 +452,11 @@ def exec(
for arg in step.arguments
]

outputs = flow[1]
output_port_name = next(iter(outputs))[0] if outputs else ""
if output_port_name := next(iter(flow[1]))[0] if flow[1] else "":
self.location_ports.add(output_port_name)
self.programs[self.current_location.name].write(
f"""
{self._get_indentation()}available_port_data.setdefault("{output_port_name}", Event())
{self._get_indentation()}input_port_names = {[port_name for port_name, _ in flow[0]]}
{self._get_indentation()}for port_name in input_port_names:
{self._get_indentation()} available_port_data.setdefault(port_name, Event())
{self._get_indentation()}_exec("{step.name}", "{step.display_name}", input_port_names, "{output_port_name}", "{step.processors[output_port_name].type if output_port_name else ""}", "{step.processors[output_port_name].glob if output_port_name else ""}", "{step.command}", {arguments})"""
{self._get_indentation()}_exec("{step.name}", "{step.display_name}", {[port_name for port_name, _ in flow[0]]}, "{output_port_name}", "{step.processors[output_port_name].type if output_port_name else ""}", "{step.processors[output_port_name].glob if output_port_name else ""}", "{step.command}", {arguments})"""
)

def par(self) -> None:
Expand All @@ -472,6 +477,7 @@ def f{self.function_counter}():""")
self.function_counter += 1

def recv(self, port: str, data_type: str, src: str, dst: str):
self.location_ports.add(port)
self.programs[self.current_location.name].write(
f"""
{self._get_indentation()}{self._get_thread(self.current_location.name)} = _thread(_recv, "{port}", "{data_type}", "{src}")"""
Expand Down