From e0ab2786b60d70cdc0a9489d57766a1e37650e1d Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Sat, 11 Apr 2026 18:00:52 +0800 Subject: [PATCH 1/9] [Bugfix] Restore Inductor Cache Mechanism --- magi_compiler/_api.py | 4 +- magi_compiler/config.py | 13 +- magi_compiler/magi_backend/magi_backend.py | 21 +-- .../magi_backend/piecewise_compiler.py | 8 +- .../test_torch_inductor_cache.py | 162 ++++++++++++++++++ 5 files changed, 188 insertions(+), 20 deletions(-) create mode 100644 tests/torch_native_tests/test_torch_inductor_cache.py diff --git a/magi_compiler/_api.py b/magi_compiler/_api.py index 7a23298..d93782e 100644 --- a/magi_compiler/_api.py +++ b/magi_compiler/_api.py @@ -416,14 +416,14 @@ def _compilation_context(state: MagiCompileState): from .magi_depyf.inspect import explain_compilation _debug_dump_path = debug_dump_path(state.compile_config.cache_root_dir, state.model_idx, state.model_tag) - _cache_dump_path = cache_dump_path(state.compile_config.cache_root_dir, state.model_idx, state.model_tag) + _inductor_cache_dump_path = cache_dump_path(state.compile_config.cache_root_dir) with ( patch.object(torch._dynamo.config, "assume_static_by_default", False), patch.object(torch._dynamo.config, "enable_cpp_symbolic_shape_guards", False), patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False), _hijack_inline_call_to_collect_traced_files(state), - patch.dict(os.environ, {"TORCHINDUCTOR_CACHE_DIR": (_cache_dump_path / "inductor_cache").as_posix()}), + patch.dict(os.environ, {"TORCHINDUCTOR_CACHE_DIR": (_inductor_cache_dump_path).as_posix()}), explain_compilation(_debug_dump_path.as_posix()), ): yield diff --git a/magi_compiler/config.py b/magi_compiler/config.py index 902014e..866fb80 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -258,15 +258,22 @@ def model_rank_dir_name(model_idx: int, model_tag: str | None) -> str: return f"model_{model_idx}_rank_{rank}" -def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None) -> Path: +def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = None) -> Path: from datetime import datetime run_id = datetime.now().strftime("run_%Y%m%d_%H%M%S") return Path(cache_root_dir) / "magi_depyf" / run_id / model_rank_dir_name(model_idx, model_tag) -def cache_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None) -> Path: - return Path(cache_root_dir) / "torch_compile_cache" / model_rank_dir_name(model_idx, model_tag) +def cache_dump_path(cache_root_dir: str, \ + model_idx: int | None = None, model_tag: str | None = None) -> Path: + if not model_idx and not model_tag: + # Inductor cache path + return Path(cache_root_dir) / "inductor_cache" + else: + # Magi cache path + assert model_idx and model_tag, "model_idx, model_tag are required for magi_cache path" + return Path(cache_root_dir) / "magi_cache" / model_rank_dir_name(model_idx, model_tag) def inductor_compile_config_hash(inductor_compile_config: dict[str, Any]) -> str: diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 631f46d..04ac777 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -104,23 +104,23 @@ def compile_context(self, runtime_shape: int | None = None, graph_index: int | N with pass_context(runtime_shape, graph_index): yield - def initialize_cache(self, cache_dir: Path, prefix: str = ""): + def initialize_cache(self, cache_dir: Path, prefix: str | None = None): """ Initialize the cache directory for the compiler. The organization of the cache directory is as follows: - cache_dir=/path/to/torch_compile_cache/rank_i_j/hash_str/prefix/ + cache_dir=/path/to/magi_cache/model_{idx}[_{tag}]_rank_{rank}/hash_str/[prefix/] inside cache_dir, there will be: - - magi_compile_cache.py + - magi_cache_indices.py - computation_graph.py for multiple prefixes, they can share the same base cache dir of - /path/to/torch_compile_cache/rank_i_j/hash_str/ to store some + /path/to/magi_cache/model_{idx}[_{tag}]_rank_{rank}/hash_str/ to store some common compilation artifacts. """ self.cache_dir: Path = cache_dir - self.cache_file_path: Path = cache_dir / "magi_compile_cache.py" + self.cache_file_path: Path = cache_dir / "magi_cache_indices.py" if self.disable_cache: magi_logger.info("MagiCompiler's cache is disabled.") @@ -482,13 +482,10 @@ def _init_cache(self) -> str: ] ) - # Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/{model_tag}/ - self.local_cache_dir: Path = ( - cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key / self.model_tag - ) - self.local_cache_dir.mkdir(parents=True, exist_ok=True) - - self.compiler_manager.initialize_cache(self.local_cache_dir, self.model_tag) + # Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/ + self.local_magi_cache_path: Path = cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key + self.local_magi_cache_path.mkdir(parents=True, exist_ok=True) + self.compiler_manager.initialize_cache(self.local_magi_cache_path) @observe_lifecycle("graph_split") def _split_graph(self, graph: fx.GraphModule) -> tuple[fx.GraphModule, list[SplitItem]]: diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index a7d5093..b4b953d 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -19,7 +19,6 @@ from typing import Any import torch -import torch._inductor.compile_fx import torch.fx as fx from magi_compiler.magi_depyf.timeline import observe_lifecycle @@ -160,8 +159,11 @@ def hash(self) -> str: factors: list[Any] = [CacheBase.get_system(), torch_key()] return compute_hash(factors) - def initialize_cache(self, cache_dir: Path, prefix: str = ""): - self.cache_dir: Path = cache_dir + def initialize_cache(self, cache_dir: Path, prefix: str | None = None): + if prefix: + self.cache_dir: Path = cache_dir / prefix + else: + self.cache_dir: Path = cache_dir @observe_lifecycle("compiler_compile") def compile( diff --git a/tests/torch_native_tests/test_torch_inductor_cache.py b/tests/torch_native_tests/test_torch_inductor_cache.py new file mode 100644 index 0000000..b43abd1 --- /dev/null +++ b/tests/torch_native_tests/test_torch_inductor_cache.py @@ -0,0 +1,162 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TorchInductor cache.""" + +import os +import tempfile +from dataclasses import dataclass +from unittest.mock import patch + +import pytest +import torch +import torch.nn.functional as F +from torch._dynamo.utils import counters + +from tests.model_definition import TransformerConfig, create_transformer_model_with_initial_params + + +@dataclass(frozen=True) +class CounterDelta: + """Store cache counter deltas for one test run.""" + + autograd_hit: int + autograd_miss: int + inductor_hit: int + inductor_miss: int + + +EXPECTED = { + "train": CounterDelta(autograd_hit=31, autograd_miss=2, inductor_hit=0, inductor_miss=0), + "eval": CounterDelta(autograd_hit=31, autograd_miss=1, inductor_hit=0, inductor_miss=0), +} + + +@pytest.fixture(autouse=True) +def model_with_clean_cache(): + """Build a transformer model with an isolated Magi cache directory.""" + + with tempfile.TemporaryDirectory() as tmp_dir: + with patch.dict(os.environ, {"MAGI_COMPILE_CACHE_ROOT_DIR": tmp_dir}): + transformer_config = TransformerConfig( + vocab_size=128256, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + max_position_embeddings=8192, + rms_norm_eps=1e-6, + params_dtype=torch.bfloat16, + ) + model, _ = create_transformer_model_with_initial_params(transformer_config, device="cuda") + yield model, transformer_config + + +@pytest.fixture +def cache_snapshot(): + """Return a callable that snapshots current Dynamo cache counters.""" + + def _snapshot() -> dict[str, dict[str, int]]: + return { + "autograd": { + "hit": counters["aot_autograd"]["autograd_cache_hit"], + "miss": counters["aot_autograd"]["autograd_cache_miss"], + }, + "inductor": { + "hit": counters["inductor"]["inductor_cache_hit"], + "miss": counters["inductor"]["inductor_cache_miss"], + }, + } + + return _snapshot + + +def _device() -> torch.device: + """Select CUDA when available, otherwise CPU.""" + + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _delta(before: dict, after: dict) -> CounterDelta: + """Compute per-backend cache hit/miss deltas between two snapshots.""" + + return CounterDelta( + autograd_hit=after["autograd"]["hit"] - before["autograd"]["hit"], + autograd_miss=after["autograd"]["miss"] - before["autograd"]["miss"], + inductor_hit=after["inductor"]["hit"] - before["inductor"]["hit"], + inductor_miss=after["inductor"]["miss"] - before["inductor"]["miss"], + ) + + +def _assert_delta(actual: CounterDelta, expected: CounterDelta): + """Assert cache counter deltas against expected values.""" + + assert actual == expected, f"counter delta mismatch, got={actual}, expected={expected}" + + +class TestTorchInductorCache: + """Validate TorchInductor cache behavior in train/eval flows.""" + + def test_training_mode(self, model_with_clean_cache, cache_snapshot): + """Verify cache counter deltas for repeated training iterations.""" + model, model_config = model_with_clean_cache + model = model.to(_device()) + model.train() + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2) + max_batch_size = 2 + num_epochs = 10 + + before = cache_snapshot() + for _ in range(num_epochs): + dummy_input = torch.randint( + 0, model_config.vocab_size, + (max_batch_size, model_config.max_position_embeddings), + device=_device(), + ) + dummy_label = torch.randint( + 0, model_config.vocab_size, + (max_batch_size, model_config.max_position_embeddings), + device=_device(), + ) + + optimizer.zero_grad() + logits = model.forward(dummy_input) + loss = F.cross_entropy(logits.view(-1, model_config.vocab_size), dummy_label.view(-1)) + loss.backward() + optimizer.step() + + after = cache_snapshot() + _assert_delta(_delta(before, after), EXPECTED["train"]) + + def test_evaluation_mode(self, model_with_clean_cache, cache_snapshot): + """Verify cache counter deltas for a no-grad evaluation forward.""" + + model, model_config = model_with_clean_cache + model = model.to(_device()) + model.eval() + + max_batch_size = 2 + before = cache_snapshot() + with torch.no_grad(): + dummy_input = torch.randint( + 0, model_config.vocab_size, + (max_batch_size, model_config.max_position_embeddings), + device=_device(), + ) + model.forward(dummy_input) + after = cache_snapshot() + + _assert_delta(_delta(before, after), EXPECTED["eval"]) \ No newline at end of file From 12c91b557ca807912c86e42140f8b5164d03a4fa Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 13 Apr 2026 17:59:48 +0800 Subject: [PATCH 2/9] [chore] fix ci --- magi_compiler/config.py | 3 +-- .../test_torch_inductor_cache.py | 14 ++++---------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/magi_compiler/config.py b/magi_compiler/config.py index 866fb80..e215dea 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -265,8 +265,7 @@ def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = return Path(cache_root_dir) / "magi_depyf" / run_id / model_rank_dir_name(model_idx, model_tag) -def cache_dump_path(cache_root_dir: str, \ - model_idx: int | None = None, model_tag: str | None = None) -> Path: +def cache_dump_path(cache_root_dir: str, model_idx: int | None = None, model_tag: str | None = None) -> Path: if not model_idx and not model_tag: # Inductor cache path return Path(cache_root_dir) / "inductor_cache" diff --git a/tests/torch_native_tests/test_torch_inductor_cache.py b/tests/torch_native_tests/test_torch_inductor_cache.py index b43abd1..69bef6c 100644 --- a/tests/torch_native_tests/test_torch_inductor_cache.py +++ b/tests/torch_native_tests/test_torch_inductor_cache.py @@ -122,14 +122,10 @@ def test_training_mode(self, model_with_clean_cache, cache_snapshot): before = cache_snapshot() for _ in range(num_epochs): dummy_input = torch.randint( - 0, model_config.vocab_size, - (max_batch_size, model_config.max_position_embeddings), - device=_device(), + 0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device() ) dummy_label = torch.randint( - 0, model_config.vocab_size, - (max_batch_size, model_config.max_position_embeddings), - device=_device(), + 0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device() ) optimizer.zero_grad() @@ -152,11 +148,9 @@ def test_evaluation_mode(self, model_with_clean_cache, cache_snapshot): before = cache_snapshot() with torch.no_grad(): dummy_input = torch.randint( - 0, model_config.vocab_size, - (max_batch_size, model_config.max_position_embeddings), - device=_device(), + 0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device() ) model.forward(dummy_input) after = cache_snapshot() - _assert_delta(_delta(before, after), EXPECTED["eval"]) \ No newline at end of file + _assert_delta(_delta(before, after), EXPECTED["eval"]) From 7e5b133e66537f4318a657f6bf4def360713488f Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 13 Apr 2026 18:10:20 +0800 Subject: [PATCH 3/9] [chore] fix code style --- magi_compiler/magi_backend/magi_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 04ac777..1c15ae5 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -483,7 +483,9 @@ def _init_cache(self) -> str: ) # Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/ - self.local_magi_cache_path: Path = cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key + self.local_magi_cache_path: Path = ( + cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key + ) self.local_magi_cache_path.mkdir(parents=True, exist_ok=True) self.compiler_manager.initialize_cache(self.local_magi_cache_path) From c13ece7fc71b5a8cdd09114ca586b4577bc6313c Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 13 Apr 2026 19:27:01 +0800 Subject: [PATCH 4/9] [chroe] update unittests --- tests/feature_tests/test_restart_analysis_cache.py | 2 +- tests/torch_native_tests/test_torch_inductor_cache.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/feature_tests/test_restart_analysis_cache.py b/tests/feature_tests/test_restart_analysis_cache.py index 828d532..0ecfb19 100644 --- a/tests/feature_tests/test_restart_analysis_cache.py +++ b/tests/feature_tests/test_restart_analysis_cache.py @@ -41,7 +41,7 @@ def test_restart_analysis_cache_handle_marks_and_skips_artifact_load(tmp_path: P assert p2.returncode == 0, f"worker2 failed\nstdout:\n{p2.stdout}\nstderr:\n{p2.stderr}" assert "too many values to unpack" not in p2.stderr - cache_files = list(cache_root.rglob("magi_compile_cache.py")) + cache_files = list(cache_root.rglob("magi_cache_indices.py")) assert cache_files, "no cache file generated" any_marked = False for cache_file in cache_files: diff --git a/tests/torch_native_tests/test_torch_inductor_cache.py b/tests/torch_native_tests/test_torch_inductor_cache.py index 69bef6c..a7830b2 100644 --- a/tests/torch_native_tests/test_torch_inductor_cache.py +++ b/tests/torch_native_tests/test_torch_inductor_cache.py @@ -37,8 +37,9 @@ class CounterDelta: inductor_miss: int +# NOTE: may be different on different machines, and this config is suitable for CI machine EXPECTED = { - "train": CounterDelta(autograd_hit=31, autograd_miss=2, inductor_hit=0, inductor_miss=0), + "train": CounterDelta(autograd_hit=31, autograd_miss=1, inductor_hit=0, inductor_miss=0), "eval": CounterDelta(autograd_hit=31, autograd_miss=1, inductor_hit=0, inductor_miss=0), } From 3f449f2691b0085124ad047bbb682fd95a11b73f Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 13 Apr 2026 19:40:24 +0800 Subject: [PATCH 5/9] [chore] fix ci --- magi_compiler/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/magi_compiler/config.py b/magi_compiler/config.py index e215dea..e70d2b2 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -266,12 +266,12 @@ def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = def cache_dump_path(cache_root_dir: str, model_idx: int | None = None, model_tag: str | None = None) -> Path: - if not model_idx and not model_tag: + if model_idx is None and model_tag is None: # Inductor cache path return Path(cache_root_dir) / "inductor_cache" else: # Magi cache path - assert model_idx and model_tag, "model_idx, model_tag are required for magi_cache path" + assert model_idx is not None and model_tag is not None, "model_idx, model_tag are required for magi_cache path" return Path(cache_root_dir) / "magi_cache" / model_rank_dir_name(model_idx, model_tag) From c0527470f880d68ecca3e48226ea5b5a36e8aa4c Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 14 Apr 2026 11:33:50 +0800 Subject: [PATCH 6/9] [chores] rerun ci --- ...{test_torch_inductor_cache.py => test_inductor_cache_reuse.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/torch_native_tests/{test_torch_inductor_cache.py => test_inductor_cache_reuse.py} (100%) diff --git a/tests/torch_native_tests/test_torch_inductor_cache.py b/tests/torch_native_tests/test_inductor_cache_reuse.py similarity index 100% rename from tests/torch_native_tests/test_torch_inductor_cache.py rename to tests/torch_native_tests/test_inductor_cache_reuse.py From a9712640bf131c9b2966b9f8623419c670e5def9 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 14 Apr 2026 16:09:52 +0800 Subject: [PATCH 7/9] [chores] refactor cache structure --- magi_compiler/_api.py | 4 ++-- magi_compiler/config.py | 14 ++++++-------- magi_compiler/magi_backend/magi_backend.py | 4 ++-- magi_compiler/magi_backend/piecewise_compiler.py | 7 ++----- tests/feature_tests/test_restart_analysis_cache.py | 2 +- 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/magi_compiler/_api.py b/magi_compiler/_api.py index f9cc1b4..0487671 100644 --- a/magi_compiler/_api.py +++ b/magi_compiler/_api.py @@ -25,7 +25,7 @@ from torch import nn from torch._dynamo.symbolic_convert import InliningInstructionTranslator -from magi_compiler.config import cache_dump_path, debug_dump_path +from magi_compiler.config import debug_dump_path, inductor_cache_dump_path from magi_compiler.cuda.cudart import pin_memory_in_place from magi_compiler.magi_backend.magi_compiler_base import MagiCompileState from magi_compiler.utils import compilation_counter, envs, magi_logger @@ -397,7 +397,7 @@ def _compilation_context(state: MagiCompileState): from .magi_depyf.inspect import explain_compilation _debug_dump_path = debug_dump_path(state.compile_config.cache_root_dir, state.model_idx, state.model_tag) - _inductor_cache_dump_path = cache_dump_path(state.compile_config.cache_root_dir) + _inductor_cache_dump_path = inductor_cache_dump_path(state.compile_config.cache_root_dir) with ( _isolated_dynamo_config(), diff --git a/magi_compiler/config.py b/magi_compiler/config.py index e70d2b2..c5edf38 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -265,14 +265,12 @@ def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = return Path(cache_root_dir) / "magi_depyf" / run_id / model_rank_dir_name(model_idx, model_tag) -def cache_dump_path(cache_root_dir: str, model_idx: int | None = None, model_tag: str | None = None) -> Path: - if model_idx is None and model_tag is None: - # Inductor cache path - return Path(cache_root_dir) / "inductor_cache" - else: - # Magi cache path - assert model_idx is not None and model_tag is not None, "model_idx, model_tag are required for magi_cache path" - return Path(cache_root_dir) / "magi_cache" / model_rank_dir_name(model_idx, model_tag) +def magi_cache_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = None) -> Path: + return Path(cache_root_dir) / "magi_cache" / model_rank_dir_name(model_idx, model_tag) + + +def inductor_cache_dump_path(cache_root_dir: str, model_idx: int | None = None, model_tag: str | None = None) -> Path: + return Path(cache_root_dir) / "inductor_cache" def inductor_compile_config_hash(inductor_compile_config: dict[str, Any]) -> str: diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 1c15ae5..ba69b0c 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -30,7 +30,7 @@ from torch._guards import detect_fake_mode import magi_compiler.utils.envs as envs -from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, cache_dump_path, inductor_compile_config_hash +from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, inductor_compile_config_hash, magi_cache_dump_path from magi_compiler.magi_depyf.timeline import observe_lifecycle, observe_lifecycle_context from magi_compiler.offload.offload_warpper import OffloadWrapper from magi_compiler.passes import CustomJointGraphPartitionFn, FullGraphPassManager, PostGradPassManager, pass_context @@ -484,7 +484,7 @@ def _init_cache(self) -> str: # Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/ self.local_magi_cache_path: Path = ( - cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key + magi_cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key ) self.local_magi_cache_path.mkdir(parents=True, exist_ok=True) self.compiler_manager.initialize_cache(self.local_magi_cache_path) diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index b4b953d..cc64738 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -159,11 +159,8 @@ def hash(self) -> str: factors: list[Any] = [CacheBase.get_system(), torch_key()] return compute_hash(factors) - def initialize_cache(self, cache_dir: Path, prefix: str | None = None): - if prefix: - self.cache_dir: Path = cache_dir / prefix - else: - self.cache_dir: Path = cache_dir + def initialize_cache(self, cache_dir: Path): + self.cache_dir: Path = cache_dir @observe_lifecycle("compiler_compile") def compile( diff --git a/tests/feature_tests/test_restart_analysis_cache.py b/tests/feature_tests/test_restart_analysis_cache.py index e2e9e36..8e69c80 100644 --- a/tests/feature_tests/test_restart_analysis_cache.py +++ b/tests/feature_tests/test_restart_analysis_cache.py @@ -42,7 +42,7 @@ def test_restart_analysis_cache_handle_marks_and_skips_artifact_load(tmp_path: P assert p2.returncode == 0, f"worker2 failed\nstdout:\n{p2.stdout}\nstderr:\n{p2.stderr}" assert "too many values to unpack" not in p2.stderr - cache_files = list(cache_root.rglob("magi_cache_indices.py")) + cache_files = list(cache_root.rglob("subgraph_indices.py")) assert cache_files, "no cache file generated" any_marked = False for cache_file in cache_files: From 613437d111b5e72c0e990b1d4ee7352d9401cb83 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 14 Apr 2026 16:25:32 +0800 Subject: [PATCH 8/9] [chores] rerun ci --- magi_compiler/magi_backend/magi_backend.py | 4 ++-- magi_compiler/magi_backend/piecewise_compiler.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index ba69b0c..c2a76e7 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -104,7 +104,7 @@ def compile_context(self, runtime_shape: int | None = None, graph_index: int | N with pass_context(runtime_shape, graph_index): yield - def initialize_cache(self, cache_dir: Path, prefix: str | None = None): + def initialize_cache(self, cache_dir: Path): """ Initialize the cache directory for the compiler. @@ -139,7 +139,7 @@ def initialize_cache(self, cache_dir: Path, prefix: str | None = None): cache_handle = CacheHandle(*handle) self.cache[cache_entry] = cache_handle - self.compiler.initialize_cache(cache_dir=self.cache_dir, prefix=prefix) + self.compiler.initialize_cache(cache_dir=self.cache_dir) def save_to_file(self): if self.disable_cache: diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index cc64738..81d6126 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -63,7 +63,7 @@ class CompilerInterface: name: str @abstractmethod - def initialize_cache(self, cache_dir: Path, prefix: str = ""): + def initialize_cache(self, cache_dir: Path): """ when the MagiCompiler process uses `cache_dir` as the cache directory, the compiler should initialize itself with the cache directory, From 7d11d3bc1e3763b65b75a946013ac797700815e4 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 14 Apr 2026 16:42:35 +0800 Subject: [PATCH 9/9] [chores] rename index file --- magi_compiler/magi_backend/magi_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index c2a76e7..c00bad2 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -111,7 +111,7 @@ def initialize_cache(self, cache_dir: Path): The organization of the cache directory is as follows: cache_dir=/path/to/magi_cache/model_{idx}[_{tag}]_rank_{rank}/hash_str/[prefix/] inside cache_dir, there will be: - - magi_cache_indices.py + - subgraph_indices.py - computation_graph.py for multiple prefixes, they can share the same base cache dir of @@ -120,7 +120,7 @@ def initialize_cache(self, cache_dir: Path): """ self.cache_dir: Path = cache_dir - self.cache_file_path: Path = cache_dir / "magi_cache_indices.py" + self.cache_file_path: Path = cache_dir / "subgraph_indices.py" if self.disable_cache: magi_logger.info("MagiCompiler's cache is disabled.")