Skip to content

Commit ff5105c

Browse files
committed
simplify
1 parent dd18103 commit ff5105c

2 files changed

Lines changed: 14 additions & 54 deletions

File tree

src/graph_tools.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,8 @@ function _compute_hessian_sparsity(
409409
nodes[r].index,
410410
)
411411
elseif nodes[r].type == NODE_VARIABLE_BLOCK
412-
# `NODE_VARIABLE_BLOCK` would need its block shape to
413-
# enumerate the variable indices. Hessian sparsity for
414-
# block-variable expressions is a follow-up; for now we
415-
# error so the gradient-only path stays correct.
412+
# TODO `NODE_VARIABLE_BLOCK` would need its block shape to
413+
# enumerate the variable indices.
416414
error(
417415
"Internal error: Hessian sparsity for " *
418416
"NODE_VARIABLE_BLOCK is not yet supported.",

src/parse_moi.jl

Lines changed: 12 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -143,76 +143,38 @@ function _parse_moi_stack!(
143143
::Vector{Tuple{Int,Any}},
144144
::Model,
145145
expr::Expression,
146-
x::ArrayOfContiguousVariables{2},
146+
x::ArrayOfContiguousVariables,
147147
parent_index::Int,
148148
)
149-
m, n = x.size
150149
# Emit a single block node. The block represents the contiguous range of
151-
# `m * n` MOI variable indices `x.offset+1 .. x.offset+m*n`, laid out in
152-
# column-major order (matching `Matrix{Float64}` and `Base.LinearIndices`),
150+
# MOI variable indices `x.offset+1, ...`, laid out in
151+
# column-major order (matching `Array{Float64}` and `Base.LinearIndices`),
153152
# which is the layout `_view_array` will see at evaluation time.
154153
push!(
155154
expr.nodes,
156155
Node(NODE_MOI_VARIABLE_BLOCK, x.offset + 1, parent_index),
157156
)
158-
expr.block_shapes[length(expr.nodes)] = [m, n]
159-
return
160-
end
161-
162-
function _parse_moi_stack!(
163-
stack::Vector{Tuple{Int,Any}},
164-
data::Model,
165-
expr::Expression,
166-
x::ArrayOfContiguousVariables{1},
167-
parent_index::Int,
168-
)
169-
m = x.size[1]
170-
vect_id = data.operators.multivariate_operator_to_id[:vect]
171-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vect_id, parent_index))
172-
vect_idx = length(expr.nodes)
173-
for i in m:-1:1
174-
vi = MOI.VariableIndex(x.offset + i)
175-
push!(stack, (vect_idx, vi))
176-
end
157+
expr.block_shapes[length(expr.nodes)] = collect(x.size)
177158
return
178159
end
179160

180-
# ── Constant matrices and vectors ────────────────────────────────────────────
161+
# ── Constant arrays ────────────────────────────────────────────
181162

182163
function _parse_moi_stack!(
183164
::Vector{Tuple{Int,Any}},
184165
::Model,
185-
expr::Expression{T},
186-
x::AbstractMatrix{<:Real},
166+
expr::Expression,
167+
x::AbstractArray{<:Real},
187168
parent_index::Int,
188-
) where {T}
189-
m, n = size(x)
190-
# Emit a single value block. We push the `m * n` flat values to
191-
# `expr.values` in column-major order (matching `Matrix{Float64}`'s memory
169+
)
170+
# Emit a single value block. We push the flat values to
171+
# `expr.values` in column-major order (matching `Array{Float64}`'s memory
192172
# layout); `node.index` records the start of that contiguous range so
193173
# `_SubexpressionStorage` can copy it into the tape in one block at
194174
# construction time.
195175
start_idx = length(expr.values) + 1
196-
for j in 1:n, i in 1:m
197-
push!(expr.values, convert(T, x[i, j]))
198-
end
176+
append!(expr.values, x)
199177
push!(expr.nodes, Node(NODE_VALUE_BLOCK, start_idx, parent_index))
200-
expr.block_shapes[length(expr.nodes)] = [m, n]
201-
return
202-
end
203-
204-
function _parse_moi_stack!(
205-
stack::Vector{Tuple{Int,Any}},
206-
data::Model,
207-
expr::Expression,
208-
x::AbstractVector{<:Real},
209-
parent_index::Int,
210-
)
211-
vect_id = data.operators.multivariate_operator_to_id[:vect]
212-
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, vect_id, parent_index))
213-
vect_idx = length(expr.nodes)
214-
for i in length(x):-1:1
215-
push!(stack, (vect_idx, x[i]))
216-
end
178+
expr.block_shapes[length(expr.nodes)] = collect(size(x))
217179
return
218180
end

0 commit comments

Comments
 (0)