Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
**P**addle **A**utomatically **Diff** precision toolkits.


## 最近更新(latest 8.21
## 最近更新(latest 9.2

### 使用单行命令对齐(当前仅支持前向对齐
### 使用单行命令对齐(支持前反向对齐

运行命令前,请运行 `python -m padiff.cli -h` 获取更详细的参数说明。

Expand All @@ -18,6 +18,8 @@ python -m padiff.cli \
--pd_cmd "python paddle_project/run.py" \
--pt_model_name "pt_model" \
--pd_model_name "pd_model" \
--pt_optim_name "pt_optimizer" \
--pd_optim_name "pd_optimizer" \
--log_dir "./padiff_log" \
--align_depth 1 \
--single_step_mode "forward" \
Expand All @@ -37,10 +39,12 @@ yaml 文件样例

```python
# padiff_config.yaml
pt_cmd: "python transformer4sr/train_transformer_ori.py"
pd_cmd: "python paddle_project/train_transformer_ori.py"
pt_model_name: "transformer"
pd_model_name: "transformer"
pt_cmd: "python transformer4sr/train_transformer.py"
pd_cmd: "python paddle_project/train_transformer.py"
pt_model_name: "transformer_pt"
pd_model_name: "transformer_pd"
pt_optim_name: "optimizer_pt"
pd_optim_name: "optimizer_pd"
log_dir: "./padiff_log"
align_depth: 2
single_step_mode: "forward"
Expand Down
68 changes: 68 additions & 0 deletions padiff/abstracts/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import contextvars
from typing import Dict

# --- Core state management class ---
# This is an internal state shared by all Guards and should be placed first
Expand Down Expand Up @@ -56,6 +57,66 @@ def _traversal(node, bucket):
_global_report = None


class _CallsContext:
"""
A global context for managing forward call counts across multiple PaDiffGuard invocations.
This ensures that max_calls is respected even when PaDiffGuard is re-entered.
"""

_state = contextvars.ContextVar("_calls_context_state", default=None)

def __init__(self):
self._state.set({"count": 0, "limit": 0, "active": False})

@property
def state(self) -> Dict:
s = self._state.get()
if s is None:
s = {"count": 0, "limit": 0, "active": False}
self._state.set(s)
return s

def set_limit(self, limit: int):
self.state["limit"] = limit
self.state["active"] = True

def increment(self) -> int:
if not self.state["active"]:
return 0
self.state["count"] += 1
return self.state["count"]

def is_exceeded(self) -> bool:
if not self.state["active"]:
return False
return self.state["count"] >= self.state["limit"]

def reset(self):
self.state["count"] = 0
self.state["limit"] = 0
self.state["active"] = False

@classmethod
def get_current(cls) -> "_CallsContext":
return cls()


_calls_context = None


class _CallsComplete(Exception):
"""A private exception used by PaDiffGuard to interrupt execution.

This exception is raised by the internal calls_hook when the maximum number
of calls (max_calls) has been reached. It is caught by PaDiffGuard
to exit the context manager gracefully.
"""

def __init__(self, message="CallsComplete: maximum number of forward calls reached."):
self.message = message
super().__init__(self.message)


# --- Public utility functions for external calls ---
# These are "accessors" to the Guard's internal state and should be placed after the Guard it depends on

Expand Down Expand Up @@ -86,3 +147,10 @@ def find_base_report_node(net_id, step_idx):
raise RuntimeError(f"Index out of range: net_id={net_id}, step_idx={step_idx}, list length={len(node_list)}")

return _context.base[net_id][step_idx]


def get_calls_context() -> _CallsContext:
global _calls_context
if _calls_context is None:
_calls_context = _CallsContext()
return _calls_context
194 changes: 136 additions & 58 deletions padiff/abstracts/hooks/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
import json
import os
import sys
import functools
import inspect
import paddle
import torch

from ...utils import set_seed, logger
from .base import _context, _current_report
from ...utils import set_seed, logger, wrap_optimizer_step
from .base import _context, _current_report, _CallsComplete, current_report, get_calls_context
from .hook import register_hooker
from ..proxy import create_model
from ...tools import dump_report
from ...tools import load_first_input_from_dump, load_init_weights_from_dump


_global_report = None
Expand Down Expand Up @@ -110,24 +114,108 @@ def AlignmentGuard(model, seed=42):
pass


class _CallsComplete(Exception):
"""A private exception used by PaDiffGuard to interrupt execution.

This exception is raised by the internal calls_hook when the maximum number
of calls (max_calls) has been reached. It is caught by PaDiffGuard
to exit the context manager gracefully.
@contextlib.contextmanager
def InputCaptureGuard(model, base_dump_path=None, framework=None, load_first_inputs=False):
"""
Context manager to capture or inject first input. Cannot be implemented as a hook
because hook registered through 'register_forward_pre_hook' can only capture args but not kwargs.
"""
if hasattr(model, "_padiff_input_captured"):
yield
return

original_forward = model.forward

@functools.wraps(original_forward)
def tracked_forward(*args, **kwargs):
report = current_report()
if not args and not kwargs:
logger.warning("Skipped capturing or loading input: both args and kwargs are empty.")
return original_forward(*args, **kwargs)

if load_first_inputs and base_dump_path and framework and not hasattr(report, "_inputs_loaded"):
assert framework is not None, "'framework' must be setted if 'load_first_inputs' is True"
logger.info("Loading first input from dump")
loaded_inputs = load_first_input_from_dump(base_dump_path, framework)
if loaded_inputs is not None:
loaded_args, loaded_kwargs = loaded_inputs
try:
sig = inspect.signature(original_forward)
valid_arg_names = set(sig.parameters.keys())
except Exception as e:
logger.warning(f"Failed to get forward signature: {e}. Using all keys.")
valid_arg_names = set(loaded_kwargs.keys())

filtered_kwargs = {k: v for k, v in loaded_kwargs.items() if k in valid_arg_names}
dropped_keys = set(loaded_kwargs.keys()) - set(filtered_kwargs.keys())
if dropped_keys:
logger.debug(f"Dropped keys not in forward signature: {dropped_keys}")

final_kwargs = {**kwargs, **filtered_kwargs}
report._inputs_loaded = True
return original_forward(*loaded_args, **final_kwargs)

if report and not hasattr(report, "first_input_captured"):

def serialize(x):
if isinstance(x, (paddle.Tensor, torch.Tensor)):
return ("Tensor", x.detach().cpu().numpy())
elif isinstance(x, dict):
return ("dict", {k: serialize(v) for k, v in x.items()})
elif isinstance(x, (list, tuple)):
return (type(x).__name__, [serialize(item) for item in x])
else:
return (type(x).__name__, str(x))

serialized = {
"args": [serialize(x) for x in args] if args else [],
"kwargs": {k: serialize(v) for k, v in kwargs.items()},
}
if serialized["args"] or serialized["kwargs"]:
report.first_input = serialized
report.first_input_captured = True
logger.info(f"Captured full input: args={len(args)}, kwargs={list(kwargs.keys())}")
else:
logger.warning("Skipped capturing input: serialized input is empty.")

return original_forward(*args, **kwargs)

model.forward = tracked_forward
model._padiff_input_captured = True

try:
yield
finally:
model.forward = original_forward


@contextlib.contextmanager
def MaxCallsGuard(max_calls: int, model):
if max_calls <= 0:
yield
return

calls_context = get_calls_context()

def pre_hook(m, input):
if calls_context.is_exceeded():
logger.warning(f"PaDiffGuard: max_calls={max_calls} reached, raising _CallsComplete")
raise _CallsComplete()
count = calls_context.increment()
logger.info(f"MaxCallsGuard: forward start calling #{count}")

def __init__(self, message="CallsComplete: maximum number of forward calls reached."):
self.message = message
super().__init__(self.message)
handle = model.register_forward_pre_hook(pre_hook)
try:
yield
finally:
handle.remove()


@contextlib.contextmanager
def PaDiffGuard(
model,
optimizer=None,
name="model",
auto_dump=True,
align_depth="inf",
single_step_mode=None, # None, "forward", "backward"
load_init_weights=False,
Expand All @@ -139,73 +227,63 @@ def PaDiffGuard(
black_list=None,
keys_mapping=None,
):
# create_model
if not hasattr(model, "report"):
proxy_model = create_model(model, name=name)
else:
proxy_model = model

logger.debug(f"PaDiffGuard: depth of alignment is {align_depth}.")
proxy_model.marker.update_black_list_with_depth(align_depth)
proxy_model.update_black_list_with_name(black_list)
# moniter number of calls
calls_context = get_calls_context()
reset_flag = calls_context.state["count"] == 0

if load_init_weights or load_first_inputs or (single_step_mode is not None):
assert (
base_dump_path is not None
), "'base_dump_path' should not be None, when loading of init weights or/and first inputs is needed or using single_step mode"
if reset_flag:
# set max calls
calls_context.set_limit(max_calls)

# load init weights
if load_init_weights:
from ...tools import load_init_weights_from_dump
logger.info(f"PaDiffGuard: creating proxy model.")
proxy_model = create_model(model, name=name, reset_dir=reset_flag)
model._padiff_proxy = proxy_model

load_init_weights_from_dump(base_dump_path, proxy_model, keys_mapping)
if optimizer is not None and not hasattr(optimizer, "_padiff_proxy_model"):
logger.info(f"PaDiffGuard: wrapping optimizer.step().")
optimizer._padiff_proxy_model = proxy_model
wrap_optimizer_step(optimizer)

# load first inputs
if load_first_inputs:
from ...tools import load_first_input_from_dump
if load_init_weights or load_first_inputs or (single_step_mode is not None):
assert (
base_dump_path is not None
), "'base_dump_path' should not be None, when loading of init weights or/and first inputs is needed or using single_step mode"

assert framework is not None, "'framework' must be setted if 'load_first_inputs' is True"
loaded_inputs = load_first_input_from_dump(base_dump_path, framework)
proxy_model.report._loaded_inputs = loaded_inputs
# load init weights
if load_init_weights:
load_init_weights_from_dump(base_dump_path, proxy_model, keys_mapping)

# moniter number of calls
calls_count = 0
logger.debug(f"PaDiffGuard: depth of alignment is {align_depth}.")
proxy_model.marker.update_black_list_with_depth(align_depth)
proxy_model.update_black_list_with_name(black_list)

def calls_hook(m, input, output):
nonlocal calls_count
calls_count += 1
logger.debug(f"PaDiffGuard: forward call #{calls_count}")
if calls_count >= max_calls:
logger.warning(f"PaDiffGuard: max_calls={max_calls} reached, raising _CallsComplete")
raise _CallsComplete()
else:
proxy_model = model._padiff_proxy

try:
# set hooks
with contextlib.ExitStack() as stack:
# moniter number of calls
if max_calls > 0:
stack.enter_context(MaxCallsGuard(max_calls, proxy_model))

stack.enter_context(AlignmentGuard(proxy_model, seed=seed))
stack.enter_context(report_guard(proxy_model.report))
# load first inputs
if reset_flag:
stack.enter_context(InputCaptureGuard(proxy_model.model, base_dump_path, framework, load_first_inputs))

if single_step_mode is not None:
stack.enter_context(SingleStepGuard(single_step_mode, base_dump_path))

stack.enter_context(register_hooker(proxy_model))
count_handle = proxy_model.register_forward_post_hook(calls_hook)
stack.callback(count_handle.remove)

yield model

# dump report
if auto_dump:
dump_report(proxy_model, proxy_model.dump_path)

except _CallsComplete:
logger.info(f"PaDiffGuard: calls completed ({calls_count}/{max_calls})")
# dump report
if auto_dump:
try:
dump_report(proxy_model, proxy_model.dump_path)
except Exception as e:
logger.error(e)
# dump
proxy_model.dump_report(proxy_model.dump_path)
proxy_model.dump_weights(proxy_model.dump_path)

sys.exit(0)

Expand Down
Loading
Loading