Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ python -m padiff.cli \
或将命令写入 .yaml 文件后,运行

```sh
python -m padiff.cli --padiff_config.yaml
python -m padiff.cli --config padiff_config.yaml
```

yaml 文件样例
Expand Down
106 changes: 66 additions & 40 deletions padiff/abstracts/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import sys
import contextlib
import functools
from functools import partial
import inspect

import numpy as np
import paddle
Expand All @@ -26,7 +28,6 @@
flatten,
for_each_grad_tensor,
logger,
to_sequence,
map_structure,
)
from .base import current_report, find_base_report_node, single_step_state
Expand All @@ -44,8 +45,8 @@ def register_hooker(model):

# register model-level hooks
handle_init = model.register_forward_pre_hook(partial(init_weights_hook))
handle_input = model.register_forward_pre_hook(partial(first_input_hook))
remove_handles.extend([handle_init, handle_input])
remove_handles.append(handle_init)
inject_input_capture(models[0].model)

# register layer-level hooks
for mod in models:
Expand Down Expand Up @@ -87,43 +88,6 @@ def init_weights_hook(model, input):
return None


def first_input_hook(model, input):
report = current_report()
if report is None:
logger.debug("first_input_hook: current_report is None")
return None

if hasattr(report, "_loaded_inputs") and report._loaded_inputs is not None:
logger.debug("first_input_hook: loading first input")
loaded_inputs = report._loaded_inputs
if isinstance(loaded_inputs, list):
return tuple(loaded_inputs)
return loaded_inputs

if not hasattr(report, "first_input_captured"):
logger.debug("first_input_hook: capturing first input")

def serialize(x):
if isinstance(x, (paddle.Tensor, torch.Tensor)):
return ("Tensor", x.detach().cpu().numpy())
elif isinstance(x, dict):
return ("dict", {k: serialize(v) for k, v in x.items()})
elif isinstance(x, (list, tuple)):
return (type(x).__name__, [serialize(item) for item in x])
else:
return (type(x).__name__, str(x))

try:
serialized = [serialize(x) for x in to_sequence(input)]
if len(serialized) == 0:
logger.warning(f"No first input captured")
report.first_input = serialized
report.first_input_captured = True
except Exception as e:
logger.warning(f"Failed to capture first input: {e}")
return None


"""
hooks used to build module structure
"""
Expand Down Expand Up @@ -316,3 +280,65 @@ def inner(input_):
return input_

return inner


def inject_input_capture(model):
if hasattr(model, "_padiff_input_captured"):
return
original_forward = model.forward

@functools.wraps(original_forward)
def tracked_forward(*args, **kwargs):
report = current_report()

if not args and not kwargs:
logger.warning("Skipped capturing or loading input: both args and kwargs are empty.")
return original_forward(*args, **kwargs)

if hasattr(report, "_loaded_inputs") and report._loaded_inputs is not None:
logger.info("Loading first input from dump")
loaded_args, loaded_kwargs = report._loaded_inputs

try:
sig = inspect.signature(original_forward)
valid_arg_names = set(sig.parameters.keys())
except Exception as e:
logger.warning(f"Failed to get forward signature: {e}. Using all keys.")
valid_arg_names = set(loaded_kwargs.keys())

filtered_kwargs = {k: v for k, v in loaded_kwargs.items() if k in valid_arg_names}
dropped_keys = set(loaded_kwargs.keys()) - set(filtered_kwargs.keys())
if dropped_keys:
logger.debug(f"Dropped keys not in forward signature: {dropped_keys}")

final_kwargs = {**kwargs, **filtered_kwargs}
delattr(report, "_loaded_inputs")
return original_forward(*loaded_args, **final_kwargs)

if report and not hasattr(report, "first_input_captured"):

def serialize(x):
if isinstance(x, (paddle.Tensor, torch.Tensor)):
return ("Tensor", x.detach().cpu().numpy())
elif isinstance(x, dict):
return ("dict", {k: serialize(v) for k, v in x.items()})
elif isinstance(x, (list, tuple)):
return (type(x).__name__, [serialize(item) for item in x])
else:
return (type(x).__name__, str(x))

serialized = {
"args": [serialize(x) for x in args] if args else [],
"kwargs": {k: serialize(v) for k, v in kwargs.items()},
}
if serialized["args"] or serialized["kwargs"]:
report.first_input = serialized
report.first_input_captured = True
logger.info(f"Captured full input: args={len(args)}, kwargs={list(kwargs.keys())}")
else:
logger.warning("Skipped capturing input: serialized input is empty.")

return original_forward(*args, **kwargs)

model.forward = tracked_forward
model._padiff_input_captured = True
16 changes: 11 additions & 5 deletions padiff/ast_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
**kwargs,
):
self.framework = framework
self.base_name = src_model_name.split(".")[0] # get trainer if trainer.model
self.src_model_name = src_model_name # model(inputs)
self.padiff_model_name = f"model_{framework.lower()}" # "model_paddle"
self.proxy_model_name = "proxy_model" # proxy_model = create_model(model)
Expand Down Expand Up @@ -90,13 +91,13 @@ def is_model_call(self, node):
return False
func = node.func
# model(inp)
if isinstance(func, ast.Name) and func.id == self.src_model_name:
if isinstance(func, ast.Name) and func.id == self.base_name:
return True
# model.forward(inp), model.submodule(inp)
if isinstance(func, ast.Attribute):
if func.attr in self.exclude_methods:
if self.base_name != "trainer" and func.attr in self.exclude_methods:
return False
return self.is_model_attribute(func, self.src_model_name)
return self.is_model_attribute(func, self.base_name)
return False

