diff --git a/padiff/abstracts/hooks/guard.py b/padiff/abstracts/hooks/guard.py index 18426b5..f97c230 100644 --- a/padiff/abstracts/hooks/guard.py +++ b/padiff/abstracts/hooks/guard.py @@ -284,6 +284,8 @@ def PaDiffGuard( # dump proxy_model.dump_report(proxy_model.dump_path) proxy_model.dump_weights(proxy_model.dump_path) + if optimizer is None: + proxy_model.dump_grads(proxy_model.dump_path) sys.exit(0) diff --git a/padiff/abstracts/hooks/hook.py b/padiff/abstracts/hooks/hook.py index 7ec297d..1b7af6c 100644 --- a/padiff/abstracts/hooks/hook.py +++ b/padiff/abstracts/hooks/hook.py @@ -28,6 +28,7 @@ logger, map_structure, get_numpy_from_tensor, + is_require_grad, ) from .base import current_report, find_base_report_node, single_step_state @@ -46,16 +47,19 @@ def register_hooker(model): handle_init = model.register_forward_pre_hook(partial(init_weights_hook)) remove_handles.append(handle_init) + param_handles = creat_param_handles(model) + remove_handles.extend(param_handles) + # register layer-level hooks for mod in models: pre_handle = mod.register_forward_pre_hook(partial(pre_structure_hook)) if mod.model not in marker.black_list: - logger.debug(f"Register(info_hook): {mod.model.__class__.__name__}(net_id={idx})") + logger.debug(f"Register(info_hook): {mod.class_name}(net_id={idx})") handle = mod.register_forward_post_hook(partial(info_hook, net_id=idx)) remove_handles.append(handle) idx += 1 else: - logger.debug(f"Skip(info_hook): {mod.model.__class__.__name__}(blacklisted)") + logger.debug(f"Skip(info_hook): {mod.class_name}(blacklisted)") post_handle = mod.register_forward_post_hook(partial(post_structure_hook)) remove_handles.extend([pre_handle, post_handle]) yield @@ -228,6 +232,26 @@ def tensor_hook(x_grad, bwd_item, nth_tensor, net_id): return x_grad +def creat_param_handles(model): + handles = [] + + for name, proxy_param in model.named_parameters(recursively=True): + if is_require_grad(proxy_param.param): + + def make_hook(param, param_name): + def hook(grad): + logger.debug(f"Grad hook triggered for {param_name}") + param._collected_grad = grad + return grad + + return hook + + handle = proxy_param.param.register_hook(make_hook(proxy_param.param, name)) + handles.append(handle) + + return handles + + """ utils """ diff --git a/padiff/abstracts/report.py b/padiff/abstracts/report.py index 33627ad..eab6440 100644 --- a/padiff/abstracts/report.py +++ b/padiff/abstracts/report.py @@ -14,7 +14,7 @@ import os -from ..utils import Counter, for_each_grad_tensor, for_each_tensor +from ..utils import Counter, for_each_grad_tensor, for_each_tensor, for_each_grad_tensor_no_require class LayerStack: @@ -188,7 +188,7 @@ def tensors_for_compare(self): if self.type == "forward": return [t for (t,) in for_each_tensor(self.output)] if self.type == "backward": - return [t for (t,) in for_each_grad_tensor(self.input_grads)] + return [t for (t,) in for_each_grad_tensor_no_require(self.input_grads)] def __repr__(self): return self.__str__() diff --git a/padiff/cli.py b/padiff/cli.py index 000a889..c2baeb0 100644 --- a/padiff/cli.py +++ b/padiff/cli.py @@ -124,6 +124,7 @@ def main(): 这些参数指定您在脚本中创建优化器实例的**变量名**。 * 它们不是类名,也不是文件名。 * 它们是优化器实例化时 `=` 左边的标识符。 + * 默认值: None (不传递优化器) 示例: 如果您的 PyTorch 脚本中有: diff --git a/padiff/tools/dump.py b/padiff/tools/dump.py index c5e9b3b..92dd2ca 100644 --- a/padiff/tools/dump.py +++ b/padiff/tools/dump.py @@ -194,14 +194,14 @@ def dump_grads(model, path): grad_dumper = numpy_dumper(path + "/grads", "grads") def _dump(param_name, param, param_info): - if param.main_grad() is not None: - file_name = grad_dumper(param.main_grad()) - param_info["grads"][param_name] = file_name - elif param.grad() is not None: - file_name = grad_dumper(param.grad()) - param_info["grads"][param_name] = file_name - else: - param_info["grads"][param_name] = None + grad = param.main_grad() + if grad is None: + grad = param.grad() + if grad is None and hasattr(param.param, "_collected_grad"): + grad = param.param._collected_grad + grad = get_numpy_from_tensor(grad) if grad is not None else None + + param_info["grads"][param_name] = grad_dumper(grad) if grad is not None else None dump_param_prototype(model, _dump, f"{path}/grads.json") diff --git a/padiff/utils/utils.py b/padiff/utils/utils.py index fee8541..40ef696 100644 --- a/padiff/utils/utils.py +++ b/padiff/utils/utils.py @@ -228,6 +228,12 @@ def filter_fn(ts): yield ts +def for_each_grad_tensor_no_require(*structure): + for ts in for_each_tensor(*structure): + if is_tensors(*ts): + yield ts + + def map_structure_and_replace_key(func, structure1, structure2): """ Apply `func` to each entry in `structure` and return a new structure.