diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/run_merging.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/run_merging.py new file mode 100644 index 000000000..196bac92f --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/run_merging.py @@ -0,0 +1,97 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Merging utility for Orbax checkpoints.""" + +import asyncio +from collections.abc import Sequence + +from absl import app +from absl import flags +from etils import epath +import jax +from orbax.checkpoint.experimental.v1._src.layout import orbax_layout +from orbax.checkpoint.experimental.v1._src.partial import merging + + +FLAGS = flags.FLAGS + +_IN_PATHS = flags.DEFINE_multi_string( + 'in_paths', + None, + 'Paths of checkpoints to merge.', +) +_OUT_PATH = flags.DEFINE_string( + 'out_path', + None, + 'Output checkpoint path.', +) +_PER_HOST_MEMORY_LIMIT_BYTES = flags.DEFINE_integer( + 'per_host_memory_limit_bytes', + None, + 'Memory limit in bytes per CPU host for partial loading and saving.' + ' Non-uniform memory limits are not supported.', +) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + if not _IN_PATHS.value: + raise app.UsageError('Flag --in_paths must be specified.') + if _OUT_PATH.value is None: + raise app.UsageError('Flag --out_path must be specified.') + if _PER_HOST_MEMORY_LIMIT_BYTES.value is None: + raise app.UsageError( + 'Flag --per_host_memory_limit_bytes must be specified.' + ) + + if _PER_HOST_MEMORY_LIMIT_BYTES.value <= 0: + raise ValueError('per_host_memory_limit_bytes must be positive.') + + # Validate input checkpoints. + layout = orbax_layout.OrbaxLayout() + for path_str in _IN_PATHS.value: + path = epath.Path(path_str) + if not path.exists(): + raise FileNotFoundError(f'Input path {path_str} does not exist.') + # OrbaxLayout.validate is async. + try: + asyncio.run(layout.validate(path)) + except Exception as e: + raise ValueError( + f'Input path {path_str} is not a valid checkpoint.' + ) from e + + # Validate output path. + out_path = epath.Path(_OUT_PATH.value) + if out_path.exists(): + if not out_path.is_dir(): + raise ValueError( + f'Output path {_OUT_PATH.value} exists but is not a directory.' + ) + if list(out_path.iterdir()): + raise ValueError( + f'Output path {_OUT_PATH.value} exists and is not empty.' + ) + + if jax.process_index() == 0: + out_path.mkdir(parents=True, exist_ok=True) + + merging.merge_checkpoints( + _IN_PATHS.value, + _OUT_PATH.value, + _PER_HOST_MEMORY_LIMIT_BYTES.value, + ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/run_merging_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/run_merging_test.py new file mode 100644 index 000000000..9ccf596ae --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/run_merging_test.py @@ -0,0 +1,82 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from absl.testing import flagsaver +from etils import epath +import jax +from orbax.checkpoint.experimental.v1._src.layout import orbax_layout +from orbax.checkpoint.experimental.v1._src.partial import merging +from orbax.checkpoint.experimental.v1._src.partial import run_merging + + +class RunMergingTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.out_path = self.create_tempdir().full_path + self.in_paths = [self.create_tempdir().full_path] + + @mock.patch.object( + orbax_layout.OrbaxLayout, 'validate', new_callable=mock.AsyncMock + ) + @mock.patch.object(merging, 'merge_checkpoints', autospec=True) + @mock.patch.object(jax, 'process_index', return_value=0) + def test_main_success(self, _, mock_merge, mock_validate): + with flagsaver.flagsaver( + in_paths=self.in_paths, + out_path=self.out_path, + per_host_memory_limit_bytes=1024, + ): + run_merging.main([]) + + mock_validate.assert_called() + mock_merge.assert_called_once() + + @mock.patch.object( + orbax_layout.OrbaxLayout, 'validate', new_callable=mock.AsyncMock + ) + @mock.patch.object(jax, 'process_index', return_value=0) + def test_main_invalid_output_not_empty(self, *_): + out_path = epath.Path(self.out_path) + (out_path / 'some_file').write_text('content') + + with flagsaver.flagsaver( + in_paths=self.in_paths, + out_path=self.out_path, + per_host_memory_limit_bytes=1024, + ): + with self.assertRaisesRegex(ValueError, 'not empty'): + run_merging.main([]) + + @mock.patch.object( + orbax_layout.OrbaxLayout, 'validate', new_callable=mock.AsyncMock + ) + @mock.patch.object(jax, 'process_index', return_value=0) + def test_main_invalid_input(self, _, mock_validate): + mock_validate.side_effect = ValueError('Invalid checkpoint') + + with flagsaver.flagsaver( + in_paths=self.in_paths, + out_path=self.out_path, + per_host_memory_limit_bytes=1024, + ): + with self.assertRaisesRegex(ValueError, 'is not a valid checkpoint'): + run_merging.main([]) + + +if __name__ == '__main__': + absltest.main()