Skip to content

Commit 298ecae

Browse files
committed
Add tests
1 parent 02898c0 commit 298ecae

10 files changed

Lines changed: 129 additions & 5 deletions

File tree

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@g
77
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
88
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1012
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1113
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1214
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
@@ -17,6 +19,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1719
Calculus = "0.5.2"
1820
DataStructures = "0.18, 0.19"
1921
ForwardDiff = "1"
22+
JuMP = "1.29.4"
23+
LinearAlgebra = "1.12.0"
2024
MathOptInterface = "1.40"
2125
NaNMath = "1"
2226
SparseArrays = "1.10"

perf/neural.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
# Needs https://github.com/jump-dev/JuMP.jl/pull/3451
22
using JuMP
3-
4-
include(joinpath(@__DIR__, "array_of_variables.jl"))
5-
include(joinpath(@__DIR__, "array_expr.jl"))
3+
using ArrayDiff
64

75
n = 2
86
X = rand(n, n)
97
model = Model()
10-
@variable(model, W[1:n, 1:n], container = ArrayOfVariables)
8+
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
119
W * X
12-
tanh.(W * X)

src/ArrayDiff.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,6 @@ function Evaluator(
6262
return Evaluator(model, NLPEvaluator(model, ordered_variables))
6363
end
6464

65+
include("JuMP/JuMP.jl")
66+
6567
end # module

src/JuMP/JuMP.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# JuMP extension
2+
3+
import JuMP
4+
5+
# Equivalent of `AbstractJuMPScalar` but for arrays
6+
abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end
7+
8+
include("variables.jl")
9+
include("nlp_expr.jl")
10+
include("operators.jl")
11+
include("print.jl")

src/JuMP/nlp_expr.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
struct GenericArrayExpr{V<:JuMP.AbstractVariableRef,N} <: AbstractJuMPArray{JuMP.GenericNonlinearExpr{V},N}
2+
head::Symbol
3+
args::Vector{Any}
4+
size::NTuple{N,Int}
5+
end
6+
7+
const ArrayExpr{N} = GenericArrayExpr{JuMP.VariableRef,N}
8+
9+
function Base.getindex(::GenericArrayExpr, args...)
10+
error("`getindex` not implemented, build vectorized expression instead")
11+
end
12+
13+
Base.size(expr::GenericArrayExpr) = expr.size

src/JuMP/operators.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
function Base.:(*)(A::MatrixOfVariables, B::Matrix)
2+
return GenericArrayExpr{JuMP.variable_ref_type(A.model),2}(:*, Any[A, B], (size(A, 1), size(B, 2)))
3+
end

src/JuMP/print.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
function Base.show(io::IO, ::MIME"text/plain", v::ArrayOfVariables)
2+
return print(io, Base.summary(v), " with offset ", v.offset)
3+
end
4+
5+
function Base.show(io::IO, ::MIME"text/plain", v::GenericArrayExpr)
6+
return print(io, Base.summary(v))
7+
end
8+
9+
function Base.show(io::IO, v::AbstractJuMPArray)
10+
return show(io, MIME"text/plain"(), v)
11+
end

src/JuMP/variables.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Taken out of GenOpt, we can add ArrayDiff as dependency to GenOpt and remove it in GenOpt
2+
3+
struct ArrayOfVariables{T,N} <: AbstractJuMPArray{JuMP.GenericVariableRef{T},N}
4+
model::JuMP.GenericModel{T}
5+
offset::Int64
6+
size::NTuple{N,Int64}
7+
end
8+
9+
const MatrixOfVariables{T} = ArrayOfVariables{T,2}
10+
11+
Base.size(array::ArrayOfVariables) = array.size
12+
function Base.getindex(A::ArrayOfVariables{T}, I...) where {T}
13+
index =
14+
A.offset + Base._to_linear_index(Base.CartesianIndices(A.size), I...)
15+
return JuMP.GenericVariableRef{T}(A.model, MOI.VariableIndex(index))
16+
end
17+
18+
function JuMP.Containers.container(
19+
f::Function,
20+
indices::JuMP.Containers.VectorizedProductIterator{NTuple{N,Base.OneTo{Int}}},
21+
::Type{ArrayOfVariables},
22+
) where {N}
23+
return to_generator(JuMP.Containers.container(f, indices, Array))
24+
end
25+
26+
JuMP._is_real(::ArrayOfVariables) = true
27+
28+
function Base.convert(
29+
::Type{ArrayOfVariables{T,N}},
30+
array::Array{JuMP.GenericVariableRef{T},N},
31+
) where {T,N}
32+
model = JuMP.owner_model(array[1])
33+
offset = JuMP.index(array[1]).value - 1
34+
for i in eachindex(array)
35+
@assert JuMP.owner_model(array[i]) === model
36+
@assert JuMP.index(array[i]).value == offset + i
37+
end
38+
return ArrayOfVariables{T,N}(model, offset, size(array))
39+
end
40+
41+
function to_generator(array::Array{JuMP.GenericVariableRef{T},N}) where {T,N}
42+
return convert(ArrayOfVariables{T,N}, array)
43+
end

test/JuMP.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
module TestJuMP
2+
3+
using Test
4+
5+
using JuMP
6+
using ArrayDiff
7+
8+
function runtests()
9+
for name in names(@__MODULE__; all = true)
10+
if startswith("$(name)", "test_")
11+
@testset "$(name)" begin
12+
getfield(@__MODULE__, name)()
13+
end
14+
end
15+
end
16+
return
17+
end
18+
19+
function test_array_product()
20+
n = 2
21+
X = rand(n, n)
22+
model = Model()
23+
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
24+
@test W isa ArrayDiff.MatrixOfVariables{Float64}
25+
@test JuMP.index(W[1, 1]) == MOI.VariableIndex(1)
26+
@test JuMP.index(W[2, 1]) == MOI.VariableIndex(2)
27+
@test JuMP.index(W[2]) == MOI.VariableIndex(2)
28+
@test sprint(show, W) == "2×2 ArrayDiff.ArrayOfVariables{Float64, 2} with offset 0"
29+
prod = W * X
30+
@test prod isa ArrayDiff.ArrayExpr{2}
31+
@test sprint(show, prod) == "2×2 ArrayDiff.GenericArrayExpr{VariableRef, 2}"
32+
err = ErrorException("`getindex` not implemented, build vectorized expression instead")
33+
@test_throws err prod[1, 1]
34+
return
35+
end
36+
37+
end # module
38+
39+
TestJuMP.runtests()

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
include("ReverseAD.jl")
22
include("ArrayDiff.jl")
3+
include("JuMP.jl")

0 commit comments

Comments
 (0)