Skip to content

Fixed transformations#1327

Open
penelopeysm wants to merge 16 commits intobreakingfrom
py/transforms
Open

Fixed transformations#1327
penelopeysm wants to merge 16 commits intobreakingfrom
py/transforms

Conversation

@penelopeysm
Copy link
Copy Markdown
Member

@penelopeysm penelopeysm commented Mar 17, 2026

Closes #1249. I'll write more here later, but you can see the changelog for a good overview of this PR.

Things to do

  • Generalise LinkedVecTransformAccumulator to something that's more like FixedTransformAccumulator
  • Add a convenience function for 'static'-fying an LDF (?)
  • Add tests for new behaviour
  • Add dev docs on FixedTransform

Possibly in a separate PR:


Some scripts to use to benchmark:

using DynamicPPL, Distributions, LogDensityProblems, Random, Chairmarks, ForwardDiff, ADTypes, LinearAlgebra

# @model function esc(J, y, sigma)
#     mu ~ Normal(0, 5)
#     tau ~ truncated(Cauchy(0, 5); lower=0)
#     theta ~ MvNormal(fill(mu, J), tau^2 * I)
#     for i in 1:J
#         y[i] ~ Normal(theta[i], sigma[i])
#     end
# end
# J = 8
# y = [28, 8, -3, 7, -1, 1, 18, 12]
# sigma = [15, 10, 16, 11, 9, 11, 10, 18]
# m = esc(J, y, sigma)

@model function f()
    x ~ product_distribution([Beta(2,2), Uniform(4, 5), Normal()])
end
m = f()

ldf1 = LogDensityFunction(m, getlogjoint_internal, LinkAll(); adtype=AutoForwardDiff(), fix_transforms=true);
p = rand(Xoshiro(468), ldf1);
@b LogDensityProblems.logdensity(ldf1 ,p)
@b LogDensityProblems.logdensity_and_gradient(ldf1 ,p)

ldf2 = LogDensityFunction(m, getlogjoint_internal, LinkAll(); adtype=AutoForwardDiff());
p = rand(Xoshiro(468), ldf2);
@b LogDensityProblems.logdensity(ldf2 ,p)
@b LogDensityProblems.logdensity_and_gradient(ldf2 ,p)

@penelopeysm penelopeysm changed the base branch from main to breaking March 17, 2026 19:19
@penelopeysm penelopeysm reopened this Mar 17, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 17, 2026

Benchmark Report

  • this PR's head: 328514d3adde9e4ee8ed5caab5f3e5ae821d87df
  • base branch: e01fbd60d5a54a3b314b5bc5ea6654cefcba34a0

Computer Information

