Skip to content

Commit 15655b6

Browse files
authored
Add allocation tests (#48)
* Add allocation tests * Fix allocation * Fix format * Fix
1 parent b554b3b commit 15655b6

3 files changed

Lines changed: 108 additions & 42 deletions

File tree

src/reverse_mode.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ function _forward_eval(
165165
idx2 = last(children_indices)
166166
@inbounds ix1 = children_arr[idx1]
167167
@inbounds ix2 = children_arr[idx2]
168-
v1 = _view_array(f.forward_storage, f.sizes, ix1)
169-
v2 = _view_array(f.forward_storage, f.sizes, ix2)
170-
out = _view_array(f.forward_storage, f.sizes, k)
168+
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
169+
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
170+
out = _view_matrix(f.forward_storage, f.sizes, k)
171171
LinearAlgebra.mul!(out, v1, v2)
172172
# We deliberately don't write v1/v2 into partials_storage
173173
# here: the matmul reverse branch reads forward_storage
@@ -343,8 +343,8 @@ function _forward_eval(
343343
elseif node.index == 15 # sum
344344
@assert N == 1
345345
ix = children_arr[first(children_indices)]
346-
inp = _view_array(f.forward_storage, f.sizes, ix)
347-
fill!(_view_array(f.partials_storage, f.sizes, ix), one(T))
346+
inp = _view_linear(f.forward_storage, f.sizes, ix)
347+
fill!(_view_linear(f.partials_storage, f.sizes, ix), one(T))
348348
@s f.forward_storage[k] = sum(inp)
349349
elseif node.index == 16 # row
350350
for j in _eachindex(f.sizes, k)
@@ -393,12 +393,12 @@ function _forward_eval(
393393
child1 = first(children_indices)
394394
@inbounds ix1 = children_arr[child1]
395395
@inbounds ix2 = children_arr[child1+1]
396-
out = _view_array(f.forward_storage, f.sizes, k)
397-
v1 = _view_array(f.forward_storage, f.sizes, ix1)
398-
v2 = _view_array(f.forward_storage, f.sizes, ix2)
396+
out = _view_linear(f.forward_storage, f.sizes, k)
397+
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
398+
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
399399
out .= v1 .- v2
400-
fill!(_view_array(f.partials_storage, f.sizes, ix1), one(T))
401-
fill!(_view_array(f.partials_storage, f.sizes, ix2), -one(T))
400+
fill!(_view_linear(f.partials_storage, f.sizes, ix1), one(T))
401+
fill!(_view_linear(f.partials_storage, f.sizes, ix2), -one(T))
402402
elseif node.index == 3 # :* (broadcasted)
403403
# Node `k` is not scalar, so we do matrix multiplication
404404
if f.sizes.ndims[k] != 0
@@ -466,9 +466,9 @@ function _forward_eval(
466466
f.forward_storage,
467467
f.sizes.storage_offset[ix2]+1,
468468
)
469-
out = _view_array(f.forward_storage, f.sizes, k)
470-
inp = _view_array(f.forward_storage, f.sizes, ix1)
471-
partials = _view_array(f.partials_storage, f.sizes, ix1)
469+
out = _view_linear(f.forward_storage, f.sizes, k)
470+
inp = _view_linear(f.forward_storage, f.sizes, ix1)
471+
partials = _view_linear(f.partials_storage, f.sizes, ix1)
472472
if exponent == 2
473473
out .= inp .* inp
474474
partials .= 2 .* inp
@@ -518,9 +518,9 @@ function _forward_eval(
518518
@j f.forward_storage[k] = -val
519519
end
520520
elseif operators.univariate_operators[node.index] === :tanh
521-
out = _view_array(f.forward_storage, f.sizes, k)
522-
inp = _view_array(f.forward_storage, f.sizes, child_idx)
523-
partials = _view_array(f.partials_storage, f.sizes, child_idx)
521+
out = _view_linear(f.forward_storage, f.sizes, k)
522+
inp = _view_linear(f.forward_storage, f.sizes, child_idx)
523+
partials = _view_linear(f.partials_storage, f.sizes, child_idx)
524524
out .= tanh.(inp)
525525
partials .= one(T) .- out .* out
526526
else
@@ -618,11 +618,11 @@ function _reverse_eval(
618618
idx2 = last(children_indices)
619619
ix1 = children_arr[idx1]
620620
ix2 = children_arr[idx2]
621-
v1 = _view_array(f.forward_storage, f.sizes, ix1)
622-
v2 = _view_array(f.forward_storage, f.sizes, ix2)
623-
rev_parent = _view_array(f.reverse_storage, f.sizes, k)
624-
rev_v1 = _view_array(f.reverse_storage, f.sizes, ix1)
625-
rev_v2 = _view_array(f.reverse_storage, f.sizes, ix2)
621+
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
622+
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
623+
rev_parent = _view_matrix(f.reverse_storage, f.sizes, k)
624+
rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1)
625+
rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2)
626626
LinearAlgebra.mul!(rev_v1, rev_parent, v2')
627627
LinearAlgebra.mul!(rev_v2, v1', rev_parent)
628628
continue
@@ -881,12 +881,12 @@ function _reverse_eval(
881881
# diagonal entries are stored in `f.partials_storage`. We broadcast
882882
# `rev_child .= rev_parent .* partial` over the whole array (with the
883883
# 0 * Inf guard preserved).
884-
rev_parent = _view_array(f.reverse_storage, f.sizes, k)
884+
rev_parent = _view_linear(f.reverse_storage, f.sizes, k)
885885
for child_idx in children_indices
886886
ix = children_arr[child_idx]
887887
@assert _size(f.sizes, k) == _size(f.sizes, ix)
888-
rev_child = _view_array(f.reverse_storage, f.sizes, ix)
889-
partial = _view_array(f.partials_storage, f.sizes, ix)
888+
rev_child = _view_linear(f.reverse_storage, f.sizes, ix)
889+
partial = _view_linear(f.partials_storage, f.sizes, ix)
890890
rev_child .= ifelse.(
891891
(rev_parent .== 0) .& .!isfinite.(partial),
892892
rev_parent,

src/sizes.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,27 +68,41 @@ implementation just calls `getindex`; this is a hook for storage backends
6868
_scalar_load(storage::AbstractVector, idx::Int) = @inbounds storage[idx]
6969

7070
"""
71-
_view_array(storage, sizes, k) -> AbstractArray
71+
_view_linear(storage, sizes, k) -> SubArray
7272
73-
Return a view of the slice of `storage` that holds node `k`'s array value,
74-
reshaped to that node's natural shape. The view aliases the underlying
75-
`storage` (no copy), so mutating the returned array writes back into the tape.
76-
For a scalar (`ndims[k] == 0`) node this returns a length-1 vector view.
73+
Return a flat 1-D view of the slice of `storage` that holds node `k`'s array
74+
value. The view aliases the underlying `storage` (no copy), so mutating it
75+
writes back into the tape. For a scalar (`ndims[k] == 0`) node this returns
76+
a length-1 vector view.
77+
78+
Use this for elementwise (broadcasted) operations and reductions that don't
79+
need the array's natural shape — keeping the return type-stable
80+
(`SubArray{T,1,...}`) avoids the heap-boxing that a multi-shape return type
81+
would force.
7782
"""
78-
function _view_array(storage::AbstractVector, sizes::Sizes, k::Int)
79-
nd = sizes.ndims[k]
83+
function _view_linear(storage::AbstractVector, sizes::Sizes, k::Int)
8084
offset = sizes.storage_offset[k]
81-
if nd == 0
82-
return view(storage, (offset+1):(offset+1))
83-
elseif nd == 1
84-
n = sizes.size[sizes.size_offset[k]+1]
85-
return view(storage, (offset+1):(offset+n))
86-
else
87-
N = _length(sizes, k)
88-
v = view(storage, (offset+1):(offset+N))
89-
szs = ntuple(d -> sizes.size[sizes.size_offset[k]+d], nd)
90-
return reshape(v, szs)
91-
end
85+
N = _length(sizes, k)
86+
return view(storage, (offset+1):(offset+N))
87+
end
88+
89+
"""
90+
_view_matrix(storage, sizes, k) -> ReshapedArray
91+
92+
Return a 2-D view of the slice of `storage` that holds node `k`'s array
93+
value. A 1-D node is treated as a column vector `(n, 1)` and a 0-D node as
94+
`(1, 1)`. Always returns a 2-D `Base.ReshapedArray`, which is what callers
95+
like `LinearAlgebra.mul!` need; keeping the return type-stable avoids
96+
heap-boxing.
97+
"""
98+
function _view_matrix(storage::AbstractVector, sizes::Sizes, k::Int)
99+
@assert sizes.ndims[k] == 2
100+
offset = sizes.storage_offset[k]
101+
size_off = sizes.size_offset[k]
102+
m = sizes.size[size_off+1]
103+
n = sizes.size[size_off+2]
104+
v = view(storage, (offset+1):(offset+m*n))
105+
return reshape(v, (m, n))
92106
end
93107

94108
"""

test/JuMP.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,58 @@ function test_neural()
307307
end
308308
end
309309

310+
# Builds the same `sum((W2*tanh.(W1*X) - target)^2)` MLP that `test_neural`
311+
# exercises and checks that, after warmup, both `eval_objective` and
312+
# `eval_objective_gradient` are allocation-free on the CPU `Vector{Float64}`
313+
# tape — including when the input `x` has changed since the last call (which
314+
# is the path that actually re-runs forward+reverse, not the
315+
# `last_x == x` short-circuit).
316+
function test_neural_allocations()
317+
n = 2
318+
X = [1.0 0.5; 0.3 0.8]
319+
target = [0.5 0.2; 0.1 0.7]
320+
model = Model()
321+
@variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
322+
@variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
323+
Y = W2 * tanh.(W1 * X)
324+
loss = sum((Y .- target) .^ 2)
325+
mode = ArrayDiff.Mode()
326+
ad = ArrayDiff.model(mode)
327+
MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss))
328+
evaluator = MOI.Nonlinear.Evaluator(
329+
ad,
330+
mode,
331+
JuMP.index.(JuMP.all_variables(model)),
332+
)
333+
MOI.initialize(evaluator, [:Grad])
334+
x1 = Float64.(collect(1:8))
335+
x2 = Float64.(collect(2:9))
336+
g = zeros(8)
337+
# Wrapped in typed functions so `@allocated` doesn't capture the
338+
# return-value boxing that happens when calling `eval_objective`
339+
# directly from the macro's untyped scope (each `MOI.eval_objective`
340+
# returns a `Float64` which then escapes into `Any`-typed scope).
341+
_obj(ev, x) = MOI.eval_objective(ev, x)
342+
function _grad!(ev, g, x)
343+
MOI.eval_objective_gradient(ev, g, x)
344+
return nothing
345+
end
346+
# Warmup: trigger JIT compilation for both `eval_objective` and
347+
# `eval_objective_gradient`. Two distinct inputs so `_reverse_mode`'s
348+
# `last_x == x` short-circuit doesn't elide the work on the second call.
349+
_obj(evaluator, x1)
350+
_obj(evaluator, x2)
351+
_grad!(evaluator, g, x1)
352+
_grad!(evaluator, g, x2)
353+
# Now alternate: each measured call sees `last_x ≠ x`, so it actually
354+
# runs the full forward + reverse passes through the block tape.
355+
@test 0 == @allocated _obj(evaluator, x1)
356+
@test 0 == @allocated _obj(evaluator, x2)
357+
@test 0 == @allocated _grad!(evaluator, g, x1)
358+
@test 0 == @allocated _grad!(evaluator, g, x2)
359+
return
360+
end
361+
310362
function test_moi_function()
311363
model = Model()
312364
@variable(model, W[1:2, 1:2], container = ArrayDiff.ArrayOfVariables)

0 commit comments

Comments
 (0)