From cb4c573a7e1e242e96d8c57a5a4f65c66a2bb9f6 Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Mon, 1 Sep 2025 09:40:35 +0000 Subject: [PATCH] feat:support saving and loading of kwargs and fix some bugs --- README.md | 2 +- padiff/abstracts/hooks/hook.py | 106 +++++++++++------- padiff/ast_injector.py | 16 ++- padiff/comparison/checker/__init__.py | 2 +- .../checker/{report.py => reports.py} | 7 +- padiff/tools/dump.py | 54 +++++++-- padiff/tools/load.py | 103 +++++++++-------- padiff/utils/log.py | 2 +- 8 files changed, 179 insertions(+), 113 deletions(-) rename padiff/comparison/checker/{report.py => reports.py} (95%) diff --git a/README.md b/README.md index 6a9442a..cf843c7 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ python -m padiff.cli \ 或将命令写入 .yaml 文件后,运行 ```sh -python -m padiff.cli --padiff_config.yaml +python -m padiff.cli --config padiff_config.yaml ``` yaml 文件样例 diff --git a/padiff/abstracts/hooks/hook.py b/padiff/abstracts/hooks/hook.py index aac87f8..3b95089 100644 --- a/padiff/abstracts/hooks/hook.py +++ b/padiff/abstracts/hooks/hook.py @@ -14,7 +14,9 @@ import sys import contextlib +import functools from functools import partial +import inspect import numpy as np import paddle @@ -26,7 +28,6 @@ flatten, for_each_grad_tensor, logger, - to_sequence, map_structure, ) from .base import current_report, find_base_report_node, single_step_state @@ -44,8 +45,8 @@ def register_hooker(model): # register model-level hooks handle_init = model.register_forward_pre_hook(partial(init_weights_hook)) - handle_input = model.register_forward_pre_hook(partial(first_input_hook)) - remove_handles.extend([handle_init, handle_input]) + remove_handles.append(handle_init) + inject_input_capture(models[0].model) # register layer-level hooks for mod in models: @@ -87,43 +88,6 @@ def init_weights_hook(model, input): return None -def first_input_hook(model, input): - report = current_report() - if report is None: - logger.debug("first_input_hook: current_report is None") - return None - - if hasattr(report, "_loaded_inputs") and report._loaded_inputs is not None: - logger.debug("first_input_hook: loading first input") - loaded_inputs = report._loaded_inputs - if isinstance(loaded_inputs, list): - return tuple(loaded_inputs) - return loaded_inputs - - if not hasattr(report, "first_input_captured"): - logger.debug("first_input_hook: capturing first input") - - def serialize(x): - if isinstance(x, (paddle.Tensor, torch.Tensor)): - return ("Tensor", x.detach().cpu().numpy()) - elif isinstance(x, dict): - return ("dict", {k: serialize(v) for k, v in x.items()}) - elif isinstance(x, (list, tuple)): - return (type(x).__name__, [serialize(item) for item in x]) - else: - return (type(x).__name__, str(x)) - - try: - serialized = [serialize(x) for x in to_sequence(input)] - if len(serialized) == 0: - logger.warning(f"No first input captured") - report.first_input = serialized - report.first_input_captured = True - except Exception as e: - logger.warning(f"Failed to capture first input: {e}") - return None - - """ hooks used to build module structure """ @@ -316,3 +280,65 @@ def inner(input_): return input_ return inner + + +def inject_input_capture(model): + if hasattr(model, "_padiff_input_captured"): + return + original_forward = model.forward + + @functools.wraps(original_forward) + def tracked_forward(*args, **kwargs): + report = current_report() + + if not args and not kwargs: + logger.warning("Skipped capturing or loading input: both args and kwargs are empty.") + return original_forward(*args, **kwargs) + + if hasattr(report, "_loaded_inputs") and report._loaded_inputs is not None: + logger.info("Loading first input from dump") + loaded_args, loaded_kwargs = report._loaded_inputs + + try: + sig = inspect.signature(original_forward) + valid_arg_names = set(sig.parameters.keys()) + except Exception as e: + logger.warning(f"Failed to get forward signature: {e}. Using all keys.") + valid_arg_names = set(loaded_kwargs.keys()) + + filtered_kwargs = {k: v for k, v in loaded_kwargs.items() if k in valid_arg_names} + dropped_keys = set(loaded_kwargs.keys()) - set(filtered_kwargs.keys()) + if dropped_keys: + logger.debug(f"Dropped keys not in forward signature: {dropped_keys}") + + final_kwargs = {**kwargs, **filtered_kwargs} + delattr(report, "_loaded_inputs") + return original_forward(*loaded_args, **final_kwargs) + + if report and not hasattr(report, "first_input_captured"): + + def serialize(x): + if isinstance(x, (paddle.Tensor, torch.Tensor)): + return ("Tensor", x.detach().cpu().numpy()) + elif isinstance(x, dict): + return ("dict", {k: serialize(v) for k, v in x.items()}) + elif isinstance(x, (list, tuple)): + return (type(x).__name__, [serialize(item) for item in x]) + else: + return (type(x).__name__, str(x)) + + serialized = { + "args": [serialize(x) for x in args] if args else [], + "kwargs": {k: serialize(v) for k, v in kwargs.items()}, + } + if serialized["args"] or serialized["kwargs"]: + report.first_input = serialized + report.first_input_captured = True + logger.info(f"Captured full input: args={len(args)}, kwargs={list(kwargs.keys())}") + else: + logger.warning("Skipped capturing input: serialized input is empty.") + + return original_forward(*args, **kwargs) + + model.forward = tracked_forward + model._padiff_input_captured = True diff --git a/padiff/ast_injector.py b/padiff/ast_injector.py index 1cfab93..6a46ada 100644 --- a/padiff/ast_injector.py +++ b/padiff/ast_injector.py @@ -28,6 +28,7 @@ def __init__( **kwargs, ): self.framework = framework + self.base_name = src_model_name.split(".")[0] # get trainer if trainer.model self.src_model_name = src_model_name # model(inputs) self.padiff_model_name = f"model_{framework.lower()}" # "model_paddle" self.proxy_model_name = "proxy_model" # proxy_model = create_model(model) @@ -90,13 +91,13 @@ def is_model_call(self, node): return False func = node.func # model(inp) - if isinstance(func, ast.Name) and func.id == self.src_model_name: + if isinstance(func, ast.Name) and func.id == self.base_name: return True # model.forward(inp), model.submodule(inp) if isinstance(func, ast.Attribute): - if func.attr in self.exclude_methods: + if self.base_name != "trainer" and func.attr in self.exclude_methods: return False - return self.is_model_attribute(func, self.src_model_name) + return self.is_model_attribute(func, self.base_name) return False def is_model_attribute(self, node, root="model"): @@ -185,7 +186,12 @@ def add_create_model(self, node): return [node, wrapper] def wrap_with_guard(self, node): - guard_args = [ast.Name(id=self.src_model_name, ctx=ast.Load())] + path = self.src_model_name.split(".") + model_node = ast.Name(id=path[0], ctx=ast.Load()) + for attr in path[1:]: # if trainer.model + model_node = ast.Attribute(value=model_node, attr=attr, ctx=ast.Load()) + guard_args = [model_node] + guard_keywords = [] if self.mode == "align": @@ -195,7 +201,7 @@ def wrap_with_guard(self, node): logger.warning( "The current injection does not include the 'keys_mapping' parameter of loading init weights. " "If the model parameter names are inconsistent, please manually modify the injected script " - f"'debug_inject_{framework}.py' and pass 'keys_mapping' to 'PaDiffGuard(...)'" + f"'debug_inject_{self.framework}.py' and pass 'keys_mapping' to 'PaDiffGuard(...)'" ) # load_first_inputs diff --git a/padiff/comparison/checker/__init__.py b/padiff/comparison/checker/__init__.py index 97f8655..60dc00c 100644 --- a/padiff/comparison/checker/__init__.py +++ b/padiff/comparison/checker/__init__.py @@ -14,5 +14,5 @@ from .base import check_dataloader from .params import check_grads, check_params, check_weights -from .report import check_report +from .reports import check_report from ...configs import global_compare_configs, update_configs diff --git a/padiff/comparison/checker/report.py b/padiff/comparison/checker/reports.py similarity index 95% rename from padiff/comparison/checker/report.py rename to padiff/comparison/checker/reports.py index fbe13b8..1164cdd 100644 --- a/padiff/comparison/checker/report.py +++ b/padiff/comparison/checker/reports.py @@ -75,11 +75,11 @@ def _check_report_impl(report_path_0, report_path_1, cfg=None, diff_phase="both" def check_forward(nodes, reports, cfg): + logger.debug(f"Checking forward of {nodes[0]['name']}") action_name = cfg.get("action_name", None) act = get_action(reports[0], nodes[0], reports[1], nodes[1], name=action_name) try: act(nodes[0]["fwd_outputs"], nodes[1]["fwd_outputs"], cfg) - logger.debug(f"Checking forward success of {nodes[0]['name']}") return True except Exception as e: compare_info = e @@ -105,7 +105,10 @@ def check_forward(nodes, reports, cfg): return False # sublayers is compared ok, but diff found at father layer - msg = f"Sublayers of {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but diff found at their output!" + msg = ( + f"\n ⚠️ Sublayers of {nodes[0]['name']} and {nodes[1]['name']} are corresponded, but diff found at their output! " + "\n 💡 This might be reasonable since errors accumulate if single_step mode is enabled." + ) print_report_info(nodes, reports, compare_info, "Forward", msg) return False diff --git a/padiff/tools/dump.py b/padiff/tools/dump.py index d3191b9..033d12c 100644 --- a/padiff/tools/dump.py +++ b/padiff/tools/dump.py @@ -242,20 +242,52 @@ def dump_first_input(report, path): os.makedirs(first_input_path, exist_ok=True) input_idx = 0 - for i, (typ, data) in enumerate(first_input): - if typ == "Tensor" and isinstance(data, numpy.ndarray): - numpy.save(os.path.join(first_input_path, f"input_{input_idx}.npy"), data) - else: - meta_file = os.path.join(first_input_path, f"input_{input_idx}.json") - try: - json.dump({"type": typ, "data": data}, open(meta_file, "w"), indent=2, default=str) - except Exception as e: - with open(meta_file, "w") as f: - f.write(f"type: {typ}\nvalue: {str(data)}") - input_idx += 1 + meta_info = [] + if "args" in first_input: + for item in first_input["args"]: + typ, data = item + file_base = f"arg_{input_idx}" + if typ == "Tensor" and isinstance(data, numpy.ndarray): + npy_path = os.path.join(first_input_path, f"{file_base}.npy") + numpy.save(npy_path, data) + meta_info.append({"type": "Tensor", "path": f"{file_base}.npy"}) + else: + json_path = os.path.join(first_input_path, f"{file_base}.json") + try: + with open(json_path, "w") as f: + json.dump({"type": typ, "data": data}, f, indent=2, default=str) + except Exception as e: + with open(json_path, "w") as f: + f.write(f"type: {typ}\nvalue: {str(data)}") + meta_info.append({"type": typ, "path": f"{file_base}.json"}) + input_idx += 1 + + if "kwargs" in first_input: + for key, item in first_input["kwargs"].items(): + typ, data = item + file_base = f"kw_{key}_{input_idx}" + if typ == "Tensor" and isinstance(data, numpy.ndarray): + npy_path = os.path.join(first_input_path, f"{file_base}.npy") + numpy.save(npy_path, data) + meta_info.append({"type": "Tensor", "path": f"{file_base}.npy", "key": key}) + else: + json_path = os.path.join(first_input_path, f"{file_base}.json") + try: + with open(json_path, "w") as f: + json.dump({"type": typ, "data": data}, f, indent=2, default=str) + except Exception as e: + with open(json_path, "w") as f: + f.write(f"type: {typ}\nvalue: {str(data)}") + meta_info.append({"type": typ, "path": f"{file_base}.json", "key": key}) + input_idx += 1 + + meta_file = os.path.join(first_input_path, "meta.json") + with open(meta_file, "w") as f: + json.dump(meta_info, f, indent=2) return { "has_first_input": True, "first_input_dir": "first_input", "first_input_count": input_idx, + "first_input_meta": "first_input/meta.json", } diff --git a/padiff/tools/load.py b/padiff/tools/load.py index a5f453d..c60b3c5 100644 --- a/padiff/tools/load.py +++ b/padiff/tools/load.py @@ -31,64 +31,63 @@ def load_first_input_from_dump(report_path, tar_framework): return None input_dir = os.path.join(report_path, "first_input") - all_files = sorted( - [f for f in os.listdir(input_dir) if f.startswith("input_")], key=lambda x: int(x.split("_")[1].split(".")[0]) - ) - if not all_files: - logger.warning(f"Not found any 'input_*' file in {input_dir}. Please check the path.") + meta_file = os.path.join(input_dir, "meta.json") + + if not os.path.exists(meta_file): + logger.warning(f"'meta.json' not found in {input_dir}. Please check the path.") return None - reconstructed_inputs = [] - for file_name in all_files: - file_path = os.path.join(input_dir, file_name) + try: + with open(meta_file, "r") as f: + meta_info = json.load(f) + except Exception as e: + logger.error(f"Failed to load 'meta.json': {e}") + return None - if file_name.endswith(".npy"): - numpy_array = np.load(file_path) - if tar_framework == "paddle": - tensor = paddle.to_tensor(numpy_array) - tensor.stop_gradient = False - elif tar_framework == "torch": - tensor = torch.tensor(numpy_array) - tensor.requires_grad_ = True - reconstructed_inputs.append(tensor) - - elif file_name.endswith(".json"): - try: - with open(file_path) as f: - meta_data = json.load(f) - data_type = meta_data["type"] - data_value = meta_data["data"] - - if data_type == "dict": + args = [] + kwargs = {} + + for item in meta_info: + file_path = os.path.join(input_dir, item["path"]) + key = item.get("key") + + try: + if item["type"] == "Tensor": + numpy_array = np.load(file_path) + if tar_framework == "paddle": + tensor = paddle.to_tensor(numpy_array) + tensor.stop_gradient = False + elif tar_framework == "torch": + tensor = torch.tensor(numpy_array) + tensor.requires_grad_(True) + else: + raise ValueError(f"Unsupported framework: {tar_framework}") + + if key is None: + args.append(tensor) + else: + kwargs[key] = tensor + else: + if item["type"] == "dict": reconstructed_dict = {} - for key, (item_type, item_value) in data_value.items(): - if item_type == "Tensor": - raise RuntimeError( - f"Including Tensor types in meta JSON is not supported. Please check the serialize logic." - ) - else: - reconstructed_dict[key] = item_value - reconstructed_inputs.append(reconstructed_dict) - elif data_type in ["list", "tuple"]: - reconstructed_list = [] - for item_type, item_value in data_value: - if item_type == "Tensor": - raise RuntimeError(f"Including Tensor types in meta JSON is not supported.") - else: - reconstructed_list.append(item_value) - if data_type == "tuple": - reconstructed_list = tuple(reconstructed_list) - reconstructed_inputs.append(reconstructed_list) + for k, v in item["data"].items(): + reconstructed_dict[k] = v + value = reconstructed_dict + elif item["type"] in ["list", "tuple"]: + reconstructed_list = [v for v in item["data"]] + value = tuple(reconstructed_list) if item["type"] == "tuple" else reconstructed_list else: - reconstructed_inputs.append(data_value) - except Exception as e: - logger.error(f"Error loading metadata file {file_name}: {e}") - raise + value = item["data"] - else: - logger.warning(f"Ignore unknown files: {file_name}") - continue - return reconstructed_inputs + if key is None: + args.append(value) + else: + kwargs[key] = value + except Exception as e: + logger.error(f"Error loading metadata file {file_path}: {e}") + raise + + return (args, kwargs) def load_init_weights_from_dump( diff --git a/padiff/utils/log.py b/padiff/utils/log.py index 4291635..8645651 100644 --- a/padiff/utils/log.py +++ b/padiff/utils/log.py @@ -117,7 +117,7 @@ def print_report_info(nodes, reports, exc, stage, msg=None): if msg is not None: logger.warning("ADDITIONAL MESSAGE:") - logger.warning(msg.strip() + " \n") + logger.warning(msg + " \n") retstr = struct_info_log(reports, [node["origin_node"] for node in nodes], "report") logger.info(retstr)