From 49a3e9b22e6a07ad5a4b8b31c12a101c65544643 Mon Sep 17 00:00:00 2001 From: Lorenzo Van Munoz Date: Thu, 11 Jun 2026 21:55:14 -0400 Subject: [PATCH 1/2] add dilation distance to ssp --- src/julia/SSP/src/project.jl | 32 +++++++++++++++++-------------- src/julia/SSP/src/pythonic_api.jl | 9 +++++---- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/julia/SSP/src/project.jl b/src/julia/SSP/src/project.jl index 82d69bc..660f5e3 100644 --- a/src/julia/SSP/src/project.jl +++ b/src/julia/SSP/src/project.jl @@ -6,18 +6,21 @@ import SSP: init!, solve!, adjoint_solve! public ProjectionProblem, SSP1_linear, SSP1, SSP2 """ - ProjectionProblem(; rho_filtered, grid, target_points, beta=Inf, eta=1/2) + ProjectionProblem(; rho_filtered, grid, target_points, beta=Inf, eta=1/2, dilation_distance=0) Define a problem for projecting smoothed data `rho_filtered` defined on a `grid`, i.e. a tuple of range, at a list of selected `target_points`, i.e. a vector of coordinate tuples. -The projection pushes `rho` values above `eta` towards 1 and `rho` values below `eta` towards 0 with a stiffnes parameter `beta`. +The projection pushes `rho` values above `eta` towards 1 and `rho` values below `eta` towards 0 with a stiffness parameter `beta`. +The `dilation_distance` parameter sets how far to expand or contract the contour of the projection threshold. +When `dilation_distance` is positive the region above the threshold is dilated and when negative it is eroded. +Furthermore, `dilation_distance` is dimensionless in units of pixels from the boundary and must satisfy `abs(dilation_distance) < 1` """ -Base.@kwdef struct ProjectionProblem{D,G,T,B} +Base.@kwdef struct ProjectionProblem{D,G,T,B,DD} rho_filtered::D grid::G target_points::T beta::B = eltype(rho_filtered)(Inf) eta::B = eltype(rho_filtered)(1//2) - # dilation/erosion distance = 0 + dilation_distance::DD=0 end function Base.copy(prob::ProjectionProblem) @@ -27,17 +30,18 @@ function Base.copy(prob::ProjectionProblem) target_points = copy(prob.target_points), beta = prob.beta, eta = prob.eta, + dilation_distance = prob.dilation_distance, ) return newprob end -mutable struct ProjectionSolver{D,G,T,B,A,C} +mutable struct ProjectionSolver{D,G,T,B,DD,A,C} rho_filtered::D const grid::G const target_points::T beta::B eta::B - # dilation/erosion distance + dilation_distance::DD alg::A cacheval::C end @@ -72,7 +76,7 @@ The `smoothing_radius` keyword sets the radius of the smoothing kernel relative SSP2(; kws...) = SSPAlg(; interp=CubicInterp(; deriv=ValueWithGradientAndHessian()), kws...) function init!(prob::ProjectionProblem, alg::SSPAlg) - (; rho_filtered, grid, target_points, beta, eta) = prob + (; rho_filtered, grid, target_points, beta, eta, dilation_distance) = prob interp_prob = InterpolationProblem(; data=rho_filtered, grid, target_points) interp_alg = alg.interp @@ -90,7 +94,7 @@ function init!(prob::ProjectionProblem, alg::SSPAlg) cacheval = (; interp_solver, rho_projected, adj_rho_filtered_interp)# adj_rho_filtered_value, adj_rho_filtered_gradient, adj_rho_filtered_hessian) - return ProjectionSolver(rho_filtered, grid, target_points, beta, eta, alg, cacheval) + return ProjectionSolver(rho_filtered, grid, target_points, beta, eta, dilation_distance, alg, cacheval) end function solve!(solver::ProjectionSolver) @@ -99,7 +103,7 @@ end function proj_solve!(solver, alg::SSPAlg) - (; rho_filtered, grid, beta, eta, cacheval) = solver + (; rho_filtered, grid, beta, eta, dilation_distance, cacheval) = solver (; interp_solver, rho_projected) = cacheval interp_solver.data = rho_filtered @@ -118,7 +122,7 @@ function proj_solve!(solver, alg::SSPAlg) (haskey(rho_filtered_interp, :hessian) ? (; hessian = view(rho_filtered_interp.hessian, i, :, :)) : (;))... ) rho_filtered_interp_derivs_normsq = map(Base.Fix1(sum, abs2), rho_filtered_interp_derivs) - rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta) + rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta, dilation_distance) rho_projected[i] = rho_p end return (; value=rho_projected, tape=nothing) @@ -163,7 +167,7 @@ function adjoint_tanh_projection(adj_out, x, beta, eta) end -function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothing, beta, eta) +function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothing, beta, eta, dilation_distance) rho_projected = tanh_projection(rho_filtered, beta, eta) den_helper = if haskey(rho_filtered_derivs_normsq, :hessian) @@ -178,7 +182,7 @@ function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothi den_eff = sqrt(ifelse(nonzero_norm, den_helper, oneunit(den_helper))) d = (eta - rho_filtered) / den_eff - d_R = d / R_smoothing + d_R = d / R_smoothing - dilation_distance needs_smoothing = nonzero_norm & (abs(d_R) < one(d_R)) F_plus = ifelse(needs_smoothing, 1//2 + d_R * evalpoly(d_R^2, (-15//16, 5//8, -3//16)), one(d_R)) @@ -202,7 +206,7 @@ end function adjoint_proj_solve!(solver, alg::SSPAlg, adj_sol, tape) - (; rho_filtered, grid, beta, eta, cacheval) = solver + (; rho_filtered, grid, beta, eta, dilation_distance, cacheval) = solver (; interp_solver, adj_rho_filtered_interp) = cacheval # (; interp_solver, adj_rho_filtered_value, adj_rho_filtered_gradient, adj_rho_filtered_hessian) = cacheval @@ -223,7 +227,7 @@ function adjoint_proj_solve!(solver, alg::SSPAlg, adj_sol, tape) (haskey(rho_filtered_interp, :hessian) ? (; hessian = view(rho_filtered_interp.hessian, i, :, :)) : (;))... ) rho_filtered_interp_derivs_normsq = map(Base.Fix1(sum, abs2), rho_filtered_interp_derivs) - rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta) + rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta, dilation_distance) adj_rho_f, adj_rho_derivs_normsq = adjoint_smoothed_projection(adj_proj, tape, rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta) adj_rho_filtered_interp.value[i] = adj_rho_f adj_rho_filtered_interp.gradient[i, :] .= 2adj_rho_derivs_normsq.gradient .* rho_filtered_interp_derivs.gradient diff --git a/src/julia/SSP/src/pythonic_api.jl b/src/julia/SSP/src/pythonic_api.jl index 1a3b58b..351477b 100644 --- a/src/julia/SSP/src/pythonic_api.jl +++ b/src/julia/SSP/src/pythonic_api.jl @@ -47,7 +47,7 @@ function conic_filter_rrule(adj_depad_value, padsolver, convsolver, depadsolver) return adj_padprob.data end -function ssp_withsolver(alg, rho_filtered, beta, eta, grid) +function ssp_withsolver(alg, rho_filtered, beta, eta, grid, dilation_distance=0) target_points = vec(collect(Iterators.product(grid...))) prob = Project.ProjectionProblem(; rho_filtered, @@ -55,6 +55,7 @@ function ssp_withsolver(alg, rho_filtered, beta, eta, grid) target_points, beta, eta, + dilation_distance, ) solver = init(prob, alg) sol = solve!(solver) @@ -62,7 +63,7 @@ function ssp_withsolver(alg, rho_filtered, beta, eta, grid) end """ - ssp1_linear(rho_filtered, beta, eta, grid) + ssp1_linear(rho_filtered, beta, eta, grid, [dilation_distance=0]) Project using the original [SSP1 algorithm] [1] with linear interpolation. @@ -77,7 +78,7 @@ At `beta=Inf`, this projection is not differentiable through topology changes, a ssp1_linear(args...; kws...) = ssp_withsolver(Project.SSP1_linear(; kws...), args...)[1] """ - ssp1(rho_filtered, beta, eta, grid) + ssp1(rho_filtered, beta, eta, grid, [dilation_distance=0]) Project using the original [SSP1 algorithm] [1] with cubic interpolation. @@ -92,7 +93,7 @@ At `beta=Inf`, this projection is not differentiable through topology changes, a ssp1(args...; kws...) = ssp_withsolver(Project.SSP1(; kws...), args...)[1] """ - ssp2(rho_filtered, beta, eta, grid) + ssp2(rho_filtered, beta, eta, grid, [dilation_distance=0]) Project using the improved [SSP2 algorithm] [1] with cubic interpolation. From 37cf59373feb64639b4186bf4889cd3f4ff47daf Mon Sep 17 00:00:00 2001 From: Lorenzo Van Munoz Date: Wed, 17 Jun 2026 16:14:00 -0400 Subject: [PATCH 2/2] fix dilation routine and prepare pr --- examples/julia/ssp_dilation.jl | 41 +++++++++++++++++++++++ src/julia/SSP/ext/SSPChainRulesCoreExt.jl | 18 ++++++++++ src/julia/SSP/src/project.jl | 23 +++++++------ src/julia/SSP/test/project.jl | 6 ++-- src/julia/SSP/test/pythonic_api.jl | 16 +++++++++ 5 files changed, 92 insertions(+), 12 deletions(-) create mode 100644 examples/julia/ssp_dilation.jl diff --git a/examples/julia/ssp_dilation.jl b/examples/julia/ssp_dilation.jl new file mode 100644 index 0000000..7b358fa --- /dev/null +++ b/examples/julia/ssp_dilation.jl @@ -0,0 +1,41 @@ +using SSP: conic_filter, ssp2 + +using Random +using CairoMakie +using CairoMakie: colormap +using NLopt +using Zygote + + +Nx = Ny = 100 +grid = ( + range(-1, 1, length=Nx), + range(-1, 1, length=Ny), +) +# Random.seed!(42) +# design_vars = rand(Nx, Ny) +# design_vars = [sinpi(x) * sinpi(y) for (x, y) in Iterators.product(grid...)] +design_vars = let a = 0.5, b = 0.499 + # Cassini oval + [((x^2 + y^2)^2 - 2a^2 * (x^2 - y^2) + a^4 - b^4) + 0.5 for (x, y) in Iterators.product(grid...)] +end +radius = 0.1 +beta = Inf +eta = 0.5 +dilations = [-0.1, 0.0, 0.1] + +ssp_projections = map(dilations) do dilation + rho_filtered = conic_filter(design_vars, radius, grid) + rho_projected = ssp2(rho_filtered, beta, eta, grid, dilation) + return rho_projected +end + +let + fig = Figure(size = (1200, 400)) + for (i, (dilation, rho_projected)) in enumerate(zip(dilations, ssp_projections)) + ax = Axis(fig[1,2i-1]; title = "dilation = $dilation", aspect=DataAspect()) + h = heatmap!(grid..., rho_projected; colormap=colormap("grays")) + Colorbar(fig[1,2i], h) + end + save("dilation_comparison.png", fig) +end \ No newline at end of file diff --git a/src/julia/SSP/ext/SSPChainRulesCoreExt.jl b/src/julia/SSP/ext/SSPChainRulesCoreExt.jl index 8cfea91..f73718f 100644 --- a/src/julia/SSP/ext/SSPChainRulesCoreExt.jl +++ b/src/julia/SSP/ext/SSPChainRulesCoreExt.jl @@ -31,6 +31,24 @@ function rrule(::typeof(ssp2), rho_filtered, beta, eta, grid; kws...) _ssp_rrule(Project.SSP2(; kws...), rho_filtered, beta, eta, grid) end +function _ssp_rrule(alg, rho_filtered, beta, eta, grid, dilation_distance) + rho_projected, solver = ssp_withsolver(alg, rho_filtered, beta, eta, grid, dilation_distance) + function ssp_pullback(adj_rho_projected) + adj_rho_filtered = ssp_rrule(unthunk(adj_rho_projected), solver) + return NoTangent(), adj_rho_filtered, NoTangent(), NoTangent(), NoTangent(), NoTangent() + end + return rho_projected, ssp_pullback +end +function rrule(::typeof(ssp1_linear), rho_filtered, beta, eta, grid, dilation_distance; kws...) + _ssp_rrule(Project.SSP1_linear(; kws...), rho_filtered, beta, eta, grid, dilation_distance) +end +function rrule(::typeof(ssp1), rho_filtered, beta, eta, grid, dilation_distance; kws...) + _ssp_rrule(Project.SSP1(; kws...), rho_filtered, beta, eta, grid, dilation_distance) +end +function rrule(::typeof(ssp2), rho_filtered, beta, eta, grid, dilation_distance; kws...) + _ssp_rrule(Project.SSP2(; kws...), rho_filtered, beta, eta, grid, dilation_distance) +end + function _lengthconstraint_rrule(material, rho_filtered, rho_projected, grid, target_length) constraint, solver = lengthconstraint_withsolver(material, rho_filtered, rho_projected, grid, target_length) function lengthconstraint_pullback(adj_constraint) diff --git a/src/julia/SSP/src/project.jl b/src/julia/SSP/src/project.jl index 660f5e3..28df353 100644 --- a/src/julia/SSP/src/project.jl +++ b/src/julia/SSP/src/project.jl @@ -10,9 +10,8 @@ public ProjectionProblem, SSP1_linear, SSP1, SSP2 Define a problem for projecting smoothed data `rho_filtered` defined on a `grid`, i.e. a tuple of range, at a list of selected `target_points`, i.e. a vector of coordinate tuples. The projection pushes `rho` values above `eta` towards 1 and `rho` values below `eta` towards 0 with a stiffness parameter `beta`. -The `dilation_distance` parameter sets how far to expand or contract the contour of the projection threshold. +The `dilation_distance` parameter sets an approximate length how far to expand or contract the contour of the projection threshold. When `dilation_distance` is positive the region above the threshold is dilated and when negative it is eroded. -Furthermore, `dilation_distance` is dimensionless in units of pixels from the boundary and must satisfy `abs(dilation_distance) < 1` """ Base.@kwdef struct ProjectionProblem{D,G,T,B,DD} rho_filtered::D @@ -167,8 +166,7 @@ function adjoint_tanh_projection(adj_out, x, beta, eta) end -function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothing, beta, eta, dilation_distance) - rho_projected = tanh_projection(rho_filtered, beta, eta) +function smoothed_projection(rho_filtered_no_dilation, rho_filtered_derivs_normsq, R_smoothing, beta, eta, dilation_distance) den_helper = if haskey(rho_filtered_derivs_normsq, :hessian) # SSP2 @@ -181,8 +179,9 @@ function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothi nonzero_norm = abs(den_helper) > zero(den_helper) den_eff = sqrt(ifelse(nonzero_norm, den_helper, oneunit(den_helper))) + rho_filtered = rho_filtered_no_dilation + dilation_distance * den_eff d = (eta - rho_filtered) / den_eff - d_R = d / R_smoothing - dilation_distance + d_R = d / R_smoothing needs_smoothing = nonzero_norm & (abs(d_R) < one(d_R)) F_plus = ifelse(needs_smoothing, 1//2 + d_R * evalpoly(d_R^2, (-15//16, 5//8, -3//16)), one(d_R)) @@ -195,7 +194,9 @@ function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothi rho_plus_eff_projected = tanh_projection(rho_filtered_plus, beta, eta) rho_projected_smoothed = (1 - F_plus) * rho_minus_eff_projected + F_plus * rho_plus_eff_projected - tape = (; F_plus, F_minus, d, d_R, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected) + rho_projected = tanh_projection(rho_filtered, beta, eta) + + tape = (; F_plus, F_minus, d, d_R, rho_filtered, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected) return ifelse(needs_smoothing, rho_projected_smoothed, rho_projected), tape end @@ -228,7 +229,7 @@ function adjoint_proj_solve!(solver, alg::SSPAlg, adj_sol, tape) ) rho_filtered_interp_derivs_normsq = map(Base.Fix1(sum, abs2), rho_filtered_interp_derivs) rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta, dilation_distance) - adj_rho_f, adj_rho_derivs_normsq = adjoint_smoothed_projection(adj_proj, tape, rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta) + adj_rho_f, adj_rho_derivs_normsq = adjoint_smoothed_projection(adj_proj, tape, rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta, dilation_distance) adj_rho_filtered_interp.value[i] = adj_rho_f adj_rho_filtered_interp.gradient[i, :] .= 2adj_rho_derivs_normsq.gradient .* rho_filtered_interp_derivs.gradient if haskey(rho_filtered_interp, :hessian) @@ -241,8 +242,8 @@ function adjoint_proj_solve!(solver, alg::SSPAlg, adj_sol, tape) return (; rho_filtered=adj_interp_prob.data, grid=nothing, target_points=nothing) end -function adjoint_smoothed_projection(adj_rho_projected_maybe_smoothed, tape, rho_filtered, rho_filtered_derivs_normsq, R_smoothing, beta, eta) - (; F_plus, F_minus, d, d_R, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected) = tape +function adjoint_smoothed_projection(adj_rho_projected_maybe_smoothed, tape, rho_filtered_no_dilation, rho_filtered_derivs_normsq, R_smoothing, beta, eta, dilation_distance) + (; F_plus, F_minus, d, d_R, rho_filtered, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected) = tape adj_rho_projected = ifelse(needs_smoothing, zero(adj_rho_projected_maybe_smoothed), adj_rho_projected_maybe_smoothed) adj_rho_filtered = adjoint_tanh_projection(adj_rho_projected, rho_filtered, beta, eta) @@ -262,7 +263,9 @@ function adjoint_smoothed_projection(adj_rho_projected_maybe_smoothed, tape, rho adj_d = adj_d_R / R_smoothing adj_rho_filtered -= adj_d / den_eff + adj_rho_filtered_no_dilation = adj_rho_filtered adj_den_eff -= adj_d * d / den_eff + adj_den_eff += dilation_distance * adj_rho_filtered adj_den_helper = ifelse(nonzero_norm, adj_den_eff, zero(adj_den_eff)) / 2den_eff adj_rho_filtered_derivs_normsq = (; @@ -270,6 +273,6 @@ function adjoint_smoothed_projection(adj_rho_projected_maybe_smoothed, tape, rho (haskey(rho_filtered_derivs_normsq, :hessian) ? (; hessian = adj_den_helper * R_smoothing^2) : (;))... ) - return adj_rho_filtered, adj_rho_filtered_derivs_normsq + return adj_rho_filtered_no_dilation, adj_rho_filtered_derivs_normsq end end diff --git a/src/julia/SSP/test/project.jl b/src/julia/SSP/test/project.jl index 592d190..50eda38 100644 --- a/src/julia/SSP/test/project.jl +++ b/src/julia/SSP/test/project.jl @@ -20,17 +20,18 @@ for alg in ( SSP.Project.SSP1_linear(), SSP.Project.SSP1(), SSP.Project.SSP2(), -) +), dilation in [-0.3, 0.0, 0.3] # test that adjoints match finite differences Random.seed!(0) perturb = randn(size(data)) - test = let perturb=perturb, data=copy(data), alg=alg, grid=grid, target_points=target_points + test = let perturb=perturb, data=copy(data), alg=alg, grid=grid, target_points=target_points, dilation=dilation function (h) prob = SSP.Project.ProjectionProblem(; rho_filtered=data + h * perturb, grid, target_points, + dilation_distance = dilation, beta = Inf, eta = 0.5 ) @@ -44,6 +45,7 @@ for alg in ( rho_filtered=data, grid, target_points, + dilation_distance = dilation, beta = Inf, eta = 0.5 ) diff --git a/src/julia/SSP/test/pythonic_api.jl b/src/julia/SSP/test/pythonic_api.jl index 7d41326..cfce1cb 100644 --- a/src/julia/SSP/test/pythonic_api.jl +++ b/src/julia/SSP/test/pythonic_api.jl @@ -57,6 +57,22 @@ for ssp in ssp_algs dtest_ssp_di_fd = central_fdm(5, 1)(test_ssp_i, 0.0) ddata_ssp, = Zygote.gradient(test_ssp, myfilt) @test dtest_ssp_di_fd ≈ sum(ddata_ssp .* perturb) + + dilation = 0.2 + test_ssp = let radius=radius, grid=grid, ssp=ssp, beta=beta, eta=eta + function (data) + rho_projected = ssp(data, beta, eta, grid, dilation) + return sum(abs2, rho_projected) + end + end + test_ssp_i = let perturb=perturb, data=copy(myfilt), test_ssp=test_ssp + h -> test_ssp(data + h * perturb) + end + + dtest_ssp_di_fd = central_fdm(5, 1)(test_ssp_i, 0.0) + ddata_ssp, = Zygote.gradient(test_ssp, myfilt) + @test dtest_ssp_di_fd ≈ sum(ddata_ssp .* perturb) + end constraint_algs = (