Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions magi_compiler/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch import nn
from torch._dynamo.symbolic_convert import InliningInstructionTranslator

from magi_compiler.config import cache_dump_path, debug_dump_path
from magi_compiler.config import debug_dump_path, inductor_cache_dump_path
from magi_compiler.cuda.cudart import pin_memory_in_place
from magi_compiler.magi_backend.magi_compiler_base import MagiCompileState
from magi_compiler.utils import compilation_counter, envs, magi_logger
Expand Down Expand Up @@ -397,7 +397,7 @@ def _compilation_context(state: MagiCompileState):
from .magi_depyf.inspect import explain_compilation

_debug_dump_path = debug_dump_path(state.compile_config.cache_root_dir, state.model_idx, state.model_tag)
_cache_dump_path = cache_dump_path(state.compile_config.cache_root_dir, state.model_idx, state.model_tag)
_inductor_cache_dump_path = inductor_cache_dump_path(state.compile_config.cache_root_dir)

with (
_isolated_dynamo_config(),
Expand All @@ -406,7 +406,7 @@ def _compilation_context(state: MagiCompileState):
patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False),
patch.object(torch._dynamo.config, "enable_aot_compile", True),
_hijack_inline_call_to_collect_traced_files(state),
patch.dict(os.environ, {"TORCHINDUCTOR_CACHE_DIR": (_cache_dump_path / "inductor_cache").as_posix()}),
patch.dict(os.environ, {"TORCHINDUCTOR_CACHE_DIR": (_inductor_cache_dump_path).as_posix()}),
explain_compilation(_debug_dump_path.as_posix()),
):
yield
Expand Down
10 changes: 7 additions & 3 deletions magi_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,19 @@ def model_rank_dir_name(model_idx: int, model_tag: str | None) -> str:
return f"model_{model_idx}_rank_{rank}"


def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None) -> Path:
def debug_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = None) -> Path:
from datetime import datetime

run_id = datetime.now().strftime("run_%Y%m%d_%H%M%S")
return Path(cache_root_dir) / "magi_depyf" / run_id / model_rank_dir_name(model_idx, model_tag)


def cache_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None) -> Path:
return Path(cache_root_dir) / "torch_compile_cache" / model_rank_dir_name(model_idx, model_tag)
def magi_cache_dump_path(cache_root_dir: str, model_idx: int, model_tag: str | None = None) -> Path:
return Path(cache_root_dir) / "magi_cache" / model_rank_dir_name(model_idx, model_tag)


def inductor_cache_dump_path(cache_root_dir: str, model_idx: int | None = None, model_tag: str | None = None) -> Path:
return Path(cache_root_dir) / "inductor_cache"


def inductor_compile_config_hash(inductor_compile_config: dict[str, Any]) -> str:
Expand Down
25 changes: 12 additions & 13 deletions magi_compiler/magi_backend/magi_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch._guards import detect_fake_mode

