Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/julia/ssp_dilation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using SSP: conic_filter, ssp2

using Random
using CairoMakie
using CairoMakie: colormap
using NLopt
using Zygote


Nx = Ny = 100
grid = (
range(-1, 1, length=Nx),
range(-1, 1, length=Ny),
)
# Random.seed!(42)
# design_vars = rand(Nx, Ny)
# design_vars = [sinpi(x) * sinpi(y) for (x, y) in Iterators.product(grid...)]
design_vars = let a = 0.5, b = 0.499
# Cassini oval
[((x^2 + y^2)^2 - 2a^2 * (x^2 - y^2) + a^4 - b^4) + 0.5 for (x, y) in Iterators.product(grid...)]
end
radius = 0.1
beta = Inf
eta = 0.5
dilations = [-0.1, 0.0, 0.1]

ssp_projections = map(dilations) do dilation
rho_filtered = conic_filter(design_vars, radius, grid)
rho_projected = ssp2(rho_filtered, beta, eta, grid, dilation)
return rho_projected
end

let
fig = Figure(size = (1200, 400))
for (i, (dilation, rho_projected)) in enumerate(zip(dilations, ssp_projections))
ax = Axis(fig[1,2i-1]; title = "dilation = $dilation", aspect=DataAspect())
h = heatmap!(grid..., rho_projected; colormap=colormap("grays"))
Colorbar(fig[1,2i], h)
end
save("dilation_comparison.png", fig)
end
18 changes: 18 additions & 0 deletions src/julia/SSP/ext/SSPChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ function rrule(::typeof(ssp2), rho_filtered, beta, eta, grid; kws...)
_ssp_rrule(Project.SSP2(; kws...), rho_filtered, beta, eta, grid)
end

function _ssp_rrule(alg, rho_filtered, beta, eta, grid, dilation_distance)
rho_projected, solver = ssp_withsolver(alg, rho_filtered, beta, eta, grid, dilation_distance)
function ssp_pullback(adj_rho_projected)
adj_rho_filtered = ssp_rrule(unthunk(adj_rho_projected), solver)
return NoTangent(), adj_rho_filtered, NoTangent(), NoTangent(), NoTangent(), NoTangent()
end
return rho_projected, ssp_pullback
end
function rrule(::typeof(ssp1_linear), rho_filtered, beta, eta, grid, dilation_distance; kws...)
_ssp_rrule(Project.SSP1_linear(; kws...), rho_filtered, beta, eta, grid, dilation_distance)
end
function rrule(::typeof(ssp1), rho_filtered, beta, eta, grid, dilation_distance; kws...)
_ssp_rrule(Project.SSP1(; kws...), rho_filtered, beta, eta, grid, dilation_distance)
end
function rrule(::typeof(ssp2), rho_filtered, beta, eta, grid, dilation_distance; kws...)
_ssp_rrule(Project.SSP2(; kws...), rho_filtered, beta, eta, grid, dilation_distance)
end

function _lengthconstraint_rrule(material, rho_filtered, rho_projected, grid, target_length)
constraint, solver = lengthconstraint_withsolver(material, rho_filtered, rho_projected, grid, target_length)
function lengthconstraint_pullback(adj_constraint)
Expand Down
45 changes: 26 additions & 19 deletions src/julia/SSP/src/project.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@ import SSP: init!, solve!, adjoint_solve!
public ProjectionProblem, SSP1_linear, SSP1, SSP2

