Skip to content

add autocudagraph#896

Open
DrRyanHuang wants to merge 2 commits intoPaddlePaddle:developfrom
cattidea:autocudagraph
Open

add autocudagraph#896
DrRyanHuang wants to merge 2 commits intoPaddlePaddle:developfrom
cattidea:autocudagraph

Conversation

@DrRyanHuang
Copy link
Copy Markdown

@DrRyanHuang DrRyanHuang commented Apr 29, 2026

Description

Add autocudagraph decorator for automated CUDAGraph acceleration of forward and backward passes.

🌟 Key Features & Implementation

src/paddlefleet/cudagraph.py:核心基础设施

  • 核心上下文 (CUDAGraphContext):基于 dataclass 统一管理静态 Tensor Buffer、Fwd/Bwd 句柄与 PyLayer Runner。
  • 泛型 IO 解析 (get_tensors / set_tensors):支持递归遍历嵌套结构(Tensor / list / tuple / dict)提取与写回。
  • 全自动捕获流 (autocudagraph):支持 warmup -> capture -> replay 完整生命周期。
  • 内存优化:前反向分别捕获为独立 CUDAGraph,共享同一 unique memory pool。
  • 动态路由 & Fallback:引入 dispatch_key_fn,超出 max_graphs 上限时自动无缝回退 Eager 模式。
  • 梯度累积保障:自动快照与恢复 nn.Layer 参数梯度,防止 dummy backward 污染真实的跨步累积梯度。
  • 反向传播桥接:基于 paddle.autograd.PyLayer 正确路由输入张量梯度与权重参数梯度。

🧪 Test Coverage

tests/single_card_tests/test_autocudagraph.py

Test Class 覆盖点
TestPureFunctions 纯函数、多输入、stop_gradient 混合场景
TestOOPAndSubmodules OOP 调用、权重指针漂移校验、多实例隔离
TestConfigurations dispatch_key_fn 动态图分派、max_graphs Fallback 机制
TestComplexStructures 嵌套 dict/list IO、no_grad 上下文对齐
TestFatModel 多层 FFN 模拟、参数梯度端到端对齐
TestAdvancedMechanics 跨步梯度累积正确性、AMP 混合精度支持
TestEdgeMemoryOps Tensor slice/stride 等复杂内存视角操作
TestGraphLimitationsAndFeatures Dropout 随机性保持(兼容 Philox RNG)
TestDispatchNumber / TestNoGrad max_graphs 缓存边界限制、no_grad 下的对齐

📝 Test Plan

  • 单测全过:单卡 GPU 执行 test_autocudagraph.py,所有 Case 通过。
  • 数值对齐:不同 warmup_steps 下,CG 与 Eager 模式前反向 100% 绝对一致 (rtol=0, atol=0)。
  • Fallback 鲁棒性:输入组合超 max_graphs 限制时,正常执行不崩溃,平滑回退至 Eager 模式。
  • 梯度安全:多步梯度累积下,累积梯度免受 Capture 阶段内部 Backward 的污染与清零。

@DrRyanHuang DrRyanHuang requested a review from lshpku April 29, 2026 13:18
@DrRyanHuang DrRyanHuang self-assigned this Apr 29, 2026
loss.backward()

losses_eager.append(loss.item())
grads_eager.append(model_eager.head2.weight.grad.clone().cpu())
Copy link
Copy Markdown
Author

@DrRyanHuang DrRyanHuang Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: 记录一个 Paddle 的 BUG,很离谱,这里如果不带 .cpu()

            grads_eager.append(model_eager.head2.weight.grad.clone())

每次跑到固定位置,eager 的 grad 就回变成全 0!待排查具体原因

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 98.11321% with 3 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@94e7941). Learn more about missing BASE report.

Files with missing lines Patch % Lines
src/paddlefleet/cudagraph.py 98.11% 1 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             develop     #896   +/-   ##
==========================================
  Coverage           ?   98.11%           
==========================================
  Files              ?        1           
  Lines              ?      159           
  Branches           ?       33           
==========================================
  Hits               ?      156           
  Misses             ?        1           
  Partials           ?        2           
Flag Coverage Δ
coverage_combine 98.11% <98.11%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
src/paddlefleet/cudagraph.py 98.11% <98.11%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants