Skip to content

Commit 69a1419

Browse files
authored
Merge branch 'main' into remove-logdensityfunctionwrapper-v2
2 parents 35fe956 + 2748445 commit 69a1419

5 files changed

Lines changed: 73 additions & 20 deletions

File tree

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ Removed `getargnames`, `getmissings`, and `Base.nameof(::Model)` from the public
55
Removed `LogDensityFunctionWrapper` and `VarInfo(::MarginalLogDensity, ...)` from the MarginalLogDensities extension.
66
Users should now use `DynamicPPL.InitFromVector(mld, ...)` to obtain an initialisation strategy and pass it to `init!!` to get a consistent `VarInfo`.
77

8+
# 0.40.18
9+
10+
Added a check on `unflatten!!` to error if the input vector was too long.
11+
812
# 0.40.17
913

1014
Implemented missing methods for `Base.copy` on internal structs.

src/varinfo.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,14 @@ end
537537
function unflatten!!(vi::VarInfo, vec::AbstractVector)
538538
vci = VectorChunkIterator!(vec, 1)
539539
new_values = map_values!!(vci, vi.values)
540+
expected_len = vci.index - 1
541+
if length(vec) != expected_len
542+
throw(
543+
DimensionMismatch(
544+
"expected a vector of length $(expected_len), but got length $(length(vec))"
545+
),
546+
)
547+
end
540548
return VarInfo(vi.transform_strategy, new_values, vi.accs)
541549
end
542550

test/logdensityfunction.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,30 @@ using Mooncake: Mooncake
8383
end
8484
end
8585

86-
@testset "LogDensityFunction: Correctness" begin
86+
@testset "LogDensityFunction: correctness with multiple threads" begin
87+
@testset "Threaded assume" begin
88+
@model function threaded_assume()
89+
x = zeros(10)
90+
Threads.@threads for i in eachindex(x)
91+
x[i] ~ Normal()
92+
end
93+
end
94+
model = setthreadsafe(threaded_assume(), true)
95+
ldf = DynamicPPL.LogDensityFunction(model)
96+
xs = rand(ldf)
97+
@test xs isa Vector{Float64} && length(xs) == 10
98+
@test LogDensityProblems.logdensity(ldf, xs) sum(logpdf.(Normal(), xs))
99+
end
100+
87101
@testset "Threaded observe" begin
88-
@model function threaded(y)
102+
@model function threaded_observe(y)
89103
x ~ Normal()
90104
Threads.@threads for i in eachindex(y)
91105
y[i] ~ Normal(x)
92106
end
93107
end
94108
N = 100
95-
model = setthreadsafe(threaded(zeros(N)), true)
109+
model = setthreadsafe(threaded_observe(zeros(N)), true)
96110
ldf = DynamicPPL.LogDensityFunction(model)
97111

98112
xs = [1.0]

test/threadsafe.jl

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,22 +111,6 @@ const gdemo_default = gdemo_d()
111111
end
112112
end
113113

114-
@testset "promotion of VNT accumulators in TSVI" begin
115-
# See https://github.com/TuringLang/DynamicPPL.jl/pull/1284.
116-
@model function f()
117-
x = zeros(10)
118-
for i in eachindex(x)
119-
x[i] ~ Normal()
120-
end
121-
end
122-
model = setthreadsafe(f(), true)
123-
124-
vi = OnlyAccsVarInfo(RawValueAccumulator(false))
125-
_, vi = DynamicPPL.init!!(model, vi, InitFromPrior(), UnlinkAll())
126-
vnt = get_raw_values(vi)
127-
@test vnt[@varname(x)] isa Vector{Float64}
128-
end
129-
130114
@testset "check_model with threadsafe" begin
131115
# This is a partial test for https://github.com/TuringLang/DynamicPPL.jl/issues/1157
132116
@model function f()
@@ -138,6 +122,34 @@ const gdemo_default = gdemo_d()
138122
@test !check_model(model)
139123
end
140124

125+
@testset "assumes are threadsafe" begin
126+
# See https://github.com/TuringLang/DynamicPPL.jl/pull/1284.
127+
#
128+
# Note: anything that involves VarInfo is still thread-unsafe. But anything
129+
# that uses OnlyAccsVarInfo is fine
130+
@model function threaded_assume()
131+
x = zeros(10)
132+
Threads.@threads for i in eachindex(x)
133+
x[i] ~ Normal()
134+
end
135+
end
136+
model = setthreadsafe(threaded_assume(), true)
137+
138+
@testset "rand" begin
139+
vnt = rand(model)
140+
for i in 1:10
141+
@test haskey(vnt, @varname(x[i]))
142+
end
143+
end
144+
@testset "logprob" begin
145+
xfixed = rand(10)
146+
params = VarNamedTuple(; x=xfixed)
147+
@test logprior(model, params) sum(logpdf.(Normal(), xfixed))
148+
@test iszero(loglikelihood(model, params))
149+
@test logjoint(model, params) sum(logpdf.(Normal(), xfixed))
150+
end
151+
end
152+
141153
@testset "logprob correctness" begin
142154
x = rand(10_000)
143155

test/varinfo.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ end
115115
lp_d = logpdf(Normal(), values.d)
116116
m = demo() | (; c=values.c, d=values.d)
117117

118-
vi = DynamicPPL.unflatten!!(VarInfo(m), collect(values))
118+
vi = DynamicPPL.unflatten!!(VarInfo(m), [values.a, values.b])
119119

120120
vi = last(DynamicPPL.evaluate_nowarn!!(m, deepcopy(vi)))
121121
@test getlogprior(vi) == lp_a + lp_b
@@ -453,6 +453,21 @@ end
453453
end
454454
end
455455

456+
@testset "unflatten!! length check" begin
457+
@model function demo_lc()
458+
x ~ Normal()
459+
return y ~ Normal(x, 1)
460+
end
461+
model = demo_lc() | (; y=0.0)
462+
varinfo = VarInfo(model)
463+
n = length(varinfo[:])
464+
# Correct length should work.
465+
@test DynamicPPL.unflatten!!(varinfo, zeros(n)) isa VarInfo
466+
# Too many parameters should throw a DimensionMismatch.
467+
@test_throws DimensionMismatch DynamicPPL.unflatten!!(varinfo, zeros(n + 1))
468+
@test_throws DimensionMismatch DynamicPPL.unflatten!!(varinfo, zeros(2n))
469+
end
470+
456471
@testset "unflatten!! type stability" begin
457472
@model function demo(y)
458473
x ~ Normal()

0 commit comments

Comments
 (0)