Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
<!-- Demo: https://github.com/PaddlePaddle/PaDiff/pull/2 -->

### PR types

<!-- One of [ New features | Bug fixes | Function optimization | Performance optimization | Breaking changes | Others ] -->

### PR changes

<!-- One of [ APIs | Docs | Others ] -->

### Description

<!-- Describe what this PR does -->
4 changes: 4 additions & 0 deletions .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ on:
- 'docs/**'
- '**.md'
- '**.rst'
paths-ignore:
- 'PULL_REQUEST_TEMPLATE.md'
pull_request:
paths:
- 'docs/**'
- '**.md'
- '**.rst'
paths-ignore:
- 'PULL_REQUEST_TEMPLATE.md'

jobs:
docs-check:
Expand Down
38 changes: 38 additions & 0 deletions padiff/abstracts/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import contextvars
from typing import Dict
from ...utils import logger

# --- Core state management class ---
# This is an internal state shared by all Guards and should be placed first
Expand Down Expand Up @@ -154,3 +155,40 @@ def get_calls_context() -> _CallsContext:
if _calls_context is None:
_calls_context = _CallsContext()
return _calls_context


def check_configuration(single_step_mode, max_calls):
# check the impact of single_step_mode
if single_step_mode is not None:
logger.warning(
f"\n ⚠️ Single-step alignment WARNING: 'single_step_mode={single_step_mode}'. "
"This halts real backpropagation, resulting in empty 'grads' directory and invalid 'loss.backward()'."
"\n 📌 When 'single_step_mode' in ('backward', 'both'), instead, the grad of outputs to input "
"(but not param.grad) manually calculated and then dumped to the 'tensor' directory."
"\n 💡 Set 'single_step_mode=None' if normal grad updates are needed."
)

# check compatibility of single_step_mode and max_calls
if max_calls != 1:
raise ValueError(
f"\n ❌ Configuration Conflict: 'single_step_mode'={single_step_mode} is incompatible with 'max_calls={max_calls}' (must be 1)."
f"\n 📌 The 'single_step_mode' is designed to replace layer outputs with pre-saved values from a single forward/backward pass."
f"\n 📌 Using it with multiple calls will lead to undefined behavior, such as shape mismatches."
f"\n 💡 To resolve this:"
f"\n - Set 'max_calls=1' for single-step alignment, or"
f"\n - Set 'single_step_mode=None' for multi-call scenarios."
)

# check potential risks of max_calls
elif max_calls > 1:
logger.warning(
f"\n ⚠️ Multi-call WARNING: 'max_calls={max_calls}' which > 1."
"\n 📌 This feature is intended for comparing multiple forward passes on the same input sequence."
"\n 📌 To ensure valid results, you MUST guarantee that the input data order is IDENTICAL "
"between the base and raw models, otherwise it may cause alignment failure, such as shape mismatch."
"\n 📌 This means(at least but not only):"
"\n - The dataset should NOT be shuffled."
"\n - The data loader should use a fixed seed."
"\n - The input sequence must be strictly preserved."
"\n 💡 If your goal is to check alignment on a single input, consider setting 'max_calls=1'."
)
53 changes: 33 additions & 20 deletions padiff/abstracts/hooks/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import torch

from ...utils import set_seed, logger, wrap_optimizer_step
from .base import _context, _current_report, _CallsComplete, current_report, get_calls_context
from .base import (
_context,
_current_report,
_CallsComplete,
current_report,
get_calls_context,
check_configuration,
)
from .hook import register_hooker
from ..proxy import create_model
from ...tools import load_first_input_from_dump, load_init_weights_from_dump
Expand Down Expand Up @@ -82,19 +89,19 @@ def SingleStepGuard(diff_phase, base_dump_path):
if not os.path.exists(report_json_path):
logger.error(f"report.json not found at '{report_json_path}'.")

_context.phase = diff_phase
report_json_path = os.path.join(base_dump_path, "report.json")
with open(report_json_path, "r") as f:
base_report_data = json.load(f)
_context.base = _context._split_by_net_id(base_report_data)
except (FileNotFoundError, ValueError, json.JSONDecodeError) as e:
logger.error(f"SingleStepGuard failed to initialize: {type(e).__name__}: {str(e)}")
raise

_context.phase = diff_phase
_context.base = _context._split_by_net_id(base_report_data)

try:
yield

