Skip to content

Commit 1929759

Browse files
nsicchaclaude
andcommitted
Make JSON + StanLogDensityProblems direct deps, fix deanon_size for compound expressions
- Move StanLogDensityProblems from weakdeps to deps, remove extension - Move instantiate implementation from ext into stan module - Make deanon_size recursive to handle compound size expressions (e.g. _arg1 + _arg2) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4964ca4 commit 1929759

4 files changed

Lines changed: 31 additions & 36 deletions

File tree

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1111
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
12+
StanLogDensityProblems = "a545de4d-8dba-46db-9d34-4e41d3f07807"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1314

1415
[weakdeps]
@@ -18,7 +19,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1819
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1920
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2021
PosteriorDB = "1c4bc282-d2f5-44f9-b6d1-8c4424a23ad4"
21-
StanLogDensityProblems = "a545de4d-8dba-46db-9d34-4e41d3f07807"
2222

2323
[extensions]
2424
DifferentiationInterfaceExt = "DifferentiationInterface"
@@ -27,7 +27,6 @@ MarkdownExt = "Markdown"
2727
MooncakeExt = "Mooncake"
2828
OrdinaryDiffEqExt = "OrdinaryDiffEq"
2929
PosteriorDBExt = "PosteriorDB"
30-
StanLogDensityProblemsExt = "StanLogDensityProblems"
3130

3231
[sources]
3332
TestModules = {url = "https://github.com/nsiccha/TestModules.jl"}

ext/StanLogDensityProblemsExt.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/StanBlocks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export @slic, @defsig, @deffun
55
export stan_code, stan_model, stan_instantiate
66
export StanBlocksError
77

8-
using LinearAlgebra, Statistics, Distributions, LogExpFunctions, JSON
8+
using LinearAlgebra, Statistics, Distributions, LogExpFunctions, JSON, StanLogDensityProblems
99

1010
# --- Error type for StanBlocks computations (defined early so submodules can use it) ---
1111

src/slic_stan/slic.jl

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module stan
2-
using OrderedCollections
2+
using OrderedCollections, JSON, StanLogDensityProblems
33
const RV_NAME = gensym("RV")
44
dumperror(x) = (dump(x); error(x))
55
"""
@@ -544,11 +544,17 @@ forward!(x::StringExpr; info) = join(map(stan_code, forward!(x.args; info)))
544544
deanon_size(s, x) = s
545545
deanon_size(s::StanExpr, x::CanonicalExpr) = begin
546546
e = expr(s)
547-
isa(e, Symbol) || return s
548-
m = match(r"^_arg(\d+)$", string(e))
549-
isnothing(m) && return s
550-
i = parse(Int, m[1])
551-
i <= length(x.args) ? x.args[i] : s
547+
if isa(e, Symbol)
548+
m = match(r"^_arg(\d+)$", string(e))
549+
isnothing(m) && return s
550+
i = parse(Int, m[1])
551+
return i <= length(x.args) ? x.args[i] : s
552+
elseif isa(e, CanonicalExpr)
553+
new_args = map(a -> deanon_size(a, x), e.args)
554+
new_args == e.args && return s
555+
return StanExpr(remake(e, new_args...), type(s))
556+
end
557+
s
552558
end
553559
deanon_type(tt::StanType, x::CanonicalExpr) = begin
554560
sz = stan_size(tt)
@@ -995,12 +1001,24 @@ prepare_for_stan(x::Tuple) = prepare_for_stan(Dict(enumerate(x)))
9951001
bridgestan_data(x::Dict) = JSON.json(prepare_for_stan(x))
9961002
"""
9971003
Returns the StanLogDensityProblem (a compiled posterior).
998-
999-
**Warning:**
1000-
1001-
Requires loading StanLogDensityProblems.jl and JSON.jl.
10021004
"""
1003-
instantiate(args...; kwargs...) = error("Using instantiate requires loading StanLogDensityProblems.jl and JSON.jl!")
1005+
instantiate(x::Union{SlicModel,StanModel}; nan_on_error=true, make_args=["STAN_THREADS=true"], warn=false, kwargs...) = begin
1006+
sc = stan_code(x)
1007+
stan_path = get(kwargs, :path, joinpath("tmp", string(hash(sc)) * ".stan"))
1008+
mkpath(dirname(stan_path))
1009+
if !isfile(stan_path)
1010+
open(stan_path, "w") do fd
1011+
write(fd, sc)
1012+
end
1013+
end
1014+
StanLogDensityProblems.StanProblem(
1015+
stan_path,
1016+
bridgestan_data(stan_data(x));
1017+
nan_on_error,
1018+
make_args,
1019+
warn
1020+
)
1021+
end
10041022
debug_instantiate(x; kwargs...) = instantiate(x; nan_on_error=false, kwargs...)
10051023
passinstantiate(x; kwargs...) = (instantiate(x; kwargs...); x)
10061024
stan_data(x::SlicModel) = stan_data(stan_model(x))

0 commit comments

Comments
 (0)