Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# 0.41

Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)`
from the MarginalLogDensities extension. Users should now use
`DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy
and pass it to `init!!` to get a consistent `VarInfo`.

# 0.40.20

Added a public function, `DynamicPPL.extract_prefixes(::AbstractContext)`, to more generally handle the removal of `PrefixContext` entries from the context stack.
Expand Down
3 changes: 0 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ using AbstractMCMC: AbstractMCMC
using MarginalLogDensities: MarginalLogDensities
using Random

# Need this to document a method which uses a type inside the extension...
DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt)

# Doctest setup
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL, MCMCChains); recursive=true
Expand Down
6 changes: 1 addition & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,7 @@ marginalize
```

A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability.
To retrieve a VarInfo object from it, you can use:

```@docs
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
```
To obtain an initialisation strategy reflecting the state of the marginalisation, you can use [`InitFromVector`](@ref).

## Models within models

Expand Down
150 changes: 34 additions & 116 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,11 @@ module DynamicPPLMarginalLogDensitiesExt
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
using MarginalLogDensities: MarginalLogDensities

# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
# below.
struct LogDensityFunctionWrapper{
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
}
logdensity::L
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
# actual log-density evaluation.
varinfo::V
end
function (lw::LogDensityFunctionWrapper)(x, _)
return LogDensityProblems.logdensity(lw.logdensity, x)
# Make LogDensityFunction directly callable with the two-argument interface expected by
# MarginalLogDensities. The second argument is the gradient and is unused here because
# MarginalLogDensities handles differentiation separately.
function (ldf::DynamicPPL.LogDensityFunction)(x, _)
return LogDensityProblems.logdensity(ldf, x)
end

"""
Expand Down Expand Up @@ -53,7 +45,6 @@ log-density.
constructor.

## Example

```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities

Expand All @@ -80,12 +71,11 @@ julia> logpdf(Normal(2.0), 1.0)
marginal log-density can be performed in unconstrained space. However, care must be
taken if the model contains variables where the link transformation depends on a
marginalized variable. For example:

```julia
@model function f()
x ~ Normal()
y ~ truncated(Normal(); lower=x)
end
@model function f()
x ~ Normal()
y ~ truncated(Normal(); lower=x)
end
Comment on lines -85 to +78
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indent isn't necessary here

```

