@@ -3,19 +3,11 @@ module DynamicPPLMarginalLogDensitiesExt
33using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
44using MarginalLogDensities: MarginalLogDensities
55
6- # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
7- # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
8- # below.
9- struct LogDensityFunctionWrapper{
10- L<: DynamicPPL.LogDensityFunction ,V<: DynamicPPL.AbstractVarInfo
11- }
12- logdensity:: L
13- # This field is used only to reconstruct the VarInfo later on; it's not needed for the
14- # actual log-density evaluation.
15- varinfo:: V
16- end
17- function (lw:: LogDensityFunctionWrapper )(x, _)
18- return LogDensityProblems. logdensity (lw. logdensity, x)
6+ # Make LogDensityFunction directly callable with the two-argument interface expected by
7+ # MarginalLogDensities. The second argument is the gradient and is unused here because
8+ # MarginalLogDensities handles differentiation separately.
9+ function (ldf:: DynamicPPL.LogDensityFunction )(x, _)
10+ return LogDensityProblems. logdensity (ldf, x)
1911end
2012
2113"""
@@ -53,7 +45,6 @@ log-density.
5345 constructor.
5446
5547## Example
56-
5748```jldoctest
5849julia> using DynamicPPL, Distributions, MarginalLogDensities
5950
@@ -80,12 +71,11 @@ julia> logpdf(Normal(2.0), 1.0)
8071 marginal log-density can be performed in unconstrained space. However, care must be
8172 taken if the model contains variables where the link transformation depends on a
8273 marginalized variable. For example:
83-
8474 ```julia
85- @model function f()
86- x ~ Normal()
87- y ~ truncated(Normal(); lower=x)
88- end
75+ @model function f()
76+ x ~ Normal()
77+ y ~ truncated(Normal(); lower=x)
78+ end
8979 ```
9080
9181 Here, the support of `y`, and hence the link transformation used, depends on the value
@@ -101,7 +91,7 @@ function DynamicPPL.marginalize(
10191 method:: MarginalLogDensities.AbstractMarginalizer = MarginalLogDensities. LaplaceApprox (),
10292 kwargs... ,
10393)
104- # Construct the marginal log-density model.
94+ # Construct the log-density function directly from the model and varinfo .
10595 ldf = DynamicPPL. LogDensityFunction (model, getlogprob, varinfo)
10696 # Determine the indices for the variables to marginalise out.
10797 varindices = mapreduce (vcat, marginalized_varnames) do vn
@@ -110,121 +100,49 @@ function DynamicPPL.marginalize(
110100 (ldf. _varname_ranges[vn]:: RangeAndLinked ). range
111101 end
112102 mld = MarginalLogDensities. MarginalLogDensity (
113- LogDensityFunctionWrapper (ldf, varinfo),
114- varinfo[:],
115- varindices,
116- (),
117- method;
118- kwargs... ,
103+ ldf, varinfo[:], varindices, (), method; kwargs...
119104 )
120105 return mld
121106end
122107
123108"""
124- VarInfo (
125- mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper },
109+ InitFromVector (
110+ mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction },
126111 unmarginalized_params::Union{AbstractVector,Nothing}=nothing
127112 )
128113
129- Retrieve the `VarInfo` object used in the marginalisation process.
130-
131- If a Laplace approximation was used for the marginalisation, the values of the marginalized
132- parameters are also set to their mode (note that this only happens if the `mld` object has
133- been used to compute the marginal log-density at least once, so that the mode has been
134- computed).
114+ Return an [`InitFromVector`](@ref DynamicPPL.InitFromVector) initialisation strategy whose
115+ parameter vector reflects the state of `mld`.
135116
136- If a vector of `unmarginalized_params` is specified , the values for the corresponding
137- parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
138- performing an optimization of the marginal log-density .
117+ If a Laplace approximation was used for marginalisation , the marginalized parameters are set
118+ to their modal values (note that this requires `mld` to have been evaluated at least once,
119+ so that the mode has been found) .
139120
140- All other aspects of the VarInfo, such as link status, are preserved from the original
141- VarInfo used in the marginalisation.
142-
143- !!! note
144-
145- The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
146- updated. If you wish to obtain updated log-probabilities, you should re-evaluate the
147- model with the values inside the returned VarInfo, for example using:
148-
149- ```julia
150- init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing)
151- oavi = DynamicPPL.OnlyAccsVarInfo((
152- DynamicPPL.LogPriorAccumulator(),
153- DynamicPPL.LogLikelihoodAccumulator(),
154- DynamicPPL.RawValueAccumulator(false),
155- # ... whatever else you need
156- ))
157- _, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll())
158- ```
159-
160- You can then extract all the updated data from `oavi`.
161-
162- ## Example
163-
164- ```jldoctest
165- julia> using DynamicPPL, Distributions, MarginalLogDensities
166-
167- julia> @model function demo()
168- x ~ Normal()
169- y ~ Beta(2, 2)
170- end
171- demo (generic function with 2 methods)
121+ If `unmarginalized_params` is provided, those values are used for the non-marginalized
122+ parameters. This vector may be obtained e.g. by optimizing the marginal log-density.
172123
173- julia> # Note that by default `marginalize` uses a linked VarInfo.
174- mld = marginalize(demo(), [@varname(x)]);
175-
176- julia> using MarginalLogDensities: Optimization, OptimizationOptimJL
177-
178- julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
179- y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
180- OptimizationProblem. In-place: true
181- u0: 1-element Vector{Float64}:
182- 2.0
183-
184- julia> # This tells us the optimal (linked) value of `y` is around 0.
185- opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
186- retcode: Success
187- u: 1-element Vector{Float64}:
188- 4.88281250001733e-5
189-
190- julia> # Get the VarInfo corresponding to the mode of `y`.
191- vi = VarInfo(mld, opt_solution.u);
192-
193- julia> # `x` is set to its mode (which for `Normal()` is zero).
194- vi[@varname(x)]
195- 0.0
196-
197- julia> # `y` is set to the optimal value we found above.
198- DynamicPPL.getindex_internal(vi, @varname(y))
199- 1-element Vector{Float64}:
200- 4.88281250001733e-5
201-
202- julia> # To obtain values in the original constrained space, we can either
203- # use `getindex`:
204- vi[@varname(y)]
205- 0.5000122070312476
206-
207- julia> # Or invlink the entire VarInfo object using the model:
208- vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
209- 2-element Vector{Float64}:
210- 0.0
211- 0.5000122070312476
124+ To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
125+ status — use the returned strategy to re-evaluate the model:
126+ ```julia
127+ init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
128+ ldf = mld.logdensity
129+ _, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
212130```
213131"""
214- function DynamicPPL. VarInfo (
215- mld:: MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper } ,
132+ function DynamicPPL. InitFromVector (
133+ mld:: MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction } ,
216134 unmarginalized_params:: Union{AbstractVector,Nothing} = nothing ,
217135)
218- # Extract the original VarInfo. Its contents will in general be junk.
219- original_vi = mld. logdensity. varinfo
220- # Extract the stored parameters, which includes the modes for any marginalized
221- # parameters
136+ # Retrieve the full cached parameter vector (includes modal values for marginalized
137+ # parameters if a Laplace approximation has been run).
222138 full_params = MarginalLogDensities. cached_params (mld)
223- # We can then ( if needed) set the values for any non-marginalized parameters
139+ # Overwrite the non-marginalized entries if the caller supplied them.
224140 if unmarginalized_params != = nothing
225141 full_params[MarginalLogDensities. ijoint (mld)] = unmarginalized_params
226142 end
227- return DynamicPPL. unflatten!! (original_vi, full_params)
143+ # Use the convenience constructor that reads varname_ranges and transform_strategy
144+ # directly from the LogDensityFunction stored inside mld.
145+ return DynamicPPL. InitFromVector (full_params, mld. logdensity)
228146end
229147
230148end
0 commit comments