Julia Version 1.11.9
Commit 53a02c0720c (2026-02-06 00:27 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬────────┬───────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │        │       t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │        │ ─────────┬──────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │ Linked │     base │  this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │   true │   290.25 │   281.65 │    1.03 │   7.67 │    8.07 │    0.95 │   2227.42 │   2274.15 │    0.98 │
│                   LDA │    12 │ reversediff │   true │  2502.18 │  2615.12 │    0.96 │   2.18 │    2.10 │    1.04 │   5463.79 │   5504.42 │    0.99 │
│   Loop univariate 10k │ 10000 │    mooncake │   true │ 31061.43 │ 31536.61 │    0.98 │   6.38 │    6.72 │    0.95 │ 198132.80 │ 211866.61 │    0.94 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │   true │  3167.57 │  3245.90 │    0.98 │   6.21 │    6.50 │    0.96 │  19685.67 │  21106.43 │    0.93 │
│      Multivariate 10k │ 10000 │    mooncake │   true │ 33426.98 │ 33122.24 │    1.01 │   9.93 │   14.85 │    0.67 │ 332031.62 │ 491917.53 │    0.67 │
│       Multivariate 1k │  1000 │    mooncake │   true │  3434.58 │  3589.19 │    0.96 │   9.25 │    8.96 │    1.03 │  31761.88 │  32149.59 │    0.99 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │  false │     0.87 │     0.88 │    1.00 │  10.35 │   10.44 │    0.99 │      9.05 │      9.17 │    0.99 │
│           Smorgasbord │   201 │ forwarddiff │  false │   952.99 │   975.98 │    0.98 │ 105.86 │   72.25 │    1.47 │ 100879.30 │  70511.64 │    1.43 │
│           Smorgasbord │   201 │      enzyme │   true │  1291.10 │  1281.81 │    1.01 │   4.88 │    4.93 │    0.99 │   6300.88 │   6320.38 │    1.00 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │   true │  1712.82 │  1314.32 │    1.30 │  52.92 │   69.17 │    0.77 │  90647.11 │  90905.39 │    1.00 │
│           Smorgasbord │   201 │    mooncake │   true │  1299.54 │  1438.22 │    0.90 │   4.59 │    4.26 │    1.08 │   5968.04 │   6133.70 │    0.97 │
│           Smorgasbord │   201 │ reversediff │   true │  1310.46 │  1293.32 │    1.01 │ 125.06 │  128.45 │    0.97 │ 163883.77 │ 166123.35 │    0.99 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│              Submodel │     1 │    mooncake │   true │     0.87 │     0.88 │    1.00 │  26.87 │   26.31 │    1.02 │     23.47 │     23.08 │    1.02 │
└───────────────────────┴───────┴─────────────┴────────┴──────────┴──────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 17, 2026

Codecov Report

❌ Patch coverage is 75.87940% with 48 lines in your changes missing coverage. Please review.
✅ Project coverage is 78.64%. Comparing base (e01fbd6) to head (328514d).

Files with missing lines Patch % Lines
src/transformed_values.jl 52.54% 28 Missing ⚠️
src/varinfo.jl 86.53% 7 Missing ⚠️
src/accumulators/vector_values.jl 63.63% 4 Missing ⚠️
src/contexts/default.jl 40.00% 3 Missing ⚠️
src/abstract_varinfo.jl 50.00% 2 Missing ⚠️
src/accumulators/fixed_transforms.jl 85.71% 2 Missing ⚠️
src/contexts/init.jl 88.88% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1327      +/-   ##
============================================
+ Coverage     78.19%   78.64%   +0.44%     
============================================
  Files            50       50              
  Lines          3582     3521      -61     
============================================
- Hits           2801     2769      -32     
+ Misses          781      752      -29     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Copy Markdown
Contributor

DynamicPPL.jl documentation for PR #1327 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1327/

@penelopeysm penelopeysm marked this pull request as ready for review April 2, 2026 14:36
@penelopeysm penelopeysm requested a review from sunxd3 April 2, 2026 14:37
@penelopeysm
Copy link
Copy Markdown
Member Author

penelopeysm commented Apr 2, 2026

@sunxd3 would you be able to take a look? Here's the TLDR and my overall thoughts on this:

  • src/transformed_values.jl -- There used to be UntransformedValue, VectorValue and LinkedVectorValue <: AbstractTransformedValue -- these are renamed to TransformedValue(val, NoTransform()), TransformedValue(val, Unlink()), and TransformedValue(val, DynamicLink()). A lot of the changes in this PR are just boilerplate resulting from this slightly different type hierarchy. Better names are also welcome, although renaming some of this might be breaking at this point.

  • The new type hierarchy lets us extend it to TransformedValue(val, FixedTransform(f)). The new types are breaking, and we could avoid this by instead making a new subtype, FixedTransformedValue(val, f). I'm still a bit uncertain about which way to go on this: I went with the overhaul, but I don't actually think that it's necessary, and part of me thinks that it would just be better to not break the old API. If you agree, I'll try to change it back.

  • The remaining question would be whether the general interface makes sense to you (we discussed this on Monday). I think the docs are the best place to read about the interface https://turinglang.org/DynamicPPL.jl/previews/PR1327/transforms_fixed/. The implementation of the interface is in src/accumulators/fixed_transforms.jl and src/logdensityfunction.jl, but that's all quite standard boilerplate, nothing particularly interesting.

Copy link
Copy Markdown
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type hierarchy changes all makes sense. Although it is very breaking, even though I don't imagine a lot of people depend on the transformation interface, it might still worth giving some deprecation buffer (we can talk about this in the meeting).

The general interface also make sense to me.

Multivariate 10k has some performance regression on Mooncake, I wonder why.

supposed to be *fixed*, i.e., they should not depend on random choices made during model
execution!
"""
function get_fixed_transforms(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an idea: this function can be run n times and compare the if the results are the same

Comment on lines 399 to 407
@doc """
getindex(vi::AbstractVarInfo, vn::VarName[, dist::Distribution])
getindex(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution])

Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their)
distribution(s).

If `dist` is specified, the value(s) will be massaged into the representation expected by `dist`.
""" Base.getindex
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to remove these too?

vectorised values in `vnt` to have the corresponding transforms from `transforms_vnt`.

This function returns a VarNamedTuple mapping all VarNames to their corresponding
`RangeAndTransform`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need update?

Suggested change
`RangeAndTransform`.
`TransformedValue`.

?

Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD
backend itself to have been loaded (e.g. with `import Backend`).

Finally, the `fix_transform` keyword argument allows you to specify whether the transforms
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Finally, the `fix_transform` keyword argument allows you to specify whether the transforms
Finally, the `fix_transforms` keyword argument allows you to specify whether the transforms

@sunxd3
Copy link
Copy Markdown
Member

sunxd3 commented Apr 5, 2026

The biggest break from this might be the removal of vi[@varname]syntax. It just existed for quite a long time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Generalising link strategies to support StaticTransformation

2 participants