Here, the support of `y`, and hence the link transformation used, depends on the value
Expand All @@ -101,7 +91,7 @@ function DynamicPPL.marginalize(
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
kwargs...,
)
# Construct the marginal log-density model.
# Construct the log-density function directly from the model and varinfo.
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
# Determine the indices for the variables to marginalise out.
varindices = mapreduce(vcat, marginalized_varnames) do vn
Expand All @@ -110,121 +100,49 @@ function DynamicPPL.marginalize(
(ldf._varname_ranges[vn]::RangeAndLinked).range
end
mld = MarginalLogDensities.MarginalLogDensity(
LogDensityFunctionWrapper(ldf, varinfo),
varinfo[:],
varindices,
(),
method;
kwargs...,
ldf, varinfo[:], varindices, (), method; kwargs...
)
return mld
end

"""
VarInfo(
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
InitFromVector(
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
unmarginalized_params::Union{AbstractVector,Nothing}=nothing
)

Retrieve the `VarInfo` object used in the marginalisation process.

If a Laplace approximation was used for the marginalisation, the values of the marginalized
parameters are also set to their mode (note that this only happens if the `mld` object has
been used to compute the marginal log-density at least once, so that the mode has been
computed).
Return an [`InitFromVector`](@ref DynamicPPL.InitFromVector) initialisation strategy whose
parameter vector reflects the state of `mld`.

If a vector of `unmarginalized_params` is specified, the values for the corresponding
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
performing an optimization of the marginal log-density.
If a Laplace approximation was used for marginalisation, the marginalized parameters are set
to their modal values (note that this requires `mld` to have been evaluated at least once,
so that the mode has been found).

All other aspects of the VarInfo, such as link status, are preserved from the original
VarInfo used in the marginalisation.

!!! note

The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
updated. If you wish to obtain updated log-probabilities, you should re-evaluate the
model with the values inside the returned VarInfo, for example using:

```julia
init_strategy = DynamicPPL.InitFromParams(varinfo.values, nothing)
oavi = DynamicPPL.OnlyAccsVarInfo((
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.RawValueAccumulator(false),
# ... whatever else you need
))
_, oavi = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll())
```

You can then extract all the updated data from `oavi`.

## Example

```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities

julia> @model function demo()
x ~ Normal()
y ~ Beta(2, 2)
end
demo (generic function with 2 methods)
If `unmarginalized_params` is provided, those values are used for the non-marginalized
parameters. This vector may be obtained e.g. by optimizing the marginal log-density.

julia> # Note that by default `marginalize` uses a linked VarInfo.
mld = marginalize(demo(), [@varname(x)]);

julia> using MarginalLogDensities: Optimization, OptimizationOptimJL

julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
OptimizationProblem. In-place: true
u0: 1-element Vector{Float64}:
2.0

julia> # This tells us the optimal (linked) value of `y` is around 0.
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
retcode: Success
u: 1-element Vector{Float64}:
4.88281250001733e-5

julia> # Get the VarInfo corresponding to the mode of `y`.
vi = VarInfo(mld, opt_solution.u);

julia> # `x` is set to its mode (which for `Normal()` is zero).
vi[@varname(x)]
0.0

julia> # `y` is set to the optimal value we found above.
DynamicPPL.getindex_internal(vi, @varname(y))
1-element Vector{Float64}:
4.88281250001733e-5

julia> # To obtain values in the original constrained space, we can either
# use `getindex`:
vi[@varname(y)]
0.5000122070312476

julia> # Or invlink the entire VarInfo object using the model:
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
2-element Vector{Float64}:
0.0
0.5000122070312476
To obtain a fully consistent `VarInfo` — with updated log-probabilities and correct link
status — use the returned strategy to re-evaluate the model:
```julia
init_strategy = DynamicPPL.InitFromVector(mld, opt_solution.u)
ldf = mld.logdensity
_, vi = DynamicPPL.init!!(ldf.model, DynamicPPL.VarInfo(), init_strategy, ldf.transform_strategy)
```
"""
function DynamicPPL.VarInfo(
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
function DynamicPPL.InitFromVector(
mld::MarginalLogDensities.MarginalLogDensity{<:DynamicPPL.LogDensityFunction},
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
)
# Extract the original VarInfo. Its contents will in general be junk.
original_vi = mld.logdensity.varinfo
# Extract the stored parameters, which includes the modes for any marginalized
# parameters
# Retrieve the full cached parameter vector (includes modal values for marginalized
# parameters if a Laplace approximation has been run).
full_params = MarginalLogDensities.cached_params(mld)
# We can then (if needed) set the values for any non-marginalized parameters
# Overwrite the non-marginalized entries if the caller supplied them.
if unmarginalized_params !== nothing
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params
end
return DynamicPPL.unflatten!!(original_vi, full_params)
# Use the convenience constructor that reads varname_ranges and transform_strategy
# directly from the LogDensityFunction stored inside mld.
return DynamicPPL.InitFromVector(full_params, mld.logdensity)
end

end
24 changes: 19 additions & 5 deletions test/ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,34 @@ using ADTypes: AutoForwardDiff

@testset "unlinked VarInfo" begin
mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked)
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
vi = VarInfo(mx)
mx([0.5]) # evaluate to force the Laplace approximation to run and cache modal values
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
ldf = mx.logdensity
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] ≈ mode(Normal())
vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # same, but override the unmarginalized parameter with 0.5
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] ≈ mode(Normal())
@test vi[@varname(y)] ≈ 0.5
end

@testset "linked VarInfo" begin
mx = marginalize(model, [@varname(x)]; varinfo=vi_linked)
mx([0.5]) # evaluate at some point to force calculation of Laplace approx
vi = VarInfo(mx)
strategy = DynamicPPL.InitFromVector(mx) # build init strategy from cached modal values
ldf = mx.logdensity
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
@test vi[@varname(x)] ≈ mode(Normal())
vi = VarInfo(mx, [0.5]) # this 0.5 is linked
strategy = DynamicPPL.InitFromVector(mx, [0.5]) # this 0.5 is a linked value for the unmarginalized parameter y
_, vi = DynamicPPL.init!!(
ldf.model, DynamicPPL.VarInfo(), strategy, ldf.transform_strategy
)
binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2)))
@test vi[@varname(x)] ≈ mode(Normal())
# when using getindex it always returns unlinked values
Expand Down
Loading