Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions padiff/abstracts/hooks/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def PaDiffGuard(
# dump
proxy_model.dump_report(proxy_model.dump_path)
proxy_model.dump_weights(proxy_model.dump_path)
if optimizer is None:
proxy_model.dump_grads(proxy_model.dump_path)

sys.exit(0)

Expand Down
28 changes: 26 additions & 2 deletions padiff/abstracts/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
logger,
map_structure,
get_numpy_from_tensor,
is_require_grad,
)
from .base import current_report, find_base_report_node, single_step_state

Expand All @@ -46,16 +47,19 @@ def register_hooker(model):
handle_init = model.register_forward_pre_hook(partial(init_weights_hook))
remove_handles.append(handle_init)

param_handles = creat_param_handles(model)
remove_handles.extend(param_handles)

# register layer-level hooks
for mod in models:
pre_handle = mod.register_forward_pre_hook(partial(pre_structure_hook))
if mod.model not in marker.black_list:
logger.debug(f"Register(info_hook): {mod.model.__class__.__name__}(net_id={idx})")
logger.debug(f"Register(info_hook): {mod.class_name}(net_id={idx})")
handle = mod.register_forward_post_hook(partial(info_hook, net_id=idx))
remove_handles.append(handle)
idx += 1
else:
logger.debug(f"Skip(info_hook): {mod.model.__class__.__name__}(blacklisted)")
logger.debug(f"Skip(info_hook): {mod.class_name}(blacklisted)")
post_handle = mod.register_forward_post_hook(partial(post_structure_hook))
remove_handles.extend([pre_handle, post_handle])
yield
Expand Down Expand Up @@ -228,6 +232,26 @@ def tensor_hook(x_grad, bwd_item, nth_tensor, net_id):
return x_grad


def creat_param_handles(model):
handles = []

for name, proxy_param in model.named_parameters(recursively=True):
if is_require_grad(proxy_param.param):

def make_hook(param, param_name):
def hook(grad):
logger.debug(f"Grad hook triggered for {param_name}")
param._collected_grad = grad
return grad

return hook

handle = proxy_param.param.register_hook(make_hook(proxy_param.param, name))
handles.append(handle)

return handles


"""
utils
"""
Expand Down
4 changes: 2 additions & 2 deletions padiff/abstracts/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os

from ..utils import Counter, for_each_grad_tensor, for_each_tensor
from ..utils import Counter, for_each_grad_tensor, for_each_tensor, for_each_grad_tensor_no_require


class LayerStack:
Expand Down Expand Up @@ -188,7 +188,7 @@ def tensors_for_compare(self):
if self.type == "forward":
return [t for (t,) in for_each_tensor(self.output)]
if self.type == "backward":
return [t for (t,) in for_each_grad_tensor(self.input_grads)]
return [t for (t,) in for_each_grad_tensor_no_require(self.input_grads)]

def __repr__(self):
return self.__str__()
Expand Down
1 change: 1 addition & 0 deletions padiff/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def main():
这些参数指定您在脚本中创建优化器实例的**变量名**。
* 它们不是类名,也不是文件名。
* 它们是优化器实例化时 `=` 左边的标识符。
* 默认值: None (不传递优化器)

示例:
如果您的 PyTorch 脚本中有:
Expand Down
16 changes: 8 additions & 8 deletions padiff/tools/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ def dump_grads(model, path):
grad_dumper = numpy_dumper(path + "/grads", "grads")

def _dump(param_name, param, param_info):
if param.main_grad() is not None:
file_name = grad_dumper(param.main_grad())
param_info["grads"][param_name] = file_name
elif param.grad() is not None:
file_name = grad_dumper(param.grad())
param_info["grads"][param_name] = file_name
else:
param_info["grads"][param_name] = None
grad = param.main_grad()
if grad is None:
grad = param.grad()
if grad is None and hasattr(param.param, "_collected_grad"):
grad = param.param._collected_grad
grad = get_numpy_from_tensor(grad) if grad is not None else None

param_info["grads"][param_name] = grad_dumper(grad) if grad is not None else None

dump_param_prototype(model, _dump, f"{path}/grads.json")

Expand Down
6 changes: 6 additions & 0 deletions padiff/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def filter_fn(ts):
yield ts


def for_each_grad_tensor_no_require(*structure):
for ts in for_each_tensor(*structure):
if is_tensors(*ts):
yield ts


def map_structure_and_replace_key(func, structure1, structure2):
"""
Apply `func` to each entry in `structure` and return a new structure.
Expand Down
Loading