Skip to content

Commit b2d8a8a

Browse files
committed
Remove LogDensityFunctionWrapper and replace VarInfo(mld, ...) with InitFromVector
1 parent 34b8230 commit b2d8a8a

5 files changed

Lines changed: 62 additions & 126 deletions

File tree

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# 0.41
2+
3+
Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)`
4+
from the MarginalLogDensities extension. Users should now use
5+
`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy
6+
and pass it to `init!!` to get a consistent `VarInfo`.
7+
18
# 0.40.14
29

310
Fixed `check_model()` erroneously failing for models such as `x[1:2] .~ univariate_dist`.

docs/make.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ using AbstractMCMC: AbstractMCMC
1616
using MarginalLogDensities: MarginalLogDensities
1717
using Random
1818

19-
# Need this to document a method which uses a type inside the extension...
20-
DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt)
21-
2219
# Doctest setup
2320
DocMeta.setdocmeta!(
2421
DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ marginalize
169169
```
170170

171171
A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability.
172-
To retrieve a VarInfo object from it, you can use:
172+
To obtain an initialisation strategy reflecting the state of the marginalisation, you can use:
173173

174174
```@docs
175-
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
175+
InitFromVector(::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction}, ::Union{AbstractVector,Nothing})
176176
```
177177

178178
## Models within models

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 34 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,11 @@ module DynamicPPLMarginalLogDensitiesExt
33
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
44
using MarginalLogDensities: MarginalLogDensities
55

6-
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
7-
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
8-
# below.
9-
struct LogDensityFunctionWrapper{
10-
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
11-
}
12-
logdensity::L
13-
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
14-
# actual log-density evaluation.
15-
varinfo::V
16-
end
17-
function (lw::LogDensityFunctionWrapper)(x, _)
18-
return LogDensityProblems.logdensity(lw.logdensity, x)
6+
# Make LogDensityFunction directly callable with the two-argument interface expected by
7+
# MarginalLogDensities. The second argument is the gradient and is unused here because
8+
# MarginalLogDensities handles differentiation separately.
9+
function (ldf::DynamicPPL.LogDensityFunction)(x, _)
10+
return LogDensityProblems.logdensity(ldf, x)
1911
end
2012

2113
"""
@@ -53,7 +45,6 @@ log-density.
5345
constructor.
5446
5547
## Example
56-
5748
```jldoctest
5849
julia> using DynamicPPL, Distributions, MarginalLogDensities
5950
@@ -80,12 +71,11 @@ julia> logpdf(Normal(2.0), 1.0)
8071
marginal log-density can be performed in unconstrained space. However, care must be
8172
taken if the model contains variables where the link transformation depends on a
8273
marginalized variable. For example:
83-
8474
```julia
85-
@model function f()
86-
x ~ Normal()
87-
y ~ truncated(Normal(); lower=x)
88-
end
75+
@model function f()
76+
x ~ Normal()
77+
y ~ truncated(Normal(); lower=x)
78+
end
8979
```
9080
9181
Here, the support of `y`, and hence the link transformation used, depends on the value
@@ -101,7 +91,7 @@ function DynamicPPL.marginalize(
10191
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
10292
kwargs...,
10393
)
104-
# Construct the marginal log-density model.
94+
# Construct the log-density function directly from the model and varinfo.
10595
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
10696
# Determine the indices for the variables to marginalise out.
10797
varindices = mapreduce(vcat, marginalized_varnames) do vn
@@ -110,121 +100,49 @@ function DynamicPPL.marginalize(
110100
(ldf._varname_ranges[vn]::RangeAndLinked).range
111101
end
112102
mld = MarginalLogDensities.MarginalLogDensity(
113-
LogDensityFunctionWrapper(ldf, varinfo),
114-
varinfo[:],
115-
varindices,
116-
(),
117-
method;
118-
kwargs...,
103+
ldf, varinfo[:], varindices, (), method; kwargs...
119104
)
120105
return mld
121106
end
122107

123108
"""
124-
VarInfo(
125-
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
109+
InitFromVector(
110+
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
126111
unmarginalized_params::Union{AbstractVector,Nothing}=nothing
127112
)
128113
129-
Retrieve the `VarInfo` object used in the marginalisation process.
130-
131-
If a Laplace approximation was used for the marginalisation, the values of the marginalized
132-
parameters are also set to their mode (note that this only happens if the `mld` object has
133-
been used to compute the marginal log-density at least once, so that the mode has been
134-
computed).
114+
Return an [`InitFromVector`](@ref DynamicPPL.InitFromVector) initialisation strategy whose
115+
parameter vector reflects the state of `mld`.
135116
136-
If a vector of `unmarginalized_params` is specified, the values for the corresponding
137-
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
138-
performing an optimization of the marginal log-density.
117+
If a Laplace approximation was used for marginalisation, the marginalized parameters are set
118+
to their modal values (note that this requires `mld` to have been evaluated at least once,
119+
so that the mode has been found).
139120
140-
All other aspects of the VarInfo, such as link status, are preserved from the original
141-
VarInfo used in the marginalisation.
142-
143-
!!! note
144-
145-
The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
146-
updated. If you wish to obtain updated log-probabilities, you should re-evaluate the
147-
model with the values inside the returned VarInfo, for example using:
148-
149-
```julia
150-
init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing)
151-
oavi = DynamicPPL.OnlyAccsVarInfo((
152-
DynamicPPL.LogPriorAccumulator(),
153-
DynamicPPL.LogLikelihoodAccumulator(),
154-
DynamicPPL.RawValueAccumulator(false),
155-
# ... whatever else you need
156-
))
157-
_, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll())
158-
```
159-
160-
You can then extract all the updated data from `oavi`.
161-
162-
## Example
163-
164-
```jldoctest
165-
julia> using DynamicPPL, Distributions, MarginalLogDensities
166-
167-
julia> @model function demo()
168-
x ~ Normal()
169-
y ~ Beta(2, 2)
170-
end
171-
demo (generic function with 2 methods)
121+
If `unmarginalized_params` is provided, those values are used for the non-marginalized
122+
parameters. This vector may be obtained e.g. by optimizing the marginal log-density.
172123
173-
julia> # Note that by default `marginalize` uses a linked VarInfo.
174-
mld = marginalize(demo(), [@varname(x)]);
175-
176-
julia> using MarginalLogDensities: Optimization, OptimizationOptimJL
177-
178-
julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
179-
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
180-
OptimizationProblem. In-place: true
181-
u0: 1-element Vector{Float64}:
182-
2.0
183-
184-
julia> # This tells us the optimal (linked) value of `y` is around 0.
185-
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
186-
retcode: Success
187-
u: 1-element Vector{Float64}:
188-
4.88281250001733e-5
189-
190-
julia> # Get the VarInfo corresponding to the mode of `y`.
191-
vi = VarInfo(mld, opt_solution.u);
192-
193-
julia> # `x` is set to its mode (which for `Normal()` is zero).
194-
vi[@varname(x)]
195-
0.0
196-
197-
julia> # `y` is set to the optimal value we found above.
198-
DynamicPPL.getindex_internal(vi, @varname(y))
199-
1-element Vector{Float64}:
200-
4.88281250001733e-5
201-
202-
julia> # To obtain values in the original constrained space, we can either
203-
# use `getindex`:
204-
vi[@varname(y)]
205-
0.5000122070312476
206-
207-
julia> # Or invlink the entire VarInfo object using the model:
208-
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
209-
2-element Vector{Float64}:
210-
0.0
211-
0.5000122070312476
124+
To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
125+
status — use the returned strategy to re-evaluate the model:
126+
```julia
127+
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
128+
ldf = mld.logdensity
129+
_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
212130
```
213131
"""
214-
function DynamicPPL.VarInfo(
215-
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
132+
function DynamicPPL.InitFromVector(
133+
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
216134
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
217135
)
218-
# Extract the original VarInfo. Its contents will in general be junk.
219-
original_vi = mld.logdensity.varinfo
220-
# Extract the stored parameters, which includes the modes for any marginalized
221-
# parameters
136+
# Retrieve the full cached parameter vector (includes modal values for marginalized
137+
# parameters if a Laplace approximation has been run).
222138
full_params = MarginalLogDensities.cached_params(mld)
223-
# We can then (if needed) set the values for any non-marginalized parameters
139+
# Overwrite the non-marginalized entries if the caller supplied them.
224140
if unmarginalized_params !== nothing
225141
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params
226142
end
227-
return DynamicPPL.unflatten!!(original_vi, full_params)
143+
# Use the convenience constructor that reads varname_ranges and transform_strategy
144+
# directly from the LogDensityFunction stored inside mld.
145+
return DynamicPPL.InitFromVector(full_params, mld.logdensity)
228146
end
229147

230148
end

test/ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,34 @@ using ADTypes: AutoForwardDiff
7979

8080
@testset "unlinked VarInfo" begin
8181
mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked)
82-
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
83-
vi = VarInfo(mx)
82+
mx([0.5]) # evaluate to force the Laplace approximation to run and cache modal values
83+
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
84+
ldf = mx.logdensity
85+
_, vi = DynamicPPL.init!!(
86+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
87+
)
8488
@test vi[@varname(x)] mode(Normal())
85-
vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked
89+
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # same, but override the unmarginalized parameter with 0.5
90+
_, vi = DynamicPPL.init!!(
91+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
92+
)
8693
@test vi[@varname(x)] mode(Normal())
8794
@test vi[@varname(y)] 0.5
8895
end
8996

9097
@testset "linked VarInfo" begin
9198
mx = marginalize(model, [@varname(x)]; varinfo=vi_linked)
9299
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
93-
vi = VarInfo(mx)
100+
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
101+
ldf = mx.logdensity
102+
_, vi = DynamicPPL.init!!(
103+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
104+
)
94105
@test vi[@varname(x)] mode(Normal())
95-
vi = VarInfo(mx, [0.5]) # this 0.5 is linked
106+
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # this 0.5 is a linked value for the unmarginalized parameter y
107+
_, vi = DynamicPPL.init!!(
108+
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
109+
)
96110
binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2)))
97111
@test vi[@varname(x)] mode(Normal())
98112
# when using getindex it always returns unlinked values

0 commit comments

Comments
 (0)