Skip to content

Fix tangent_type for Union{NoRData, RData{...}}#1133

Open
yebai wants to merge 7 commits intomainfrom
fix-1130-tangent-type-rdata
Open

Fix tangent_type for Union{NoRData, RData{...}}#1133
yebai wants to merge 7 commits intomainfrom
fix-1130-tangent-type-rdata

Conversation

@yebai
Copy link
Copy Markdown
Member

@yebai yebai commented Apr 8, 2026

Summary

tangent_type(NoFData, R) previously only handled R <: Union{NoRData, IEEEFloat}, causing a missing-method error when R contained an RData{...} branch. This extends the dispatch to cover RData and adds _validate_rdata_union to catch invalid branches early.

Split from #1132 to make it easy to review.

Closes #1130.

CI Summary — GitHub Actions

Documentation Preview

Mooncake.jl documentation for PR #1133 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1133/

Performance

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                      Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                     String │   String │   String │      String │  String │      String │ String │
├────────────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│                   sum_1000 │ 180.0 ns │     1.51 │        1.61 │   0.778 │        3.56 │   1.39 │
│                  _sum_1000 │  1.07 μs │      6.8 │        1.02 │  4680.0 │        32.3 │   1.08 │
│               sum_sin_1000 │  7.42 μs │     2.58 │        1.18 │     1.7 │        9.06 │   1.85 │
│              _sum_sin_1000 │  5.11 μs │     3.57 │        2.47 │   296.0 │        13.3 │    2.8 │
│                   kron_sum │ 184.0 μs │     12.9 │        3.66 │    24.8 │       508.0 │   18.0 │
│              kron_view_sum │ 251.0 μs │     12.3 │        4.99 │    15.0 │       363.0 │   11.4 │
│      naive_map_sin_cos_exp │   2.3 μs │     2.69 │        1.48 │ missing │        7.19 │   2.07 │
│            map_sin_cos_exp │  2.26 μs │     3.12 │        1.53 │    2.43 │        5.95 │   2.67 │
│      broadcast_sin_cos_exp │  2.35 μs │     2.89 │        1.46 │    5.01 │        1.41 │   2.01 │
│                 simple_mlp │ 212.0 μs │     6.22 │        2.84 │    1.69 │        12.5 │   5.13 │
│                     gp_lml │ 400.0 μs │     5.91 │        1.87 │    6.57 │     missing │   2.63 │
│ turing_broadcast_benchmark │  2.05 ms │     5.05 │        3.32 │ missing │        32.6 │   1.79 │
│         large_single_block │ 450.0 ns │      5.7 │        1.98 │  3860.0 │        28.3 │   2.16 │
└────────────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

yebai and others added 2 commits April 8, 2026 22:00
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for initiating this!

Comment thread src/tangents/fwds_rvs_data.jl Outdated
if R isa Union
_validate_rdata_union(R.a)
_validate_rdata_union(R.b)
elseif R != NoRData && fdata_type(tangent_type(NoFData, R)) != NoFData
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fear the tangent_type here can produce similar error as line 931, should we catch it?

Comment thread src/test_resources.jl Outdated
end
make_P_lohi_union() = LoHi(1.0, 2.0)::Union{Nothing,LoHi}
make_P_lohi_container() = LoHiContainer(LoHi(1.0, 2.0))
make_P_nothing_or_vector() = [1.0, 2.0]::Union{Nothing,Vector{Float64}}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that Julia will still specialize the type instead of taking the Union type as hinted

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

missing tangent_type method

3 participants