Skip to content

Commit 76dde61

Browse files
committed
Make it work on Float32
1 parent 2c4c148 commit 76dde61

8 files changed

Lines changed: 75 additions & 69 deletions

File tree

src/ArrayDiff.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ Fork of `MOI.Nonlinear.SparseReverseMode` to add array support.
2020
2121
The type parameter `S` is the storage type used for the AD tape (forward,
2222
partials, and reverse storage of each subexpression). It must satisfy
23-
`S<:AbstractVector{Float64}`. Defaults to `Vector{Float64}`. Pass a different
24-
`S` (for example `CuVector{Float64}`) to keep the tape on a GPU.
23+
`S<:AbstractVector{<:Real}`. Defaults to `Vector{Float64}`. Pass a different
24+
`S` (for example `Vector{Float32}` or `CuVector{Float64}`) to run AD in
25+
another precision or keep the tape on a GPU.
2526
"""
26-
struct Mode{S<:AbstractVector{Float64}} <:
27+
struct Mode{S<:AbstractVector{<:Real}} <:
2728
MOI.Nonlinear.AbstractAutomaticDifferentiation end
2829

2930
Mode() = Mode{Vector{Float64}}()
@@ -65,7 +66,7 @@ include("evaluator.jl")
6566
include("array_nonlinear_function.jl")
6667
include("parse_moi.jl")
6768

68-
model(::Mode{S}) where {S} = Model()
69+
model(::Mode{S}) where {S} = Model{eltype(S)}()
6970

7071
# Extend MOI.Nonlinear.set_objective so that solvers calling
7172
# MOI.Nonlinear.set_objective(arraydiff_model, snf) dispatch here.
@@ -84,8 +85,8 @@ function Evaluator(
8485
model::ArrayDiff.Model,
8586
::Mode{S},
8687
ordered_variables::Vector{MOI.VariableIndex},
87-
) where {S<:AbstractVector{Float64}}
88-
return Evaluator(model, NLPEvaluator{S}(model, ordered_variables))
88+
) where {S<:AbstractVector{<:Real}}
89+
return Evaluator(model, NLPEvaluator{eltype(S),S}(model, ordered_variables))
8990
end
9091

9192
# Called by solvers via MOI.Nonlinear.Evaluator(nlp_model, ad_backend, vars).

src/JuMP/moi_bridge.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ function _to_moi_arg(x::GenericArrayExpr{V,N}) where {V,N}
1212
return ArrayNonlinearFunction{N}(x.head, args, x.size, x.broadcasted)
1313
end
1414

15-
_to_moi_arg(x::Matrix{Float64}) = x
15+
_to_moi_arg(x::Matrix{<:Real}) = x
1616

17-
_to_moi_arg(x::Real) = Float64(x)
17+
_to_moi_arg(x::Real) = x
1818

1919
function JuMP.moi_function(x::GenericArrayExpr{V,N}) where {V,N}
2020
return _to_moi_arg(x)

src/array_nonlinear_function.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function _map_indices_arg(index_map::F, x::ArrayOfContiguousVariables) where {F}
8787
return MOI.Utilities.map_indices(index_map, x)
8888
end
8989

90-
function _map_indices_arg(::F, x::Matrix{Float64}) where {F}
90+
function _map_indices_arg(::F, x::Matrix{<:Real}) where {F}
9191
return x
9292
end
9393

src/mathoptinterface_api.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ function MOI.features_available(d::NLPEvaluator)
2020
end
2121

2222
function MOI.initialize(
23-
d::NLPEvaluator{S},
23+
d::NLPEvaluator{T,S},
2424
requested_features::Vector{Symbol},
25-
) where {S<:AbstractVector{Float64}}
25+
) where {T<:Real,S<:AbstractVector{T}}
2626
# Check that we support the features requested by the user.
2727
available_features = MOI.features_available(d)
2828
for feature in requested_features
@@ -40,10 +40,10 @@ function MOI.initialize(
4040
end
4141
d.objective = nothing
4242
d.residual = nothing
43-
d.user_output_buffer = zeros(largest_user_input_dimension)
44-
d.jac_storage = zeros(max(N, largest_user_input_dimension))
45-
d.constraints = _FunctionStorage{S}[]
46-
d.last_x = fill(NaN, N)
43+
d.user_output_buffer = zeros(T, largest_user_input_dimension)
44+
d.jac_storage = zeros(T, max(N, largest_user_input_dimension))
45+
d.constraints = _FunctionStorage{T,S}[]
46+
d.last_x = fill(T(NaN), N)
4747
d.want_hess = :Hess in requested_features
4848
want_hess_storage = (:HessVec in requested_features) || d.want_hess
4949
coloring_storage = Coloring.IndexedSet(N)
@@ -67,9 +67,9 @@ function MOI.initialize(
6767
subexpression_edgelist =
6868
Vector{Set{Tuple{Int,Int}}}(undef, num_subexpressions)
6969
d.subexpressions =
70-
Vector{_SubexpressionStorage{S}}(undef, num_subexpressions)
71-
d.subexpression_forward_values = zeros(num_subexpressions)
72-
d.subexpression_reverse_values = zeros(num_subexpressions)
70+
Vector{_SubexpressionStorage{T,S}}(undef, num_subexpressions)
71+
d.subexpression_forward_values = zeros(T, num_subexpressions)
72+
d.subexpression_reverse_values = zeros(T, num_subexpressions)
7373
for k in d.subexpression_order
7474
# Only load expressions which actually are used
7575
d.subexpression_forward_values[k] = NaN
@@ -145,6 +145,7 @@ function MOI.initialize(
145145
moi_index_to_consecutive_index,
146146
shared_partials_storage_ϵ,
147147
d,
148+
S,
148149
)
149150
residual = _FunctionStorage(
150151
subexpr,
@@ -235,7 +236,7 @@ function MOI.eval_objective_gradient(d::NLPEvaluator, g, x)
235236
error("No nonlinear objective.")
236237
end
237238
_reverse_mode(d, x)
238-
fill!(g, 0.0)
239+
fill!(g, zero(eltype(g)))
239240
_extract_reverse_pass(g, d, something(d.objective))
240241
return
241242
end

src/reverse_mode.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,9 +945,9 @@ function _extract_reverse_pass(
945945
f::_FunctionStorage,
946946
) where {T}
947947
for i in f.dependent_subexpressions
948-
d.subexpression_reverse_values[i] = 0.0
948+
d.subexpression_reverse_values[i] = zero(T)
949949
end
950-
_extract_reverse_pass_inner(g, f, d.subexpression_reverse_values, 1.0)
950+
_extract_reverse_pass_inner(g, f, d.subexpression_reverse_values, one(T))
951951
for i in length(f.dependent_subexpressions):-1:1
952952
k = f.dependent_subexpressions[i]
953953
_extract_reverse_pass_inner(

src/sizes.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,11 @@ function _infer_sizes(
438438
return sizes
439439
end
440440

441-
struct _SubexpressionStorage{S<:AbstractVector{Float64}}
441+
struct _SubexpressionStorage{T<:Real,S<:AbstractVector{T}}
442442
nodes::Vector{Node}
443443
adj::SparseArrays.SparseMatrixCSC{Bool,Int}
444444
sizes::Sizes
445-
const_values::Vector{Float64}
445+
const_values::Vector{T}
446446
forward_storage::S
447447
partials_storage::S
448448
reverse_storage::S
@@ -452,19 +452,19 @@ struct _SubexpressionStorage{S<:AbstractVector{Float64}}
452452
function _SubexpressionStorage(
453453
nodes::Vector{Node},
454454
adj::SparseArrays.SparseMatrixCSC{Bool,Int},
455-
const_values::Vector{Float64},
455+
const_values::Vector{T},
456456
block_shapes::Dict{Int,Vector{Int}},
457457
partials_storage_ϵ::Vector{Float64},
458458
linearity::Linearity,
459-
::Type{S} = Vector{Float64},
460-
) where {S<:AbstractVector{Float64}}
459+
::Type{S} = Vector{T},
460+
) where {T<:Real,S<:AbstractVector{T}}
461461
sizes = _infer_sizes(nodes, adj, block_shapes)
462462
N = _length(sizes)
463463
# Pre-load value blocks into forward_storage once at construction;
464464
# each block is a contiguous-to-contiguous bulk copy. Individual
465465
# `NODE_VALUE` scalars (rare — exponents, constant divisors, etc) and
466466
# variable nodes are loaded by `_forward_eval` in the per-node loop.
467-
cpu_buffer = zeros(N)
467+
cpu_buffer = zeros(T, N)
468468
for k in 1:length(nodes)
469469
node = nodes[k]
470470
if node.type == NODE_VALUE_BLOCK
@@ -475,14 +475,14 @@ struct _SubexpressionStorage{S<:AbstractVector{Float64}}
475475
end
476476
end
477477
forward_storage = convert(S, cpu_buffer)
478-
return new{S}(
478+
return new{T,S}(
479479
nodes,
480480
adj,
481481
sizes,
482482
const_values,
483483
forward_storage,
484-
fill!(S(undef, N), 0.0), # partials_storage,
485-
fill!(S(undef, N), 0.0), # reverse_storage,
484+
fill!(S(undef, N), zero(T)), # partials_storage,
485+
fill!(S(undef, N), zero(T)), # reverse_storage,
486486
partials_storage_ϵ,
487487
linearity,
488488
)

src/types.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function _subexpression_and_linearity(
103103
partials_storage_ϵ::Vector{Float64},
104104
d,
105105
::Type{S} = Vector{Float64},
106-
) where {S<:AbstractVector{Float64}}
106+
) where {S<:AbstractVector{<:Real}}
107107
nodes = _replace_moi_variables(expr.nodes, moi_index_to_consecutive_index)
108108
adj = adjacency_matrix(nodes)
109109
linearity = if d.want_hess
@@ -114,7 +114,7 @@ function _subexpression_and_linearity(
114114
return _SubexpressionStorage(
115115
nodes,
116116
adj,
117-
convert(Vector{Float64}, expr.values),
117+
convert(Vector{eltype(S)}, expr.values),
118118
copy(expr.block_shapes),
119119
partials_storage_ϵ,
120120
linearity[1],
@@ -123,28 +123,28 @@ function _subexpression_and_linearity(
123123
linearity
124124
end
125125

126-
struct _FunctionStorage{S<:AbstractVector{Float64}}
127-
expr::_SubexpressionStorage{S}
126+
struct _FunctionStorage{T<:Real,S<:AbstractVector{T}}
127+
expr::_SubexpressionStorage{T,S}
128128
grad_sparsity::Vector{Int}
129129
# Nonzero pattern of Hessian matrix
130130
hess_I::Vector{Int}
131131
hess_J::Vector{Int}
132132
rinfo::Coloring.RecoveryInfo # coloring info for hessians
133-
seed_matrix::Matrix{Float64}
133+
seed_matrix::Matrix{T}
134134
# subexpressions which this function depends on, ordered for forward pass.
135135
dependent_subexpressions::Vector{Int}
136136

137137
function _FunctionStorage(
138-
expr::_SubexpressionStorage{S},
138+
expr::_SubexpressionStorage{T,S},
139139
num_variables,
140140
coloring_storage::Coloring.IndexedSet,
141141
want_hess::Bool,
142-
subexpressions::Vector{_SubexpressionStorage{S}},
142+
subexpressions::Vector{_SubexpressionStorage{T,S}},
143143
dependent_subexpressions,
144144
subexpression_edgelist,
145145
subexpression_variables,
146146
linearity::Vector{Linearity},
147-
) where {S<:AbstractVector{Float64}}
147+
) where {T<:Real,S<:AbstractVector{T}}
148148
empty!(coloring_storage)
149149
_compute_gradient_sparsity!(coloring_storage, expr)
150150
for k in dependent_subexpressions
@@ -166,7 +166,7 @@ struct _FunctionStorage{S<:AbstractVector{Float64}}
166166
coloring_storage,
167167
)
168168
seed_matrix = Coloring.seed_matrix(rinfo)
169-
return new{S}(
169+
return new{T,S}(
170170
expr,
171171
grad_sparsity,
172172
hess_I,
@@ -176,13 +176,13 @@ struct _FunctionStorage{S<:AbstractVector{Float64}}
176176
dependent_subexpressions,
177177
)
178178
else
179-
return new{S}(
179+
return new{T,S}(
180180
expr,
181181
grad_sparsity,
182182
Int[],
183183
Int[],
184184
Coloring.RecoveryInfo(),
185-
Array{Float64}(undef, 0, 0),
185+
Array{T}(undef, 0, 0),
186186
dependent_subexpressions,
187187
)
188188
end
@@ -305,30 +305,30 @@ interface.
305305
!!! warning
306306
Before using, you must initialize the evaluator using `MOI.initialize`.
307307
"""
308-
mutable struct NLPEvaluator{S<:AbstractVector{Float64}} <:
308+
mutable struct NLPEvaluator{T<:Real,S<:AbstractVector{T}} <:
309309
MOI.AbstractNLPEvaluator
310310
data::Model
311311
ordered_variables::Vector{MOI.VariableIndex}
312312

313-
objective::Union{Nothing,_FunctionStorage{S}}
314-
residual::Union{Nothing,_FunctionStorage{S}}
315-
constraints::Vector{_FunctionStorage{S}}
316-
subexpressions::Vector{_SubexpressionStorage{S}}
313+
objective::Union{Nothing,_FunctionStorage{T,S}}
314+
residual::Union{Nothing,_FunctionStorage{T,S}}
315+
constraints::Vector{_FunctionStorage{T,S}}
316+
subexpressions::Vector{_SubexpressionStorage{T,S}}
317317
subexpression_order::Vector{Int}
318318
# Storage for the subexpressions in reverse-mode automatic differentiation.
319-
subexpression_forward_values::Vector{Float64}
320-
subexpression_reverse_values::Vector{Float64}
319+
subexpression_forward_values::Vector{T}
320+
subexpression_reverse_values::Vector{T}
321321
subexpression_linearity::Vector{Linearity}
322322

323323
# A cache of the last x. This is used to guide whether we need to re-run
324324
# reverse-mode automatic differentiation.
325-
last_x::Vector{Float64}
325+
last_x::Vector{T}
326326

327327
# Temporary storage for computing Jacobians. This is also used as temporary
328328
# storage for the input of multivariate functions.
329-
jac_storage::Vector{Float64}
329+
jac_storage::Vector{T}
330330
# Temporary storage for the gradient of multivariate functions
331-
user_output_buffer::Vector{Float64}
331+
user_output_buffer::Vector{T}
332332

333333
# storage for computing hessians
334334
# these Float64 vectors are reinterpreted to hold multiple epsilon components
@@ -343,10 +343,10 @@ mutable struct NLPEvaluator{S<:AbstractVector{Float64}} <:
343343
hessian_sparsity::Vector{Tuple{Int64,Int64}}
344344
max_chunk::Int # chunk size for which we've allocated storage
345345

346-
function NLPEvaluator{S}(
346+
function NLPEvaluator{T,S}(
347347
data::Model,
348348
ordered_variables::Vector{MOI.VariableIndex},
349-
) where {S<:AbstractVector{Float64}}
350-
return new{S}(data, ordered_variables)
349+
) where {T<:Real,S<:AbstractVector{T}}
350+
return new{T,S}(data, ordered_variables)
351351
end
352352
end

0 commit comments

Comments
 (0)