Skip to content

Commit fcca4c7

Browse files
committed
Remove LogDensityFunctionWrapper and replace VarInfo(mld, ...) with InitFromVector
1 parent a31e3e3 commit fcca4c7

6 files changed

Lines changed: 63 additions & 127 deletions

File tree

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# 0.41
2+
3+
Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)`
4+
from the MarginalLogDensities extension. Users should now use
5+
`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy
6+
and pass it to `init!!` to get a consistent `VarInfo`.
7+
18
# 0.40.14
29

310
Fixed `check_model()` erroneously failing for models such as `x[1:2] .~ univariate_dist`.

docs/make.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ using AbstractMCMC: AbstractMCMC
1616
using MarginalLogDensities: MarginalLogDensities
1717
using Random
1818

19-
# Need this to document a method which uses a type inside the extension...
20-
DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt)
21-
2219
# Doctest setup
2320
DocMeta.setdocmeta!(
2421
DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true

docs/src/accs/threadsafe.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ model = setthreadsafe(g(y), true)
1717

1818
This is accomplished by creating one copy of each accumulator per thread (using `DynamicPPL.split`), and then after the model evaluation is complete, merging the result of each thread's accumulator with `DynamicPPL.combine`.
1919

20-
**This means that if you are implementing your own accumulator, you will need to implement the `split` and `combine` methods for it in order for it work correctly in thread-safe mode.**
20+
**This means that if you are implementing your own accumulator, you will need to implement the `split` and `combine` methods for it in order for it to work correctly in thread-safe mode.**
2121

2222
Each accumulator sees only the tilde-statements that were executed on its own thread.
2323
However, the intent is that after merging the results from all threads, the final accumulator should be equivalent to what would have been obtained by a single-threaded evaluation (modulo ordering).

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ marginalize
169169
```
170170

171171
A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability.
172-
To retrieve a VarInfo object from it, you can use:
172+
To obtain an initialisation strategy reflecting the state of the marginalisation, you can use:
173173

174174
```@docs
175-
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
175+
InitFromVector(::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, ::Union{AbstractVector,Nothing})
176176
```
177177

178178
## Models within models

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 34 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,11 @@ module DynamicPPLMarginalLogDensitiesExt
33
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
44
using MarginalLogDensities: MarginalLogDensities
55

6-
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
7-
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
8-
# below.
9-
struct LogDensityFunctionWrapper{
10-
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
11-
}
12-
logdensity::L
13-
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
14-
# actual log-density evaluation.
15-
varinfo::V
16-
end
17-
function (lw::LogDensityFunctionWrapper)(x, _)
18-
return LogDensityProblems.logdensity(lw.logdensity, x)
6+
# Make LogDensityFunction directly callable with the two-argument interface expected by
7+
# MarginalLogDensities. The second argument is the gradient and is unused here because
8+
# MarginalLogDensities handles differentiation separately.
9+
function (ldf::DynamicPPL.LogDensityFunction)(x, _)
10+
return LogDensityProblems.logdensity(ldf, x)
1911
end
2012

2113
"""
@@ -53,7 +45,6 @@ log-density.
5345
constructor.
5446
5547
## Example
56-
5748
```jldoctest
5849
julia> using DynamicPPL, Distributions, MarginalLogDensities
5950
@@ -80,12 +71,11 @@ julia> logpdf(Normal(2.0), 1.0)
8071
marginal log-density can be performed in unconstrained space. However, care must be
8172
taken if the model contains variables where the link transformation depends on a
8273
marginalized variable. For example:
83-
8474
```julia
85-
@model function f()
86-
x ~ Normal()
87-
y ~ truncated(Normal(); lower=x)
88-
end
75+
@model function f()
76+
x ~ Normal()
77+
y ~ truncated(Normal(); lower=x)
78+
end
8979
```
9080
9181
Here, the support of `y`, and hence the link transformation used, depends on the value
@@ -101,7 +91,7 @@ function DynamicPPL.marginalize(
10191
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
10292
kwargs...,
10393
)
104-
# Construct the marginal log-density model.
94+
# Construct the log-density function directly from the model and varinfo.
10595
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
10696
# Determine the indices for the variables to marginalise out.
10797
varindices = mapreduce(vcat, marginalized_varnames) do vn
@@ -110,121 +100,49 @@ function DynamicPPL.marginalize(
110100
(ldf._varname_ranges[vn]::RangeAndLinked).range
111101
end
112102
mld = MarginalLogDensities.MarginalLogDensity(
113-
LogDensityFunctionWrapper(ldf, varinfo),
114-
varinfo[:],
115-
varindices,
116-
(),
117-
method;
118-
kwargs...,
103+
ldf, varinfo[:], varindices, (), method; kwargs...
119104
)
120105
return mld
121106
end
122107

123108
"""
124-
VarInfo(
125-
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
109+
InitFromVector(
110+
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
126111
unmarginalized_params::Union{AbstractVector,Nothing}=nothing
127112
)
128113
129-
Retrieve the `VarInfo` object used in the marginalisation process.
130-
131-
If a Laplace approximation was used for the marginalisation, the values of the marginalized
132-
parameters are also set to their mode (note that this only happens if the `mld` object has
133-
been used to compute the marginal log-density at least once, so that the mode has been
134-
computed).
114+
Return an [`InitFromVector`](@ref DynamicPPL.InitFromVector) initialisation strategy whose
115+
parameter vector reflects the state of `mld`.
135116
136-
If a vector of `unmarginalized_params` is specified, the values for the corresponding
137-
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
138-
performing an optimization of the marginal log-density.
117+
If a Laplace approximation was used for marginalisation, the marginalized parameters are set
118+
to their modal values (note that this requires `mld` to have been evaluated at least once,
119+
so that the mode has been found).
139120
140-
All other aspects of the VarInfo, such as link status, are preserved from the original
141-
VarInfo used in the marginalisation.
142-
143-
!!! note
144-
145-
The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
146-
updated. If you wish to obtain updated log-probabilities, you should re-evaluate the
147-
model with the values inside the returned VarInfo, for example using:
148-
149-
```julia
150-
init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing)
151-
oavi = DynamicPPL.OnlyAccsVarInfo((
152-
DynamicPPL.LogPriorAccumulator(),
153-
DynamicPPL.LogLikelihoodAccumulator(),
154-
DynamicPPL.RawValueAccumulator(false),
155-
# ... whatever else you need
156-
))
157-
_, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll())
158-
```
159-
160-
You can then extract all the updated data from `oavi`.
161-
162-
## Example
163-
164-
```jldoctest
165-
julia> using DynamicPPL, Distributions, MarginalLogDensities
166-
167-
julia> @model function demo()
168-
x ~ Normal()
169-
y ~ Beta(2, 2)
170-
end
171-
demo (generic function with 2 methods)
121+
If `unmarginalized_params` is provided, those values are used for the non-marginalized
122+
parameters. This vector may be obtained e.g. by optimizing the marginal log-density.
172123
173-
julia> # Note that by default `marginalize` uses a linked VarInfo.
174-
mld = marginalize(demo(), [@varname(x)]);
175-
176-
julia> using MarginalLogDensities: Optimization, OptimizationOptimJL
177-
178-
julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
179-
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
180-
OptimizationProblem. In-place: true
181-
u0: 1-element Vector{Float64}:
182-
2.0
183-
184-
julia> # This tells us the optimal (linked) value of `y` is around 0.
185-
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
186-
retcode: Success
187-
u: 1-element Vector{Float64}:
188-
4.88281250001733e-5
189-
190-
julia> # Get the VarInfo corresponding to the mode of `y`.
191-
vi = VarInfo(mld, opt_solution.u);
192-
193-
julia> # `x` is set to its mode (which for `Normal()` is zero).
194-
vi[@varname(x)]
195-
0.0
196-
197-
julia> # `y` is set to the optimal value we found above.
198-
DynamicPPL.getindex_internal(vi, @varname(y))
199-
1-element Vector{Float64}:
200-
4.88281250001733e-5
201-
202-
julia> # To obtain values in the original constrained space, we can either
203-
# use `getindex`:
204-
vi[@varname(y)]
205-
0.5000122070312476
206-
207-
julia> # Or invlink the entire VarInfo object using the model:
208-
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
209-
2-element Vector{Float64}:
210-
0.0
211-
0.5000122070312476
124+
To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
125+
status — use the returned strategy to re-evaluate the model:
126+
```julia
127+
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
128+
ldf = mld.logdensity
129+
_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
212130
```
213131
"""
214-
function DynamicPPL.VarInfo(
215-
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
132+
function DynamicPPL.InitFromVector(
133+
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
216134
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
217135
)
218-
# Extract the original VarInfo. Its contents will in general be junk.
219-
original_vi = mld.logdensity.varinfo
220-
# Extract the stored parameters, which includes the modes for any marginalized
221-
# parameters
136+
# Retrieve the full cached parameter vector (includes modal values for marginalized
137+
# parameters if a Laplace approximation has been run).
222138
full_params = MarginalLogDensities.cached_params(mld)
223-
# We can then (if needed) set the values for any non-marginalized parameters
139+
# Overwrite the non-marginalized entries if the caller supplied them.
224140
if unmarginalized_params !== nothing
225141
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params
226142
end
227-
return DynamicPPL.unflatten!!(original_vi, full_params)
143+
# Use the convenience constructor that reads varname_ranges and transform_strategy
144+
# directly from the LogDensityFunction stored inside mld.
145+
return DynamicPPL.InitFromVector(full_params, mld.logdensity)
228146
end
229147

230148
end

test/ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,34 @@ using ADTypes: AutoForwardDiff
7979

8080
@testset "unlinked VarInfo" begin
8181
mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked)
82-
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
83-
vi = VarInfo(mx)
82+
mx([0.5]) # evaluate to force the Laplace approximation to run and cache modal values
83+
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
84+
ldf = mx.logdensity
85+
_, vi = DynamicPPL.init!!(
86+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
87+
)
8488
@test vi[@varname(x)] mode(Normal())
85-
vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked
89+
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # same, but override the unmarginalized parameter with 0.5
90+
_, vi = DynamicPPL.init!!(
91+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
92+
)
8693
@test vi[@varname(x)] mode(Normal())
8794
@test vi[@varname(y)] 0.5
8895
end
8996

9097
@testset "linked VarInfo" begin
9198
mx = marginalize(model, [@varname(x)]; varinfo=vi_linked)
9299
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
93-
vi = VarInfo(mx)
100+
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
101+
ldf = mx.logdensity
102+
_, vi = DynamicPPL.init!!(
103+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
104+
)
94105
@test vi[@varname(x)] mode(Normal())
95-
vi = VarInfo(mx, [0.5]) # this 0.5 is linked
106+
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # this 0.5 is a linked value for the unmarginalized parameter y
107+
_, vi = DynamicPPL.init!!(
108+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
109+
)
96110
binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2)))
97111
@test vi[@varname(x)] mode(Normal())
98112
# when using getindex it always returns unlinked values

0 commit comments

Comments
 (0)