except _CallsComplete:
raise
except Exception as e:
logger.error(f"SingleStepGuard failed to initialize: {e}")
raise
finally:
_context.phase = old_phase
_context.base = old_base
Expand Down Expand Up @@ -244,6 +251,9 @@ def PaDiffGuard(
model._padiff_proxy = proxy_model
logger.debug(f"PaDiffGuard: creating proxy model.")

# check single step mode
check_configuration(single_step_mode, max_calls)

if optimizer is not None and not hasattr(optimizer, "_padiff_proxy_model"):
logger.debug(f"PaDiffGuard: wrapping optimizer.step().")
optimizer._padiff_proxy_model = proxy_model
Expand All @@ -265,6 +275,15 @@ def PaDiffGuard(
else:
proxy_model = model._padiff_proxy

def perform_final_dump():
try:
proxy_model.dump_report(proxy_model.dump_path)
proxy_model.dump_weights(proxy_model.dump_path)
if optimizer is None:
proxy_model.dump_grads(proxy_model.dump_path)
except Exception as dump_e:
logger.error(f"PaDiffGuard: failed to dump! {dump_e}")

try:
# set hooks
with contextlib.ExitStack() as stack:
Expand All @@ -285,19 +304,13 @@ def PaDiffGuard(

yield model

perform_final_dump()

except _CallsComplete as e:
pass
perform_final_dump()
sys.exit(0)

except Exception as e:
logger.error(f"PaDiffGuard: failed! {e}")
raise

finally:
try:
proxy_model.dump_report(proxy_model.dump_path)
proxy_model.dump_weights(proxy_model.dump_path)
if optimizer is None:
proxy_model.dump_grads(proxy_model.dump_path)
except Exception as e:
logger.error(f"PaDiffGuard: failed to dump! {e}")
sys.exit(0)
perform_final_dump()
sys.exit(1)
21 changes: 16 additions & 5 deletions padiff/abstracts/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ def info_hook(model, input, output, net_id):
# if under single step forward guard
if single_step_state() in ("forward", "both") and net_id != -1:
# two report_item with same id, the step_idx should be corresponded
model_name = _model.__class__.__name__
step_idx = len(list(filter(lambda x: x.type == "forward" and x.net_id == net_id, report.items))) - 1
base_report_node = single_step_check(report, net_id, step_idx, _model.__class__.__name__, "forward")
retval = map_structure(replace_forward_output(base_report_node), output)
base_report_node = single_step_check(report, net_id, step_idx, model_name, "forward")
retval = map_structure(replace_forward_output(base_report_node, model_name), output)
__in_info_hook__ = False
return retval
else:
Expand Down Expand Up @@ -259,17 +260,27 @@ def __init__(self, net):
self.__api__ = net.__api__


def replace_forward_output(node):
def replace_forward_output(node, current_name=None):
numpy_file_list = node["fwd_outputs"]
cur_idx = 0

def inner(input_):
nonlocal cur_idx
if isinstance(input_, (paddle.Tensor, torch.Tensor)):
if cur_idx >= len(numpy_file_list):
raise RuntimeError(
"In single step mode, try to replace tensor by dumpped numpy value, but the number of tensors and numpy is not equal. Maybe the models are not corresponded."
f"\n ⚠️ Single-step alignment FAILED: the {cur_idx + 1}st output is requested, "
f"but only {len(numpy_file_list)} pre-saved numpy files are available."
f"\n 📌 Layer Name: {current_name}(raw)"
"\n 💡 Possible Causes and Solutions:"
"\n - The number of outputs from the current layer in the raw model does not match "
"that of its corresponding layer in the base model."
"\n - Verify that both models have identical architectures for this layer."
"\n - If the corresponding relationship of the current layer is correct, "
"please disable single step mode, or add the layer to blacklist to skip the check of this layer."
)
value = np.load(numpy_file_list[cur_idx]["path"])
cur_idx += 1
if isinstance(input_, paddle.Tensor):
return paddle.to_tensor(value, dtype=input_.dtype)
else:
Expand All @@ -288,7 +299,7 @@ def single_step_check(report, net_id, step_idx, current_name, node_type, bwd_ite
warning_msg = (
f"\n ⚠️ Single-step alignment WARNING: {node_type} with net_id={net_id} mismatch!\n"
f" 📌 Mismatch {node_type.capitalize()}: {base_report_node['name']}(base) vs {current_name}(raw)\n"
f" 💡 Suggestion: Models have different architectures or initialization order. "
" 💡 Suggestion: Models have different architectures or initialization order. "
"Please check the model implementation or decrease 'align_depth' to reduce the alignment "
"granularity, or add layers that do not require alignment to the blacklist."
)
Expand Down
12 changes: 11 additions & 1 deletion padiff/abstracts/marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def update_white_list(self, layers, mode="self"):
self.white_list_recursively.update(set(layers))
self.use_white_list = True

def update_unassigned_weights_list(self, layers, mode="all"):
def update_unassigned_weights_list(self, layers, mode="all", include_black_list=False):
assert mode in ("self", "sublayers", "all")
if isinstance(layers, (paddle.nn.Layer, torch.nn.Module)):
layers = [layers]
Expand All @@ -62,6 +62,16 @@ def update_unassigned_weights_list(self, layers, mode="all"):
if mode in ("sublayers", "all"):
self.unassigned_weights_list_recursively.update(set(layers))

if include_black_list:
self.sync_unassigned_with_black_list(mode)

def sync_unassigned_with_black_list(self, mode="all"):
assert mode in ("self", "sublayers", "all")
if mode in ("self", "all"):
self.unassigned_weights_list.update(self.black_list)
if mode in ("sublayers", "all"):
self.unassigned_weights_list_recursively.update(self.black_list_recursively)

def set_layer_map(self, layer_map):
_layer_map = []
for layer in self.traversal_for_layer_map():
Expand Down
4 changes: 2 additions & 2 deletions padiff/abstracts/proxy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def update_black_list(self, layers, mode="all"):
def update_white_list(self, layers, mode="self"):
self.marker.update_white_list(layers, mode)

def update_unassigned_weights_list(self, layers, mode="self"):
self.marker.update_unassigned_weights_list(layers, mode)
def update_unassigned_weights_list(self, layers, mode="self", include_black_list=False):
self.marker.update_unassigned_weights_list(layers, mode, include_black_list)

def update_black_list_with_class(self, layer_class, mode="self"):
all_sub_layers = []
Expand Down
6 changes: 3 additions & 3 deletions padiff/comparison/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def compare_dumps(dump_path1, dump_path2, cfg=None, diff_phase="both"):
grads_success = None
if os.path.exists(f"{dump_path1}/grads.json") and os.path.exists(f"{dump_path2}/grads.json"):
if len(os.listdir(f"{dump_path1}/grads")) == 0 or len(os.listdir(f"{dump_path2}/grads")) == 0:
logger.warning(f" ⚠️ Grads dir is empty of {dump_path1} or/and {dump_path2}\n")
logger.warning(f"⚠️ Grads dir is empty of {dump_path1} or/and {dump_path2}\n")
else:
logger.info("🔍 Start comparison grads (check_grads)...")
try:
Expand All @@ -54,7 +54,7 @@ def compare_dumps(dump_path1, dump_path2, cfg=None, diff_phase="both"):
weights_success = None
if os.path.exists(f"{dump_path1}/weights.json") and os.path.exists(f"{dump_path2}/weights.json"):
if len(os.listdir(f"{dump_path1}/weights")) == 0 or len(os.listdir(f"{dump_path2}/weights")) == 0:
logger.warning(f" ⚠️ Weights dir is empty of {dump_path1} or/and {dump_path2}\n")
logger.warning(f"⚠️ Weights dir is empty of {dump_path1} or/and {dump_path2}\n")
else:
logger.info("🔍 Start comparison weights (check_weights)...")
try:
Expand All @@ -71,7 +71,7 @@ def compare_dumps(dump_path1, dump_path2, cfg=None, diff_phase="both"):
params_success = None
if os.path.exists(f"{dump_path1}/params.json") and os.path.exists(f"{dump_path2}/params.json"):
if len(os.listdir(f"{dump_path1}/params")) == 0 or len(os.listdir(f"{dump_path2}/params")) == 0:
logger.warning(f" ⚠️ Params dir is empty of {dump_path1} or/and {dump_path2}\n")
logger.warning(f"⚠️ Params dir is empty of {dump_path1} or/and {dump_path2}\n")
else:
logger.info("🔍 Start comparison all parameters (check_params)...")
try:
Expand Down
3 changes: 2 additions & 1 deletion padiff/tools/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ def _dump(param_name, param, param_info):
dump_param_prototype(model, _dump, f"{path}/params.json")


def dump_weights(model, path):
def dump_weights(model, path, exclude_blacklist=True):
weight_dumper = numpy_dumper(path + "/weights", "weights")
model.update_unassigned_weights_list([], mode="all", include_black_list=exclude_blacklist)

def _dump(param_name, param, param_info):
file_name = weight_dumper(param.numpy())
Expand Down
4 changes: 2 additions & 2 deletions padiff/tools/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def load_init_weights_from_dump(
param_key = param_name

if param_key not in loaded_weights:
logger.warning(f"param {param_key}({param_name}) not found, skip it.")
logger.debug(f"param {param_key}({param_name}) not found, skip it.")
continue
np_value = loaded_weights[param_key]

Expand Down Expand Up @@ -204,7 +204,7 @@ def load_init_weights_from_dump(
if success_count == all_count:
logger.info(f"Loading success: all {all_count} init_weights loaded. ")
else:
logger.warning(f"Loading might fail! {all_count} init_weights in total but only {success_count} loaded!")
logger.warning(f"Loading might FAILED! {all_count} init_weights in total but only {success_count} loaded!")
return True
except Exception as e:
logger.error(f"{type(e).__name__}: {e}")
Expand Down
6 changes: 4 additions & 2 deletions padiff/utils/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import functools

from ..tools import dump_grads
from .log import logger


def wrap_optimizer_step(optimizer):
if hasattr(optimizer, "_original_step"):
return

logger.debug("wrap_optimizer_step: wrap optimizer.step() in first call")
original_step = optimizer.step

@functools.wraps(original_step)
Expand All @@ -29,7 +30,8 @@ def wrapped_step():

proxy_model = getattr(optimizer, "_padiff_proxy_model", None)
if proxy_model is not None:
dump_grads(proxy_model, proxy_model.dump_path)
logger.debug(f"wrap_optimizer_step: Dump grads after optimizer.step()")
proxy_model.dump_grads(proxy_model.dump_path)

optimizer.step = wrapped_step
optimizer._original_step = original_step
Loading