"""
ProjectionProblem(; rho_filtered, grid, target_points, beta=Inf, eta=1/2)
ProjectionProblem(; rho_filtered, grid, target_points, beta=Inf, eta=1/2, dilation_distance=0)

Define a problem for projecting smoothed data `rho_filtered` defined on a `grid`, i.e. a tuple of range, at a list of selected `target_points`, i.e. a vector of coordinate tuples.
The projection pushes `rho` values above `eta` towards 1 and `rho` values below `eta` towards 0 with a stiffnes parameter `beta`.
The projection pushes `rho` values above `eta` towards 1 and `rho` values below `eta` towards 0 with a stiffness parameter `beta`.
The `dilation_distance` parameter sets an approximate length how far to expand or contract the contour of the projection threshold.
When `dilation_distance` is positive the region above the threshold is dilated and when negative it is eroded.
"""
Base.@kwdef struct ProjectionProblem{D,G,T,B}
Base.@kwdef struct ProjectionProblem{D,G,T,B,DD}
rho_filtered::D
grid::G
target_points::T
beta::B = eltype(rho_filtered)(Inf)
eta::B = eltype(rho_filtered)(1//2)
# dilation/erosion distance = 0
dilation_distance::DD=0
end

function Base.copy(prob::ProjectionProblem)
Expand All @@ -27,17 +29,18 @@ function Base.copy(prob::ProjectionProblem)
target_points = copy(prob.target_points),
beta = prob.beta,
eta = prob.eta,
dilation_distance = prob.dilation_distance,
)
return newprob
end

mutable struct ProjectionSolver{D,G,T,B,A,C}
mutable struct ProjectionSolver{D,G,T,B,DD,A,C}
rho_filtered::D
const grid::G
const target_points::T
beta::B
eta::B
# dilation/erosion distance
dilation_distance::DD
alg::A
cacheval::C
end
Expand Down Expand Up @@ -72,7 +75,7 @@ The `smoothing_radius` keyword sets the radius of the smoothing kernel relative
SSP2(; kws...) = SSPAlg(; interp=CubicInterp(; deriv=ValueWithGradientAndHessian()), kws...)

function init!(prob::ProjectionProblem, alg::SSPAlg)
(; rho_filtered, grid, target_points, beta, eta) = prob
(; rho_filtered, grid, target_points, beta, eta, dilation_distance) = prob

interp_prob = InterpolationProblem(; data=rho_filtered, grid, target_points)
interp_alg = alg.interp
Expand All @@ -90,7 +93,7 @@ function init!(prob::ProjectionProblem, alg::SSPAlg)

cacheval = (; interp_solver, rho_projected, adj_rho_filtered_interp)# adj_rho_filtered_value, adj_rho_filtered_gradient, adj_rho_filtered_hessian)

return ProjectionSolver(rho_filtered, grid, target_points, beta, eta, alg, cacheval)
return ProjectionSolver(rho_filtered, grid, target_points, beta, eta, dilation_distance, alg, cacheval)
end

function solve!(solver::ProjectionSolver)
Expand All @@ -99,7 +102,7 @@ end

function proj_solve!(solver, alg::SSPAlg)

(; rho_filtered, grid, beta, eta, cacheval) = solver
(; rho_filtered, grid, beta, eta, dilation_distance, cacheval) = solver
(; interp_solver, rho_projected) = cacheval

interp_solver.data = rho_filtered
Expand All @@ -118,7 +121,7 @@ function proj_solve!(solver, alg::SSPAlg)
(haskey(rho_filtered_interp, :hessian) ? (; hessian = view(rho_filtered_interp.hessian, i, :, :)) : (;))...
)
rho_filtered_interp_derivs_normsq = map(Base.Fix1(sum, abs2), rho_filtered_interp_derivs)
rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta)
rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta, dilation_distance)
rho_projected[i] = rho_p
end
return (; value=rho_projected, tape=nothing)
Expand Down Expand Up @@ -163,8 +166,7 @@ function adjoint_tanh_projection(adj_out, x, beta, eta)
end


function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothing, beta, eta)
rho_projected = tanh_projection(rho_filtered, beta, eta)
function smoothed_projection(rho_filtered_no_dilation, rho_filtered_derivs_normsq, R_smoothing, beta, eta, dilation_distance)

den_helper = if haskey(rho_filtered_derivs_normsq, :hessian)
# SSP2
Expand All @@ -177,6 +179,7 @@ function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothi
nonzero_norm = abs(den_helper) > zero(den_helper)

den_eff = sqrt(ifelse(nonzero_norm, den_helper, oneunit(den_helper)))
rho_filtered = rho_filtered_no_dilation + dilation_distance * den_eff
d = (eta - rho_filtered) / den_eff
d_R = d / R_smoothing

