Skip to content

Commit 0fe2432

Browse files
committed
Vectorized loading
1 parent 15655b6 commit 0fe2432

7 files changed

Lines changed: 157 additions & 110 deletions

File tree

src/graph_tools.jl

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ field, which should be interpreted as follows:
2323
* `NODE_CALL_UNIVARIATE_BROADCASTED`: the index into `operators.univariate_operators`
2424
* `NODE_CALL_MULTIVARIATE_BROADCASTED`: the index into `operators.multivariate_operators`
2525
* `NODE_CALL_REDUCE`: the index into `operators.multivariate_operators`
26+
* `NODE_MOI_VARIABLE_BLOCK`: a contiguous block of `MOI.VariableIndex`. The
27+
`index` field is the value of the FIRST `MOI.VariableIndex` in the block;
28+
the shape `(m, n)` is stored in `Expression.block_shapes`. The block holds
29+
`m * n` variables in column-major order.
30+
* `NODE_VARIABLE_BLOCK`: same as `NODE_MOI_VARIABLE_BLOCK` but after MOI
31+
variable indices have been remapped to consecutive 1-based internal
32+
indices. `index` is the FIRST internal index of the block.
33+
* `NODE_VALUE_BLOCK`: a contiguous block of constants. `index` is the start
34+
index in the `.values` field; the next `m * n` entries (column-major) are
35+
the block's data. Shape stored in `Expression.block_shapes`.
2636
"""
2737
@enum(
2838
NodeType,
@@ -51,6 +61,12 @@ field, which should be interpreted as follows:
5161
NODE_CALL_MULTIVARIATE_BROADCASTED,
5262
# Index into the multivariate operators, with reduction
5363
NODE_CALL_REDUCE,
64+
# Block-of-variables node, before MOI → internal index remap.
65+
NODE_MOI_VARIABLE_BLOCK,
66+
# Block-of-variables node, after MOI → internal index remap.
67+
NODE_VARIABLE_BLOCK,
68+
# Block-of-constants node.
69+
NODE_VALUE_BLOCK,
5470
)
5571

5672
@enum(Linearity, CONSTANT, LINEAR, PIECEWISE_LINEAR, NONLINEAR)
@@ -96,6 +112,16 @@ function _replace_moi_variables(
96112
moi_index_to_consecutive_index[MOI.VariableIndex(node.index)],
97113
node.parent,
98114
)
115+
elseif node.type == NODE_MOI_VARIABLE_BLOCK
116+
# `node.index` is the FIRST MOI variable index in the block. For an
117+
# `ArrayOfContiguousVariables`, all variables in the block are
118+
# contiguous in MOI's ordering, so we just remap the first index
119+
# and reuse it as the consecutive offset. The block's length comes
120+
# from `Expression.block_shapes[i]`, which the caller threads
121+
# through separately.
122+
first_consec =
123+
moi_index_to_consecutive_index[MOI.VariableIndex(node.index)]
124+
new_nodes[i] = Node(NODE_VARIABLE_BLOCK, first_consec, node.parent)
99125
else
100126
new_nodes[i] = node
101127
end
@@ -122,10 +148,10 @@ function _classify_linearity(
122148
children_arr = SparseArrays.rowvals(adj)
123149
for k in length(nodes):-1:1
124150
node = nodes[k]
125-
if node.type == NODE_VARIABLE
151+
if node.type == NODE_VARIABLE || node.type == NODE_VARIABLE_BLOCK
126152
linearity[k] = LINEAR
127153
continue
128-
elseif node.type == NODE_VALUE
154+
elseif node.type == NODE_VALUE || node.type == NODE_VALUE_BLOCK
129155
linearity[k] = CONSTANT
130156
continue
131157
elseif node.type == NODE_PARAMETER
@@ -218,24 +244,29 @@ function _classify_linearity(
218244
end
219245

220246
"""
221-
_compute_gradient_sparsity!(
222-
indices::Coloring.IndexedSet,
223-
nodes::Vector{Nonlinear.Node},
224-
)
247+
_compute_gradient_sparsity!(indices::Coloring.IndexedSet, f)
248+
249+
Compute the sparsity pattern of the gradient of an expression (that is, a list
250+
of which variable indices are present).
225251
226-
Compute the sparsity pattern of the gradient of an expression (that is, a list of
227-
which variable indices are present).
252+
`f` is duck-typed (as its type is defined later) to a
253+
`_SubexpressionStorage`-like object exposing `f.nodes` and `f.sizes`.
254+
For `NODE_VARIABLE_BLOCK` nodes, the block's length is read from `f.sizes`
255+
(via `_length`), and the block contributes that many consecutive variable
256+
indices starting at `nodes[k].index`.
228257
"""
229-
function _compute_gradient_sparsity!(
230-
indices::Coloring.IndexedSet,
231-
nodes::Vector{Node},
232-
)
233-
for node in nodes
258+
function _compute_gradient_sparsity!(indices::Coloring.IndexedSet, f)
259+
for (k, node) in enumerate(f.nodes)
234260
if node.type == NODE_VARIABLE
235261
push!(indices, node.index)
236-
elseif node.type == NODE_MOI_VARIABLE
262+
elseif node.type == NODE_VARIABLE_BLOCK
263+
len = _length(f.sizes, k)
264+
for i in 0:(len - 1)
265+
push!(indices, node.index + i)
266+
end
267+
elseif node.type == NODE_MOI_VARIABLE || node.type == NODE_MOI_VARIABLE_BLOCK
237268
error(
238-
"Internal error: Invalid to compute sparsity if NODE_MOI_VARIABLE " *
269+
"Internal error: Invalid to compute sparsity if $(node.type) " *
239270
"nodes are present.",
240271
)
241272
end
@@ -341,6 +372,7 @@ function _compute_hessian_sparsity(
341372
child_group_variables = Dict{Int,Set{Int}}()
342373
for (k, node) in enumerate(nodes)
343374
@assert node.type != NODE_MOI_VARIABLE
375+
@assert node.type != NODE_MOI_VARIABLE_BLOCK
344376
if input_linearity[k] == CONSTANT
345377
continue # No hessian contribution from constant nodes
346378
end
@@ -376,6 +408,13 @@ function _compute_hessian_sparsity(
376408
child_group_variables[child_group_idx],
377409
nodes[r].index,
378410
)
411+
elseif nodes[r].type == NODE_VARIABLE_BLOCK
412+
# TODO `NODE_VARIABLE_BLOCK` would need its block shape to
413+
# enumerate the variable indices.
414+
error(
415+
"Internal error: Hessian sparsity for " *
416+
"NODE_VARIABLE_BLOCK is not yet supported.",
417+
)
379418
elseif nodes[r].type == NODE_SUBEXPRESSION
380419
sub_vars = subexpression_variables[nodes[r].index]
381420
if !haskey(child_group_variables, child_group_idx)

src/mathoptinterface_api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function MOI.initialize(
8787
max(max_expr_with_sub_length, length(subex.nodes))
8888
if d.want_hess
8989
empty!(coloring_storage)
90-
_compute_gradient_sparsity!(coloring_storage, subex.nodes)
90+
_compute_gradient_sparsity!(coloring_storage, subex)
9191
# union with all dependent expressions
9292
for idx in _list_subexpressions(subex.nodes)
9393
union!(coloring_storage, subexpression_variables[idx])

src/parse_moi.jl

Lines changed: 22 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -140,98 +140,38 @@ end
140140
# ── ArrayOfContiguousVariables ───────────────────────────────────────────────────
141141

142142
function _parse_moi_stack!(
143-
stack::Vector{Tuple{Int,Any}},
144-
data::Model,
145-
expr::Expression,
146-
x::ArrayOfContiguousVariables{2},
147-
parent_index::Int,
148-
)
149-
m, n = x.size
150-
# Build vcat(row(v11, v12, ...), row(v21, v22, ...), ...).
151-
#
152-
# The outer loop is `1:m` (forward order), NOT `m:-1:1`. The `:row` nodes
153-
# we push end up at consecutive positions in `expr.nodes`, and `:vcat`
154-
# later reads its children in tape-index order (CSC `nzrange`) — so the
155-
# row with the smallest tape index becomes row 1 of the output matrix.
156-
# If the outer loop ran in reverse, `row_m` would land at the smallest
157-
# tape index and `:vcat` would silently place it as row 1, producing a
158-
# row-flipped matrix on the tape (a latent bug, fixed here).
159-
#
160-
# The inner loop stays `n:-1:1` because the items go on the stack and pop
161-
# in LIFO order — pushing in reverse j order gives forward j-order on
162-
# pop, which matches the column-major layout below.
163-
vcat_id = data.operators.multivariate_operator_to_id[:vcat]
164-
row_id = data.operators.multivariate_operator_to_id[:row]
165-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vcat_id, parent_index))
166-
vcat_idx = length(expr.nodes)
167-
for i in 1:m
168-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, row_id, vcat_idx))
169-
row_idx = length(expr.nodes)
170-
for j in n:-1:1
171-
vi = MOI.VariableIndex(x.offset + (j - 1) * m + i)
172-
push!(stack, (row_idx, vi))
173-
end
174-
end
175-
return
176-
end
177-
178-
function _parse_moi_stack!(
179-
stack::Vector{Tuple{Int,Any}},
180-
data::Model,
143+
::Vector{Tuple{Int,Any}},
144+
::Model,
181145
expr::Expression,
182-
x::ArrayOfContiguousVariables{1},
146+
x::ArrayOfContiguousVariables,
183147
parent_index::Int,
184148
)
185-
m = x.size[1]
186-
vect_id = data.operators.multivariate_operator_to_id[:vect]
187-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vect_id, parent_index))
188-
vect_idx = length(expr.nodes)
189-
for i in m:-1:1
190-
vi = MOI.VariableIndex(x.offset + i)
191-
push!(stack, (vect_idx, vi))
192-
end
149+
# Emit a single block node. The block represents the contiguous range of
150+
# MOI variable indices `x.offset+1, ...`, laid out in
151+
# column-major order (matching `Array{Float64}` and `Base.LinearIndices`),
152+
# which is the layout `_view_array` will see at evaluation time.
153+
push!(expr.nodes, Node(NODE_MOI_VARIABLE_BLOCK, x.offset + 1, parent_index))
154+
expr.block_shapes[length(expr.nodes)] = collect(x.size)
193155
return
194156
end
195157

196-
# ── Constant matrices and vectors ────────────────────────────────────────────
158+
# ── Constant arrays ────────────────────────────────────────────
197159

198160
function _parse_moi_stack!(
199-
stack::Vector{Tuple{Int,Any}},
200-
data::Model,
201-
expr::Expression,
202-
x::AbstractMatrix{<:Real},
203-
parent_index::Int,
204-
)
205-
m, n = size(x)
206-
# See the `ArrayOfContiguousVariables{2}` overload for the rationale on
207-
# the `1:m` outer loop (the previous `m:-1:1` produced a row-flipped
208-
# matrix on the tape).
209-
vcat_id = data.operators.multivariate_operator_to_id[:vcat]
210-
row_id = data.operators.multivariate_operator_to_id[:row]
211-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vcat_id, parent_index))
212-
vcat_idx = length(expr.nodes)
213-
for i in 1:m
214-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, row_id, vcat_idx))
215-
row_idx = length(expr.nodes)
216-
for j in n:-1:1
217-
push!(stack, (row_idx, x[i, j]))
218-
end
219-
end
220-
return
221-
end
222-
223-
function _parse_moi_stack!(
224-
stack::Vector{Tuple{Int,Any}},
225-
data::Model,
161+
::Vector{Tuple{Int,Any}},
162+
::Model,
226163
expr::Expression,
227-
x::AbstractVector{<:Real},
164+
x::AbstractArray{<:Real},
228165
parent_index::Int,
229166
)
230-
vect_id = data.operators.multivariate_operator_to_id[:vect]
231-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vect_id, parent_index))
232-
vect_idx = length(expr.nodes)
233-
for i in length(x):-1:1
234-
push!(stack, (vect_idx, x[i]))
235-
end
167+
# Emit a single value block. We push the flat values to
168+
# `expr.values` in column-major order (matching `Array{Float64}`'s memory
169+
# layout); `node.index` records the start of that contiguous range so
170+
# `_SubexpressionStorage` can copy it into the tape in one block at
171+
# construction time.
172+
start_idx = length(expr.values) + 1
173+
append!(expr.values, x)
174+
push!(expr.nodes, Node(NODE_VALUE_BLOCK, start_idx, parent_index))
175+
expr.block_shapes[length(expr.nodes)] = collect(size(x))
236176
return
237177
end

src/reverse_mode.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ function _forward_eval(
125125
# f.forward_storage[k] = x[node.index]
126126
elseif node.type == NODE_VALUE
127127
f.forward_storage[j] = f.const_values[node.index]
128+
elseif node.type == NODE_VARIABLE_BLOCK
129+
# Contiguous-to-contiguous copy from `x` into the tape: on CPU a
130+
# `copyto!`, on GPU a single `cudaMemcpy`. This is the fast path
131+
# for matrix variables from `ArrayOfContiguousVariables{2}`.
132+
tape_range = _storage_range(f.sizes, k)
133+
len = length(tape_range)
134+
copyto!(
135+
view(f.forward_storage, tape_range),
136+
view(x, node.index:(node.index + len - 1)),
137+
)
138+
elseif node.type == NODE_VALUE_BLOCK
139+
# Pre-loaded into `forward_storage` at construction.
128140
elseif node.type == NODE_SUBEXPRESSION
129141
f.forward_storage[j] = d.subexpression_forward_values[node.index]
130142
elseif node.type == NODE_PARAMETER
@@ -945,7 +957,18 @@ function _extract_reverse_pass_inner(
945957
) where {T}
946958
@assert length(f.reverse_storage) >= _length(f.sizes)
947959
for (k, node) in enumerate(f.nodes)
948-
if node.type == NODE_VARIABLE
960+
if node.type == NODE_VARIABLE_BLOCK
961+
# Each block has a contiguous tape range and a contiguous `output`
962+
# range: gather the adjoint, transfer to host in one memcpy, and
963+
# accumulate into the matching slice of `output`.
964+
tape_range = _storage_range(f.sizes, k)
965+
len = length(tape_range)
966+
x_range = node.index:(node.index + len - 1)
967+
cpu_buf =
968+
convert(Vector{T}, view(f.reverse_storage, tape_range))
969+
view(output, x_range) .+= scale .* cpu_buf
970+
elseif node.type == NODE_VARIABLE
971+
# Per-leaf scalar — rare, so the per-leaf `cudaMemcpy` is fine.
949972
output[node.index] += scale * @s f.reverse_storage[k]
950973
elseif node.type == NODE_SUBEXPRESSION
951974
subexpressions[node.index] += scale * @s f.reverse_storage[k]

src/sizes.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ macro j(expr)
170170
end
171171

172172
# /!\ Can only be called in decreasing `k` order
173-
function _add_size!(sizes::Sizes, k::Int, size::Tuple)
173+
function _add_size!(sizes::Sizes, k::Int, size)
174174
sizes.ndims[k] = length(size)
175175
sizes.size_offset[k] = length(sizes.size)
176176
append!(sizes.size, size)
@@ -198,6 +198,7 @@ end
198198
function _infer_sizes(
199199
nodes::Vector{Node},
200200
adj::SparseArrays.SparseMatrixCSC{Bool,Int},
201+
block_shapes::Dict{Int,Vector{Int}} = Dict{Int,Vector{Int}}(),
201202
)
202203
sizes = Sizes(
203204
zeros(Int, length(nodes)),
@@ -208,6 +209,15 @@ function _infer_sizes(
208209
children_arr = SparseArrays.rowvals(adj)
209210
for k in length(nodes):-1:1
210211
node = nodes[k]
212+
# Block leaves carry their shape in `block_shapes`; they're 2D (m, n)
213+
# by construction.
214+
if node.type == NODE_VARIABLE_BLOCK ||
215+
node.type == NODE_VALUE_BLOCK ||
216+
node.type == NODE_MOI_VARIABLE_BLOCK
217+
shape = block_shapes[k]
218+
_add_size!(sizes, k, shape)
219+
continue
220+
end
211221
children_indices = SparseArrays.nzrange(adj, k)
212222
N = length(children_indices)
213223
if node.type == NODE_CALL_MULTIVARIATE
@@ -429,18 +439,34 @@ struct _SubexpressionStorage{S<:AbstractVector{Float64}}
429439
nodes::Vector{Node},
430440
adj::SparseArrays.SparseMatrixCSC{Bool,Int},
431441
const_values::Vector{Float64},
442+
block_shapes::Dict{Int,Vector{Int}},
432443
partials_storage_ϵ::Vector{Float64},
433444
linearity::Linearity,
434445
::Type{S} = Vector{Float64},
435446
) where {S<:AbstractVector{Float64}}
436-
sizes = _infer_sizes(nodes, adj)
447+
sizes = _infer_sizes(nodes, adj, block_shapes)
437448
N = _length(sizes)
449+
# Pre-load value blocks into forward_storage once at construction;
450+
# each block is a contiguous-to-contiguous bulk copy. Individual
451+
# `NODE_VALUE` scalars (rare — exponents, constant divisors, etc) and
452+
# variable nodes are loaded by `_forward_eval` in the per-node loop.
453+
cpu_buffer = zeros(N)
454+
for k in 1:length(nodes)
455+
node = nodes[k]
456+
if node.type == NODE_VALUE_BLOCK
457+
j = sizes.storage_offset[k] + 1
458+
len = _length(sizes, k)
459+
cpu_buffer[j:(j + len - 1)] .=
460+
view(const_values, node.index:(node.index + len - 1))
461+
end
462+
end
463+
forward_storage = convert(S, cpu_buffer)
438464
return new{S}(
439465
nodes,
440466
adj,
441467
sizes,
442468
const_values,
443-
fill!(S(undef, N), 0.0), # forward_storage,
469+
forward_storage,
444470
fill!(S(undef, N), 0.0), # partials_storage,
445471
fill!(S(undef, N), 0.0), # reverse_storage,
446472
partials_storage_ϵ,

0 commit comments

Comments
 (0)