Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ lint-all:
# # # # # # # # # # # # # # # Test Block # # # # # # # # # # # # # # #

.PHONY: test
test: unit-test
test: unit-test unit-test-special coverage-report

unit-test:
@echo "Running unit tests with coverage..."
PYTHONPATH=. coverage run --source=. tests/padiff_unittests.py
PADIFF_SILENT=1 PYTHONPATH="$(shell pwd):$(PYTHONPATH)" coverage run --source=. tests/padiff_unittests.py

unit-test-special:
@echo "Running test_api_to_Layer.py with PADIFF_API_CHECK=ON"
PADIFF_SILENT=1 PADIFF_API_CHECK=ON PYTHONPATH="$(shell pwd):$(PYTHONPATH)" coverage run --source=. --append tests/test_api_to_Layer.py

coverage-report:
@echo ""
@echo "Coverage Report:"
coverage report -m
Expand Down
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
**P**addle **A**utomatically **Diff** precision toolkits.


## 最近更新(latest 9.2
## 最近更新(latest 9.8

### 使用单行命令对齐(支持前反向对齐)

Expand Down Expand Up @@ -54,9 +54,15 @@ compare_mode: "mean"
action_name: "equal"
```

### 开启 debug 模式(获取更多 log 信息)
### log 设置

设置环境变量 `export PADIFF_DEBUG=1`,或使用命令运行 `PADIFF_DEBUG=1 python -m padiff.cli ...`
#### 开启 debug 模式

为了获取更多 log 信息,可以设置环境变量 `export PADIFF_LOG_LEVEL=DEBUG`,或使用命令运行 `PADIFF_LOG_LEVEL=DEBUG python -m padiff.cli ...`

#### 开启静默模式

或者为了保持控制台信息简洁,可以设置环境变量 `PADIFF_SILENT=1`,以便仅保存 log 文件,不在控制台输出 log 信息


## 简介
Expand Down
2 changes: 1 addition & 1 deletion padiff/abstracts/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def find_base_report_node(net_id, step_idx):
raise RuntimeError(f"Cannot find net_id={net_id} in base report.")

node_list = _context.base[net_id]
if step_idx < 0 or step_idx >= len(node_list):
if step_idx >= len(node_list):
raise RuntimeError(f"Index out of range: net_id={net_id}, step_idx={step_idx}, list length={len(node_list)}")

return _context.base[net_id][step_idx]
Expand Down
23 changes: 12 additions & 11 deletions padiff/abstracts/hooks/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ def PaDiffGuard(
# set max calls
calls_context.set_limit(max_calls)

logger.info(f"PaDiffGuard: creating proxy model.")
proxy_model = create_model(model, name=name, reset_dir=reset_flag)
model._padiff_proxy = proxy_model
logger.debug(f"PaDiffGuard: creating proxy model.")

if optimizer is not None and not hasattr(optimizer, "_padiff_proxy_model"):
logger.info(f"PaDiffGuard: wrapping optimizer.step().")
logger.debug(f"PaDiffGuard: wrapping optimizer.step().")
optimizer._padiff_proxy_model = proxy_model
wrap_optimizer_step(optimizer)

Expand Down Expand Up @@ -280,15 +280,16 @@ def PaDiffGuard(

yield model

except _CallsComplete:
# dump
proxy_model.dump_report(proxy_model.dump_path)
proxy_model.dump_weights(proxy_model.dump_path)
if optimizer is None:
proxy_model.dump_grads(proxy_model.dump_path)

sys.exit(0)

except SystemExit as e:
logger.info("PaDiffGuard: SystemExit received, skipping dump_report.")
raise

finally:
try:
proxy_model.dump_report(proxy_model.dump_path)
proxy_model.dump_weights(proxy_model.dump_path)
if optimizer is None:
proxy_model.dump_grads(proxy_model.dump_path)
except Exception as e:
logger.error(f"Failed to dump: {e}")
sys.exit(0)
76 changes: 44 additions & 32 deletions padiff/abstracts/hooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,38 +167,7 @@ def info_hook(model, input, output, net_id):
if single_step_state() == "forward" and net_id != -1:
# two report_item with same id, the step_idx should be corresponded
step_idx = len(list(filter(lambda x: x.type == "forward" and x.net_id == net_id, report.items))) - 1

try:
base_report_node = find_base_report_node(net_id, step_idx)
except (IndexError, RuntimeError) as e:
error_msg = str(e)
base_max_calls = "unknown"
if "list length=" in error_msg:
try:
base_max_calls = int(error_msg.split("list length=")[1].split()[0])
except:
pass
current_calls = step_idx + 1
route = getattr(model, "route", "unknown")
logger.error(
f"\n ❌ Single-step alignment FAILED: Execution path mismatch!"
f"\n 📌 Layer '{route}' called {current_calls} times (current) vs {base_max_calls} times (base)."
f"\n 📌 Check the forward logic in both models around this layer."
)
sys.exit(1)

if base_report_node["name"] != _model.__class__.__name__:
warning_msg = (
f"\n ⚠️ Single-step alignment FAILED: Layer with net_id={net_id} mismatch!"
f"\n 📌 Mismatch Layer: {base_report_node['name']}(base) vs {_model.__class__.__name__}(raw)"
f"\n 💡 Suggestion: Models have different architectures or initialization order. "
"Please check the model implementation or decrease 'align_depth' to reduce the alignment "
"granularity, or add layers that do not require alignment to the blacklist."
)
logger.warning(warning_msg)
else:
logger.debug(f"Single Step: {_model.__class__.__name__}(net_id={net_id})")

base_report_node = single_step_check(report, net_id, step_idx, _model.__class__.__name__, "forward")
retval = map_structure(replace_forward_output(base_report_node), output)
__in_info_hook__ = False
return retval
Expand Down Expand Up @@ -297,3 +266,46 @@ def inner(input_):
return input_

return inner


def single_step_check(report, net_id, step_idx, current_name, node_type, bwd_item=None):

try:
base_report_node = find_base_report_node(net_id, step_idx)
if base_report_node["name"] != current_name:
warning_msg = (
f"\n ⚠️ Single-step alignment FAILED: {node_type} with net_id={net_id} mismatch!\n"
f" 📌 Mismatch {node_type.capitalize()}: {base_report_node['name']}(base) vs {current_name}(raw)\n"
f" 💡 Suggestion: Models have different architectures or initialization order. "
"Please check the model implementation or decrease 'align_depth' to reduce the alignment "
"granularity, or add layers that do not require alignment to the blacklist."
)
logger.warning(warning_msg)
else:
logger.debug(f"Single Step: {current_name}(net_id={net_id})")

return base_report_node

except (IndexError, RuntimeError) as e:
error_msg = str(e)
base_max_calls = "unknown"
if "list length=" in error_msg:
try:
base_max_calls = int(error_msg.split("list length=")[1].split()[0])
except:
pass
current_calls = step_idx + 1
route = "unknown"
if bwd_item and hasattr(bwd_item.net, "route"):
route = bwd_item.net.route
elif hasattr(report.stack._top().net, "route"):
route = report.stack._top().net.route

logger.error(
f"\n ❌ Single-step alignment FAILED: Execution path mismatch in {node_type}!"
f"\n 📌 Layer '{route}' called {current_calls} times (current) vs {base_max_calls} times (base)."
f"\n 📌 Check the {node_type} logic in both models around this layer."
)
sys.exit(1)

return None
2 changes: 1 addition & 1 deletion padiff/comparison/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __call__(self, file_list_0, file_list_1, cfg):
tensor_0 = load_numpy(info_0["path"])
tensor_1 = load_numpy(info_1["path"])

if cfg["transpose"]:
if "transpose" in cfg and cfg["transpose"]:
tensor_1 = np.transpose(tensor_1)

if tensor_0.size == 0 or tensor_1.size == 0:
Expand Down
69 changes: 39 additions & 30 deletions padiff/comparison/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,53 @@ def compare_dumps(dump_path1, dump_path2, cfg=None, diff_phase="both"):
# check grads
grads_success = None
if os.path.exists(f"{dump_path1}/grads.json") and os.path.exists(f"{dump_path2}/grads.json"):
logger.info("🔍 Start comparison grads (check_grads)...")
try:
grads_success = check_grads(dump_path1, dump_path2, cfg=cfg)
if grads_success:
logger.info("✅ check_grads: SUCCESS !!!\n")
else:
logger.error("❌ check_grads: FAILED !!!\n")
except Exception as e:
logger.error(f"❌ check_grads: FAILED with error: {e}\n")
grads_success = False
if len(os.listdir(f"{dump_path1}/grads")) == 0 or len(os.listdir(f"{dump_path2}/grads")) == 0:
logger.warning(f" ⚠️ Grads dir is empty of {dump_path1} or/and {dump_path2}\n")
else:
logger.info("🔍 Start comparison grads (check_grads)...")
try:
grads_success = check_grads(dump_path1, dump_path2, cfg=cfg)
if grads_success:
logger.info("✅ check_grads: SUCCESS !!!\n")
else:
logger.error("❌ check_grads: FAILED !!!\n")
except Exception as e:
logger.error(f"❌ check_grads: FAILED with error: {e}\n")
grads_success = False

# check weights
weights_success = None
if os.path.exists(f"{dump_path1}/weights.json") and os.path.exists(f"{dump_path2}/weights.json"):
logger.info("🔍 Start comparison weights (check_weights)...")
try:
weights_success = check_weights(dump_path1, dump_path2, cfg=cfg)
if weights_success:
logger.info("✅ check_weights: SUCCESS !!!\n")
else:
logger.error("❌ check_weights: FAILED !!!\n")
except Exception as e:
logger.error(f"❌ check_weights: FAILED with error: {e}\n")
weights_success = False
if len(os.listdir(f"{dump_path1}/weights")) == 0 or len(os.listdir(f"{dump_path2}/weights")) == 0:
logger.warning(f" ⚠️ Weights dir is empty of {dump_path1} or/and {dump_path2}\n")
else:
logger.info("🔍 Start comparison weights (check_weights)...")
try:
weights_success = check_weights(dump_path1, dump_path2, cfg=cfg)
if weights_success:
logger.info("✅ check_weights: SUCCESS !!!\n")
else:
logger.error("❌ check_weights: FAILED !!!\n")
except Exception as e:
logger.error(f"❌ check_weights: FAILED with error: {e}\n")
weights_success = False

# check params
params_success = None
if os.path.exists(f"{dump_path1}/params.json") and os.path.exists(f"{dump_path2}/params.json"):
logger.info("🔍 Start comparison all parameters (check_params)...")
try:
params_success = check_params(dump_path1, dump_path2, cfg=cfg)
if params_success:
logger.info("✅ check_params: SUCCESS !!!\n")
else:
logger.error("❌ check_params: FAILED !!!\n")
except Exception as e:
logger.error(f"❌ check_params: FAILED with error: {e}\n")
params_success = False
if len(os.listdir(f"{dump_path1}/params")) == 0 or len(os.listdir(f"{dump_path2}/params")) == 0:
logger.warning(f" ⚠️ Params dir is empty of {dump_path1} or/and {dump_path2}\n")
else:
logger.info("🔍 Start comparison all parameters (check_params)...")
try:
params_success = check_params(dump_path1, dump_path2, cfg=cfg)
if params_success:
logger.info("✅ check_params: SUCCESS !!!\n")
else:
logger.error("❌ check_params: FAILED !!!\n")
except Exception as e:
logger.error(f"❌ check_params: FAILED with error: {e}\n")
params_success = False

# final result
success = report_success
Expand Down
50 changes: 33 additions & 17 deletions padiff/tools/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def load_first_input_from_dump(report_path, tar_framework):
args = []
kwargs = {}

NATIVE_TYPES = (int, float, str, bool, type(None))

for item in meta_info:
file_path = os.path.join(input_dir, item["path"])
key = item.get("key")
Expand All @@ -62,31 +64,45 @@ def load_first_input_from_dump(report_path, tar_framework):
tensor.requires_grad_(True)
else:
raise ValueError(f"Unsupported framework: {tar_framework}")
value = tensor

if key is None:
args.append(tensor)
else:
kwargs[key] = tensor
else:
with open(file_path, "r") as f:
full_item = json.load(f)

if item["type"] == "dict":
reconstructed_dict = {}
for k, v in item["data"].items():
reconstructed_dict[k] = v
value = reconstructed_dict
elif item["type"] in ["list", "tuple"]:
reconstructed_list = [v for v in item["data"]]
value = tuple(reconstructed_list) if item["type"] == "tuple" else reconstructed_list
value = {k: v for k, v in full_item["data"].items()}
elif item["type"] == "list":
value = [v for v in full_item["data"]]
elif item["type"] == "tuple":
value = tuple(v for v in full_item["data"])
elif item["type"] == "int":
value = int(full_item["data"])
elif item["type"] == "float":
value = float(full_item["data"])
elif item["type"] == "bool":
value = full_item["data"].lower() == "true"
elif item["type"] == "NoneType":
value = None
elif item["type"] == "str":
value = full_item["data"]
else:
value = item["data"]
logger.warning(f"Skipping unsupported input type '{item['type']}' for input(key={key}).")
continue

if key is None:
args.append(value)
else:
kwargs[key] = value

if key is None:
args.append(value)
else:
kwargs[key] = value
except Exception as e:
logger.error(f"Error loading metadata file {file_path}: {e}")
logger.error(f"Error loading input(key={key}) in {file_path}: {e}")
raise

if not args and not kwargs:
logger.warning("No valid inputs were loaded from the dump.")
return None

return (args, kwargs)


Expand Down
Loading