From 7883d497f430c03182598705f6e63ac858e57300 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Mon, 5 May 2025 23:57:36 +0800 Subject: [PATCH 01/19] Add Riemannian manifold HMC --- docs/src/api.md | 1 + research/tests/runtests.jl | 2 - src/AdvancedHMC.jl | 21 ++- src/riemannian/hamiltonian.jl | 298 ++++------------------------------ src/riemannian/metric.jl | 63 +++++++ src/sampler.jl | 6 +- src/trajectory.jl | 7 +- test/Project.toml | 2 + test/demo.jl | 10 +- test/integrator.jl | 5 +- test/riemannian.jl | 120 ++++++++++++-- test/trajectory.jl | 65 +++----- 12 files changed, 268 insertions(+), 332 deletions(-) create mode 100644 src/riemannian/metric.jl diff --git a/docs/src/api.md b/docs/src/api.md index 54b5939dc..a1c488fb8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,6 +8,7 @@ This modularity means that different HMC variants can be easily constructed by c - Unit metric: `UnitEuclideanMetric(dim)` - Diagonal metric: `DiagEuclideanMetric(dim)` - Dense metric: `DenseEuclideanMetric(dim)` + - Dense Riemannian metric: `DenseRiemannianMetric(size, G, ∂G∂θ)` where `dim` is the dimensionality of the sampling space. diff --git a/research/tests/runtests.jl b/research/tests/runtests.jl index da95548df..0633bc593 100644 --- a/research/tests/runtests.jl +++ b/research/tests/runtests.jl @@ -5,11 +5,9 @@ Pkg.add(; url="https://github.com/chalk-lab/MCMCLogDensityProblems.jl.git"); # include the source code for experimental HMC include("../src/relativistic_hmc.jl") -include("../src/riemannian_hmc.jl") # include the tests for experimental HMC include("relativistic_hmc.jl") -include("riemannian_hmc.jl") Comonicon.@main function runtests(patterns...; dry::Bool=false) return retest(patterns...; dry=dry, verbose=Inf) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index b25710d5f..41d934e60 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -2,7 +2,19 @@ module AdvancedHMC using Statistics: mean, var, middle using LinearAlgebra: - Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling + Symmetric, + UpperTriangular, + mul!, + ldiv!, + dot, + I, + diag, + cholesky, + UniformScaling, + logdet, + tr, + eigen, + diagm using StatsFuns: logaddexp, logsumexp, loghalf using Random: Random, AbstractRNG using ProgressMeter: ProgressMeter @@ -40,7 +52,7 @@ struct GaussianKinetic <: AbstractKinetic end export GaussianKinetic include("metric.jl") -export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric +export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric include("hamiltonian.jl") export Hamiltonian @@ -50,6 +62,11 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog include("riemannian/integrator.jl") export GeneralizedLeapfrog +include("riemannian/metric.jl") +export IdentityMap, SoftAbsMap, DenseRiemannianMetric + +include("riemannian/hamiltonian.jl") + include("trajectory.jl") export Trajectory, HMCKernel, diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index feddb4114..f8acc7971 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,257 +1,16 @@ -using Random - -### integrator.jl - -import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step -using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size - -""" -$(TYPEDEF) - -Generalized leapfrog integrator with fixed step size `ϵ`. - -# Fields - -$(TYPEDFIELDS) -""" -struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} - "Step size." - ϵ::T - n::Int -end -function Base.show(io::IO, l::GeneralizedLeapfrog) - return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") -end - -# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T} - dv = ∂H∂θ(h, θ, r) - return return_cache ? (dv, nothing) : dv -end - -# TODO Make sure vectorization works -# TODO Check if tempering is valid -function step( - lf::GeneralizedLeapfrog{T}, - h::Hamiltonian, - z::P, - n_steps::Int=1; - fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 - full_trajectory::Val{FullTraj}=Val(false), -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} - n_steps = abs(n_steps) # to support `n_steps < 0` cases - - ϵ = fwd ? step_size(lf) : -step_size(lf) - ϵ = ϵ' - - res = if FullTraj - Vector{P}(undef, n_steps) - else - z - end - - for i in 1:n_steps - θ_init, r_init = z.θ, z.r - # Tempering - #r = temper(lf, r, (i=i, is_half=true), n_steps) - #! Eq (16) of Girolami & Calderhead (2011) - r_half = copy(r_init) - local cache - for j in 1:(lf.n) - # Reuse cache for the first iteration - if j == 1 - (; value, gradient) = z.ℓπ - elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) - (; value, gradient) = retval - else # reuse cache - (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) - end - r_half = r_init - ϵ / 2 * gradient - # println("r_half: ", r_half) - end - #! Eq (17) of Girolami & Calderhead (2011) - θ_full = copy(θ_init) - term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop - for j in 1:(lf.n) - θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) - # println("θ_full :", θ_full) - end - #! Eq (18) of Girolami & Calderhead (2011) - (; value, gradient) = ∂H∂θ(h, θ_full, r_half) - r_full = r_half - ϵ / 2 * gradient - # println("r_full: ", r_full) - # Tempering - #r = temper(lf, r, (i=i, is_half=false), n_steps) - # Create a new phase point by caching the logdensity and gradient - z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) - # Update result - if FullTraj - res[i] = z - else - res = z - end - if !isfinite(z) - # Remove undef - if FullTraj - res = res[isassigned.(Ref(res), 1:n_steps)] - end - break - end - # @assert false - end - return res -end - -# TODO Make the order of θ and r consistent with neg_energy -∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) -∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) - -### hamiltonian.jl - -import AdvancedHMC: refresh, phasepoint -using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric - -# To change L180 of hamiltonian.jl -function phasepoint( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - θ::AbstractVecOrMat{T}, - h::Hamiltonian, -) where {T<:Real} - return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) -end - -# To change L191 of hamiltonian.jl -function refresh( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - ::FullMomentumRefreshment, - h::Hamiltonian, - z::PhasePoint, -) - return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) -end - -# To change L215 of hamiltonian.jl -function refresh( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - ref::PartialMomentumRefreshment, - h::Hamiltonian, - z::PhasePoint, -) - return phasepoint( - h, - z.θ, - ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ), - ) -end - -### metric.jl - -import AdvancedHMC: _rand -using AdvancedHMC: AbstractMetric -using LinearAlgebra: eigen, cholesky, Symmetric - -abstract type AbstractRiemannianMetric <: AbstractMetric end - -abstract type AbstractHessianMap end - -struct IdentityMap <: AbstractHessianMap end - -(::IdentityMap)(x) = x - -struct SoftAbsMap{T} <: AbstractHessianMap - α::T -end - -# TODO Register softabs with ReverseDiff -#! The definition of SoftAbs from Page 3 of Betancourt (2012) -function softabs(X, α=20.0) - F = eigen(X) # ReverseDiff cannot diff through `eigen` - Q = hcat(F.vectors) - λ = F.values - softabsλ = λ .* coth.(α * λ) - return Q * diagm(softabsλ) * Q', Q, λ, softabsλ -end - -(map::SoftAbsMap)(x) = softabs(x, map.α)[1] - -struct DenseRiemannianMetric{ - T, - TM<:AbstractHessianMap, - A<:Union{Tuple{Int},Tuple{Int,Int}}, - AV<:AbstractVecOrMat{T}, - TG, - T∂G∂θ, -} <: AbstractRiemannianMetric - size::A - G::TG # TODO store G⁻¹ here instead - ∂G∂θ::T∂G∂θ - map::TM - _temp::AV -end - -# TODO Make dense mass matrix support matrix-mode parallel -function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} - _temp = Vector{Float64}(undef, size[1]) - return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) -end -# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D)) -# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D) -# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz))) -# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) - -# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) - -Base.size(e::DenseRiemannianMetric) = e.size -Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] -Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") - -function rand_momentum( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - metric::DenseRiemannianMetric{T}, - kinetic, +#! Eq (14) of Girolami & Calderhead (2011) +function ∂H∂r( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, θ::AbstractVecOrMat, -) where {T} - r = _randn(rng, T, size(metric)...) - G⁻¹ = inv(metric.map(metric.G(θ))) - chol = cholesky(Symmetric(G⁻¹)) - ldiv!(chol.U, r) - return r -end - -### hamiltonian.jl - -import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r -using LinearAlgebra: logabsdet, tr - -# QUES Do we want to change everything to position dependent by default? -# Add θ to ∂H∂r for DenseRiemannianMetric -function phasepoint( - h::Hamiltonian{<:DenseRiemannianMetric}, - θ::T, - r::T; - ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), -) where {T<:AbstractVecOrMat} - return PhasePoint(θ, r, ℓπ, ℓκ) -end - -# Negative kinetic energy -#! Eq (13) of Girolami & Calderhead (2011) -function neg_energy( - h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T -) where {T<:AbstractVecOrMat} - G = h.metric.map(h.metric.G(θ)) - D = size(G, 1) - # Need to consider the normalizing term as it is no longer same for different θs - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined - mul!(h.metric._temp, inv(G), r) - return -logZ - dot(r, h.metric._temp) / 2 + r::AbstractVecOrMat, +) + H = h.metric.G(θ) + G = h.metric.map(H) + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't end -# QUES L31 of hamiltonian.jl now reads a bit weird (semantically) function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}, ) where {T} @@ -293,14 +52,14 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} end function ∂H∂θ( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}, ) where {T} return ∂H∂θ_cache(h, θ, r) end function ∂H∂θ_cache( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, θ::AbstractVecOrMat{T}, r::AbstractVecOrMat{T}; return_cache=false, @@ -342,17 +101,26 @@ function ∂H∂θ_cache( return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv end -#! Eq (14) of Girolami & Calderhead (2011) -function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat -) - H = h.metric.G(θ) - # if !all(isfinite, H) - # println("θ: ", θ) - # println("H: ", H) - # end - G = h.metric.map(H) - # return inv(G) * r - # println("G \ r: ", G \ r) - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't +# QUES Do we want to change everything to position dependent by default? +# Add θ to ∂H∂r for DenseRiemannianMetric +function phasepoint( + h::Hamiltonian{<:DenseRiemannianMetric}, + θ::T, + r::T; + ℓπ=∂H∂θ(h, θ), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), +) where {T<:AbstractVecOrMat} + return PhasePoint(θ, r, ℓπ, ℓκ) +end + +#! Eq (13) of Girolami & Calderhead (2011) +function neg_energy( + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T +) where {T<:AbstractVecOrMat} + G = h.metric.map(h.metric.G(θ)) + D = size(G, 1) + # Need to consider the normalizing term as it is no longer same for different θs + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined + mul!(h.metric._temp, inv(G), r) + return -logZ - dot(r, h.metric._temp) / 2 end diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl new file mode 100644 index 000000000..41d11127c --- /dev/null +++ b/src/riemannian/metric.jl @@ -0,0 +1,63 @@ +abstract type AbstractRiemannianMetric <: AbstractMetric end + +abstract type AbstractHessianMap end + +struct IdentityMap <: AbstractHessianMap end + +(::IdentityMap)(x) = x + +struct SoftAbsMap{T} <: AbstractHessianMap + α::T +end + +function softabs(X, α=20.0) + F = eigen(X) # ReverseDiff cannot diff through `eigen` + Q = hcat(F.vectors) + λ = F.values + softabsλ = λ .* coth.(α * λ) + return Q * diagm(softabsλ) * Q', Q, λ, softabsλ +end + +(map::SoftAbsMap)(x) = softabs(x, map.α)[1] + +# TODO Register softabs with ReverseDiff +#! The definition of SoftAbs from Page 3 of Betancourt (2012) +struct DenseRiemannianMetric{ + T, + TM<:AbstractHessianMap, + A<:Union{Tuple{Int},Tuple{Int,Int}}, + AV<:AbstractVecOrMat{T}, + TG, + T∂G∂θ, +} <: AbstractRiemannianMetric + size::A + G::TG # TODO store G⁻¹ here instead + ∂G∂θ::T∂G∂θ + map::TM + _temp::AV +end + +# TODO Make dense mass matrix support matrix-mode parallel +function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) + _temp = Vector{Float64}(undef, first(size)) + return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) +end + +Base.size(e::DenseRiemannianMetric) = e.size +Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] +function Base.show(io::IO, drm::DenseRiemannianMetric) + return print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric") +end + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::DenseRiemannianMetric{T}, + kinetic, + θ::AbstractVecOrMat, +) where {T} + r = _randn(rng, T, size(metric)...) + G⁻¹ = inv(metric.map(metric.G(θ))) + chol = cholesky(Symmetric(G⁻¹)) + ldiv!(chol.U, r) + return r +end diff --git a/src/sampler.jl b/src/sampler.jl index c0a426814..e0138819c 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -117,7 +117,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=pm_next!, + (pm_next!)::Function=(pm_next!), ) return sample( Random.default_rng(), @@ -130,7 +130,7 @@ function sample( drop_warmup=drop_warmup, verbose=verbose, progress=progress, - (pm_next!)=pm_next!, + (pm_next!)=(pm_next!), ) end @@ -168,7 +168,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=pm_next!, + (pm_next!)::Function=(pm_next!), ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results diff --git a/src/trajectory.jl b/src/trajectory.jl index a76807605..aa8c90cae 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -133,8 +133,9 @@ $(TYPEDEF) Slice sampler for the starting single leaf tree. Slice variable is initialized. """ -SliceTS(rng::AbstractRNG, z0::PhasePoint) = +function SliceTS(rng::AbstractRNG, z0::PhasePoint) SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) +end """ $(TYPEDEF) @@ -278,7 +279,7 @@ function transition( hamiltonian_energy=H, hamiltonian_energy_error=H - H0, # check numerical error in proposed phase point. - numerical_error=!all(isfinite, H′), + numerical_error=(!all(isfinite, H′)), ), stat(τ.integrator), ) @@ -717,7 +718,7 @@ function transition( ( n_steps=tree.nα, is_accept=true, - acceptance_rate=tree.sum_α / tree.nα, + acceptance_rate=(tree.sum_α / tree.nα), log_density=zcand.ℓπ.value, hamiltonian_energy=H, hamiltonian_energy_error=H - H0, diff --git a/test/Project.toml b/test/Project.toml index f38214815..3e2a793b2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,12 +7,14 @@ Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MCMCLogDensityProblems = "8a639fad-7908-4fe4-8003-906e9297f002" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/test/demo.jl b/test/demo.jl index 98315daa3..c9010a7f6 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -10,8 +10,9 @@ using LinearAlgebra, ADTypes LogDensityProblems.logdensity(p::DemoProblem, θ) = logpdf(MvNormal(zeros(p.dim), I), θ) LogDensityProblems.dimension(p::DemoProblem) = p.dim - LogDensityProblems.capabilities(::Type{DemoProblem}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{DemoProblem}) = LogDensityProblems.LogDensityOrder{ + 0 + }() # Choose parameter dimensionality and initial parameter value D = 10 @@ -66,8 +67,9 @@ end return -((1 - p.μ) / p.σ)^2 end LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2 - LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = LogDensityProblems.LogDensityOrder{ + 0 + }() ℓπ = DemoProblemComponentArrays() diff --git a/test/integrator.jl b/test/integrator.jl index b9eb14076..f5a3dbea4 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -112,8 +112,9 @@ using Statistics: mean LogDensityProblems.logdensity(::NegU, x) = -dot(x, x) / 2 LogDensityProblems.dimension(d::NegU) = d.dim - LogDensityProblems.capabilities(::Type{NegU}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{NegU}) = LogDensityProblems.LogDensityOrder{ + 0 + }() negU = NegU(1) diff --git a/test/riemannian.jl b/test/riemannian.jl index 67b1cad08..0cfcb8233 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -1,28 +1,63 @@ -using ReTest, AdvancedHMC - -include("../src/riemannian_hmc.jl") -include("../src/riemannian_hmc_utility.jl") - +using ReTest, Random +using AdvancedHMC, ForwardDiff, AbstractMCMC +using LinearAlgebra +using MCMCLogDensityProblems using FiniteDiff: finite_difference_gradient, finite_difference_hessian, finite_difference_jacobian -using Distributions: MvNormal -using AdvancedHMC: neg_energy, energy +using AdvancedHMC: neg_energy, energy, ∂H∂θ, ∂H∂r + +# Fisher information metric +function gen_∂G∂θ_fwd(Vfunc, x; f=identity) + _Hfunc = gen_hess_fwd(Vfunc, x) + Hfunc = x -> _Hfunc(x)[3] + # QUES What's the best output format of this function? + cfg = ForwardDiff.JacobianConfig(Hfunc, x) + d = length(x) + out = zeros(eltype(x), d^2, d) + return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) + return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] +end + +function gen_hess_fwd(func, x::AbstractVector) + function hess(x::AbstractVector) + return nothing, nothing, ForwardDiff.hessian(func, x) + end + return hess +end + +function reshape_∂G∂θ(H) + d = size(H, 2) + return cat((H[((i - 1) * d + 1):(i * d), :] for i in 1:d)...; dims=3) +end -# Taken from https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/test/finitedifftests.jl -δ(a, b) = maximum(abs.(a - b)) +function prepare_sample(ℓπ, initial_θ, λ) + Vfunc = x -> -ℓπ(x) + _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, initial_θ) # x -> (value, gradient, hessian) + Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug -@testset "Riemannian" begin - hps = (; λ=1e-2, α=20.0, ϵ=0.1, n=6, L=8) + fstabilize = H -> H + λ * I + Gfunc = x -> begin + H = fstabilize(Hfunc(x)[3]) + all(isfinite, H) ? H : diagm(ones(length(x))) + end + _∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize) + ∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x)) + + return Vfunc, Hfunc, Gfunc, ∂G∂θfunc +end +@testset "Constructors tests" begin + δ(a, b) = maximum(abs.(a - b)) @testset "$(nameof(typeof(target)))" for target in [HighDimGaussian(2), Funnel()] rng = MersenneTwister(1110) + λ = 1e-2 θ₀ = rand(rng, dim(target)) ℓπ = MCMCLogDensityProblems.gen_logpdf(target) ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀) - Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample_target(hps, θ₀, ℓπ) + Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample(ℓπ, θ₀, λ) D = dim(target) # ==2 for this test x = zeros(D) # randn(rng, D) @@ -36,7 +71,7 @@ using AdvancedHMC: neg_energy, energy end @testset "$(nameof(typeof(hessmap)))" for hessmap in - [IdentityMap(), SoftAbsMap(hps.α)] + [IdentityMap(), SoftAbsMap(20.0)] metric = DenseRiemannianMetric((D,), Gfunc, ∂G∂θfunc, hessmap) kinetic = GaussianKinetic() hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) @@ -67,3 +102,62 @@ using AdvancedHMC: neg_energy, energy end end end + +@testset "Multi variate Normal with Riemannian HMC" begin + # Set the number of samples to draw and warmup iterations + n_samples = 2_000 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + λ = 1e-2 + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + # Define a Hamiltonian system + metric = DenseRiemannianMetric((D,), G, ∂G∂θ) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + # Define a leapfrog solver, with the initial step size chosen heuristically + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) + + # Define an HMC sampler with the following components + # - multinomial sampling scheme, + # - generalised No-U-Turn criteria, and + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + + # Run the sampler to draw samples from the specified Gaussian, where + # - `samples` will store the samples + # - `stats` will store diagnostic statistics for each sample + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) + @test length(samples) == n_samples + @test length(stats) == n_samples +end + +@testset "Multi variate Normal with Riemannian HMC softabs metric" begin + # Set the number of samples to draw and warmup iterations + n_samples = 2_000 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + λ = 1e-2 + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + + # Define a Hamiltonian system + metric = DenseRiemannianMetric((D,), G, ∂G∂θ, λSoftAbsMap(20.0)) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + # Define a leapfrog solver, with the initial step size chosen heuristically + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) + + # Define an HMC sampler with the following components + # - multinomial sampling scheme, + # - generalised No-U-Turn criteria, and + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + + # Run the sampler to draw samples from the specified Gaussian, where + # - `samples` will store the samples + # - `stats` will store diagnostic statistics for each sample + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) + @test length(samples) == n_samples + @test length(stats) == n_samples +end diff --git a/test/trajectory.jl b/test/trajectory.jl index 403fd4463..4bf0ac4d1 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -257,46 +257,35 @@ end traj_r = hcat(map(z -> z.r, traj_z)...) rho = cumsum(traj_r; dims=2) - ts_hand_isturn_fwd = - hand_isturn.( - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) - ts_ahmc_isturn_fwd = - ahmc_isturn.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_fwd = hand_isturn.( + Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) + ) + ts_ahmc_isturn_fwd = ahmc_isturn.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_hand_isturn_generalised_fwd = - hand_isturn_generalised.( - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) - ts_ahmc_isturn_generalised_fwd = - ahmc_isturn_generalised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_generalised_fwd = hand_isturn_generalised.( + Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) + ) + ts_ahmc_isturn_generalised_fwd = ahmc_isturn_generalised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_ahmc_isturn_strictgeneralised_fwd = - ahmc_isturn_strictgeneralised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strictgeneralised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) check_subtree_u_turns.( Ref(h), Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)] From 1fcdd0982625ee4e04416fcf237c82d22a1870ac Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sat, 10 May 2025 17:57:47 +0800 Subject: [PATCH 02/19] Format --- src/trajectory.jl | 2 +- test/demo.jl | 10 +++---- test/integrator.jl | 5 ++-- test/trajectory.jl | 65 +++++++++++++++++++++++++++------------------- 4 files changed, 45 insertions(+), 37 deletions(-) diff --git a/src/trajectory.jl b/src/trajectory.jl index aa8c90cae..2764993df 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -134,7 +134,7 @@ Slice sampler for the starting single leaf tree. Slice variable is initialized. """ function SliceTS(rng::AbstractRNG, z0::PhasePoint) - SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) + return SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) end """ diff --git a/test/demo.jl b/test/demo.jl index c9010a7f6..98315daa3 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -10,9 +10,8 @@ using LinearAlgebra, ADTypes LogDensityProblems.logdensity(p::DemoProblem, θ) = logpdf(MvNormal(zeros(p.dim), I), θ) LogDensityProblems.dimension(p::DemoProblem) = p.dim - LogDensityProblems.capabilities(::Type{DemoProblem}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{DemoProblem}) = + LogDensityProblems.LogDensityOrder{0}() # Choose parameter dimensionality and initial parameter value D = 10 @@ -67,9 +66,8 @@ end return -((1 - p.μ) / p.σ)^2 end LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2 - LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = + LogDensityProblems.LogDensityOrder{0}() ℓπ = DemoProblemComponentArrays() diff --git a/test/integrator.jl b/test/integrator.jl index f5a3dbea4..b9eb14076 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -112,9 +112,8 @@ using Statistics: mean LogDensityProblems.logdensity(::NegU, x) = -dot(x, x) / 2 LogDensityProblems.dimension(d::NegU) = d.dim - LogDensityProblems.capabilities(::Type{NegU}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{NegU}) = + LogDensityProblems.LogDensityOrder{0}() negU = NegU(1) diff --git a/test/trajectory.jl b/test/trajectory.jl index 4bf0ac4d1..403fd4463 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -257,35 +257,46 @@ end traj_r = hcat(map(z -> z.r, traj_z)...) rho = cumsum(traj_r; dims=2) - ts_hand_isturn_fwd = hand_isturn.( - Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) - ) - ts_ahmc_isturn_fwd = ahmc_isturn.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_fwd = + hand_isturn.( + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) + ts_ahmc_isturn_fwd = + ahmc_isturn.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_hand_isturn_generalised_fwd = hand_isturn_generalised.( - Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) - ) - ts_ahmc_isturn_generalised_fwd = ahmc_isturn_generalised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_generalised_fwd = + hand_isturn_generalised.( + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) + ts_ahmc_isturn_generalised_fwd = + ahmc_isturn_generalised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strictgeneralised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_ahmc_isturn_strictgeneralised_fwd = + ahmc_isturn_strictgeneralised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) check_subtree_u_turns.( Ref(h), Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)] From 567e2a82dfd91c2482f84e7f3ea4743b16cdc9c4 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Mon, 30 Jun 2025 02:23:44 +0800 Subject: [PATCH 03/19] format --- src/sampler.jl | 6 +++--- src/trajectory.jl | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index e0138819c..c0a426814 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -117,7 +117,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=(pm_next!), + (pm_next!)::Function=pm_next!, ) return sample( Random.default_rng(), @@ -130,7 +130,7 @@ function sample( drop_warmup=drop_warmup, verbose=verbose, progress=progress, - (pm_next!)=(pm_next!), + (pm_next!)=pm_next!, ) end @@ -168,7 +168,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=(pm_next!), + (pm_next!)::Function=pm_next!, ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results diff --git a/src/trajectory.jl b/src/trajectory.jl index bcbbc25f0..2a4eb98df 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -141,9 +141,8 @@ $(TYPEDEF) Slice sampler for the starting single leaf tree. Slice variable is initialized. """ -function SliceTS(rng::AbstractRNG, z0::PhasePoint) - return SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) -end +SliceTS(rng::AbstractRNG, z0::PhasePoint) = + SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) """ $(TYPEDEF) From c9e6b0af7dfc273f6310a403869698c2510d2078 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Fri, 21 Nov 2025 23:10:34 +0800 Subject: [PATCH 04/19] Include Riemannian HMC tests --- test/riemannian.jl | 2 +- test/runtests.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index 0cfcb8233..f11522151 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -141,7 +141,7 @@ end _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) # Define a Hamiltonian system - metric = DenseRiemannianMetric((D,), G, ∂G∂θ, λSoftAbsMap(20.0)) + metric = DenseRiemannianMetric((D,), G, ∂G∂θ, SoftAbsMap(20.0)) kinetic = GaussianKinetic() hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) diff --git a/test/runtests.jl b/test/runtests.jl index d0fb6ea88..fa816e8b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ if GROUP == "All" || GROUP == "AdvancedHMC" include("abstractmcmc.jl") include("mcmcchains.jl") include("constructors.jl") + include("riemannian.jl") retest(; dry=false, verbose=Inf) end From 280ca158ada69f8c54bf2270a508db840f7c0eb5 Mon Sep 17 00:00:00 2001 From: Nikolas Siccha Date: Thu, 18 Dec 2025 16:39:39 +0100 Subject: [PATCH 05/19] start minimal refactor for merging into main --- docs/src/api.md | 12 ++- src/riemannian/hamiltonian.jl | 138 ++++++++++++++++++---------------- 2 files changed, 83 insertions(+), 67 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index e7caf2d0c..1a39d5ba4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,9 +8,17 @@ This modularity means that different HMC variants can be easily constructed by c - Unit metric: `UnitEuclideanMetric(dim)` - Diagonal metric: `DiagEuclideanMetric(dim)` - Dense metric: `DenseEuclideanMetric(dim)` - - Dense Riemannian metric: `DenseRiemannianMetric(size, G, ∂G∂θ)` -where `dim` is the dimensionality of the sampling space. +where `dim` is the dimension of the sampling space. + +Furthermore, there is now an experimental dense Riemannian metric implementation, specifiable as `DenseRiemannianMetric(dim, premetric, premetric_sensitivities, metric_map=IdentityMap())`, with + + - `dim`: again the dimension of the sampling space, + - `premetric`: a function which, for a given posterior position `pos`, computes either + a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or + b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`), + - `premetric_sensitivities`: a function which, again for a given posterior position `pos`, computes the sensitivities with respect to this position of the **`premetric`** function, + - `metric_map=IdentityMap()`: a function which takes in `premetric(pos)` and returns a symmetric positive definite matrix. Provided options are `IdentityMap()` or `SoftAbsMap(alpha)`, with the `SoftAbsMap` type allowing to work directly with the `premetric` returning the Hessian of the log density function, which generally is not guaranteed to be positive definite.. ### [Integrator (`integrator`)](@id integrator) diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index f8acc7971..11266b5bb 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,49 +1,74 @@ #! Eq (14) of Girolami & Calderhead (2011) +"The gradient of the Hamiltonian with respect to the momentum." function ∂H∂r( h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, - θ::AbstractVecOrMat, - r::AbstractVecOrMat, + θ::AbstractVector, + r::AbstractVector, ) H = h.metric.G(θ) G = h.metric.map(H) - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't + return G \ r end +""" +Computes `tr(A*B)` for square n x n matrices `A` and `B` in O(n^2) without computing `A*B`, which would be O(n^3). + +Doesn't actually check that A and B are both n x n matrices. +""" +tr_product(A::AbstractMatrix, B::AbstractMatrix) = sum(Base.broadcasted(*, A', B)) +"Computes `tr(A*v*v')`, i.e. dot(v,A,v)." +tr_product(A::AbstractMatrix, v::AbstractVector) = sum(Base.broadcasted(*, v, A, v')) + + function ∂H∂θ( + h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, + θ::AbstractVector, + r::AbstractVector, +) + return first(∂H∂θ_cache(h, θ, r)) +end +""" + +""" +@views function ∂H∂θ_cache( h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, - θ::AbstractVecOrMat{T}, - r::AbstractVecOrMat{T}, + θ::AbstractVector{T}, + r::AbstractVector{T}; + cache=nothing ) where {T} - ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) - G = h.metric.map(h.metric.G(θ)) - invG = inv(G) - ∂G∂θ = h.metric.∂G∂θ(θ) - d = length(∂ℓπ∂θ) + cache = @something cache begin + log_density, log_density_gradient = h.∂ℓπ∂θ(θ) + # h.metric.map is the IdentityMap + metric = h.metric.G(θ) + # The metric is inverted to be able to compute `tr_product(inv_metric, ...)` efficiently - + # but this may still be a bad idea! + inv_metric = inv(metric) + metric_sensitivities = h.metric.∂G∂θ(θ) + rv1 = map(eachindex(log_density_gradient)) do i + -log_density_gradient[i] + .5 * tr_product(inv_metric, metric_sensitivities[:, :, i]) + end + (;log_density, inv_metric, metric_sensitivities, rv1) + end + # (;log_density, inv_metric_r, metric_sensitivities, rv1) = cache + inv_metric_r = cache.inv_metric * r return DualValue( - ℓπ, + cache.log_density, #! Eq (15) of Girolami & Calderhead (2011) - -mapreduce(vcat, 1:d) do i - ∂G∂θᵢ = ∂G∂θ[:, :, i] - ∂ℓπ∂θ[i] - 1 / 2 * tr(invG * ∂G∂θᵢ) + 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r - # Gr = G \ r - # ∂ℓπ∂θ[i] - 1 / 2 * tr(G \ ∂G∂θᵢ) + 1 / 2 * Gr' * ∂G∂θᵢ * Gr - # 1 / 2 * tr(invG * ∂G∂θᵢ) - # 1 / 2 * r' * invG * ∂G∂θᵢ * invG * r - end, - ) + cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i + .5 * tr_product(cache.metric_sensitivities[:, :, i], inv_metric_r) + end + ), cache end -# Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 -#! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" -dsoftabsdλ(α, λ) = coth(α * λ) + λ * α * -csch(λ * α)^2 - #! J as defined in middle of the right column of Page 3 of Betancourt (2012) function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} d = length(λ) J = Matrix{T}(undef, d, d) for i in 1:d, j in 1:d J[i, j] = if (λ[i] == λ[j]) - dsoftabsdλ(α, λ[i]) + # Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 + #! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" + coth(α * λ[i]) + λ[i] * α * -csch(λ[i] * α)^2 else ((λ[i] * coth(α * λ[i]) - λ[j] * coth(α * λ[j])) / (λ[i] - λ[j])) end @@ -51,54 +76,37 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} return J end -function ∂H∂θ( +@views function ∂H∂θ_cache( h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, - θ::AbstractVecOrMat{T}, - r::AbstractVecOrMat{T}, -) where {T} - return ∂H∂θ_cache(h, θ, r) -end -function ∂H∂θ_cache( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, - θ::AbstractVecOrMat{T}, - r::AbstractVecOrMat{T}; - return_cache=false, + θ::AbstractVector{T}, + r::AbstractVector{T}; cache=nothing, ) where {T} - # Terms that only dependent on θ can be cached in θ-unchanged loops - if isnothing(cache) - ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) - H = h.metric.G(θ) - ∂H∂θ = h.metric.∂G∂θ(θ) - - G, Q, λ, softabsλ = softabs(H, h.metric.map.α) - - R = diagm(1 ./ softabsλ) - - # softabsΛ = diagm(softabsλ) - # M = inv(softabsΛ) * Q' * r - # M = R * Q' * r # equiv to above but avoid inv - + cache = @something cache begin + log_density, log_density_gradient = h.∂ℓπ∂θ(θ) + premetric = h.metric.G(θ) + premetric_sensitivities = h.metric.∂G∂θ(θ) + metric, Q, λ, softabsλ = softabs(premetric, h.metric.map.α) J = make_J(λ, h.metric.map.α) #! Based on the two equations from the right column of Page 3 of Betancourt (2012) - term_1_cached = Q * (R .* J) * Q' - else - ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached = cache - end - d = length(∂ℓπ∂θ) - D = diagm((Q' * r) ./ softabsλ) - term_2_cached = Q * D * J * D * Q' - g = - -mapreduce(vcat, 1:d) do i - ∂H∂θᵢ = ∂H∂θ[:, :, i] - # ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1) - # NOTE Some further optimization can be done here: cache the 1st product all together - ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly + tmpv = diag(J) ./ softabsλ + tmpm = Q * Diagonal(tmpv) * Q' + + rv1 = map(eachindex(log_density_gradient)) do i + -log_density_gradient[i] + .5 * tr_product(tmpm, premetric_sensitivities[:, :, i]) end + (;log_density, Q, softabsλ, tmpv, tmpm, rv1) + end + cache.tmpv .= (cache.Q' * r) ./ cache.softabsλ + cache.tmpm .= Q * (J .* cache.tmpv .* cache.tmpv') * Q' - dv = DualValue(ℓπ, g) - return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv + return DualValue( + cache.log_density, + cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i + .5 * tr_product(cache.tmpm, cache.premetric_sensitivities[:, :, i]) + end + ), cache end # QUES Do we want to change everything to position dependent by default? From 64213106df1a5dd4b54f091420952e6ef4779cbc Mon Sep 17 00:00:00 2001 From: THargreaves Date: Wed, 7 Jan 2026 16:59:02 +0000 Subject: [PATCH 06/19] Implement unified Riemannian metric --- src/AdvancedHMC.jl | 10 +- src/riemannian/hamiltonian.jl | 311 ++++++++++++++++++++++----------- src/riemannian/integrator.jl | 35 ++-- src/riemannian/metric.jl | 317 +++++++++++++++++++++++++++++++--- test/riemannian.jl | 276 ++++++++++++++++++++++++----- 5 files changed, 756 insertions(+), 193 deletions(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 4aa3f145f..ca2f96487 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -4,6 +4,7 @@ using Statistics: mean, var, middle using LinearAlgebra: Symmetric, UpperTriangular, + Diagonal, mul!, ldiv!, dot, @@ -63,6 +64,8 @@ include("riemannian/integrator.jl") export GeneralizedLeapfrog include("riemannian/metric.jl") +export RiemannianMetric, SoftAbsRiemannianMetric +# Deprecated exports (for backward compatibility) export IdentityMap, SoftAbsMap, DenseRiemannianMetric include("riemannian/hamiltonian.jl") @@ -89,7 +92,12 @@ export find_good_eps include("adaptation/Adaptation.jl") using .Adaptation import .Adaptation: - StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation, PositionOrPhasePoint + StepSizeAdaptor, + MassMatrixAdaptor, + StanHMCAdaptor, + NesterovDualAveraging, + NoAdaptation, + PositionOrPhasePoint # Helpers for initializing adaptors via AHMC structs diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 11266b5bb..3422fc2a7 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -1,134 +1,235 @@ -#! Eq (14) of Girolami & Calderhead (2011) -"The gradient of the Hamiltonian with respect to the momentum." -function ∂H∂r( - h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, +""" + tr_product(A, B) + +Compute `tr(A * B)` for square matrices in O(n²) without forming the product. +Uses the identity: tr(A * B) = sum(A' .* B) +""" +tr_product(A::AbstractMatrix, B::AbstractMatrix) = sum(Base.broadcasted(*, A', B)) + +""" + tr_product(A, v) + +Compute `tr(A * v * v')` = v' * A * v efficiently. +""" +tr_product(A::AbstractMatrix, v::AbstractVector) = dot(v, A * v) + +#### +#### Gradient cache for θ-dependent computations +#### + +""" + RiemannianGradCache{T, TG, TP} + +Cache for θ-dependent computations in Riemannian HMC gradient calculation. +This allows reusing expensive eigendecomposition/factorization across fixed-point iterations. + +# Fields +- `G_eval`: Evaluated metric (SoftAbsEval or matrix) +- `∂P∂θ`: Pre-metric sensitivities, shape (d, d, d) +- `ℓπ`: Log density value at θ +- `∂ℓπ∂θ`: Log density gradient at θ +- `logdet_terms`: Precomputed 0.5 * tr(M_logdet * ∂P∂θ[:,:,i]) for each i +""" +struct RiemannianGradCache{T,TG,TP} + G_eval::TG + ∂P∂θ::TP + ℓπ::T + ∂ℓπ∂θ::Vector{T} + logdet_terms::Vector{T} +end + +""" + build_grad_cache(h::Hamiltonian{<:AbstractRiemannianMetric}, θ) + +Build cache for gradient computation at position θ. +Computes all θ-dependent quantities that can be reused across r values. +""" +function build_grad_cache( + h::Hamiltonian{<:AbstractRiemannianMetric}, θ::AbstractVector{T} +) where {T} + # Evaluate log density and gradient + ℓπ, ∂ℓπ∂θ = h.∂ℓπ∂θ(θ) + + # Evaluate metric and sensitivities + G_eval = metric_eval(h.metric, θ) + ∂P∂θ = metric_sensitivity(h.metric, θ) + + # Get logdet gradient matrix and precompute logdet gradient terms + M_logdet = logdet_grad_matrix(G_eval) + d = size(∂P∂θ, 3) + logdet_terms = Vector{T}(undef, d) + @inbounds for i in 1:d + ∂Pᵢ = @view ∂P∂θ[:, :, i] + logdet_terms[i] = T(0.5) * tr_product(M_logdet, ∂Pᵢ) + end + + return RiemannianGradCache(G_eval, ∂P∂θ, ℓπ, ∂ℓπ∂θ, logdet_terms) +end + +""" + ∂H∂θ_from_cache(cache::RiemannianGradCache, r) + +Compute Hamiltonian gradient ∂H/∂θ using cached θ-dependent values. +Only performs r-dependent computation (kinetic gradient matrix and trace products). +""" +function ∂H∂θ_from_cache(cache::RiemannianGradCache{T}, r::AbstractVector) where {T} + # Compute kinetic gradient matrix (r-dependent) + M_kinetic = kinetic_grad_matrix(cache.G_eval, r) + + # Compute full gradient + d = length(cache.∂ℓπ∂θ) + grad = Vector{T}(undef, d) + + @inbounds for i in 1:d + ∂Pᵢ = @view cache.∂P∂θ[:, :, i] + # ∂H/∂θᵢ = -∂ℓπ/∂θᵢ + 0.5*tr(M_logdet*∂P/∂θᵢ) - 0.5*tr(M_kinetic*∂P/∂θᵢ) + kinetic_term = T(0.5) * tr_product(M_kinetic, ∂Pᵢ) + grad[i] = -cache.∂ℓπ∂θ[i] + cache.logdet_terms[i] - kinetic_term + end + + return DualValue(cache.ℓπ, grad) +end + +#### +#### Main gradient interface +#### + +""" + ∂H∂θ(h::Hamiltonian{<:AbstractRiemannianMetric}, θ, r) + +Compute the gradient of the Hamiltonian with respect to position θ. +Returns a DualValue containing (log_density, gradient). + +Ref: Eq (15) of Girolami & Calderhead (2011) +""" +function ∂H∂θ( + h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, θ::AbstractVector, r::AbstractVector, ) - H = h.metric.G(θ) - G = h.metric.map(H) - return G \ r + cache = build_grad_cache(h, θ) + return ∂H∂θ_from_cache(cache, r) end """ -Computes `tr(A*B)` for square n x n matrices `A` and `B` in O(n^2) without computing `A*B`, which would be O(n^3). + ∂H∂θ_cache(h, θ, r; cache=nothing) + +Compute ∂H/∂θ with optional caching for fixed-point iterations. +Returns (DualValue, cache) tuple. -Doesn't actually check that A and B are both n x n matrices. +When cache is provided, reuses θ-dependent computations (eigendecomposition, +logdet gradient terms) and only recomputes r-dependent terms. """ -tr_product(A::AbstractMatrix, B::AbstractMatrix) = sum(Base.broadcasted(*, A', B)) -"Computes `tr(A*v*v')`, i.e. dot(v,A,v)." -tr_product(A::AbstractMatrix, v::AbstractVector) = sum(Base.broadcasted(*, v, A, v')) +function ∂H∂θ_cache( + h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, + θ::AbstractVector, + r::AbstractVector; + cache=nothing, +) + cache = @something cache build_grad_cache(h, θ) + return ∂H∂θ_from_cache(cache, r), cache +end +#### +#### Momentum gradient ∂H/∂r +#### -function ∂H∂θ( +""" + ∂H∂r(h::Hamiltonian{<:AbstractRiemannianMetric}, θ, r; G_eval=nothing) + +Compute the gradient of the Hamiltonian with respect to momentum r. +For Riemannian metrics: ∂H/∂r = G(θ)⁻¹ * r + +If `G_eval` is provided, uses it directly instead of recomputing the metric. + +Ref: Eq (14) of Girolami & Calderhead (2011) +""" +function ∂H∂r( + h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, + θ::AbstractVector, + r::AbstractVector; + G_eval=nothing, +) + G = @something G_eval metric_eval(h.metric, θ) + return G \ r +end + +# Non-keyword version for backward compatibility with integrator +function ∂H∂r( h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, θ::AbstractVector, r::AbstractVector, ) - return first(∂H∂θ_cache(h, θ, r)) + G_eval = metric_eval(h.metric, θ) + return G_eval \ r end + +#### +#### Negative energy (log probability) +#### + """ + neg_energy(h::Hamiltonian{<:AbstractRiemannianMetric}, r, θ; G_eval=nothing) + +Compute the negative kinetic energy for Riemannian metrics. +Includes the log-determinant normalization term since G depends on θ. + +If `G_eval` is provided, uses it directly instead of recomputing the metric. +K(r, θ) = 0.5 * (D*log(2π) + log|G(θ)| + r'G(θ)⁻¹r) +neg_energy = -K = -0.5 * (D*log(2π) + log|G(θ)| + r'G(θ)⁻¹r) + +Ref: Eq (13) of Girolami & Calderhead (2011) """ -@views function ∂H∂θ_cache( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, - θ::AbstractVector{T}, - r::AbstractVector{T}; - cache=nothing -) where {T} - cache = @something cache begin - log_density, log_density_gradient = h.∂ℓπ∂θ(θ) - # h.metric.map is the IdentityMap - metric = h.metric.G(θ) - # The metric is inverted to be able to compute `tr_product(inv_metric, ...)` efficiently - - # but this may still be a bad idea! - inv_metric = inv(metric) - metric_sensitivities = h.metric.∂G∂θ(θ) - rv1 = map(eachindex(log_density_gradient)) do i - -log_density_gradient[i] + .5 * tr_product(inv_metric, metric_sensitivities[:, :, i]) - end - (;log_density, inv_metric, metric_sensitivities, rv1) - end - # (;log_density, inv_metric_r, metric_sensitivities, rv1) = cache - inv_metric_r = cache.inv_metric * r - return DualValue( - cache.log_density, - #! Eq (15) of Girolami & Calderhead (2011) - cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i - .5 * tr_product(cache.metric_sensitivities[:, :, i], inv_metric_r) - end - ), cache -end +function neg_energy( + h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, + r::AbstractVector, + θ::AbstractVector; + G_eval=nothing, +) + G = @something G_eval metric_eval(h.metric, θ) + D = length(r) -#! J as defined in middle of the right column of Page 3 of Betancourt (2012) -function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} - d = length(λ) - J = Matrix{T}(undef, d, d) - for i in 1:d, j in 1:d - J[i, j] = if (λ[i] == λ[j]) - # Ref: https://www.wolframalpha.com/input?i=derivative+of+x+*+coth%28a+*+x%29 - #! Based on middle of the right column of Page 3 of Betancourt (2012) "Note that whenλi=λj, such as for the diagonal elementsor degenerate eigenvalues, this becomes the derivative" - coth(α * λ[i]) + λ[i] * α * -csch(λ[i] * α)^2 - else - ((λ[i] * coth(α * λ[i]) - λ[j] * coth(α * λ[j])) / (λ[i] - λ[j])) - end - end - return J + # Quadratic form: r' * G⁻¹ * r + G_inv_r = G \ r + quadform = dot(r, G_inv_r) + + # Log normalization constant (position-dependent) + logZ = (D * log(2π) + logdet(G)) / 2 + + return -logZ - quadform / 2 end -@views function ∂H∂θ_cache( - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, - θ::AbstractVector{T}, - r::AbstractVector{T}; - cache=nothing, -) where {T} - cache = @something cache begin - log_density, log_density_gradient = h.∂ℓπ∂θ(θ) - premetric = h.metric.G(θ) - premetric_sensitivities = h.metric.∂G∂θ(θ) - metric, Q, λ, softabsλ = softabs(premetric, h.metric.map.α) - J = make_J(λ, h.metric.map.α) - - #! Based on the two equations from the right column of Page 3 of Betancourt (2012) - tmpv = diag(J) ./ softabsλ - tmpm = Q * Diagonal(tmpv) * Q' - - rv1 = map(eachindex(log_density_gradient)) do i - -log_density_gradient[i] + .5 * tr_product(tmpm, premetric_sensitivities[:, :, i]) - end - (;log_density, Q, softabsλ, tmpv, tmpm, rv1) - end - cache.tmpv .= (cache.Q' * r) ./ cache.softabsλ - cache.tmpm .= Q * (J .* cache.tmpv .* cache.tmpv') * Q' - - return DualValue( - cache.log_density, - cache.rv1 .- Base.broadcasted(eachindex(cache.rv1)) do i - .5 * tr_product(cache.tmpm, cache.premetric_sensitivities[:, :, i]) - end - ), cache +# Non-keyword version for backward compatibility +function neg_energy( + h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, + r::AbstractVector, + θ::AbstractVector, +) + G_eval = metric_eval(h.metric, θ) + return neg_energy(h, r, θ; G_eval=G_eval) end -# QUES Do we want to change everything to position dependent by default? -# Add θ to ∂H∂r for DenseRiemannianMetric +#### +#### Phase point construction +#### + +""" +Create a PhasePoint for Riemannian metrics, computing position-dependent kinetic energy. +Shares the metric evaluation between neg_energy and ∂H∂r to avoid redundant computation. +""" function phasepoint( - h::Hamiltonian{<:DenseRiemannianMetric}, + h::Hamiltonian{<:AbstractRiemannianMetric}, θ::T, r::T; - ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), + ℓπ=∂H∂θ(h, θ, r), + G_eval=nothing, + ℓκ=nothing, ) where {T<:AbstractVecOrMat} + if isnothing(ℓκ) + # Compute G_eval once and share between neg_energy and ∂H∂r + G = @something G_eval metric_eval(h.metric, θ) + ℓκ = DualValue(neg_energy(h, r, θ; G_eval=G), ∂H∂r(h, θ, r; G_eval=G)) + end return PhasePoint(θ, r, ℓπ, ℓκ) end - -#! Eq (13) of Girolami & Calderhead (2011) -function neg_energy( - h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T -) where {T<:AbstractVecOrMat} - G = h.metric.map(h.metric.G(θ)) - D = size(G, 1) - # Need to consider the normalizing term as it is no longer same for different θs - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined - mul!(h.metric._temp, inv(G), r) - return -logZ - dot(r, h.metric._temp) / 2 -end diff --git a/src/riemannian/integrator.jl b/src/riemannian/integrator.jl index a6d2de5fc..b582fc1c8 100644 --- a/src/riemannian/integrator.jl +++ b/src/riemannian/integrator.jl @@ -15,18 +15,14 @@ $(TYPEDFIELDS) struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} "Step size." ϵ::T + "Number of fixed-point iterations for implicit steps." n::Int end + function Base.show(io::IO, l::GeneralizedLeapfrog) return print(io, "GeneralizedLeapfrog(ϵ=", round.(l.ϵ; sigdigits=3), ", n=", l.n, ")") end -# fallback to ignore return_cache & cache kwargs for other ∂H∂θ -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) - dv = ∂H∂θ(h, θ, r) - return return_cache ? (dv, nothing) : dv -end - # TODO(Kai) make sure vectorization works # TODO(Kai) check if tempering is valid # TODO(Kai) abstract out the 3 main steps and merge with `step` in `integrator.jl` @@ -55,42 +51,43 @@ function step( for i in 1:n_steps θ_init, r_init = z.θ, z.r - # Tempering - #r = temper(lf, r, (i=i, is_half=true), n_steps) - # eq (16) of Girolami & Calderhead (2011) + + # Eq (16) of Girolami & Calderhead (2011) - implicit momentum half-step r_half = r_init local cache = nothing for j in 1:(lf.n) - # Reuse cache for the first iteration if j == 1 + # First iteration: use cached values from phase point (; value, gradient) = z.ℓπ - elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) + else + # Subsequent iterations: build/reuse cache for θ-dependent computations + retval, cache = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) (; value, gradient) = retval - else # reuse cache - (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) end r_half = r_init - ϵ / 2 * gradient end - # eq (17) of Girolami & Calderhead (2011) + + # Eq (17) of Girolami & Calderhead (2011) - implicit position step θ_full = θ_init - term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop + term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop for j in 1:(lf.n) θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) end - # eq (18) of Girolami & Calderhead (2011) + + # Eq (18) of Girolami & Calderhead (2011) - explicit momentum half-step (; value, gradient) = ∂H∂θ(h, θ_full, r_half) r_full = r_half - ϵ / 2 * gradient - # Tempering - #r = temper(lf, r, (i=i, is_half=false), n_steps) + # Create a new phase point by caching the logdensity and gradient z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) + # Update result if FullTraj res[i] = z else res = z end + if !isfinite(z) # Remove undef if FullTraj diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index 41d11127c..da79c453c 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -1,9 +1,265 @@ +#### +#### Riemannian Metric Types +#### + +""" +Abstract type for Riemannian (position-dependent) metrics. + +Subtypes must implement: +- `metric_eval(metric, θ)` - evaluate metric at position θ +- `metric_sensitivity(metric, θ)` - compute ∂P/∂θ where P is the "pre-metric" + (G itself for RiemannianMetric, or H the Hessian for SoftAbsRiemannianMetric) +""" abstract type AbstractRiemannianMetric <: AbstractMetric end +#### +#### SoftAbsEval - cached eigendecomposition for SoftAbs metrics +#### + +""" + SoftAbsEval{T} + +Cached result of evaluating a SoftAbs metric at a position θ. +Stores eigendecomposition and precomputed matrices for efficient gradient computation. + +# Fields +- `Q`: Eigenvectors (orthogonal matrix) +- `softabsλ`: Transformed eigenvalues: λᵢ * coth(α * λᵢ) +- `J`: Jacobian matrix encoding the derivative of softabs (divided difference formula) +- `M_logdet`: Precomputed matrix Q * (R .* J) * Q' for logdet gradient +""" +struct SoftAbsEval{T<:AbstractFloat} + Q::Matrix{T} + softabsλ::Vector{T} + J::Matrix{T} + M_logdet::Matrix{T} +end + +# Standard operations for SoftAbsEval +function Base.:\(G::SoftAbsEval, p::AbstractVector) + return G.Q * ((G.Q' * p) ./ G.softabsλ) +end + +function LinearAlgebra.logdet(G::SoftAbsEval) + return sum(log, G.softabsλ) +end + +""" + unwhiten(G::SoftAbsEval, z) + +Transform z ~ N(0, I) to sample from N(0, G). +""" +function unwhiten(G::SoftAbsEval, z::AbstractVector) + return G.Q * (sqrt.(G.softabsλ) .* z) +end + +#### +#### RiemannianMetric - for user-provided PD metrics +#### + +""" + RiemannianMetric{TG, T∂G} + +Riemannian metric where the user provides a function returning a positive-definite +matrix (or AbstractPDMat subtype). + +# Fields +- `size`: Tuple{Int} giving the dimension +- `calc_G`: Function θ → G(θ), returns a positive-definite matrix +- `calc_∂G∂θ`: Function θ → ∂G/∂θ, returns Array{T,3} of shape (d, d, d) + +# Example +```julia +# Simple Fisher information metric +calc_G = θ -> PDMat(fisher_information(θ)) +calc_∂G∂θ = θ -> ForwardDiff.jacobian(θ -> vec(fisher_information(θ)), θ) |> reshape_∂G∂θ +metric = RiemannianMetric((d,), calc_G, calc_∂G∂θ) +``` +""" +struct RiemannianMetric{TG,T∂G} <: AbstractRiemannianMetric + size::Tuple{Int} + calc_G::TG # θ → Matrix or AbstractPDMat + calc_∂G∂θ::T∂G # θ → Array{T,3} +end + +Base.size(m::RiemannianMetric) = m.size +Base.size(m::RiemannianMetric, dim::Int) = m.size[dim] + +function Base.show(io::IO, m::RiemannianMetric) + return print(io, "RiemannianMetric(size=", m.size, ")") +end + +# Interface implementations for RiemannianMetric +metric_eval(m::RiemannianMetric, θ) = m.calc_G(θ) +metric_sensitivity(m::RiemannianMetric, θ) = m.calc_∂G∂θ(θ) + +#### +#### SoftAbsRiemannianMetric - for Hessian-based metrics with SoftAbs regularization +#### + +""" + SoftAbsRiemannianMetric{T, TH, T∂H} + +Riemannian metric based on the SoftAbs transformation of a Hessian. +The Hessian may not be positive-definite; the SoftAbs transformation +G = Q * diag(λ * coth(α*λ)) * Q' guarantees positive-definiteness. + +# Fields +- `size`: Tuple{Int} giving the dimension +- `calc_H`: Function θ → H(θ), returns the Hessian matrix (the "pre-metric") +- `calc_∂H∂θ`: Function θ → ∂H/∂θ, returns Array{T,3} of shape (d, d, d) +- `α`: SoftAbs regularization parameter (larger = closer to |λ|) + +# References +- Betancourt, M. "A general metric for Riemannian manifold Hamiltonian Monte Carlo" (2012) +""" +struct SoftAbsRiemannianMetric{T<:AbstractFloat,TH,T∂H} <: AbstractRiemannianMetric + size::Tuple{Int} + calc_H::TH # θ → Hessian matrix (pre-metric) + calc_∂H∂θ::T∂H # θ → Array{T,3} + α::T +end + +Base.size(m::SoftAbsRiemannianMetric) = m.size +Base.size(m::SoftAbsRiemannianMetric, dim::Int) = m.size[dim] +Base.eltype(::SoftAbsRiemannianMetric{T}) where {T} = T + +function Base.show(io::IO, m::SoftAbsRiemannianMetric) + return print(io, "SoftAbsRiemannianMetric(size=", m.size, ", α=", m.α, ")") +end + +""" + make_J(λ, α) + +Construct the J matrix for softabs gradient computation. +J encodes the derivative of the softabs transformation using the divided difference formula. + +For i ≠ j: J[i,j] = (softabs(λᵢ) - softabs(λⱼ)) / (λᵢ - λⱼ) +For i = j: J[i,i] = d/dλ [λ coth(αλ)] = coth(αλ) - αλ csch²(αλ) + +# References +- Betancourt (2012) +""" +function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat} + d = length(λ) + J = Matrix{T}(undef, d, d) + @inbounds for i in 1:d, j in 1:d + if λ[i] == λ[j] + # Derivative case (diagonal or degenerate eigenvalues) + # d/dλ [λ coth(αλ)] = coth(αλ) - αλ csch²(αλ) + J[i, j] = coth(α * λ[i]) - α * λ[i] * csch(α * λ[i])^2 + else + # Divided difference + J[i, j] = (λ[i] * coth(α * λ[i]) - λ[j] * coth(α * λ[j])) / (λ[i] - λ[j]) + end + end + return J +end + +""" + metric_eval(m::SoftAbsRiemannianMetric, θ) + +Evaluate SoftAbs metric at position θ, returning a `SoftAbsEval` with cached matrices. +""" +function metric_eval(m::SoftAbsRiemannianMetric{T}, θ) where {T} + H = m.calc_H(θ) + F = eigen(Symmetric(H)) + λ = F.values + Q = F.vectors + + # SoftAbs transformation: G = Q * diag(softabsλ) * Q' + softabsλ = λ .* coth.(m.α .* λ) + + # Compute J matrix for gradient chain rule + J = make_J(λ, m.α) + + # Precompute M_logdet = Q * (R .* J) * Q' where R = diag(1 ./ softabsλ) + # This is used for: ∂log|G|/∂θᵢ = 0.5 * tr(M_logdet * ∂H/∂θᵢ) + R = Diagonal(one(T) ./ softabsλ) + M_logdet = Q * (R .* J) * Q' + + return SoftAbsEval(Q, softabsλ, J, M_logdet) +end + +metric_sensitivity(m::SoftAbsRiemannianMetric, θ) = m.calc_∂H∂θ(θ) + +#### +#### Gradient matrices for unified computation +#### + +""" + logdet_grad_matrix(G) + +Return the matrix M such that ∂log|G|/∂θᵢ = 0.5 * tr(M * ∂P/∂θᵢ), where P is the +"pre-metric" (G itself for RiemannianMetric, or H the Hessian for SoftAbsRiemannianMetric). + +For dense matrices: M = G⁻¹ +For SoftAbsEval: M = Q * (R .* J) * Q' (precomputed in metric_eval) + +The J matrix in SoftAbsEval absorbs the chain rule through the softabs transformation, +so the same formula works with ∂H/∂θ instead of ∂G/∂θ. +""" +logdet_grad_matrix(G::SoftAbsEval) = G.M_logdet +logdet_grad_matrix(G::AbstractMatrix) = inv(G) + +""" + kinetic_grad_matrix(G, r) + +Return the matrix M such that ∂(r'G⁻¹r)/∂θᵢ = -tr(M * ∂P/∂θᵢ), where P is the +"pre-metric" (G itself for RiemannianMetric, or H the Hessian for SoftAbsRiemannianMetric). + +For dense matrices: M = (G⁻¹r)(G⁻¹r)' (rank-1 outer product) +For SoftAbsEval: M = Q * D * J * D * Q' where D = diag((Q'r) ./ softabsλ) + +For SoftAbsEval, the J matrix absorbs the chain rule through softabs, allowing +the gradient to be computed with respect to ∂H/∂θ rather than ∂G/∂θ. This avoids +recomputing J for each value of r during fixed-point iterations. +""" +function kinetic_grad_matrix(G::SoftAbsEval, r::AbstractVector) + # D = diag((Q'r) ./ softabsλ) + d = (G.Q' * r) ./ G.softabsλ + D = Diagonal(d) + return G.Q * D * G.J * D * G.Q' +end + +function kinetic_grad_matrix(G::AbstractMatrix, r::AbstractVector) + v = G \ r + return v * v' # Rank-1 outer product +end + +#### +#### Momentum sampling +#### + +function rand_momentum( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + metric::AbstractRiemannianMetric, + ::GaussianKinetic, + θ::AbstractVecOrMat, +) + G = metric_eval(metric, θ) + T = eltype(metric) === Any ? eltype(θ) : eltype(metric) + z = _randn(rng, T, size(metric)...) + return unwhiten(G, z) +end + +# unwhiten for regular matrices (PDMat or dense) +function unwhiten(G::AbstractMatrix, z::AbstractVector) + # G = L * L', so sample = L * z where L = chol(G).L + chol = cholesky(Symmetric(G)) + return chol.L * z +end + +# eltype for RiemannianMetric (needed for rand_momentum) +Base.eltype(::RiemannianMetric) = Any # Will use eltype(θ) as fallback + +#### +#### Deprecated types (for backward compatibility) +#### + abstract type AbstractHessianMap end struct IdentityMap <: AbstractHessianMap end - (::IdentityMap)(x) = x struct SoftAbsMap{T} <: AbstractHessianMap @@ -11,17 +267,20 @@ struct SoftAbsMap{T} <: AbstractHessianMap end function softabs(X, α=20.0) - F = eigen(X) # ReverseDiff cannot diff through `eigen` - Q = hcat(F.vectors) + F = eigen(Symmetric(X)) + Q = F.vectors λ = F.values softabsλ = λ .* coth.(α * λ) - return Q * diagm(softabsλ) * Q', Q, λ, softabsλ + return Q * Diagonal(softabsλ) * Q', Q, λ, softabsλ end (map::SoftAbsMap)(x) = softabs(x, map.α)[1] -# TODO Register softabs with ReverseDiff -#! The definition of SoftAbs from Page 3 of Betancourt (2012) +""" + DenseRiemannianMetric (deprecated) + +Use `RiemannianMetric` or `SoftAbsRiemannianMetric` instead. +""" struct DenseRiemannianMetric{ T, TM<:AbstractHessianMap, @@ -31,33 +290,51 @@ struct DenseRiemannianMetric{ T∂G∂θ, } <: AbstractRiemannianMetric size::A - G::TG # TODO store G⁻¹ here instead + G::TG ∂G∂θ::T∂G∂θ map::TM _temp::AV end -# TODO Make dense mass matrix support matrix-mode parallel function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) + Base.depwarn( + "DenseRiemannianMetric is deprecated. Use RiemannianMetric (for IdentityMap) or SoftAbsRiemannianMetric (for SoftAbsMap) instead.", + :DenseRiemannianMetric, + ) _temp = Vector{Float64}(undef, first(size)) return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) end Base.size(e::DenseRiemannianMetric) = e.size Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] +Base.eltype(::DenseRiemannianMetric{T}) where {T} = T + function Base.show(io::IO, drm::DenseRiemannianMetric) - return print(io, "DenseRiemannianMetric$(drm.size) with $(drm.map) metric") + return print( + io, + "DenseRiemannianMetric", + drm.size, + " with ", + nameof(typeof(drm.map)), + " (deprecated)", + ) end -function rand_momentum( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - metric::DenseRiemannianMetric{T}, - kinetic, - θ::AbstractVecOrMat, -) where {T} - r = _randn(rng, T, size(metric)...) - G⁻¹ = inv(metric.map(metric.G(θ))) - chol = cholesky(Symmetric(G⁻¹)) - ldiv!(chol.U, r) - return r +# metric_eval and metric_sensitivity for deprecated DenseRiemannianMetric +function metric_eval(m::DenseRiemannianMetric{T,<:IdentityMap}, θ) where {T} + return m.G(θ) end + +function metric_eval(m::DenseRiemannianMetric{T,<:SoftAbsMap}, θ) where {T} + H = m.G(θ) + F = eigen(Symmetric(H)) + λ = F.values + Q = F.vectors + softabsλ = λ .* coth.(m.map.α .* λ) + J = make_J(λ, m.map.α) + R = Diagonal(one(T) ./ softabsλ) + M_logdet = Q * (R .* J) * Q' + return SoftAbsEval(Q, softabsλ, J, M_logdet) +end + +metric_sensitivity(m::DenseRiemannianMetric, θ) = m.∂G∂θ(θ) diff --git a/test/riemannian.jl b/test/riemannian.jl index f11522151..22d331d92 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -1,28 +1,41 @@ using ReTest, Random using AdvancedHMC, ForwardDiff, AbstractMCMC using LinearAlgebra +using Distributions: MvNormal, logpdf using MCMCLogDensityProblems using FiniteDiff: finite_difference_gradient, finite_difference_hessian, finite_difference_jacobian -using AdvancedHMC: neg_energy, energy, ∂H∂θ, ∂H∂r +using AdvancedHMC: + neg_energy, + energy, + ∂H∂θ, + ∂H∂r, + metric_eval, + metric_sensitivity, + logdet_grad_matrix, + kinetic_grad_matrix, + SoftAbsEval, + RiemannianMetric, + SoftAbsRiemannianMetric + +#### +#### Test utilities +#### + +function gen_hess_fwd(func, x::AbstractVector) + function hess(x::AbstractVector) + return nothing, nothing, ForwardDiff.hessian(func, x) + end + return hess +end -# Fisher information metric function gen_∂G∂θ_fwd(Vfunc, x; f=identity) _Hfunc = gen_hess_fwd(Vfunc, x) Hfunc = x -> _Hfunc(x)[3] - # QUES What's the best output format of this function? cfg = ForwardDiff.JacobianConfig(Hfunc, x) d = length(x) out = zeros(eltype(x), d^2, d) return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) - return out # default output shape [∂H∂x₁; ∂H∂x₂; ...] -end - -function gen_hess_fwd(func, x::AbstractVector) - function hess(x::AbstractVector) - return nothing, nothing, ForwardDiff.hessian(func, x) - end - return hess end function reshape_∂G∂θ(H) @@ -32,8 +45,8 @@ end function prepare_sample(ℓπ, initial_θ, λ) Vfunc = x -> -ℓπ(x) - _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, initial_θ) # x -> (value, gradient, hessian) - Hfunc = x -> copy.(_Hfunc(x)) # _Hfunc do in-place computation, copy to avoid bug + _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, initial_θ) + Hfunc = x -> copy.(_Hfunc(x)) fstabilize = H -> H + λ * I Gfunc = x -> begin @@ -46,8 +59,113 @@ function prepare_sample(ℓπ, initial_θ, λ) return Vfunc, Hfunc, Gfunc, ∂G∂θfunc end -@testset "Constructors tests" begin - δ(a, b) = maximum(abs.(a - b)) +δ(a, b) = maximum(abs.(a - b)) + +#### +#### Tests for unified API (RiemannianMetric, SoftAbsRiemannianMetric) +#### + +@testset "New Riemannian API" begin + @testset "$(nameof(typeof(target)))" for target in [HighDimGaussian(2), Funnel()] + rng = MersenneTwister(1110) + λ = 1e-2 + + θ₀ = rand(rng, dim(target)) + + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀) + + _, _, Gfunc, ∂G∂θfunc = prepare_sample(ℓπ, θ₀, λ) + + D = dim(target) + x = zeros(D) + r = randn(rng, D) + + @testset "RiemannianMetric (PDMat-style)" begin + metric = RiemannianMetric((D,), Gfunc, ∂G∂θfunc) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + # Test metric_eval returns a matrix + G_eval = metric_eval(metric, x) + @test G_eval isa AbstractMatrix + @test size(G_eval) == (D, D) + + # Test metric_sensitivity + ∂G = metric_sensitivity(metric, x) + @test size(∂G) == (D, D, D) + + # Test gradient matrices + M_logdet = logdet_grad_matrix(G_eval) + @test size(M_logdet) == (D, D) + + M_kinetic = kinetic_grad_matrix(G_eval, r) + @test size(M_kinetic) == (D, D) + + # Test ∂H∂θ against finite differences + Hamifunc = (x, r) -> energy(hamiltonian, r, x) + energy(hamiltonian, x) + Hamifuncx = x -> Hamifunc(x, r) + @test δ( + finite_difference_gradient(Hamifuncx, x), ∂H∂θ(hamiltonian, x, r).gradient + ) < 1e-4 + + # Test ∂H∂r against finite differences + Hamifuncr = r -> Hamifunc(x, r) + @test δ(finite_difference_gradient(Hamifuncr, r), ∂H∂r(hamiltonian, x, r)) < + 1e-4 + end + + @testset "SoftAbsRiemannianMetric" begin + α = 20.0 + metric = SoftAbsRiemannianMetric((D,), Gfunc, ∂G∂θfunc, α) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + # Test metric_eval returns SoftAbsEval + G_eval = metric_eval(metric, x) + @test G_eval isa SoftAbsEval + @test size(G_eval.Q) == (D, D) + @test length(G_eval.softabsλ) == D + @test size(G_eval.J) == (D, D) + @test size(G_eval.M_logdet) == (D, D) + + # Test standard operations on SoftAbsEval + v = randn(rng, D) + @test length(G_eval \ v) == D + @test logdet(G_eval) isa Real + + # Test gradient matrices + M_logdet = logdet_grad_matrix(G_eval) + @test M_logdet === G_eval.M_logdet # Should be cached + + M_kinetic = kinetic_grad_matrix(G_eval, r) + @test size(M_kinetic) == (D, D) + + # Test kinetic energy matches MvNormal logpdf + G_matrix = G_eval.Q * Diagonal(G_eval.softabsλ) * G_eval.Q' + @test neg_energy(hamiltonian, r, x) ≈ + logpdf(MvNormal(zeros(D), Symmetric(G_matrix)), r) + + # Test ∂H∂θ against finite differences + Hamifunc = (x, r) -> energy(hamiltonian, r, x) + energy(hamiltonian, x) + Hamifuncx = x -> Hamifunc(x, r) + @test δ( + finite_difference_gradient(Hamifuncx, x), ∂H∂θ(hamiltonian, x, r).gradient + ) < 1e-4 + + # Test ∂H∂r against finite differences + Hamifuncr = r -> Hamifunc(x, r) + @test δ(finite_difference_gradient(Hamifuncr, r), ∂H∂r(hamiltonian, x, r)) < + 1e-4 + end + end +end + +#### +#### Tests for deprecated API (DenseRiemannianMetric) +#### + +@testset "Deprecated DenseRiemannianMetric (backward compatibility)" begin @testset "$(nameof(typeof(target)))" for target in [HighDimGaussian(2), Funnel()] rng = MersenneTwister(1110) λ = 1e-2 @@ -59,25 +177,24 @@ end Vfunc, Hfunc, Gfunc, ∂G∂θfunc = prepare_sample(ℓπ, θ₀, λ) - D = dim(target) # ==2 for this test - x = zeros(D) # randn(rng, D) + D = dim(target) + x = zeros(D) r = randn(rng, D) - @testset "Autodiff" begin + @testset "Autodiff utilities" begin @test δ(finite_difference_gradient(ℓπ, x), ∂ℓπ∂θ(x)[end]) < 1e-4 @test δ(finite_difference_hessian(Vfunc, x), Hfunc(x)[end]) < 1e-4 - # finite_difference_jacobian returns shape of (4, 2), reshape_∂G∂θ turns it into (2, 2, 2) @test δ(reshape_∂G∂θ(finite_difference_jacobian(Gfunc, x)), ∂G∂θfunc(x)) < 1e-4 end @testset "$(nameof(typeof(hessmap)))" for hessmap in [IdentityMap(), SoftAbsMap(20.0)] - metric = DenseRiemannianMetric((D,), Gfunc, ∂G∂θfunc, hessmap) + # Suppress deprecation warning + metric = @test_deprecated DenseRiemannianMetric((D,), Gfunc, ∂G∂θfunc, hessmap) kinetic = GaussianKinetic() hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) - if hessmap isa SoftAbsMap || # only test kinetic energy for SoftAbsMap as that of IdentityMap can be non-PD - all(iszero, x) # or for x==0 that I know it's PD + if hessmap isa SoftAbsMap || all(iszero, x) @testset "Kinetic energy" begin Σ = hamiltonian.metric.map(hamiltonian.metric.G(x)) @test neg_energy(hamiltonian, r, x) ≈ logpdf(MvNormal(zeros(D), Σ), r) @@ -103,61 +220,124 @@ end end end -@testset "Multi variate Normal with Riemannian HMC" begin - # Set the number of samples to draw and warmup iterations - n_samples = 2_000 +#### +#### Integration tests with sampling +#### + +@testset "Sampling with unified RiemannianMetric" begin + n_samples = 100 rng = MersenneTwister(1110) initial_θ = rand(rng, D) λ = 1e-2 _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) - # Define a Hamiltonian system - metric = DenseRiemannianMetric((D,), G, ∂G∂θ) + + metric = RiemannianMetric((D,), G, ∂G∂θ) kinetic = GaussianKinetic() hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) - # Define a leapfrog solver, with the initial step size chosen heuristically initial_ϵ = 0.01 integrator = GeneralizedLeapfrog(initial_ϵ, 6) + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) + + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=false) + @test length(samples) == n_samples + @test length(stats) == n_samples +end + +@testset "Sampling with SoftAbsRiemannianMetric" begin + n_samples = 100 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + λ = 1e-2 + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) - # Define an HMC sampler with the following components - # - multinomial sampling scheme, - # - generalised No-U-Turn criteria, and + metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 20.0) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) - # Run the sampler to draw samples from the specified Gaussian, where - # - `samples` will store the samples - # - `stats` will store diagnostic statistics for each sample - samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=false) @test length(samples) == n_samples @test length(stats) == n_samples end -@testset "Multi variate Normal with Riemannian HMC softabs metric" begin - # Set the number of samples to draw and warmup iterations - n_samples = 2_000 +@testset "Sampling with deprecated DenseRiemannianMetric (IdentityMap)" begin + n_samples = 100 rng = MersenneTwister(1110) initial_θ = rand(rng, D) λ = 1e-2 _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) - # Define a Hamiltonian system - metric = DenseRiemannianMetric((D,), G, ∂G∂θ, SoftAbsMap(20.0)) + metric = @test_deprecated DenseRiemannianMetric((D,), G, ∂G∂θ) kinetic = GaussianKinetic() hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) - # Define a leapfrog solver, with the initial step size chosen heuristically initial_ϵ = 0.01 integrator = GeneralizedLeapfrog(initial_ϵ, 6) + kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) - # Define an HMC sampler with the following components - # - multinomial sampling scheme, - # - generalised No-U-Turn criteria, and + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=false) + @test length(samples) == n_samples + @test length(stats) == n_samples +end + +@testset "Sampling with deprecated DenseRiemannianMetric (SoftAbsMap)" begin + n_samples = 100 + rng = MersenneTwister(1110) + initial_θ = rand(rng, D) + λ = 1e-2 + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + + metric = @test_deprecated DenseRiemannianMetric((D,), G, ∂G∂θ, SoftAbsMap(20.0)) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 6) kernel = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(8))) - # Run the sampler to draw samples from the specified Gaussian, where - # - `samples` will store the samples - # - `stats` will store diagnostic statistics for each sample - samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=true) + samples, stats = sample(rng, hamiltonian, kernel, initial_θ, n_samples; progress=false) @test length(samples) == n_samples @test length(stats) == n_samples end + +#### +#### Energy conservation tests +#### + +@testset "Energy conservation" begin + rng = MersenneTwister(42) + D_test = 2 + target = HighDimGaussian(D_test) + θ₀ = rand(rng, D_test) + λ = 1e-2 + + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, θ₀) + _, _, G, ∂G∂θ = prepare_sample(ℓπ, θ₀, λ) + + @testset "SoftAbsRiemannianMetric energy conservation" begin + metric = SoftAbsRiemannianMetric((D_test,), G, ∂G∂θ, 20.0) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + # Small step size for better energy conservation + integrator = GeneralizedLeapfrog(0.001, 10) + + # Create initial phase point + θ_init = zeros(D_test) + r_init = randn(rng, D_test) + z0 = AdvancedHMC.phasepoint(hamiltonian, θ_init, r_init) + H0 = -AdvancedHMC.neg_energy(z0) + + # Take 10 leapfrog steps + z1 = AdvancedHMC.step(integrator, hamiltonian, z0, 10) + H1 = -AdvancedHMC.neg_energy(z1) + + # Energy should be approximately conserved + @test abs(H1 - H0) < 0.1 + end +end From 840cd2d9dfd725a459e4da6efef4555f704884a3 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Tue, 13 Jan 2026 20:14:32 +0800 Subject: [PATCH 07/19] Fix for compilation --- src/riemannian/hamiltonian.jl | 20 -------------------- src/riemannian/metric.jl | 2 ++ 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/riemannian/hamiltonian.jl b/src/riemannian/hamiltonian.jl index 3422fc2a7..52b7dc632 100644 --- a/src/riemannian/hamiltonian.jl +++ b/src/riemannian/hamiltonian.jl @@ -154,16 +154,6 @@ function ∂H∂r( return G \ r end -# Non-keyword version for backward compatibility with integrator -function ∂H∂r( - h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, - θ::AbstractVector, - r::AbstractVector, -) - G_eval = metric_eval(h.metric, θ) - return G_eval \ r -end - #### #### Negative energy (log probability) #### @@ -200,16 +190,6 @@ function neg_energy( return -logZ - quadform / 2 end -# Non-keyword version for backward compatibility -function neg_energy( - h::Hamiltonian{<:AbstractRiemannianMetric,<:GaussianKinetic}, - r::AbstractVector, - θ::AbstractVector, -) - G_eval = metric_eval(h.metric, θ) - return neg_energy(h, r, θ; G_eval=G_eval) -end - #### #### Phase point construction #### diff --git a/src/riemannian/metric.jl b/src/riemannian/metric.jl index da79c453c..73c6af776 100644 --- a/src/riemannian/metric.jl +++ b/src/riemannian/metric.jl @@ -1,3 +1,5 @@ +import LinearAlgebra + #### #### Riemannian Metric Types #### From cbba891883047b302cd1abb2cbac527620f757e1 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Tue, 13 Jan 2026 21:55:49 +0800 Subject: [PATCH 08/19] Fix dHdr to allow Generalised NUTS --- src/hamiltonian.jl | 4 ++-- src/trajectory.jl | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index c782e1a24..24f744ca6 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -101,7 +101,7 @@ function Base.similar(z::PhasePoint{<:AbstractVecOrMat{T}}) where {T<:AbstractFl end function phasepoint( - h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r)) + h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)) ) where {T<:AbstractVecOrMat} return PhasePoint(θ, r, ℓπ, ℓκ) end @@ -115,7 +115,7 @@ function phasepoint( _r::T2; r=safe_rsimilar(θ, _r), ℓπ=∂H∂θ(h, θ), - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r)), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), ) where {T1<:AbstractVecOrMat,T2<:AbstractVecOrMat} return PhasePoint(θ, r, ℓπ, ℓκ) end diff --git a/src/trajectory.jl b/src/trajectory.jl index 2e3c1d550..52fdd9fa6 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -552,7 +552,7 @@ function isterminated(::ClassicNoUTurn, h::Hamiltonian, t::BinaryTree) # z0 is starting point and z1 is ending point z0, z1 = t.zleft, t.zright Δθ = z1.θ - z0.θ - s = (dot(Δθ, ∂H∂r(h, -z0.r)) >= 0) || (dot(-Δθ, ∂H∂r(h, z1.r)) >= 0) + s = (dot(Δθ, ∂H∂r(h, z0.θ, -z0.r)) >= 0) || (dot(-Δθ, ∂H∂r(h, z1.θ, z1.r)) >= 0) return Termination(s, false) end @@ -565,7 +565,7 @@ Ref: https://arxiv.org/abs/1701.02434 """ function isterminated(::GeneralisedNoUTurn, h::Hamiltonian, t::BinaryTree) rho = t.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.r), ∂H∂r(h, t.zright.r)) + s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, t.zright.θ, t.zright.r)) return Termination(s, false) end @@ -595,7 +595,7 @@ phase point of `tright`, the right subtree. """ function check_left_subtree(h::Hamiltonian, t::T, tleft::T, tright::T) where {T<:BinaryTree} rho = tleft.ts.rho + tright.zleft.r - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.r), ∂H∂r(h, tright.zleft.r)) + s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, tright.zleft.θ, tright.zleft.r)) return Termination(s, false) end @@ -608,7 +608,7 @@ function check_right_subtree( h::Hamiltonian, t::T, tleft::T, tright::T ) where {T<:BinaryTree} rho = tleft.zright.r + tright.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, tleft.zright.r), ∂H∂r(h, t.zright.r)) + s = generalised_uturn_criterion(rho, ∂H∂r(h, tleft.zright.θ, tleft.zright.r), ∂H∂r(h, t.zright.θ, t.zright.r)) return Termination(s, false) end From 0c9e825b40c0acf24bb43787e34e2c8de3a970c9 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Tue, 13 Jan 2026 22:07:43 +0800 Subject: [PATCH 09/19] Add basic validity test for RHMC --- test/riemannian.jl | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/test/riemannian.jl b/test/riemannian.jl index 22d331d92..c6dad39d6 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -17,6 +17,7 @@ using AdvancedHMC: SoftAbsEval, RiemannianMetric, SoftAbsRiemannianMetric +using Statistics #### #### Test utilities @@ -341,3 +342,69 @@ end @test abs(H1 - H0) < 0.1 end end + +#### +#### Validation tests +#### + +@testset "Validation testing" begin + target = HighDimGaussian(2) + rng = MersenneTwister(125) + λ = 1e-2 + + initial_θ = rand(rng, dim(target)) + + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, initial_θ) + + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + + D = dim(target) + x = zeros(D) + r = randn(rng, D) + + n_samples = 100 + n_adapts = 50 + + mean_tol = 3 / sqrt(n_samples) + var_tol = 1.5 * sqrt(2 / (n_samples - 1)) + + @testset "RiemannianMetric (PDMat-style)" begin + metric = RiemannianMetric((D,), G, ∂G∂θ) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 12) + kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + + acceptance_rate = 0.7 + adaptor = StepSizeAdaptor(acceptance_rate, integrator) + + samples, stats = sample( + rng, hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=false + ) + @test mean(samples) ≈ zeros(D) atol = mean_tol + @test Statistics.var(samples) ≈ ones(D) atol = var_tol + end + + @testset "SoftAbsRiemannianMetric" begin + # We do not need SoftAbs for Gaussian target, so using small α + metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 1.0) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 12) + kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + + acceptance_rate = 0.7 + adaptor = StepSizeAdaptor(acceptance_rate, integrator) + + samples, stats = sample( + rng, hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=false + ) + @test mean(samples) ≈ zeros(D) atol = mean_tol + @test Statistics.var(samples) ≈ ones(D) atol = var_tol + end +end From 909d72aa11f0189e9df84377ab84e58d270fb682 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Wed, 14 Jan 2026 15:40:05 +0800 Subject: [PATCH 10/19] Format --- src/hamiltonian.jl | 6 +++++- src/trajectory.jl | 15 +++++++++++---- test/riemannian.jl | 18 ++++++++++++++++-- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index 24f744ca6..ece931c44 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -101,7 +101,11 @@ function Base.similar(z::PhasePoint{<:AbstractVecOrMat{T}}) where {T<:AbstractFl end function phasepoint( - h::Hamiltonian, θ::T, r::T; ℓπ=∂H∂θ(h, θ), ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)) + h::Hamiltonian, + θ::T, + r::T; + ℓπ=∂H∂θ(h, θ), + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), ) where {T<:AbstractVecOrMat} return PhasePoint(θ, r, ℓπ, ℓκ) end diff --git a/src/trajectory.jl b/src/trajectory.jl index 52fdd9fa6..8ef4700b7 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -141,8 +141,9 @@ $(TYPEDEF) Slice sampler for the starting single leaf tree. Slice variable is initialized. """ -SliceTS(rng::AbstractRNG, z0::PhasePoint) = +function SliceTS(rng::AbstractRNG, z0::PhasePoint) SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) +end """ $(TYPEDEF) @@ -565,7 +566,9 @@ Ref: https://arxiv.org/abs/1701.02434 """ function isterminated(::GeneralisedNoUTurn, h::Hamiltonian, t::BinaryTree) rho = t.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, t.zright.θ, t.zright.r)) + s = generalised_uturn_criterion( + rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, t.zright.θ, t.zright.r) + ) return Termination(s, false) end @@ -595,7 +598,9 @@ phase point of `tright`, the right subtree. """ function check_left_subtree(h::Hamiltonian, t::T, tleft::T, tright::T) where {T<:BinaryTree} rho = tleft.ts.rho + tright.zleft.r - s = generalised_uturn_criterion(rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, tright.zleft.θ, tright.zleft.r)) + s = generalised_uturn_criterion( + rho, ∂H∂r(h, t.zleft.θ, t.zleft.r), ∂H∂r(h, tright.zleft.θ, tright.zleft.r) + ) return Termination(s, false) end @@ -608,7 +613,9 @@ function check_right_subtree( h::Hamiltonian, t::T, tleft::T, tright::T ) where {T<:BinaryTree} rho = tleft.zright.r + tright.ts.rho - s = generalised_uturn_criterion(rho, ∂H∂r(h, tleft.zright.θ, tleft.zright.r), ∂H∂r(h, t.zright.θ, t.zright.r)) + s = generalised_uturn_criterion( + rho, ∂H∂r(h, tleft.zright.θ, tleft.zright.r), ∂H∂r(h, t.zright.θ, t.zright.r) + ) return Termination(s, false) end diff --git a/test/riemannian.jl b/test/riemannian.jl index c6dad39d6..690111269 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -382,7 +382,14 @@ end adaptor = StepSizeAdaptor(acceptance_rate, integrator) samples, stats = sample( - rng, hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=false + rng, + hamiltonian, + kernel, + initial_θ, + n_samples, + adaptor, + n_adapts; + progress=false, ) @test mean(samples) ≈ zeros(D) atol = mean_tol @test Statistics.var(samples) ≈ ones(D) atol = var_tol @@ -402,7 +409,14 @@ end adaptor = StepSizeAdaptor(acceptance_rate, integrator) samples, stats = sample( - rng, hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=false + rng, + hamiltonian, + kernel, + initial_θ, + n_samples, + adaptor, + n_adapts; + progress=false, ) @test mean(samples) ≈ zeros(D) atol = mean_tol @test Statistics.var(samples) ≈ ones(D) atol = var_tol From 642585f20ef84d990401b6a90d925195c632825c Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Sat, 17 Jan 2026 14:32:00 +0000 Subject: [PATCH 11/19] Add funnel test and fix gaussian validation test --- test/riemannian.jl | 79 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 7 deletions(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index 690111269..117b3df16 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -347,7 +347,7 @@ end #### Validation tests #### -@testset "Validation testing" begin +@testset "Validation testing (Gaussian)" begin target = HighDimGaussian(2) rng = MersenneTwister(125) λ = 1e-2 @@ -360,8 +360,6 @@ end _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) D = dim(target) - x = zeros(D) - r = randn(rng, D) n_samples = 100 n_adapts = 50 @@ -375,10 +373,10 @@ end hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) initial_ϵ = 0.01 - integrator = GeneralizedLeapfrog(initial_ϵ, 12) + integrator = GeneralizedLeapfrog(initial_ϵ, 15) kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) - acceptance_rate = 0.7 + acceptance_rate = 0.9 adaptor = StepSizeAdaptor(acceptance_rate, integrator) samples, stats = sample( @@ -402,10 +400,10 @@ end hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) initial_ϵ = 0.01 - integrator = GeneralizedLeapfrog(initial_ϵ, 12) + integrator = GeneralizedLeapfrog(initial_ϵ, 15) kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) - acceptance_rate = 0.7 + acceptance_rate = 0.9 adaptor = StepSizeAdaptor(acceptance_rate, integrator) samples, stats = sample( @@ -422,3 +420,70 @@ end @test Statistics.var(samples) ≈ ones(D) atol = var_tol end end + +@testset "Validation testing (Funnel)" begin + + # 1D Wasserstein-1 distance + function w1(a::AbstractVector, b::AbstractVector) + sa = sort(a) + sb = sort(b) + return mean(abs.(sa .- sb)) + end + + target = Funnel() + rng = MersenneTwister(234) + λ = 1e-2 + + initial_θ = rand(rng, dim(target)) + + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, initial_θ) + + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + + D = dim(target) + + n_samples = 1000 + n_adapts = 500 + + # True samples + v_true = 3 .* randn(rng, n_samples) + X_true = Matrix{Float64}(undef, n_samples, 1) + for n in 1:n_samples + s = exp(v_true[n] / 2) + @inbounds X_true[n, :] .= s .* randn(rng, 1) + end + + tol_1 = 10 / sqrt(n_samples) + tol_2 = 30 / sqrt(n_samples) + + @testset "SoftAbsRiemannianMetric" begin + metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 20.0) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 15) + kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + + acceptance_rate = 0.9 + adaptor = StepSizeAdaptor(acceptance_rate, integrator) + + samples, stats = sample( + rng, + hamiltonian, + kernel, + initial_θ, + n_samples, + adaptor, + n_adapts; + progress=false, + ) + + θ = reduce(vcat, (permutedims(s) for s in samples)) + # 1st marginal + @test w1(θ[:, 1], v_true) < tol_1 + # 2nd marginal + @test w1(θ[:, 2], X_true[:, 1]) < tol_2 + end +end \ No newline at end of file From effb1154da27c4a1efab97ecce8ff453c732fe13 Mon Sep 17 00:00:00 2001 From: Jamie Price <49832778+J-Price-3@users.noreply.github.com> Date: Sat, 17 Jan 2026 16:25:46 +0000 Subject: [PATCH 12/19] Format test/riemannian.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/riemannian.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index 117b3df16..fac311725 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -486,4 +486,4 @@ end # 2nd marginal @test w1(θ[:, 2], X_true[:, 1]) < tol_2 end -end \ No newline at end of file +end From cb43e2a82dd1eb304038f602895f69822d208cab Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Mon, 19 Jan 2026 11:07:10 +0000 Subject: [PATCH 13/19] Update validation tests to both use w1 distance and a more logical tolerance --- test/riemannian.jl | 294 ++++++++++++++++++++++++++++----------------- 1 file changed, 181 insertions(+), 113 deletions(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index fac311725..c4d394e65 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -347,143 +347,211 @@ end #### Validation tests #### -@testset "Validation testing (Gaussian)" begin - target = HighDimGaussian(2) - rng = MersenneTwister(125) - λ = 1e-2 +@testset "Validation testing" begin + # 1D Wasserstein-1 distance + function w1(a::AbstractVector, b::AbstractVector) + sa = sort(a) + sb = sort(b) + return mean(abs.(sa .- sb)) + end - initial_θ = rand(rng, dim(target)) + @testset "Validation testing (Gaussian)" begin - ℓπ = MCMCLogDensityProblems.gen_logpdf(target) - ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, initial_θ) + # 1D normal Wasserstein-1 distance tolerance estimator + function w1_tol_normal_1d(; + n::Int, reps::Int=200, q::Float64=0.999, rng=Random.default_rng() + ) + poolN = max(50_000, 50n) + pool = randn(rng, poolN) + + vals = Vector{Float64}(undef, reps) + for i in 1:reps + a = pool[rand(rng, 1:poolN, n)] + b = pool[rand(rng, 1:poolN, n)] + vals[i] = w1(a, b) + end + sort!(vals) + return vals[clamp(ceil(Int, q * reps), 1, reps)] + end - _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + target = HighDimGaussian(2) + rng = MersenneTwister(125) + λ = 1e-2 - D = dim(target) + initial_θ = rand(rng, dim(target)) - n_samples = 100 - n_adapts = 50 + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, initial_θ) - mean_tol = 3 / sqrt(n_samples) - var_tol = 1.5 * sqrt(2 / (n_samples - 1)) + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) - @testset "RiemannianMetric (PDMat-style)" begin - metric = RiemannianMetric((D,), G, ∂G∂θ) - kinetic = GaussianKinetic() - hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + D = dim(target) - initial_ϵ = 0.01 - integrator = GeneralizedLeapfrog(initial_ϵ, 15) - kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) - - acceptance_rate = 0.9 - adaptor = StepSizeAdaptor(acceptance_rate, integrator) - - samples, stats = sample( - rng, - hamiltonian, - kernel, - initial_θ, - n_samples, - adaptor, - n_adapts; - progress=false, - ) - @test mean(samples) ≈ zeros(D) atol = mean_tol - @test Statistics.var(samples) ≈ ones(D) atol = var_tol + n_samples = 100 + n_adapts = 50 + + tol_w1 = w1_tol_normal_1d(; n=n_samples, rng=rng) + + # Samples are RHMC so we relax the tolerance + tol_w1 *= 2.0 + + x_true = randn(rng, n_samples) + y_true = randn(rng, n_samples) + + @testset "RiemannianMetric (PDMat-style)" begin + metric = RiemannianMetric((D,), G, ∂G∂θ) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 15) + kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + + acceptance_rate = 0.9 + adaptor = StepSizeAdaptor(acceptance_rate, integrator) + + samples, stats = sample( + rng, + hamiltonian, + kernel, + initial_θ, + n_samples, + adaptor, + n_adapts; + progress=false, + ) + θ = reduce(vcat, (permutedims(s) for s in samples)) + # 1st marginal + @test w1(θ[:, 1], x_true) < tol_w1 + # 2nd marginal + @test w1(θ[:, 2], y_true) < tol_w1 + end + + @testset "SoftAbsRiemannianMetric" begin + # We do not need SoftAbs for Gaussian target, so using small α + metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 1.0) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 15) + kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + + acceptance_rate = 0.9 + adaptor = StepSizeAdaptor(acceptance_rate, integrator) + + samples, stats = sample( + rng, + hamiltonian, + kernel, + initial_θ, + n_samples, + adaptor, + n_adapts; + progress=false, + ) + + θ = reduce(vcat, (permutedims(s) for s in samples)) + # 1st marginal + @test w1(θ[:, 1], x_true) < tol_w1 + # 2nd marginal + @test w1(θ[:, 2], y_true) < tol_w1 + end end - @testset "SoftAbsRiemannianMetric" begin - # We do not need SoftAbs for Gaussian target, so using small α - metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 1.0) - kinetic = GaussianKinetic() - hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + @testset "Validation testing (Funnel)" begin + + # Funnel i.i.d. sampler + # θ layout: [v, x1] + function funnel_iid(rng::AbstractRNG, n::Int) + v = 3.0 .* randn(rng, n) + X = Matrix{Float64}(undef, n, 1) + for i in 1:n + s = exp(v[i] / 2) + @inbounds X[i, :] .= s .* randn(rng, 1) + end + return v, X + end - initial_ϵ = 0.01 - integrator = GeneralizedLeapfrog(initial_ϵ, 15) - kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) - - acceptance_rate = 0.9 - adaptor = StepSizeAdaptor(acceptance_rate, integrator) - - samples, stats = sample( - rng, - hamiltonian, - kernel, - initial_θ, - n_samples, - adaptor, - n_adapts; - progress=false, + # 1D Wasserstein-1 distance tolerances for Funnel marginals + function funnel_w1_tols(; + n::Int, + reps::Int=200, + q::Float64=0.999, + inflate::Float64=2.0, + rng::AbstractRNG=Random.default_rng(), ) - @test mean(samples) ≈ zeros(D) atol = mean_tol - @test Statistics.var(samples) ≈ ones(D) atol = var_tol - end -end + vals_v = Vector{Float64}(undef, reps) + vals_x1 = Vector{Float64}(undef, reps) -@testset "Validation testing (Funnel)" begin + for i in 1:reps + vA, XA = funnel_iid(rng, n) + vB, XB = funnel_iid(rng, n) - # 1D Wasserstein-1 distance - function w1(a::AbstractVector, b::AbstractVector) - sa = sort(a) - sb = sort(b) - return mean(abs.(sa .- sb)) - end + x1A = XA[:, 1] + x1B = XB[:, 1] - target = Funnel() - rng = MersenneTwister(234) - λ = 1e-2 + vals_v[i] = w1(vA, vB) + vals_x1[i] = w1(x1A, x1B) + end - initial_θ = rand(rng, dim(target)) + sort!(vals_v) + sort!(vals_x1) + idx = clamp(ceil(Int, q * reps), 1, reps) - ℓπ = MCMCLogDensityProblems.gen_logpdf(target) - ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, initial_θ) + return (tol_v=inflate * vals_v[idx], tol_x1=inflate * vals_x1[idx]) + end - _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) + target = Funnel() + rng = MersenneTwister(234) + λ = 1e-2 - D = dim(target) + initial_θ = rand(rng, dim(target)) - n_samples = 1000 - n_adapts = 500 + ℓπ = MCMCLogDensityProblems.gen_logpdf(target) + ∂ℓπ∂θ = MCMCLogDensityProblems.gen_logpdf_grad(target, initial_θ) - # True samples - v_true = 3 .* randn(rng, n_samples) - X_true = Matrix{Float64}(undef, n_samples, 1) - for n in 1:n_samples - s = exp(v_true[n] / 2) - @inbounds X_true[n, :] .= s .* randn(rng, 1) - end + _, _, G, ∂G∂θ = prepare_sample(ℓπ, initial_θ, λ) - tol_1 = 10 / sqrt(n_samples) - tol_2 = 30 / sqrt(n_samples) + D = dim(target) - @testset "SoftAbsRiemannianMetric" begin - metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 20.0) - kinetic = GaussianKinetic() - hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) + n_samples = 1000 + n_adapts = 500 - initial_ϵ = 0.01 - integrator = GeneralizedLeapfrog(initial_ϵ, 15) - kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) - - acceptance_rate = 0.9 - adaptor = StepSizeAdaptor(acceptance_rate, integrator) - - samples, stats = sample( - rng, - hamiltonian, - kernel, - initial_θ, - n_samples, - adaptor, - n_adapts; - progress=false, - ) + # True samples + v_true, X_true = funnel_iid(rng, n_samples) + + # Wasserstein-1 distance tolerances + tols = funnel_w1_tols(; n=n_samples, rng=rng) + + @testset "SoftAbsRiemannianMetric" begin + metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 20.0) + kinetic = GaussianKinetic() + hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) - θ = reduce(vcat, (permutedims(s) for s in samples)) - # 1st marginal - @test w1(θ[:, 1], v_true) < tol_1 - # 2nd marginal - @test w1(θ[:, 2], X_true[:, 1]) < tol_2 + initial_ϵ = 0.01 + integrator = GeneralizedLeapfrog(initial_ϵ, 15) + kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) + + acceptance_rate = 0.9 + adaptor = StepSizeAdaptor(acceptance_rate, integrator) + + samples, stats = sample( + rng, + hamiltonian, + kernel, + initial_θ, + n_samples, + adaptor, + n_adapts; + progress=false, + ) + + θ = reduce(vcat, (permutedims(s) for s in samples)) + # 1st marginal + @test w1(θ[:, 1], v_true) < tols.tol_v + # 2nd marginal + @test w1(θ[:, 2], X_true[:, 1]) < tols.tol_x1 + end end end From 3680b871c1b07eda195e10223ba0c1f6e777f7db Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Mon, 19 Jan 2026 14:21:36 +0000 Subject: [PATCH 14/19] Reduce validation test tolerance --- test/riemannian.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index c4d394e65..6fa498609 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -392,9 +392,6 @@ end tol_w1 = w1_tol_normal_1d(; n=n_samples, rng=rng) - # Samples are RHMC so we relax the tolerance - tol_w1 *= 2.0 - x_true = randn(rng, n_samples) y_true = randn(rng, n_samples) @@ -478,7 +475,7 @@ end n::Int, reps::Int=200, q::Float64=0.999, - inflate::Float64=2.0, + inflate::Float64=1.0, rng::AbstractRNG=Random.default_rng(), ) vals_v = Vector{Float64}(undef, reps) From 7e914955bdc482b139ba523fa2dbf58e3c7985a7 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Mon, 19 Jan 2026 15:15:10 +0000 Subject: [PATCH 15/19] Increase validation test tolerance slightly --- test/riemannian.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index 6fa498609..8adce58c5 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -392,6 +392,8 @@ end tol_w1 = w1_tol_normal_1d(; n=n_samples, rng=rng) + tol_w1 *= 1.5 + x_true = randn(rng, n_samples) y_true = randn(rng, n_samples) @@ -475,7 +477,7 @@ end n::Int, reps::Int=200, q::Float64=0.999, - inflate::Float64=1.0, + inflate::Float64=1.5, rng::AbstractRNG=Random.default_rng(), ) vals_v = Vector{Float64}(undef, reps) From af46f2ea409ed916089a71fd82b29281c427cd15 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Wed, 21 Jan 2026 12:30:47 +0000 Subject: [PATCH 16/19] Prevent test type instability --- test/riemannian.jl | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index 8adce58c5..ae76d3f83 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -24,19 +24,32 @@ using Statistics #### function gen_hess_fwd(func, x::AbstractVector) + cfg = ForwardDiff.HessianConfig(func, x) + H = Matrix{eltype(x)}(undef, length(x), length(x)) + function hess(x::AbstractVector) - return nothing, nothing, ForwardDiff.hessian(func, x) + ForwardDiff.hessian!(H, func, x, cfg) + return H end return hess end function gen_∂G∂θ_fwd(Vfunc, x; f=identity) - _Hfunc = gen_hess_fwd(Vfunc, x) - Hfunc = x -> _Hfunc(x)[3] - cfg = ForwardDiff.JacobianConfig(Hfunc, x) + chunk = ForwardDiff.Chunk(x) + tag = ForwardDiff.Tag(Vfunc, eltype(x)) + jac_cfg = ForwardDiff.JacobianConfig(Vfunc, x, chunk, tag) + hess_cfg = ForwardDiff.HessianConfig(Vfunc, jac_cfg.duals, chunk, tag) + d = length(x) out = zeros(eltype(x), d^2, d) - return x -> ForwardDiff.jacobian!(out, Hfunc, x, cfg) + + function ∂G∂θ_fwd(y) + hess = z -> ForwardDiff.hessian(Vfunc, z, hess_cfg, Val{false}()) + ForwardDiff.jacobian!(out, hess, y, jac_cfg, Val{false}()) + return out + end + + return ∂G∂θ_fwd end function reshape_∂G∂θ(H) @@ -46,12 +59,12 @@ end function prepare_sample(ℓπ, initial_θ, λ) Vfunc = x -> -ℓπ(x) - _Hfunc = MCMCLogDensityProblems.gen_hess(Vfunc, initial_θ) + _Hfunc = gen_hess_fwd(Vfunc, initial_θ) Hfunc = x -> copy.(_Hfunc(x)) fstabilize = H -> H + λ * I Gfunc = x -> begin - H = fstabilize(Hfunc(x)[3]) + H = fstabilize(Hfunc(x)) all(isfinite, H) ? H : diagm(ones(length(x))) end _∂G∂θfunc = gen_∂G∂θ_fwd(x -> -ℓπ(x), initial_θ; f=fstabilize) From 41267f124e1db7285f902b122400a31dee0dcf47 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Wed, 21 Jan 2026 12:58:32 +0000 Subject: [PATCH 17/19] Fix tests --- test/riemannian.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index ae76d3f83..f5ffe5213 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -197,7 +197,7 @@ end @testset "Autodiff utilities" begin @test δ(finite_difference_gradient(ℓπ, x), ∂ℓπ∂θ(x)[end]) < 1e-4 - @test δ(finite_difference_hessian(Vfunc, x), Hfunc(x)[end]) < 1e-4 + @test δ(finite_difference_hessian(Vfunc, x), Hfunc(x)) < 1e-4 @test δ(reshape_∂G∂θ(finite_difference_jacobian(Gfunc, x)), ∂G∂θfunc(x)) < 1e-4 end From c9d301679ea9da6ac8ba91c716027ce2a1195fdc Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Wed, 21 Jan 2026 16:48:15 +0000 Subject: [PATCH 18/19] Fix flaky test --- test/riemannian.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index f5ffe5213..a76506b7f 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -545,7 +545,7 @@ end integrator = GeneralizedLeapfrog(initial_ϵ, 15) kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) - acceptance_rate = 0.9 + acceptance_rate = 0.7 adaptor = StepSizeAdaptor(acceptance_rate, integrator) samples, stats = sample( From 6494b34a4e2a7bca51c0cf6955b57818ef0a4562 Mon Sep 17 00:00:00 2001 From: Jamie Price Date: Thu, 22 Jan 2026 11:12:56 +0000 Subject: [PATCH 19/19] Fix flaky test (I promise it works this time) --- test/riemannian.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/riemannian.jl b/test/riemannian.jl index a76506b7f..15edbb925 100644 --- a/test/riemannian.jl +++ b/test/riemannian.jl @@ -490,7 +490,7 @@ end n::Int, reps::Int=200, q::Float64=0.999, - inflate::Float64=1.5, + inflate::Float64=1.8, rng::AbstractRNG=Random.default_rng(), ) vals_v = Vector{Float64}(undef, reps) @@ -537,12 +537,12 @@ end tols = funnel_w1_tols(; n=n_samples, rng=rng) @testset "SoftAbsRiemannianMetric" begin - metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 20.0) + metric = SoftAbsRiemannianMetric((D,), G, ∂G∂θ, 40.0) kinetic = GaussianKinetic() hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ) initial_ϵ = 0.01 - integrator = GeneralizedLeapfrog(initial_ϵ, 15) + integrator = GeneralizedLeapfrog(initial_ϵ, 5) kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) acceptance_rate = 0.7