import magi_compiler.utils.envs as envs
from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, cache_dump_path, inductor_compile_config_hash
from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode, inductor_compile_config_hash, magi_cache_dump_path
from magi_compiler.magi_depyf.timeline import observe_lifecycle, observe_lifecycle_context
from magi_compiler.offload.offload_warpper import OffloadWrapper
from magi_compiler.passes import CustomJointGraphPartitionFn, FullGraphPassManager, PostGradPassManager, pass_context
Expand Down Expand Up @@ -104,23 +104,23 @@ def compile_context(self, runtime_shape: int | None = None, graph_index: int | N
with pass_context(runtime_shape, graph_index):
yield

def initialize_cache(self, cache_dir: Path, prefix: str = ""):
def initialize_cache(self, cache_dir: Path):
"""
Initialize the cache directory for the compiler.

The organization of the cache directory is as follows:
cache_dir=/path/to/torch_compile_cache/rank_i_j/hash_str/prefix/
cache_dir=/path/to/magi_cache/model_{idx}[_{tag}]_rank_{rank}/hash_str/[prefix/]
inside cache_dir, there will be:
- magi_compile_cache.py
- subgraph_indices.py
- computation_graph.py

for multiple prefixes, they can share the same base cache dir of
/path/to/torch_compile_cache/rank_i_j/hash_str/ to store some
/path/to/magi_cache/model_{idx}[_{tag}]_rank_{rank}/hash_str/ to store some
common compilation artifacts.
"""

self.cache_dir: Path = cache_dir
self.cache_file_path: Path = cache_dir / "magi_compile_cache.py"
self.cache_file_path: Path = cache_dir / "subgraph_indices.py"

if self.disable_cache:
magi_logger.info("MagiCompiler's cache is disabled.")
Expand All @@ -139,7 +139,7 @@ def initialize_cache(self, cache_dir: Path, prefix: str = ""):
cache_handle = CacheHandle(*handle)
self.cache[cache_entry] = cache_handle

self.compiler.initialize_cache(cache_dir=self.cache_dir, prefix=prefix)
self.compiler.initialize_cache(cache_dir=self.cache_dir)

def save_to_file(self):
if self.disable_cache:
Expand Down Expand Up @@ -482,13 +482,12 @@ def _init_cache(self) -> str:
]
)

# Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/{model_tag}/
self.local_cache_dir: Path = (
cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key / self.model_tag
# Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/
self.local_magi_cache_path: Path = (
magi_cache_dump_path(self.compile_config.cache_root_dir, self.model_idx, self.model_tag) / hash_key
)
self.local_cache_dir.mkdir(parents=True, exist_ok=True)

self.compiler_manager.initialize_cache(self.local_cache_dir, self.model_tag)
self.local_magi_cache_path.mkdir(parents=True, exist_ok=True)
self.compiler_manager.initialize_cache(self.local_magi_cache_path)

@observe_lifecycle("graph_split")
def _split_graph(self, graph: fx.GraphModule) -> tuple[fx.GraphModule, list[SplitItem]]:
Expand Down
5 changes: 2 additions & 3 deletions magi_compiler/magi_backend/piecewise_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any

import torch
import torch._inductor.compile_fx
import torch.fx as fx

from magi_compiler.magi_depyf.timeline import observe_lifecycle
Expand Down Expand Up @@ -64,7 +63,7 @@ class CompilerInterface:
name: str

@abstractmethod
def initialize_cache(self, cache_dir: Path, prefix: str = ""):
def initialize_cache(self, cache_dir: Path):
"""
when the MagiCompiler process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
Expand Down Expand Up @@ -160,7 +159,7 @@ def hash(self) -> str:
factors: list[Any] = [CacheBase.get_system(), torch_key()]
return compute_hash(factors)

def initialize_cache(self, cache_dir: Path, prefix: str = ""):
def initialize_cache(self, cache_dir: Path):
self.cache_dir: Path = cache_dir

@observe_lifecycle("compiler_compile")
Expand Down
2 changes: 1 addition & 1 deletion tests/feature_tests/test_restart_analysis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_restart_analysis_cache_handle_marks_and_skips_artifact_load(tmp_path: P
assert p2.returncode == 0, f"worker2 failed\nstdout:\n{p2.stdout}\nstderr:\n{p2.stderr}"
assert "too many values to unpack" not in p2.stderr

cache_files = list(cache_root.rglob("magi_compile_cache.py"))
cache_files = list(cache_root.rglob("subgraph_indices.py"))
assert cache_files, "no cache file generated"
any_marked = False
for cache_file in cache_files:
Expand Down
157 changes: 157 additions & 0 deletions tests/torch_native_tests/test_inductor_cache_reuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test TorchInductor cache."""

import os
import tempfile
from dataclasses import dataclass
from unittest.mock import patch

import pytest
import torch
import torch.nn.functional as F
from torch._dynamo.utils import counters

from tests.model_definition import TransformerConfig, create_transformer_model_with_initial_params


@dataclass(frozen=True)
class CounterDelta:
"""Store cache counter deltas for one test run."""

autograd_hit: int
autograd_miss: int
inductor_hit: int
inductor_miss: int


# NOTE: may be different on different machines, and this config is suitable for CI machine
EXPECTED = {
"train": CounterDelta(autograd_hit=31, autograd_miss=1, inductor_hit=0, inductor_miss=0),
"eval": CounterDelta(autograd_hit=31, autograd_miss=1, inductor_hit=0, inductor_miss=0),
}


@pytest.fixture(autouse=True)
def model_with_clean_cache():
"""Build a transformer model with an isolated Magi cache directory."""

with tempfile.TemporaryDirectory() as tmp_dir:
with patch.dict(os.environ, {"MAGI_COMPILE_CACHE_ROOT_DIR": tmp_dir}):
transformer_config = TransformerConfig(
vocab_size=128256,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
max_position_embeddings=8192,
rms_norm_eps=1e-6,
params_dtype=torch.bfloat16,
)
model, _ = create_transformer_model_with_initial_params(transformer_config, device="cuda")
yield model, transformer_config


@pytest.fixture
def cache_snapshot():
"""Return a callable that snapshots current Dynamo cache counters."""

def _snapshot() -> dict[str, dict[str, int]]:
return {
"autograd": {
"hit": counters["aot_autograd"]["autograd_cache_hit"],
"miss": counters["aot_autograd"]["autograd_cache_miss"],
},
"inductor": {
"hit": counters["inductor"]["inductor_cache_hit"],
"miss": counters["inductor"]["inductor_cache_miss"],
},
}

return _snapshot


def _device() -> torch.device:
"""Select CUDA when available, otherwise CPU."""

return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _delta(before: dict, after: dict) -> CounterDelta:
"""Compute per-backend cache hit/miss deltas between two snapshots."""

return CounterDelta(
autograd_hit=after["autograd"]["hit"] - before["autograd"]["hit"],
autograd_miss=after["autograd"]["miss"] - before["autograd"]["miss"],
inductor_hit=after["inductor"]["hit"] - before["inductor"]["hit"],
inductor_miss=after["inductor"]["miss"] - before["inductor"]["miss"],
)


def _assert_delta(actual: CounterDelta, expected: CounterDelta):
"""Assert cache counter deltas against expected values."""

assert actual == expected, f"counter delta mismatch, got={actual}, expected={expected}"


class TestTorchInductorCache:
"""Validate TorchInductor cache behavior in train/eval flows."""

def test_training_mode(self, model_with_clean_cache, cache_snapshot):
"""Verify cache counter deltas for repeated training iterations."""
model, model_config = model_with_clean_cache
model = model.to(_device())
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)
max_batch_size = 2
num_epochs = 10

before = cache_snapshot()
for _ in range(num_epochs):
dummy_input = torch.randint(
0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device()
)
dummy_label = torch.randint(
0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device()
)

optimizer.zero_grad()
logits = model.forward(dummy_input)
loss = F.cross_entropy(logits.view(-1, model_config.vocab_size), dummy_label.view(-1))
loss.backward()
optimizer.step()

after = cache_snapshot()
_assert_delta(_delta(before, after), EXPECTED["train"])

def test_evaluation_mode(self, model_with_clean_cache, cache_snapshot):
"""Verify cache counter deltas for a no-grad evaluation forward."""

model, model_config = model_with_clean_cache
model = model.to(_device())
model.eval()

max_batch_size = 2
before = cache_snapshot()
with torch.no_grad():
dummy_input = torch.randint(
0, model_config.vocab_size, (max_batch_size, model_config.max_position_embeddings), device=_device()
)
model.forward(dummy_input)
after = cache_snapshot()

_assert_delta(_delta(before, after), EXPECTED["eval"])