def is_model_attribute(self, node, root="model"):
Expand Down Expand Up @@ -185,7 +186,12 @@ def add_create_model(self, node):
return [node, wrapper]

def wrap_with_guard(self, node):
guard_args = [ast.Name(id=self.src_model_name, ctx=ast.Load())]
path = self.src_model_name.split(".")
model_node = ast.Name(id=path[0], ctx=ast.Load())
for attr in path[1:]: # if trainer.model
model_node = ast.Attribute(value=model_node, attr=attr, ctx=ast.Load())
guard_args = [model_node]

guard_keywords = []

if self.mode == "align":
Expand All @@ -195,7 +201,7 @@ def wrap_with_guard(self, node):
logger.warning(
"The current injection does not include the 'keys_mapping' parameter of loading init weights. "
"If the model parameter names are inconsistent, please manually modify the injected script "
f"'debug_inject_{framework}.py' and pass 'keys_mapping' to 'PaDiffGuard(...)'"
f"'debug_inject_{self.framework}.py' and pass 'keys_mapping' to 'PaDiffGuard(...)'"
)

# load_first_inputs
Expand Down
2 changes: 1 addition & 1 deletion padiff/comparison/checker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@

from .base import check_dataloader
from .params import check_grads, check_params, check_weights
from .report import check_report
from .reports import check_report
from ...configs import global_compare_configs, update_configs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def _check_report_impl(report_path_0, report_path_1, cfg=None, diff_phase="both"


def check_forward(nodes, reports, cfg):
logger.debug(f"Checking forward of {nodes[0]['name']}")
action_name = cfg.get("action_name", None)
act = get_action(reports[0], nodes[0], reports[1], nodes[1], name=action_name)
try:
act(nodes[0]["fwd_outputs"], nodes[1]["fwd_outputs"], cfg)
logger.debug(f"Checking forward success of {nodes[0]['name']}")
return True
except Exception as e:
compare_info = e
Expand All @@ -105,7 +105,10 @@ def check_forward(nodes, reports, cfg):
return False

# sublayers is compared ok, but diff found at father layer
msg = f"Sublayers of {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but diff found at their output!"
msg = (
f"\n ⚠️ Sublayers of {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but diff found at their output! "
"\n 💡 This might be reasonable since errors accumulate if single_step mode is enabled."
)
print_report_info(nodes, reports, compare_info, "Forward", msg)
return False

Expand Down
54 changes: 43 additions & 11 deletions padiff/tools/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,52 @@ def dump_first_input(report, path):
os.makedirs(first_input_path, exist_ok=True)

input_idx = 0
for i, (typ, data) in enumerate(first_input):
if typ == "Tensor" and isinstance(data, numpy.ndarray):
numpy.save(os.path.join(first_input_path, f"input_{input_idx}.npy"), data)
else:
meta_file = os.path.join(first_input_path, f"input_{input_idx}.json")
try:
json.dump({"type": typ, "data": data}, open(meta_file, "w"), indent=2, default=str)
except Exception as e:
with open(meta_file, "w") as f:
f.write(f"type: {typ}\nvalue: {str(data)}")
input_idx += 1
meta_info = []
if "args" in first_input:
for item in first_input["args"]:
typ, data = item
file_base = f"arg_{input_idx}"
if typ == "Tensor" and isinstance(data, numpy.ndarray):
npy_path = os.path.join(first_input_path, f"{file_base}.npy")
numpy.save(npy_path, data)
meta_info.append({"type": "Tensor", "path": f"{file_base}.npy"})
else:
json_path = os.path.join(first_input_path, f"{file_base}.json")
try:
with open(json_path, "w") as f:
json.dump({"type": typ, "data": data}, f, indent=2, default=str)
except Exception as e:
with open(json_path, "w") as f:
f.write(f"type: {typ}\nvalue: {str(data)}")
meta_info.append({"type": typ, "path": f"{file_base}.json"})
input_idx += 1

if "kwargs" in first_input:
for key, item in first_input["kwargs"].items():
typ, data = item
file_base = f"kw_{key}_{input_idx}"
if typ == "Tensor" and isinstance(data, numpy.ndarray):
npy_path = os.path.join(first_input_path, f"{file_base}.npy")
numpy.save(npy_path, data)
meta_info.append({"type": "Tensor", "path": f"{file_base}.npy", "key": key})
else:
json_path = os.path.join(first_input_path, f"{file_base}.json")
try:
with open(json_path, "w") as f:
json.dump({"type": typ, "data": data}, f, indent=2, default=str)
except Exception as e:
with open(json_path, "w") as f:
f.write(f"type: {typ}\nvalue: {str(data)}")
meta_info.append({"type": typ, "path": f"{file_base}.json", "key": key})
input_idx += 1

meta_file = os.path.join(first_input_path, "meta.json")
with open(meta_file, "w") as f:
json.dump(meta_info, f, indent=2)

return {
"has_first_input": True,
"first_input_dir": "first_input",
"first_input_count": input_idx,
"first_input_meta": "first_input/meta.json",
}
Loading
Loading