diff --git a/magi_compiler/_api.py b/magi_compiler/_api.py index 442b67c..0487671 100644 --- a/magi_compiler/_api.py +++ b/magi_compiler/_api.py @@ -25,7 +25,7 @@ from torch import nn from torch._dynamo.symbolic_convert import InliningInstructionTranslator -from magi_compiler.config import cache_dump_path, debug_dump_path +from magi_compiler.config import debug_dump_path, inductor_cache_dump_path from magi_compiler.cuda.cudart import pin_memory_in_place from magi_compiler.magi_backend.magi_compiler_base import MagiCompileState from magi_compiler.utils import compilation_counter, envs, magi_logger @@ -397,7 +397,7 @@ def _compilation_context(state: MagiCompileState): from .magi_depyf.inspect import explain_compilation _debug_dump_path = debug_dump_path(state.compile_config.cache_root_dir, state.model_idx, state.model_tag) - _cache_dump_path = cache_dump_path(state.compile_config.cache_root_dir, state.model_idx, state.model_tag) + _inductor_cache_dump_path = inductor_cache_dump_path(state.compile_config.cache_root_dir) with ( _isolated_dynamo_config(), @@ -406,7 +406,7 @@ def _compilation_context(state: MagiCompileState): patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False), patch.object(torch._dynamo.config, "enable_aot_compile", True), _hijack_inline_call_to_collect_traced_files(state), - patch.dict(os.environ, {"TORCHINDUCTOR_CACHE_DIR": (_cache_dump_path / "inductor_cache").as_posix()}), + patch.dict(os.environ, {"TORCHINDUCTOR_CACHE_DIR": (_inductor_cache_dump_path).as_posix()}), explain_compilation(_debug_dump_path.as_posix()), ): yield diff --git a/magi_compiler/config.py b/magi_compiler/config.py index 902014e..c5edf38 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -258,15 +258,19 @@ def model_rank_dir_name(model_idx: int, model_tag: str | None) -> str: return f"model_{model_idx}_rank_{rank}" -def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None) -> Path: +def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = None) -> Path: from datetime import datetime run_id = datetime.now().strftime("run_%Y%m%d_%H%M%S") return Path(cache_root_dir) / "magi_depyf" / run_id / model_rank_dir_name(model_idx, model_tag) -def cache_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None) -> Path: - return Path(cache_root_dir) / "torch_compile_cache" / model_rank_dir_name(model_idx, model_tag) +def magi_cache_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = None) -> Path: + return Path(cache_root_dir) / "magi_cache" / model_rank_dir_name(model_idx, model_tag) + + +def inductor_cache_dump_path(cache_root_dir: str, model_idx: int | None = None, model_tag: str | None = None) -> Path: + return Path(cache_root_dir) / "inductor_cache" def inductor_compile_config_hash(inductor_compile_config: dict[str, Any]) -> str: diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 631f46d..c00bad2 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -30,7 +30,7 @@ from torch._guards import detect_fake_mode import magi_compiler.utils.envs as envs -from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, cache_dump_path, inductor_compile_config_hash +from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, inductor_compile_config_hash, magi_cache_dump_path from magi_compiler.magi_depyf.timeline import observe_lifecycle, observe_lifecycle_context from magi_compiler.offload.offload_warpper import OffloadWrapper from magi_compiler.passes import CustomJointGraphPartitionFn, FullGraphPassManager, PostGradPassManager, pass_context @@ -104,23 +104,23 @@ def compile_context(self, runtime_shape: int | None = None, graph_index: int | N with pass_context(runtime_shape, graph_index): yield - def initialize_cache(self, cache_dir: Path, prefix: str = ""): + def initialize_cache(self, cache_dir: Path): """ Initialize the cache directory for the compiler. The organization of the cache directory is as follows: - cache_dir=/path/to/torch_compile_cache/rank_i_j/hash_str/prefix/ + cache_dir=/path/to/magi_cache/model_{idx}[_{tag}]_rank_{rank}/hash_str/[prefix/] inside cache_dir, there will be: - - magi_compile_cache.py + - subgraph_indices.py - computation_graph.py for multiple prefixes, they can share the same base cache dir of - /path/to/torch_compile_cache/rank_i_j/hash_str/ to store some + /path/to/magi_cache/model_{idx}[_{tag}]_rank_{rank}/hash_str/ to store some common compilation artifacts. """ self.cache_dir: Path = cache_dir - self.cache_file_path: Path = cache_dir / "magi_compile_cache.py" + self.cache_file_path: Path = cache_dir / "subgraph_indices.py" if self.disable_cache: magi_logger.info("MagiCompiler's cache is disabled.") @@ -139,7 +139,7 @@ def initialize_cache(self, cache_dir: Path, prefix: str = ""): cache_handle = CacheHandle(*handle) self.cache[cache_entry] = cache_handle - self.compiler.initialize_cache(cache_dir=self.cache_dir, prefix=prefix) + self.compiler.initialize_cache(cache_dir=self.cache_dir) def save_to_file(self): if self.disable_cache: @@ -482,13 +482,12 @@ def _init_cache(self) -> str: ] ) - # Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/{model_tag}/ - self.local_cache_dir: Path = ( - cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key / self.model_tag + # Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/ + self.local_magi_cache_path: Path = ( + magi_cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key ) - self.local_cache_dir.mkdir(parents=True, exist_ok=True) - - self.compiler_manager.initialize_cache(self.local_cache_dir, self.model_tag) + self.local_magi_cache_path.mkdir(parents=True, exist_ok=True) + self.compiler_manager.initialize_cache(self.local_magi_cache_path) @observe_lifecycle("graph_split") def _split_graph(self, graph: fx.GraphModule) -> tuple[fx.GraphModule, list[SplitItem]]: diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index a7d5093..81d6126 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -19,7 +19,6 @@ from typing import Any import torch -import torch._inductor.compile_fx import torch.fx as fx from magi_compiler.magi_depyf.timeline import observe_lifecycle @@ -64,7 +63,7 @@ class CompilerInterface: name: str @abstractmethod - def initialize_cache(self, cache_dir: Path, prefix: str = ""): + def initialize_cache(self, cache_dir: Path): """ when the MagiCompiler process uses `cache_dir` as the cache directory, the compiler should initialize itself with the cache directory, @@ -160,7 +159,7 @@ def hash(self) -> str: factors: list[Any] = [CacheBase.get_system(), torch_key()] return compute_hash(factors) - def initialize_cache(self, cache_dir: Path, prefix: str = ""): + def initialize_cache(self, cache_dir: Path): self.cache_dir: Path = cache_dir @observe_lifecycle("compiler_compile") diff --git a/tests/feature_tests/test_restart_analysis_cache.py b/tests/feature_tests/test_restart_analysis_cache.py index f470223..8e69c80 100644 --- a/tests/feature_tests/test_restart_analysis_cache.py +++ b/tests/feature_tests/test_restart_analysis_cache.py @@ -42,7 +42,7 @@ def test_restart_analysis_cache_handle_marks_and_skips_artifact_load(tmp_path: P assert p2.returncode == 0, f"worker2 failed\nstdout:\n{p2.stdout}\nstderr:\n{p2.stderr}" assert "too many values to unpack" not in p2.stderr - cache_files = list(cache_root.rglob("magi_compile_cache.py")) + cache_files = list(cache_root.rglob("subgraph_indices.py")) assert cache_files, "no cache file generated" any_marked = False for cache_file in cache_files: diff --git a/tests/torch_native_tests/test_inductor_cache_reuse.py b/tests/torch_native_tests/test_inductor_cache_reuse.py new file mode 100644 index 0000000..a7830b2 --- /dev/null +++ b/tests/torch_native_tests/test_inductor_cache_reuse.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TorchInductor cache.""" + +import os +import tempfile +from dataclasses import dataclass +from unittest.mock import patch + +import pytest +import torch +import torch.nn.functional as F +from torch._dynamo.utils import counters + +from tests.model_definition import TransformerConfig, create_transformer_model_with_initial_params + + +@dataclass(frozen=True) +class CounterDelta: + """Store cache counter deltas for one test run.""" + + autograd_hit: int + autograd_miss: int + inductor_hit: int + inductor_miss: int + + +# NOTE: may be different on different machines, and this config is suitable for CI machine +EXPECTED = { + "train": CounterDelta(autograd_hit=31, autograd_miss=1, inductor_hit=0, inductor_miss=0), + "eval": CounterDelta(autograd_hit=31, autograd_miss=1, inductor_hit=0, inductor_miss=0), +} + + +@pytest.fixture(autouse=True) +def model_with_clean_cache(): + """Build a transformer model with an isolated Magi cache directory.""" + + with tempfile.TemporaryDirectory() as tmp_dir: + with patch.dict(os.environ, {"MAGI_COMPILE_CACHE_ROOT_DIR": tmp_dir}): + transformer_config = TransformerConfig( + vocab_size=128256, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + max_position_embeddings=8192, + rms_norm_eps=1e-6, + params_dtype=torch.bfloat16, + ) + model, _ = create_transformer_model_with_initial_params(transformer_config, device="cuda") + yield model, transformer_config + + +@pytest.fixture +def cache_snapshot(): + """Return a callable that snapshots current Dynamo cache counters.""" + + def _snapshot() -> dict[str, dict[str, int]]: + return { + "autograd": { + "hit": counters["aot_autograd"]["autograd_cache_hit"], + "miss": counters["aot_autograd"]["autograd_cache_miss"], + }, + "inductor": { + "hit": counters["inductor"]["inductor_cache_hit"], + "miss": counters["inductor"]["inductor_cache_miss"], + }, + } + + return _snapshot + + +def _device() -> torch.device: + """Select CUDA when available, otherwise CPU.""" + + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _delta(before: dict, after: dict) -> CounterDelta: + """Compute per-backend cache hit/miss deltas between two snapshots.""" + + return CounterDelta( + autograd_hit=after["autograd"]["hit"] - before["autograd"]["hit"], + autograd_miss=after["autograd"]["miss"] - before["autograd"]["miss"], + inductor_hit=after["inductor"]["hit"] - before["inductor"]["hit"], + inductor_miss=after["inductor"]["miss"] - before["inductor"]["miss"], + ) + + +def _assert_delta(actual: CounterDelta, expected: CounterDelta): + """Assert cache counter deltas against expected values.""" + + assert actual == expected, f"counter delta mismatch, got={actual}, expected={expected}" + + +class TestTorchInductorCache: + """Validate TorchInductor cache behavior in train/eval flows.""" + + def test_training_mode(self, model_with_clean_cache, cache_snapshot): + """Verify cache counter deltas for repeated training iterations.""" + model, model_config = model_with_clean_cache + model = model.to(_device()) + model.train() + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2) + max_batch_size = 2 + num_epochs = 10 + + before = cache_snapshot() + for _ in range(num_epochs): + dummy_input = torch.randint( + 0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device() + ) + dummy_label = torch.randint( + 0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device() + ) + + optimizer.zero_grad() + logits = model.forward(dummy_input) + loss = F.cross_entropy(logits.view(-1, model_config.vocab_size), dummy_label.view(-1)) + loss.backward() + optimizer.step() + + after = cache_snapshot() + _assert_delta(_delta(before, after), EXPECTED["train"]) + + def test_evaluation_mode(self, model_with_clean_cache, cache_snapshot): + """Verify cache counter deltas for a no-grad evaluation forward.""" + + model, model_config = model_with_clean_cache + model = model.to(_device()) + model.eval() + + max_batch_size = 2 + before = cache_snapshot() + with torch.no_grad(): + dummy_input = torch.randint( + 0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device() + ) + model.forward(dummy_input) + after = cache_snapshot() + + _assert_delta(_delta(before, after), EXPECTED["eval"])