Skip to content

Commit 69e9117

Browse files
committed
user api
1 parent d667f6d commit 69e9117

1 file changed

Lines changed: 63 additions & 13 deletions

File tree

src/sets.jl

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,80 @@ struct Interval{VT} <: AbstractSet
1818
u::VT
1919
end
2020

21-
@inline distance_to_set(v, s) = distance_to_set(DefaultDistance(), v, s)
22-
@inline distance_to_set(::DefaultDistance, v, s::S) where {S} = distance_to_set(EpigraphViolationDistance(), v, s)
23-
@inline distance_to_set(::NormedEpigraphDistance{p}, s::S) where {p,S} = LinearAlgebra.norm(distance_to_set(EpigraphViolationDistance(), v, s), p)
21+
@inline violation(v, s) = violation(DefaultDistance(), v, s)
22+
@inline violation(::DefaultDistance, v, s::S) where {S} = violation(EpigraphViolationDistance(), v, s)
23+
@inline violation(::NormedEpigraphDistance{p}, s::S) where {p,S} = LinearAlgebra.norm(violation(EpigraphViolationDistance(), v, s), p)
2424

25-
distance_to_set!(d, v, s) = begin
26-
d .= distance_to_set(DefaultDistance(), v, s)
25+
violation!(d, v, s) = begin
26+
d .= violation(DefaultDistance(), v, s)
2727
end
28-
distance_to_set!(::DefaultDistance, d, v, s::S) where {S} = begin
29-
d .= distance_to_set(EpigraphViolationDistance(), v, s)
28+
violation!(::DefaultDistance, d, v, s::S) where {S} = begin
29+
d .= violation(EpigraphViolationDistance(), v, s)
3030
end
31-
distance_to_set!(::NormedEpigraphDistance{p}, d, v, s::S) where {p,S} = begin
32-
d .= LinearAlgebra.norm(distance_to_set(EpigraphViolationDistance(), v, s), p)
31+
violation!(::NormedEpigraphDistance{p}, d, v, s::S) where {p,S} = begin
32+
d .= LinearAlgebra.norm(violation(EpigraphViolationDistance(), v, s), p)
3333
end
3434

3535

36-
@inline distance_to_set(::EpigraphViolationDistance, s::LessThan) = begin
36+
@inline violation(::EpigraphViolationDistance, s::LessThan) = begin
3737
@. max(v - s.u, zero(v))
3838
end
39-
@inline distance_to_set(::EpigraphViolationDistance, s::GreaterThan) = begin
39+
@inline violation(::EpigraphViolationDistance, s::GreaterThan) = begin
4040
@. max(s.l - v, zero(v))
4141
end
42-
@inline distance_to_set(::EpigraphViolationDistance, s::EqualTo) = begin
42+
@inline violation(::EpigraphViolationDistance, s::EqualTo) = begin
4343
@. abs(v - s.v)
4444
end
45-
@inline distance_to_set(::EpigraphViolationDistance, s::Interval) = begin
45+
@inline violation(::EpigraphViolationDistance, s::Interval) = begin
4646
@. max(s.l - v, v - s.u, zero(v))
4747
end
48+
49+
# FIXME is interval slow?
50+
struct BatchViolation{MT,E}
51+
model::E
52+
batch_size::Int
53+
54+
# constraints
55+
in_cons_out::MT
56+
in_cons::Interval
57+
58+
# variable bounds
59+
in_vars_out::MT
60+
in_vars::Interval
61+
end
62+
63+
function BatchViolation(model::E, batch_size::Int) where {E}
64+
lcon = model.meta.lcon
65+
ucon = model.meta.ucon
66+
67+
in_cons_out = similar(lcon, length(lcon), batch_size)
68+
in_cons = Interval(lcon, ucon)
69+
70+
lvar = model.meta.lvar
71+
uvar = model.meta.uvar
72+
73+
in_vars_out = similar(lvar, length(lvar), batch_size)
74+
in_vars = Interval(lvar, uvar)
75+
76+
return BatchViolation(
77+
model, batch_size,
78+
in_cons_out, in_cons,
79+
in_vars_out, in_vars
80+
)
81+
end
82+
83+
84+
function _constraint_violations!(b::BatchViolation, V::AbstractMatrix)
85+
violation!.(eachcol(b.in_cons_out), eachcol(V), Ref(b.in_cons))
86+
end
87+
88+
function all_violations!(bm::BatchModel, b::BatchViolation, X::AbstractMatrix)
89+
V = cons_nln_batch!(bm, X, Θ)
90+
_constraint_violations!(b, V)
91+
violation!.(eachcol(b.in_vars_out), eachcol(X), Ref(b.in_vars))
92+
end
93+
function all_violations!(bm::BatchModel, b::BatchViolation, X::AbstractMatrix, Θ::AbstractMatrix)
94+
V = cons_nln_batch!(bm, X, Θ)
95+
_constraint_violations!(b, V)
96+
violation!.(eachcol(b.in_vars_out), eachcol(X), Ref(b.in_vars))
97+
end

0 commit comments

Comments
 (0)