Skip to content

Commit 69dbb4f

Browse files
committed
Fix problems
* Add OperatorRegistry because immutable in MOI Nonlinear (and Int fields will need to be modified) * Add DEFAULT_MULTIVARIATE_OPERATORS to extend it from MOI Nonlinear * Add OrderedCollections that was used in Model
1 parent b8f31cf commit 69dbb4f

6 files changed

Lines changed: 115 additions & 17 deletions

File tree

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1010
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1111
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
12+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1213

1314
[compat]
1415
DataStructures = "0.18, 0.19"

src/ArrayDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import ForwardDiff
1010
import MathOptInterface as MOI
1111
const Nonlinear = MOI.Nonlinear
1212
import SparseArrays
13+
import OrderedCollections: OrderedDict
1314

1415
"""
1516
Mode() <: AbstractAutomaticDifferentiation

src/MOI_Nonlinear_fork.jl

Lines changed: 108 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,118 @@
11
# Inspired by MathOptInterface/src/Nonlinear/parse_expression.jl
22

3+
const DEFAULT_MULTIVARIATE_OPERATORS = [
4+
:+,
5+
:-,
6+
:*,
7+
:^,
8+
:/,
9+
:ifelse,
10+
:atan,
11+
:min,
12+
:max,
13+
:vect,
14+
:dot,
15+
:hcat,
16+
:vcat,
17+
:norm,
18+
:sum,
19+
:row,
20+
]
21+
22+
struct OperatorRegistry
23+
# NODE_CALL_UNIVARIATE
24+
univariate_operators::Vector{Symbol}
25+
univariate_operator_to_id::Dict{Symbol,Int}
26+
univariate_user_operator_start::Int
27+
registered_univariate_operators::Vector{MOI.Nonlinear._UnivariateOperator}
28+
# NODE_CALL_MULTIVARIATE
29+
multivariate_operators::Vector{Symbol}
30+
multivariate_operator_to_id::Dict{Symbol,Int}
31+
multivariate_user_operator_start::Int
32+
registered_multivariate_operators::Vector{
33+
MOI.Nonlinear._MultivariateOperator,
34+
}
35+
# NODE_LOGIC
36+
logic_operators::Vector{Symbol}
37+
logic_operator_to_id::Dict{Symbol,Int}
38+
# NODE_COMPARISON
39+
comparison_operators::Vector{Symbol}
40+
comparison_operator_to_id::Dict{Symbol,Int}
41+
function OperatorRegistry()
42+
univariate_operators = copy(DEFAULT_UNIVARIATE_OPERATORS)
43+
multivariate_operators = copy(DEFAULT_MULTIVARIATE_OPERATORS)
44+
logic_operators = [:&&, :||]
45+
comparison_operators = [:<=, :(==), :>=, :<, :>]
46+
return new(
47+
# NODE_CALL_UNIVARIATE
48+
univariate_operators,
49+
Dict{Symbol,Int}(
50+
op => i for (i, op) in enumerate(univariate_operators)
51+
),
52+
length(univariate_operators),
53+
_UnivariateOperator[],
54+
# NODE_CALL
55+
multivariate_operators,
56+
Dict{Symbol,Int}(
57+
op => i for (i, op) in enumerate(multivariate_operators)
58+
),
59+
length(multivariate_operators),
60+
_MultivariateOperator[],
61+
# NODE_LOGIC
62+
logic_operators,
63+
Dict{Symbol,Int}(op => i for (i, op) in enumerate(logic_operators)),
64+
# NODE_COMPARISON
65+
comparison_operators,
66+
Dict{Symbol,Int}(
67+
op => i for (i, op) in enumerate(comparison_operators)
68+
),
69+
)
70+
end
71+
end
72+
73+
"""
74+
Model()
75+
76+
The core datastructure for representing a nonlinear optimization problem.
77+
78+
It has the following fields:
79+
* `objective::Union{Nothing,Expression}` : holds the nonlinear objective
80+
function, if one exists, otherwise `nothing`.
81+
* `expressions::Vector{Expression}` : a vector of expressions in the model.
82+
* `constraints::OrderedDict{ConstraintIndex,Constraint}` : a map from
83+
[`ConstraintIndex`](@ref) to the corresponding [`Constraint`](@ref). An
84+
`OrderedDict` is used instead of a `Vector` to support constraint deletion.
85+
* `parameters::Vector{Float64}` : holds the current values of the parameters.
86+
* `operators::OperatorRegistry` : stores the operators used in the model.
87+
"""
88+
mutable struct Model
89+
objective::Union{Nothing,MOI.Nonlinear.Expression}
90+
expressions::Vector{MOI.Nonlinear.Expression}
91+
constraints::OrderedDict{
92+
MOI.Nonlinear.ConstraintIndex,
93+
MOI.Nonlinear.Constraint,
94+
}
95+
parameters::Vector{Float64}
96+
operators::OperatorRegistry
97+
# This is a private field, used only to increment the ConstraintIndex.
98+
last_constraint_index::Int64
99+
function Model()
100+
model = MOI.Nonlinear.Model()
101+
ops = [:vect, :dot, :hcat, :vcat, :norm, :sum, :row]
102+
start = length(model.operators.multivariate_operators)
103+
append!(model.operators.multivariate_operators, ops)
104+
for (i, op) in enumerate(ops)
105+
model.operators.multivariate_operator_to_id[op] = start + i
106+
end
107+
return model
108+
end
109+
end
110+
3111
function set_objective(model::MOI.Nonlinear.Model, obj)
4112
model.objective = parse_expression(model, obj)
5113
return
6114
end
7115

8-
function Model()
9-
model = MOI.Nonlinear.Model()
10-
append!(
11-
model.operators.multivariate_operators,
12-
[:vect, :dot, :hcat, :vcat, :norm, :sum, :row],
13-
)
14-
return model
15-
end
16-
17116
function parse_expression(data::MOI.Nonlinear.Model, input)
18117
expr = MOI.Nonlinear.Expression()
19118
parse_expression(data, expr, input, -1)

src/reverse_mode.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,8 @@ function _reverse_eval(f::_SubexpressionStorage)
430430
node = f.nodes[k]
431431
children_indices = SparseArrays.nzrange(f.adj, k)
432432
if node.type == MOI.Nonlinear.NODE_CALL_MULTIVARIATE
433-
if node.index in
434-
eachindex(MOI.Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS)
435-
op = MOI.Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS[node.index]
433+
if node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS)
434+
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
436435
if op == :vect
437436
@assert _eachindex(f.sizes, k) ==
438437
eachindex(children_indices)

src/sizes.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,11 @@ function _infer_sizes(
163163
children_indices = SparseArrays.nzrange(adj, k)
164164
N = length(children_indices)
165165
if node.type == Nonlinear.NODE_CALL_MULTIVARIATE
166-
if !(
167-
node.index in
168-
eachindex(MOI.Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS)
169-
)
166+
if !(node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS))
170167
# TODO user-defined operators
171168
continue
172169
end
173-
op = MOI.Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS[node.index]
170+
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
174171
if op == :vect
175172
_assert_scalar_children(
176173
sizes,

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
44
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
55
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
66
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"

0 commit comments

Comments
 (0)