Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions services/analysis-engine/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,7 @@ def put(self, item: tuple[str, object]) -> None:
assert archive["stem_bass"].shape == (4,)


def test_stem_separation_process_helper_maps_worker_results(tmp_path) -> None:
"""Ensure parent-side process helper maps worker result envelopes."""

def _make_fake_multiprocessing_context(item: tuple[str, object]) -> object:
class FakeQueue:
def __init__(self, item: tuple[str, object]) -> None:
self.item = item
Expand Down Expand Up @@ -965,12 +963,20 @@ def Queue(self, maxsize: int) -> FakeQueue:
assert maxsize == 1
return FakeQueue(self.item)

return FakeContext(item)


def test_stem_separation_process_helper_maps_ok_envelope() -> None:
"""Ensure parent-side process helper maps ok worker result envelopes."""
with patch(
"bandscope_analysis.api._multiprocessing_context",
return_value=FakeContext(("ok", {"stems": {}})),
return_value=_make_fake_multiprocessing_context(("ok", {"stems": {}})),
):
assert _run_stem_separation_with_timeout("/tmp/audio.wav") == {"stems": {}}


def test_stem_separation_process_helper_maps_ok_file_envelope(tmp_path) -> None:
"""Ensure parent-side process helper maps ok_file worker result envelopes."""
arrays_path = tmp_path / "worker-stems.npz"
np.savez_compressed(arrays_path, stem_bass=np.ones(4))
file_payload = {
Expand All @@ -982,14 +988,17 @@ def Queue(self, maxsize: int) -> FakeQueue:
}
with patch(
"bandscope_analysis.api._multiprocessing_context",
return_value=FakeContext(("ok_file", file_payload)),
return_value=_make_fake_multiprocessing_context(("ok_file", file_payload)),
):
loaded = _run_stem_separation_with_timeout("/tmp/audio.wav")
assert loaded["sr"] == 22050
assert loaded["stems"]["bass"].shape == (4,)
assert loaded["stem_role_types"] == {"bass": "instrument"}
assert not arrays_path.with_suffix(".json").exists()


def test_stem_separation_process_helper_rejects_invalid_file_payloads(tmp_path) -> None:
"""Ensure parent-side process helper rejects invalid ok_file payloads."""
invalid_file_payloads = [
("not-a-dict", "Stem separation returned invalid metadata."),
(
Expand All @@ -1016,7 +1025,7 @@ def Queue(self, maxsize: int) -> FakeQueue:
for payload, expected_message in invalid_file_payloads:
with patch(
"bandscope_analysis.api._multiprocessing_context",
return_value=FakeContext(("ok_file", payload)),
return_value=_make_fake_multiprocessing_context(("ok_file", payload)),
):
try:
_run_stem_separation_with_timeout("/tmp/audio.wav")
Expand All @@ -1025,6 +1034,9 @@ def Queue(self, maxsize: int) -> FakeQueue:
else:
raise AssertionError("Expected RuntimeError")


def test_stem_separation_process_helper_maps_error_envelopes() -> None:
"""Ensure parent-side process helper maps error envelopes to exceptions."""
error_cases = [
(("file_not_found", "missing"), FileNotFoundError),
(("value_error", "bad media"), ValueError),
Expand All @@ -1033,7 +1045,7 @@ def Queue(self, maxsize: int) -> FakeQueue:
for item, expected_error in error_cases:
with patch(
"bandscope_analysis.api._multiprocessing_context",
return_value=FakeContext(item),
return_value=_make_fake_multiprocessing_context(item),
):
try:
_run_stem_separation_with_timeout("/tmp/audio.wav")
Expand Down
Loading