Expand All @@ -191,7 +194,9 @@ function smoothed_projection(rho_filtered, rho_filtered_derivs_normsq, R_smoothi
rho_plus_eff_projected = tanh_projection(rho_filtered_plus, beta, eta)

rho_projected_smoothed = (1 - F_plus) * rho_minus_eff_projected + F_plus * rho_plus_eff_projected
tape = (; F_plus, F_minus, d, d_R, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected)
rho_projected = tanh_projection(rho_filtered, beta, eta)

tape = (; F_plus, F_minus, d, d_R, rho_filtered, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected)
return ifelse(needs_smoothing, rho_projected_smoothed, rho_projected), tape
end

Expand All @@ -202,7 +207,7 @@ end

function adjoint_proj_solve!(solver, alg::SSPAlg, adj_sol, tape)

(; rho_filtered, grid, beta, eta, cacheval) = solver
(; rho_filtered, grid, beta, eta, dilation_distance, cacheval) = solver
(; interp_solver, adj_rho_filtered_interp) = cacheval
# (; interp_solver, adj_rho_filtered_value, adj_rho_filtered_gradient, adj_rho_filtered_hessian) = cacheval

Expand All @@ -223,8 +228,8 @@ function adjoint_proj_solve!(solver, alg::SSPAlg, adj_sol, tape)
(haskey(rho_filtered_interp, :hessian) ? (; hessian = view(rho_filtered_interp.hessian, i, :, :)) : (;))...
)
rho_filtered_interp_derivs_normsq = map(Base.Fix1(sum, abs2), rho_filtered_interp_derivs)
rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta)
adj_rho_f, adj_rho_derivs_normsq = adjoint_smoothed_projection(adj_proj, tape, rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta)
rho_p, tape = smoothed_projection(rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta, dilation_distance)
adj_rho_f, adj_rho_derivs_normsq = adjoint_smoothed_projection(adj_proj, tape, rho_f, rho_filtered_interp_derivs_normsq, R_smoothing, beta, eta, dilation_distance)
adj_rho_filtered_interp.value[i] = adj_rho_f
adj_rho_filtered_interp.gradient[i, :] .= 2adj_rho_derivs_normsq.gradient .* rho_filtered_interp_derivs.gradient
if haskey(rho_filtered_interp, :hessian)
Expand All @@ -237,8 +242,8 @@ function adjoint_proj_solve!(solver, alg::SSPAlg, adj_sol, tape)
return (; rho_filtered=adj_interp_prob.data, grid=nothing, target_points=nothing)
end

function adjoint_smoothed_projection(adj_rho_projected_maybe_smoothed, tape, rho_filtered, rho_filtered_derivs_normsq, R_smoothing, beta, eta)
(; F_plus, F_minus, d, d_R, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected) = tape
function adjoint_smoothed_projection(adj_rho_projected_maybe_smoothed, tape, rho_filtered_no_dilation, rho_filtered_derivs_normsq, R_smoothing, beta, eta, dilation_distance)
(; F_plus, F_minus, d, d_R, rho_filtered, rho_projected, den_helper, den_eff, nonzero_norm, needs_smoothing, rho_filtered_minus, rho_filtered_plus, rho_minus_eff_projected, rho_plus_eff_projected) = tape

