diff --git a/README.md b/README.md index cf843c7..700a1a8 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@ **P**addle **A**utomatically **Diff** precision toolkits. -## 最近更新(latest 8.21) +## 最近更新(latest 9.2) -### 使用单行命令对齐(当前仅支持前向对齐) +### 使用单行命令对齐(支持前反向对齐) 运行命令前,请运行 `python -m padiff.cli -h` 获取更详细的参数说明。 @@ -18,6 +18,8 @@ python -m padiff.cli \ --pd_cmd "python paddle_project/run.py" \ --pt_model_name "pt_model" \ --pd_model_name "pd_model" \ + --pt_optim_name "pt_optimizer" \ + --pd_optim_name "pd_optimizer" \ --log_dir "./padiff_log" \ --align_depth 1 \ --single_step_mode "forward" \ @@ -37,10 +39,12 @@ yaml 文件样例 ```python # padiff_config.yaml -pt_cmd: "python transformer4sr/train_transformer_ori.py" -pd_cmd: "python paddle_project/train_transformer_ori.py" -pt_model_name: "transformer" -pd_model_name: "transformer" +pt_cmd: "python transformer4sr/train_transformer.py" +pd_cmd: "python paddle_project/train_transformer.py" +pt_model_name: "transformer_pt" +pd_model_name: "transformer_pd" +pt_optim_name: "optimizer_pt" +pd_optim_name: "optimizer_pd" log_dir: "./padiff_log" align_depth: 2 single_step_mode: "forward" diff --git a/padiff/abstracts/hooks/base.py b/padiff/abstracts/hooks/base.py index b7a4bd4..a352c17 100644 --- a/padiff/abstracts/hooks/base.py +++ b/padiff/abstracts/hooks/base.py @@ -13,6 +13,7 @@ # limitations under the License. import contextvars +from typing import Dict # --- Core state management class --- # This is an internal state shared by all Guards and should be placed first @@ -56,6 +57,66 @@ def _traversal(node, bucket): _global_report = None +class _CallsContext: + """ + A global context for managing forward call counts across multiple PaDiffGuard invocations. + This ensures that max_calls is respected even when PaDiffGuard is re-entered. + """ + + _state = contextvars.ContextVar("_calls_context_state", default=None) + + def __init__(self): + self._state.set({"count": 0, "limit": 0, "active": False}) + + @property + def state(self) -> Dict: + s = self._state.get() + if s is None: + s = {"count": 0, "limit": 0, "active": False} + self._state.set(s) + return s + + def set_limit(self, limit: int): + self.state["limit"] = limit + self.state["active"] = True + + def increment(self) -> int: + if not self.state["active"]: + return 0 + self.state["count"] += 1 + return self.state["count"] + + def is_exceeded(self) -> bool: + if not self.state["active"]: + return False + return self.state["count"] >= self.state["limit"] + + def reset(self): + self.state["count"] = 0 + self.state["limit"] = 0 + self.state["active"] = False + + @classmethod + def get_current(cls) -> "_CallsContext": + return cls() + + +_calls_context = None + + +class _CallsComplete(Exception): + """A private exception used by PaDiffGuard to interrupt execution. + + This exception is raised by the internal calls_hook when the maximum number + of calls (max_calls) has been reached. It is caught by PaDiffGuard + to exit the context manager gracefully. + """ + + def __init__(self, message="CallsComplete: maximum number of forward calls reached."): + self.message = message + super().__init__(self.message) + + # --- Public utility functions for external calls --- # These are "accessors" to the Guard's internal state and should be placed after the Guard it depends on @@ -86,3 +147,10 @@ def find_base_report_node(net_id, step_idx): raise RuntimeError(f"Index out of range: net_id={net_id}, step_idx={step_idx}, list length={len(node_list)}") return _context.base[net_id][step_idx] + + +def get_calls_context() -> _CallsContext: + global _calls_context + if _calls_context is None: + _calls_context = _CallsContext() + return _calls_context diff --git a/padiff/abstracts/hooks/guard.py b/padiff/abstracts/hooks/guard.py index c94ed22..18426b5 100644 --- a/padiff/abstracts/hooks/guard.py +++ b/padiff/abstracts/hooks/guard.py @@ -16,12 +16,16 @@ import json import os import sys +import functools +import inspect +import paddle +import torch -from ...utils import set_seed, logger -from .base import _context, _current_report +from ...utils import set_seed, logger, wrap_optimizer_step +from .base import _context, _current_report, _CallsComplete, current_report, get_calls_context from .hook import register_hooker from ..proxy import create_model -from ...tools import dump_report +from ...tools import load_first_input_from_dump, load_init_weights_from_dump _global_report = None @@ -110,24 +114,108 @@ def AlignmentGuard(model, seed=42): pass -class _CallsComplete(Exception): - """A private exception used by PaDiffGuard to interrupt execution. - - This exception is raised by the internal calls_hook when the maximum number - of calls (max_calls) has been reached. It is caught by PaDiffGuard - to exit the context manager gracefully. +@contextlib.contextmanager +def InputCaptureGuard(model, base_dump_path=None, framework=None, load_first_inputs=False): """ + Context manager to capture or inject first input. Cannot be implemented as a hook + because hook registered through 'register_forward_pre_hook' can only capture args but not kwargs. + """ + if hasattr(model, "_padiff_input_captured"): + yield + return + + original_forward = model.forward + + @functools.wraps(original_forward) + def tracked_forward(*args, **kwargs): + report = current_report() + if not args and not kwargs: + logger.warning("Skipped capturing or loading input: both args and kwargs are empty.") + return original_forward(*args, **kwargs) + + if load_first_inputs and base_dump_path and framework and not hasattr(report, "_inputs_loaded"): + assert framework is not None, "'framework' must be setted if 'load_first_inputs' is True" + logger.info("Loading first input from dump") + loaded_inputs = load_first_input_from_dump(base_dump_path, framework) + if loaded_inputs is not None: + loaded_args, loaded_kwargs = loaded_inputs + try: + sig = inspect.signature(original_forward) + valid_arg_names = set(sig.parameters.keys()) + except Exception as e: + logger.warning(f"Failed to get forward signature: {e}. Using all keys.") + valid_arg_names = set(loaded_kwargs.keys()) + + filtered_kwargs = {k: v for k, v in loaded_kwargs.items() if k in valid_arg_names} + dropped_keys = set(loaded_kwargs.keys()) - set(filtered_kwargs.keys()) + if dropped_keys: + logger.debug(f"Dropped keys not in forward signature: {dropped_keys}") + + final_kwargs = {**kwargs, **filtered_kwargs} + report._inputs_loaded = True + return original_forward(*loaded_args, **final_kwargs) + + if report and not hasattr(report, "first_input_captured"): + + def serialize(x): + if isinstance(x, (paddle.Tensor, torch.Tensor)): + return ("Tensor", x.detach().cpu().numpy()) + elif isinstance(x, dict): + return ("dict", {k: serialize(v) for k, v in x.items()}) + elif isinstance(x, (list, tuple)): + return (type(x).__name__, [serialize(item) for item in x]) + else: + return (type(x).__name__, str(x)) + + serialized = { + "args": [serialize(x) for x in args] if args else [], + "kwargs": {k: serialize(v) for k, v in kwargs.items()}, + } + if serialized["args"] or serialized["kwargs"]: + report.first_input = serialized + report.first_input_captured = True + logger.info(f"Captured full input: args={len(args)}, kwargs={list(kwargs.keys())}") + else: + logger.warning("Skipped capturing input: serialized input is empty.") + + return original_forward(*args, **kwargs) + + model.forward = tracked_forward + model._padiff_input_captured = True + + try: + yield + finally: + model.forward = original_forward + + +@contextlib.contextmanager +def MaxCallsGuard(max_calls: int, model): + if max_calls <= 0: + yield + return + + calls_context = get_calls_context() + + def pre_hook(m, input): + if calls_context.is_exceeded(): + logger.warning(f"PaDiffGuard: max_calls={max_calls} reached, raising _CallsComplete") + raise _CallsComplete() + count = calls_context.increment() + logger.info(f"MaxCallsGuard: forward start calling #{count}") - def __init__(self, message="CallsComplete: maximum number of forward calls reached."): - self.message = message - super().__init__(self.message) + handle = model.register_forward_pre_hook(pre_hook) + try: + yield + finally: + handle.remove() @contextlib.contextmanager def PaDiffGuard( model, + optimizer=None, name="model", - auto_dump=True, align_depth="inf", single_step_mode=None, # None, "forward", "backward" load_init_weights=False, @@ -139,73 +227,63 @@ def PaDiffGuard( black_list=None, keys_mapping=None, ): - # create_model - if not hasattr(model, "report"): - proxy_model = create_model(model, name=name) - else: - proxy_model = model - - logger.debug(f"PaDiffGuard: depth of alignment is {align_depth}.") - proxy_model.marker.update_black_list_with_depth(align_depth) - proxy_model.update_black_list_with_name(black_list) + # moniter number of calls + calls_context = get_calls_context() + reset_flag = calls_context.state["count"] == 0 - if load_init_weights or load_first_inputs or (single_step_mode is not None): - assert ( - base_dump_path is not None - ), "'base_dump_path' should not be None, when loading of init weights or/and first inputs is needed or using single_step mode" + if reset_flag: + # set max calls + calls_context.set_limit(max_calls) - # load init weights - if load_init_weights: - from ...tools import load_init_weights_from_dump + logger.info(f"PaDiffGuard: creating proxy model.") + proxy_model = create_model(model, name=name, reset_dir=reset_flag) + model._padiff_proxy = proxy_model - load_init_weights_from_dump(base_dump_path, proxy_model, keys_mapping) + if optimizer is not None and not hasattr(optimizer, "_padiff_proxy_model"): + logger.info(f"PaDiffGuard: wrapping optimizer.step().") + optimizer._padiff_proxy_model = proxy_model + wrap_optimizer_step(optimizer) - # load first inputs - if load_first_inputs: - from ...tools import load_first_input_from_dump + if load_init_weights or load_first_inputs or (single_step_mode is not None): + assert ( + base_dump_path is not None + ), "'base_dump_path' should not be None, when loading of init weights or/and first inputs is needed or using single_step mode" - assert framework is not None, "'framework' must be setted if 'load_first_inputs' is True" - loaded_inputs = load_first_input_from_dump(base_dump_path, framework) - proxy_model.report._loaded_inputs = loaded_inputs + # load init weights + if load_init_weights: + load_init_weights_from_dump(base_dump_path, proxy_model, keys_mapping) - # moniter number of calls - calls_count = 0 + logger.debug(f"PaDiffGuard: depth of alignment is {align_depth}.") + proxy_model.marker.update_black_list_with_depth(align_depth) + proxy_model.update_black_list_with_name(black_list) - def calls_hook(m, input, output): - nonlocal calls_count - calls_count += 1 - logger.debug(f"PaDiffGuard: forward call #{calls_count}") - if calls_count >= max_calls: - logger.warning(f"PaDiffGuard: max_calls={max_calls} reached, raising _CallsComplete") - raise _CallsComplete() + else: + proxy_model = model._padiff_proxy try: # set hooks with contextlib.ExitStack() as stack: + # moniter number of calls + if max_calls > 0: + stack.enter_context(MaxCallsGuard(max_calls, proxy_model)) + stack.enter_context(AlignmentGuard(proxy_model, seed=seed)) stack.enter_context(report_guard(proxy_model.report)) + # load first inputs + if reset_flag: + stack.enter_context(InputCaptureGuard(proxy_model.model, base_dump_path, framework, load_first_inputs)) if single_step_mode is not None: stack.enter_context(SingleStepGuard(single_step_mode, base_dump_path)) stack.enter_context(register_hooker(proxy_model)) - count_handle = proxy_model.register_forward_post_hook(calls_hook) - stack.callback(count_handle.remove) yield model - # dump report - if auto_dump: - dump_report(proxy_model, proxy_model.dump_path) - except _CallsComplete: - logger.info(f"PaDiffGuard: calls completed ({calls_count}/{max_calls})") - # dump report - if auto_dump: - try: - dump_report(proxy_model, proxy_model.dump_path) - except Exception as e: - logger.error(e) + # dump + proxy_model.dump_report(proxy_model.dump_path) + proxy_model.dump_weights(proxy_model.dump_path) sys.exit(0) diff --git a/padiff/abstracts/hooks/hook.py b/padiff/abstracts/hooks/hook.py index 3b95089..7ec297d 100644 --- a/padiff/abstracts/hooks/hook.py +++ b/padiff/abstracts/hooks/hook.py @@ -14,9 +14,7 @@ import sys import contextlib -import functools from functools import partial -import inspect import numpy as np import paddle @@ -29,6 +27,7 @@ for_each_grad_tensor, logger, map_structure, + get_numpy_from_tensor, ) from .base import current_report, find_base_report_node, single_step_state @@ -46,7 +45,6 @@ def register_hooker(model): # register model-level hooks handle_init = model.register_forward_pre_hook(partial(init_weights_hook)) remove_handles.append(handle_init) - inject_input_capture(models[0].model) # register layer-level hooks for mod in models: @@ -76,12 +74,7 @@ def init_weights_hook(model, input): init_weights = {} for name, param in model.named_parameters(): if isinstance(param, (paddle.Tensor, torch.Tensor)): - if param.dtype == torch.bfloat16: - np_array = param.detach().cpu().float().numpy() - elif param.dtype == paddle.bfloat16: - np_array = param.detach().cpu().astype("float32").numpy() - else: - np_array = param.detach().cpu().numpy() + np_array = get_numpy_from_tensor(param) init_weights[name] = np_array report.init_weights = init_weights report.init_weights_saved = True @@ -280,65 +273,3 @@ def inner(input_): return input_ return inner - - -def inject_input_capture(model): - if hasattr(model, "_padiff_input_captured"): - return - original_forward = model.forward - - @functools.wraps(original_forward) - def tracked_forward(*args, **kwargs): - report = current_report() - - if not args and not kwargs: - logger.warning("Skipped capturing or loading input: both args and kwargs are empty.") - return original_forward(*args, **kwargs) - - if hasattr(report, "_loaded_inputs") and report._loaded_inputs is not None: - logger.info("Loading first input from dump") - loaded_args, loaded_kwargs = report._loaded_inputs - - try: - sig = inspect.signature(original_forward) - valid_arg_names = set(sig.parameters.keys()) - except Exception as e: - logger.warning(f"Failed to get forward signature: {e}. Using all keys.") - valid_arg_names = set(loaded_kwargs.keys()) - - filtered_kwargs = {k: v for k, v in loaded_kwargs.items() if k in valid_arg_names} - dropped_keys = set(loaded_kwargs.keys()) - set(filtered_kwargs.keys()) - if dropped_keys: - logger.debug(f"Dropped keys not in forward signature: {dropped_keys}") - - final_kwargs = {**kwargs, **filtered_kwargs} - delattr(report, "_loaded_inputs") - return original_forward(*loaded_args, **final_kwargs) - - if report and not hasattr(report, "first_input_captured"): - - def serialize(x): - if isinstance(x, (paddle.Tensor, torch.Tensor)): - return ("Tensor", x.detach().cpu().numpy()) - elif isinstance(x, dict): - return ("dict", {k: serialize(v) for k, v in x.items()}) - elif isinstance(x, (list, tuple)): - return (type(x).__name__, [serialize(item) for item in x]) - else: - return (type(x).__name__, str(x)) - - serialized = { - "args": [serialize(x) for x in args] if args else [], - "kwargs": {k: serialize(v) for k, v in kwargs.items()}, - } - if serialized["args"] or serialized["kwargs"]: - report.first_input = serialized - report.first_input_captured = True - logger.info(f"Captured full input: args={len(args)}, kwargs={list(kwargs.keys())}") - else: - logger.warning("Skipped capturing input: serialized input is empty.") - - return original_forward(*args, **kwargs) - - model.forward = tracked_forward - model._padiff_input_captured = True diff --git a/padiff/abstracts/proxy/__init__.py b/padiff/abstracts/proxy/__init__.py index 71e6915..068d7fa 100644 --- a/padiff/abstracts/proxy/__init__.py +++ b/padiff/abstracts/proxy/__init__.py @@ -42,13 +42,14 @@ def remove_inplace(model): submodel.inplace = False -def create_model(model, name=None, dump_freq=1): +def create_model(model, name=None, dump_freq=1, reset_dir=True): retval = ProxyModel.create_from(model, name, dump_freq) init_route(retval) - if retval.framework == "paddle" and paddle.distributed.get_rank() % 8 == 0: + if retval.framework == "paddle" and paddle.distributed.get_rank() % 8 == 0 and reset_dir: # Only reset the root path once for each machine, here we assume each machine has 8 GPUs logger.reset_dir(retval.dump_path) if retval.framework == "torch": - logger.reset_dir(retval.dump_path) + if reset_dir: + logger.reset_dir(retval.dump_path) remove_inplace(retval) return retval diff --git a/padiff/abstracts/proxy/params.py b/padiff/abstracts/proxy/params.py index 146b558..3e357e4 100644 --- a/padiff/abstracts/proxy/params.py +++ b/padiff/abstracts/proxy/params.py @@ -52,8 +52,15 @@ class PaddleParam(ProxyParam): def __init__(self, param): super().__init__(param, "paddle") + def _numpy(self, tensor): + if tensor.dtype == paddle.bfloat16: + np_array = tensor.astype("float32").numpy() + else: + np_array = tensor.numpy() + return np_array + def numpy(self): - return self.param.numpy() + return self._numpy(self.param) def set_data(self, np_value): paddle.assign(paddle.to_tensor(np_value, dtype=self.param.dtype), self.param) @@ -63,14 +70,15 @@ def shape(self): def grad(self): if self.param.grad is not None: - return self.param.grad.numpy() + return self._numpy(self.param.grad) else: return None def main_grad(self): if hasattr(self.param, "main_grad") and self.param.main_grad is not None: assert self.param.grad is None - return self.param.main_grad.numpy() + return self._numpy(self.param.main_grad) + else: return None @@ -79,8 +87,15 @@ class TorchParam(ProxyParam): def __init__(self, param): super().__init__(param, "torch") + def _numpy(self, tensor): + if tensor.dtype == torch.bfloat16: + np_array = tensor.cpu().detach().float().numpy() + else: + np_array = tensor.cpu().detach().numpy() + return np_array + def numpy(self): - return self.param.data.detach().cpu().numpy() + return self._numpy(self.param.data) def set_data(self, np_value): self.param.data = torch.as_tensor(np_value).type(self.param.dtype).to(self.param.device) @@ -90,7 +105,7 @@ def shape(self): def grad(self): if self.param.grad is not None: - return self.param.grad.data.detach().cpu().numpy() + return self._numpy(self.param.grad.data) else: return None diff --git a/padiff/ast_injector.py b/padiff/ast_injector.py index 6a46ada..b8e44eb 100644 --- a/padiff/ast_injector.py +++ b/padiff/ast_injector.py @@ -22,14 +22,14 @@ class PaDiffInjector(ast.NodeTransformer): def __init__( self, framework: str, - src_model_name="model", + model_name="model", mode="base", alignment_dir=None, **kwargs, ): self.framework = framework - self.base_name = src_model_name.split(".")[0] # get trainer if trainer.model - self.src_model_name = src_model_name # model(inputs) + self.base_name = model_name.split(".")[0] # get trainer if trainer.model + self.model_name = model_name # model(inputs) self.padiff_model_name = f"model_{framework.lower()}" # "model_paddle" self.proxy_model_name = "proxy_model" # proxy_model = create_model(model) self.mode = mode @@ -67,7 +67,7 @@ def visit_Module(self, node): def visit_Assign(self, node): # # model = SimplePaddle() # for target in node.targets: - # if isinstance(target, ast.Name) and target.id == self.src_model_name: + # if isinstance(target, ast.Name) and target.id == self.model_name: # return self.add_create_model(node) # with PaDiffGuard(proxy_model): @@ -143,7 +143,7 @@ def add_create_model(self, node): targets=[ast.Name(id=self.proxy_model_name, ctx=ast.Store())], value=ast.Call( func=ast.Name(id="create_model", ctx=ast.Load()), - args=[ast.Name(id=self.src_model_name, ctx=ast.Load())], + args=[ast.Name(id=self.model_name, ctx=ast.Load())], keywords=[ast.keyword(arg="name", value=ast.Constant(value=self.padiff_model_name))], ), ) @@ -152,7 +152,7 @@ def add_create_model(self, node): mark_wrapped = ast.Assign( targets=[ ast.Attribute( - value=ast.Name(id=self.src_model_name, ctx=ast.Load()), attr="_padiff_wrapped", ctx=ast.Store() + value=ast.Name(id=self.model_name, ctx=ast.Load()), attr="_padiff_wrapped", ctx=ast.Store() ) ], value=ast.Constant(value=True), @@ -173,7 +173,7 @@ def add_create_model(self, node): op=ast.Not(), operand=ast.Call( func=ast.Name(id="hasattr", ctx=ast.Load()), - args=[ast.Name(id=self.src_model_name, ctx=ast.Load()), ast.Constant(value="_padiff_wrapped")], + args=[ast.Name(id=self.model_name, ctx=ast.Load()), ast.Constant(value="_padiff_wrapped")], keywords=[], ), ), @@ -186,7 +186,7 @@ def add_create_model(self, node): return [node, wrapper] def wrap_with_guard(self, node): - path = self.src_model_name.split(".") + path = self.model_name.split(".") model_node = ast.Name(id=path[0], ctx=ast.Load()) for attr in path[1:]: # if trainer.model model_node = ast.Attribute(value=model_node, attr=attr, ctx=ast.Load()) @@ -194,6 +194,11 @@ def wrap_with_guard(self, node): guard_keywords = [] + # optimizer + if "optimizer" in self.kwargs: + optim_kw = ast.keyword(arg="optimizer", value=ast.Name(id=self.kwargs["optimizer"], ctx=ast.Load())) + guard_keywords.append(optim_kw) + if self.mode == "align": # load_init_weights load_weights_kw = ast.keyword(arg="load_init_weights", value=ast.Constant(value=True)) @@ -282,7 +287,7 @@ def add_dump_report(self, node): def create_injected_script( src_script_path: str, framework: str, - src_model_name: str = "model", + model_name: str = "model", mode: str = "base", alignment_dir: str = None, **kwargs, @@ -300,7 +305,7 @@ def create_injected_script( injector = PaDiffInjector( framework, - src_model_name=src_model_name, + model_name=model_name, mode=mode, alignment_dir=alignment_dir, **kwargs, diff --git a/padiff/cli.py b/padiff/cli.py index 6f6a9d9..000a889 100644 --- a/padiff/cli.py +++ b/padiff/cli.py @@ -39,7 +39,8 @@ def load_yaml_config(config_path): def run_with_padiff( cmd: str, framework: str, - src_model_name: str = "model", + model_name: str = "model", + optim_name=None, mode="base", alignment_dir=None, **kwargs, @@ -56,8 +57,11 @@ def run_with_padiff( logger.error(f"Script not found: {script_path}") sys.exit(1) + if optim_name is not None: + kwargs["optimizer"] = optim_name + # run injected script - injected_script = create_injected_script(script_path, framework, src_model_name, mode, alignment_dir, **kwargs) + injected_script = create_injected_script(script_path, framework, model_name, mode, alignment_dir, **kwargs) injected_filename = os.path.basename(injected_script) new_cmd = ["python", injected_filename] + parts[2:] @@ -116,11 +120,39 @@ def main(): 那么您应该使用: --pd_model_name net - 3. 日志目录参数 (--log_dir): + 3. 优化器名参数 (--pt_optim_name, --pd_optim_name): + 这些参数指定您在脚本中创建优化器实例的**变量名**。 + * 它们不是类名,也不是文件名。 + * 它们是优化器实例化时 `=` 左边的标识符。 + + 示例: + 如果您的 PyTorch 脚本中有: + optim = torch.optim.Adam( + transformer.parameters(), + lr=1.0, + betas=(0.9, 0.98), + eps=1e-9, + ) + 那么您应该使用: + --pt_optim_name optim + + 如果您的 Paddle 脚本中有: + optimizer = paddle.optimizer.Adam( + parameters=transformer.parameters(), + learning_rate=1.0, + epsilon=1e-09, + beta1=0.9, + beta2=0.98, + weight_decay=0.0, + ) + 那么您应该使用: + --pd_optim_name optimizer + + 4. 日志目录参数 (--log_dir): 指定生成报告和日志的目录。 * 默认值: ./padiff_log - 4. 对齐深度参数 (--align_depth): + 5. 对齐深度参数 (--align_depth): 控制对齐的粒度。通过指定一个深度值,可以忽略该深度以下的所有子模块。 * 值为整数: 指定一个具体的深度。例如,--align_depth 1 会忽略深度为1及以下的所有子模块。 * 值为 'inf': (默认) 无限深度,会对齐到最细粒度的层(如 Linear, ReLU)。 @@ -130,13 +162,13 @@ def main(): --align_depth 1 # 对齐到第一层子模块 --align_depth inf # 对齐到最细粒度 - 5. 单步对齐模式参数 (--single_step_mode): + 6. 单步对齐模式参数 (--single_step_mode): 启用逐层对齐模式。 * 可选值: forward, backward, both * 默认值: None (禁用) * 当启用时,工具会从自动加载基准模型的输出,并用其替换对齐模型的相应层输出。 - 6. 结果对比参数: + 7. 结果对比参数: 控制模型输出结果的对比精度和模式。 * --atol: 绝对误差容忍度 (default: 1e-6) * --rtol: 相对误差容忍度 (default: 1e-6) @@ -151,6 +183,8 @@ def main(): --pd_cmd "python paddle_model.py" \\ --pt_model_name "model" \\ --pd_model_name "model" \\ + --pt_optim_name "optimizer" \\ + --pd_optim_name "optimizer" \\ --log_dir "./my_alignment_results" \\ --align_depth 1 \\ --single_step_mode "forward" \\ @@ -177,6 +211,18 @@ def main(): default="model", help="The model name that appears in the paddle script's code (default: 'model')", ) + parser.add_argument( + "--pt_optim_name", + type=str, + default=None, + help="The model name that appears in the pytorch script's code (default: 'model')", + ) + parser.add_argument( + "--pd_optim_name", + type=str, + default=None, + help="The model name that appears in the paddle script's code (default: 'model')", + ) parser.add_argument( "--log_dir", type=str, @@ -256,8 +302,11 @@ def main(): "action_name": args_dict.pop("action_name", "equal"), } - pt_dump_path = run_with_padiff(pt_cmd, "torch", pt_model_name, **args_dict) - pd_dump_path = run_with_padiff(pd_cmd, "paddle", pd_model_name, "align", pt_dump_path, **pd_kwargs) + pt_optim_name = args_dict.pop("pt_optim_name", None) + pd_optim_name = args_dict.pop("pd_optim_name", None) + + pt_dump_path = run_with_padiff(pt_cmd, "torch", pt_model_name, pt_optim_name, **args_dict) + pd_dump_path = run_with_padiff(pd_cmd, "paddle", pd_model_name, pd_optim_name, "align", pt_dump_path, **pd_kwargs) logger.info("Running comparison...") try: diff --git a/padiff/tools/dump.py b/padiff/tools/dump.py index 033d12c..c5e9b3b 100644 --- a/padiff/tools/dump.py +++ b/padiff/tools/dump.py @@ -18,9 +18,8 @@ import numpy import paddle -import torch -from ..utils import Counter, frames_to_string, logger, save_model_struct +from ..utils import Counter, frames_to_string, logger, save_model_struct, get_numpy_from_tensor dump_root_path = os.path.join(sys.path[0], "padiff_dump") @@ -93,12 +92,7 @@ def dump_report_node(wrap_node, tensor_dumper): "stack": frames_to_string(wrap_node.fwd_report.frames), } for tensor in wrap_node.fwd_report.tensors_for_compare(): - if tensor.dtype == torch.bfloat16: - np_array = tensor.detach().float().numpy() - elif tensor.dtype == paddle.bfloat16: - np_array = tensor.detach().astype("float32").numpy() - else: - np_array = tensor.detach().numpy() + np_array = get_numpy_from_tensor(tensor) file_name = tensor_dumper(np_array) node_info["fwd_outputs"].append( { @@ -111,7 +105,8 @@ def dump_report_node(wrap_node, tensor_dumper): ) for tensor in wrap_node.bwd_report.tensors_for_compare(): - file_name = tensor_dumper(tensor.detach().numpy()) + np_array = get_numpy_from_tensor(tensor) + file_name = tensor_dumper(np_array) node_info["bwd_grads"].append( { "path": file_name, diff --git a/padiff/utils/__init__.py b/padiff/utils/__init__.py index 519bd74..e24bf39 100644 --- a/padiff/utils/__init__.py +++ b/padiff/utils/__init__.py @@ -20,3 +20,4 @@ from .data_structures import * from .utils import * from .log import * +from .optim import * diff --git a/padiff/utils/utils.py b/padiff/utils/utils.py index a76097c..fee8541 100644 --- a/padiff/utils/utils.py +++ b/padiff/utils/utils.py @@ -19,6 +19,9 @@ import paddle import torch +import os.path as osp +import traceback + def set_seed(seed=42): np.random.seed(seed) @@ -28,6 +31,16 @@ def set_seed(seed=42): torch.cuda.manual_seed_all(seed) +def get_numpy_from_tensor(tensor): + if tensor.dtype == torch.bfloat16: + np_array = tensor.cpu().detach().float().numpy() + elif tensor.dtype == paddle.bfloat16: + np_array = tensor.cpu().detach().astype("float32").numpy() + else: + np_array = tensor.cpu().detach().numpy() + return np_array + + """ clone tensor """ @@ -253,10 +266,6 @@ def assert_tensor_equal(tensor1, tensor2, cfg): """ -import os.path as osp -import traceback - - def _is_system_package(filename): exclude = [ "lib/python",