Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions babs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def _apply_config(self) -> None:
self.pipeline = config_yaml.get('pipeline', None)
if self.pipeline is not None:
self._validate_pipeline_config()
self.container_images = self.get_container_image_paths(config_yaml)

# Check the output RIA:
self.wtf_key_info(flag_output_ria_only=True)
Expand Down Expand Up @@ -235,6 +236,26 @@ def _validate_pipeline_config(self) -> None:

print('Pipeline configuration validation complete!')

@staticmethod
def container_image_path(container_name: str) -> str:
"""Return the analysis-relative image path for a DataLad container."""
return op.join('containers', '.datalad', 'environments', container_name, 'image')

def get_container_image_paths(self, config_yaml: dict) -> list[str]:
"""Get analysis-relative container image paths used by participant jobs."""
container_images = config_yaml.get('container_images')
if isinstance(container_images, str):
container_images = [container_images]
elif container_images is None:
if self.pipeline is not None:
container_names = [step['container_name'] for step in self.pipeline]
else:
container_names = [self.container['name']]
container_images = [self.container_image_path(name) for name in container_names]

# Preserve order while avoiding duplicate datalad get calls.
return list(dict.fromkeys(container_images))

def _update_inclusion_dataframe(
self, initial_inclusion_df: pd.DataFrame | None = None
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions babs/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def babs_bootstrap(
# Create `babs_proj_config.yaml` file: ----------------------
print('Save BABS project configurations in a YAML file ...')
print("Path to this yaml file will be: 'analysis/code/babs_proj_config.yaml'")
self.container = {'name': container_name}
container_images = self.get_container_image_paths({})

env = Environment(
loader=PackageLoader('babs', 'templates'),
Expand All @@ -170,6 +172,7 @@ def babs_bootstrap(
input_ds=self.input_datasets,
container_name=container_name,
container_ds=container_ds,
container_images=container_images,
)
)
self.datalad_save(
Expand Down
44 changes: 44 additions & 0 deletions babs/interaction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""This is the main module."""

import os.path as op
import sys
import time

import datalad.api as dlapi
import numpy as np

from babs.base import BABS
Expand All @@ -18,6 +20,46 @@
class BABSInteraction(BABS):
"""Implement interactions with a BABS project - submitting jobs and checking status."""

def ensure_container_images_available(self) -> None:
"""Retrieve configured container image contents before submitting jobs."""
containers_path = op.join(self.analysis_path, 'containers')
if not op.exists(op.join(containers_path, '.datalad', 'config')):
raise FileNotFoundError(
'There is no containers DataLad dataset in folder: ' + containers_path
)

print('\nEnsuring container image(s) are available locally...')
for image_path in self.container_images:
if op.isabs(image_path):
image_path_abs = image_path
else:
image_path_abs = op.join(self.analysis_path, image_path)

if op.exists(image_path_abs):
print(f'Container image already available: {image_path}')
continue

print(f'Running `datalad get {image_path}`...')
statuses = dlapi.get(path=image_path_abs, dataset=containers_path)
if isinstance(statuses, dict):
statuses = [statuses]
elif statuses is None:
statuses = []

failed_statuses = [
status for status in statuses if status.get('status') not in {'ok', 'notneeded'}
]
if failed_statuses:
raise RuntimeError(
'Unable to retrieve container image before job submission: '
f'{image_path}\nDataLad status: {failed_statuses}'
)

if not op.exists(image_path_abs):
raise FileNotFoundError(
'Container image is still not available after `datalad get`: ' + image_path_abs
)

def babs_submit(self, count=None, submit_df=None, skip_failed=False, skip_running_jobs=False):
"""
This function submits jobs that don't have results yet and prints out job status.
Expand Down Expand Up @@ -113,6 +155,8 @@ def babs_submit(self, count=None, submit_df=None, skip_failed=False, skip_runnin
print(f'Submitting the first {count} jobs')
df_needs_submit = df_needs_submit.head(min(count, df_needs_submit.shape[0]))

self.ensure_container_images_available()

# We know task_id ahead of time, so we can add it to the dataframe
df_needs_submit['task_id'] = np.arange(1, df_needs_submit.shape[0] + 1)
# Columns to write before we know the job_id (pre-submit)
Expand Down
7 changes: 6 additions & 1 deletion babs/templates/babs_proj_config.yaml.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ input_datasets:
# container ds:
container:
name: '{{ container_name }}'
path_in: '{{ container_ds }}'
path_in: '{{ container_ds }}'

container_images:
{% for image_path in container_images %}
- '{{ image_path }}'
{% endfor %}
9 changes: 1 addition & 8 deletions docs/walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,6 @@ and results and provenance are saved. An example command of ``babs init`` is as
--queue slurm \
"${HOME}/babs_demo/my_BABS_project"

Retrieve the container:

.. code-block:: console

$ cd "${HOME}/babs_demo/my_BABS_project/analysis"
$ datalad get -r containers

.. note::
**Optional: Throttling array jobs**: If you have many jobs and want to limit how many
run simultaneously, you can add ``--throttle <number>`` to the command above.
Expand Down Expand Up @@ -745,4 +738,4 @@ You'll see printed messages like this:
0 04-17-2025 21:13 fmriprep_anat/sourcedata/freesurfer/0001/label/aparc.annot.a2009s.ctab


these are a bunch of empty files that mirror the outputs you'd get with an ``fmriprep`` run with ``--anat-only``.
these are a bunch of empty files that mirror the outputs you'd get with an ``fmriprep`` run with ``--anat-only``.
15 changes: 0 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,6 @@ def gather_slurm_job_diagnostics(
return '\n'.join(lines)


def ensure_container_image(project_root, container_name='simbids-0-0-3'):
"""Ensure container image content is present so jobs can find it at runtime."""
project_root = Path(project_root)
containers_path = project_root / 'analysis' / 'containers'
image_abs = containers_path / '.datalad' / 'environments' / container_name / 'image'
if containers_path.exists():
try:
dlapi.get(path=str(image_abs), dataset=str(containers_path))
except Exception as e:
if not image_abs.exists():
raise RuntimeError(f'Failed to get container image for tests: {e}') from e


def get_babs_project(
tmp_path_factory,
templateflow_home,
Expand Down Expand Up @@ -450,8 +437,6 @@ def get_babs_project(
initial_inclusion_df=None,
)

ensure_container_image(project_root, container_name)

if return_path:
return project_root
else:
Expand Down
8 changes: 0 additions & 8 deletions tests/e2e-slurm/container/walkthrough-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ babs init \

echo "PASSED: babs init"

pushd "test_project/analysis"
datalad get "containers/.datalad/environments/simbids-0-0-3/image"
popd

pushd "${PWD}/test_project"

echo "Check setup, without job"
Expand Down Expand Up @@ -107,10 +103,6 @@ babs init \
--keep-if-failed \
"${PWD}/${TEST2_NAME}"

pushd "${PWD}/${TEST2_NAME}/analysis"
datalad get "containers/.datalad/environments/simbids-0-0-3/image"
popd

pushd "${PWD}/${TEST2_NAME}"

babs check-setup
Expand Down
3 changes: 0 additions & 3 deletions tests/test_babs_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pandas as pd
import pytest
from conftest import (
ensure_container_image,
gather_slurm_job_diagnostics,
get_config_simbids_path,
update_yaml_for_run,
Expand Down Expand Up @@ -119,8 +118,6 @@ def test_babs_init_raw_bids(
assert not job.submitted
assert not job.is_failed

ensure_container_image(project_root, container_name)

# babs submit:
babs_submit_opts = argparse.Namespace(
project_root=project_root,
Expand Down
46 changes: 46 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,52 @@ def test_pipeline_config_details(babs_project_sessionlevel):
babs_proj._validate_pipeline_config()


def test_container_image_path():
assert (
BABS.container_image_path('fmriprep-1-2-3')
== 'containers/.datalad/environments/fmriprep-1-2-3/image'
)


def test_get_container_image_paths_from_explicit_string():
babs_proj = object.__new__(BABS)

assert babs_proj.get_container_image_paths(
{'container_images': 'containers/custom/image'}
) == ['containers/custom/image']


def test_get_container_image_paths_deduplicates_explicit_list():
babs_proj = object.__new__(BABS)

assert babs_proj.get_container_image_paths(
{'container_images': ['containers/a/image', 'containers/a/image', 'containers/b/image']}
) == ['containers/a/image', 'containers/b/image']


def test_get_container_image_paths_from_single_container():
babs_proj = object.__new__(BABS)
babs_proj.pipeline = None
babs_proj.container = {'name': 'simbids-0-0-3'}

assert babs_proj.get_container_image_paths({}) == [
'containers/.datalad/environments/simbids-0-0-3/image'
]


def test_get_container_image_paths_from_pipeline():
babs_proj = object.__new__(BABS)
babs_proj.pipeline = [
{'container_name': 'nordic-0-0-1'},
{'container_name': 'fmriprep-25-0-0'},
]

assert babs_proj.get_container_image_paths({}) == [
'containers/.datalad/environments/nordic-0-0-1/image',
'containers/.datalad/environments/fmriprep-25-0-0/image',
]


def test_update_inclusion_empty_combine(babs_project_sessionlevel):
"""Test _update_inclusion_dataframe when combined dataframe is empty."""
babs_proj = BABS(babs_project_sessionlevel)
Expand Down
Loading
Loading