@@ -111,22 +111,6 @@ const gdemo_default = gdemo_d()
111111 end
112112 end
113113
114- @testset " promotion of VNT accumulators in TSVI" begin
115- # See https://github.com/TuringLang/DynamicPPL.jl/pull/1284.
116- @model function f ()
117- x = zeros (10 )
118- for i in eachindex (x)
119- x[i] ~ Normal ()
120- end
121- end
122- model = setthreadsafe (f (), true )
123-
124- vi = OnlyAccsVarInfo (RawValueAccumulator (false ))
125- _, vi = DynamicPPL. init!! (model, vi, InitFromPrior (), UnlinkAll ())
126- vnt = get_raw_values (vi)
127- @test vnt[@varname (x)] isa Vector{Float64}
128- end
129-
130114 @testset " check_model with threadsafe" begin
131115 # This is a partial test for https://github.com/TuringLang/DynamicPPL.jl/issues/1157
132116 @model function f ()
@@ -138,6 +122,34 @@ const gdemo_default = gdemo_d()
138122 @test ! check_model (model)
139123 end
140124
125+ @testset " assumes are threadsafe" begin
126+ # See https://github.com/TuringLang/DynamicPPL.jl/pull/1284.
127+ #
128+ # Note: anything that involves VarInfo is still thread-unsafe. But anything
129+ # that uses OnlyAccsVarInfo is fine
130+ @model function threaded_assume ()
131+ x = zeros (10 )
132+ Threads. @threads for i in eachindex (x)
133+ x[i] ~ Normal ()
134+ end
135+ end
136+ model = setthreadsafe (threaded_assume (), true )
137+
138+ @testset " rand" begin
139+ vnt = rand (model)
140+ for i in 1 : 10
141+ @test haskey (vnt, @varname (x[i]))
142+ end
143+ end
144+ @testset " logprob" begin
145+ xfixed = rand (10 )
146+ params = VarNamedTuple (; x= xfixed)
147+ @test logprior (model, params) ≈ sum (logpdf .(Normal (), xfixed))
148+ @test iszero (loglikelihood (model, params))
149+ @test logjoint (model, params) ≈ sum (logpdf .(Normal (), xfixed))
150+ end
151+ end
152+
141153 @testset " logprob correctness" begin
142154 x = rand (10_000 )
143155
0 commit comments