adj_rho_projected = ifelse(needs_smoothing, zero(adj_rho_projected_maybe_smoothed), adj_rho_projected_maybe_smoothed)
adj_rho_filtered = adjoint_tanh_projection(adj_rho_projected, rho_filtered, beta, eta)
Expand All @@ -258,14 +263,16 @@ function adjoint_smoothed_projection(adj_rho_projected_maybe_smoothed, tape, rho

adj_d = adj_d_R / R_smoothing
adj_rho_filtered -= adj_d / den_eff
adj_rho_filtered_no_dilation = adj_rho_filtered
adj_den_eff -= adj_d * d / den_eff
adj_den_eff += dilation_distance * adj_rho_filtered
adj_den_helper = ifelse(nonzero_norm, adj_den_eff, zero(adj_den_eff)) / 2den_eff

adj_rho_filtered_derivs_normsq = (;
gradient = adj_den_helper,
(haskey(rho_filtered_derivs_normsq, :hessian) ? (; hessian = adj_den_helper * R_smoothing^2) : (;))...
)

return adj_rho_filtered, adj_rho_filtered_derivs_normsq
return adj_rho_filtered_no_dilation, adj_rho_filtered_derivs_normsq
end
end
9 changes: 5 additions & 4 deletions src/julia/SSP/src/pythonic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,23 @@ function conic_filter_rrule(adj_depad_value, padsolver, convsolver, depadsolver)
return adj_padprob.data
end

function ssp_withsolver(alg, rho_filtered, beta, eta, grid)
function ssp_withsolver(alg, rho_filtered, beta, eta, grid, dilation_distance=0)
target_points = vec(collect(Iterators.product(grid...)))
prob = Project.ProjectionProblem(;
rho_filtered,
grid,
target_points,
beta,
eta,
dilation_distance,
)
solver = init(prob, alg)
sol = solve!(solver)
return reshape(sol.value, size(rho_filtered)), solver
end

"""
ssp1_linear(rho_filtered, beta, eta, grid)
ssp1_linear(rho_filtered, beta, eta, grid, [dilation_distance=0])

Project using the original [SSP1 algorithm] [1] with linear interpolation.

Expand All @@ -77,7 +78,7 @@ At `beta=Inf`, this projection is not differentiable through topology changes, a
ssp1_linear(args...; kws...) = ssp_withsolver(Project.SSP1_linear(; kws...), args...)[1]

"""
ssp1(rho_filtered, beta, eta, grid)
ssp1(rho_filtered, beta, eta, grid, [dilation_distance=0])

Project using the original [SSP1 algorithm] [1] with cubic interpolation.

Expand All @@ -92,7 +93,7 @@ At `beta=Inf`, this projection is not differentiable through topology changes, a
ssp1(args...; kws...) = ssp_withsolver(Project.SSP1(; kws...), args...)[1]

"""
ssp2(rho_filtered, beta, eta, grid)
ssp2(rho_filtered, beta, eta, grid, [dilation_distance=0])

Project using the improved [SSP2 algorithm] [1] with cubic interpolation.

Expand Down
6 changes: 4 additions & 2 deletions src/julia/SSP/test/project.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ for alg in (
SSP.Project.SSP1_linear(),
SSP.Project.SSP1(),
SSP.Project.SSP2(),
)
), dilation in [-0.3, 0.0, 0.3]

# test that adjoints match finite differences
Random.seed!(0)
perturb = randn(size(data))
test = let perturb=perturb, data=copy(data), alg=alg, grid=grid, target_points=target_points
test = let perturb=perturb, data=copy(data), alg=alg, grid=grid, target_points=target_points, dilation=dilation
function (h)
prob = SSP.Project.ProjectionProblem(;
rho_filtered=data + h * perturb,
grid,
target_points,
dilation_distance = dilation,
beta = Inf,
eta = 0.5
)
Expand All @@ -44,6 +45,7 @@ for alg in (
rho_filtered=data,
grid,
target_points,
dilation_distance = dilation,
beta = Inf,
eta = 0.5
)
Expand Down
16 changes: 16 additions & 0 deletions src/julia/SSP/test/pythonic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ for ssp in ssp_algs
dtest_ssp_di_fd = central_fdm(5, 1)(test_ssp_i, 0.0)
ddata_ssp, = Zygote.gradient(test_ssp, myfilt)
@test dtest_ssp_di_fd ≈ sum(ddata_ssp .* perturb)

dilation = 0.2
test_ssp = let radius=radius, grid=grid, ssp=ssp, beta=beta, eta=eta
function (data)
rho_projected = ssp(data, beta, eta, grid, dilation)
return sum(abs2, rho_projected)
end
end
test_ssp_i = let perturb=perturb, data=copy(myfilt), test_ssp=test_ssp
h -> test_ssp(data + h * perturb)
end

dtest_ssp_di_fd = central_fdm(5, 1)(test_ssp_i, 0.0)
ddata_ssp, = Zygote.gradient(test_ssp, myfilt)
@test dtest_ssp_di_fd ≈ sum(ddata_ssp .* perturb)

end

constraint_algs = (
Expand Down
Loading