From 936ec064bc3c597b502e01c04806af66b4bb64fa Mon Sep 17 00:00:00 2001 From: AshtonSBradley Date: Thu, 7 May 2026 18:46:10 +1200 Subject: [PATCH 1/5] Fix Zygote gradient through scaled concretization --- src/basic.jl | 9 ++++----- test/zygote.jl | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/basic.jl b/src/basic.jl index c7a34060..7e166d2d 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -252,7 +252,7 @@ $TYPEDEF (λ L)*(v) = λ * L(v) """ -struct ScaledOperator{ +mutable struct ScaledOperator{ T, λType, LType, @@ -340,10 +340,9 @@ Base.resize!(L::ScaledOperator, n::Integer) = (resize!(L.L, n); L) LinearAlgebra.opnorm(L::ScaledOperator, p::Real = 2) = abs(L.λ) * opnorm(L.L, p) function update_coefficients(L::ScaledOperator, u, p, t; kwargs...) - @reset L.L = update_coefficients(L.L, u, p, t; kwargs...) - @reset L.λ = update_coefficients(L.λ, u, p, t; kwargs...) - - return L + λ = update_coefficients(L.λ, u, p, t; kwargs...) + L_inner = update_coefficients(L.L, u, p, t; kwargs...) + return ScaledOperator(λ, L_inner) end function update_coefficients!(L::ScaledOperator, u, p, t; kwargs...) diff --git a/test/zygote.jl b/test/zygote.jl index ea3a46d4..1a80fb2a 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -111,3 +111,17 @@ for (LType, L) in ( end end end + +@testset "Zygote update_coefficients concretize scaled operator" begin + A1 = MatrixOperator([1.0 0.0; 0.0 1.0]) + A2 = MatrixOperator([1.0 0.0; 0.0 0.0]) + coeff = ScalarOperator(0.0, (a, u, p, t) -> p) + L = A1 + coeff * A2 + + operator_entry(p) = (update_coefficients(L, 0, p, 0) |> concretize)[1, 1] + matrix_entry(p) = ([1.0 0.0; 0.0 1.0] + p * [1.0 0.0; 0.0 0.0])[1, 1] + + p = 1.0 + @test operator_entry(p) == matrix_entry(p) + @test Zygote.gradient(operator_entry, p)[1] == Zygote.gradient(matrix_entry, p)[1] +end From 85914bd3fc6a2c1dce30b73e4bc4bd642439df43 Mon Sep 17 00:00:00 2001 From: AshtonSBradley Date: Sat, 9 May 2026 07:59:25 +1200 Subject: [PATCH 2/5] Keep scaled operators immutable for AD fix --- src/basic.jl | 4 +-- src/scalar.jl | 42 ++++++++++++++++++++++++++++ test/ad_semantics.jl | 66 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 ++ 4 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 test/ad_semantics.jl diff --git a/src/basic.jl b/src/basic.jl index 7e166d2d..6d22b676 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -252,7 +252,7 @@ $TYPEDEF (λ L)*(v) = λ * L(v) """ -mutable struct ScaledOperator{ +struct ScaledOperator{ T, λType, LType, @@ -340,7 +340,7 @@ Base.resize!(L::ScaledOperator, n::Integer) = (resize!(L.L, n); L) LinearAlgebra.opnorm(L::ScaledOperator, p::Real = 2) = abs(L.λ) * opnorm(L.L, p) function update_coefficients(L::ScaledOperator, u, p, t; kwargs...) - λ = update_coefficients(L.λ, u, p, t; kwargs...) + λ = _freeze_updated_scalar(update_coefficients(L.λ, u, p, t; kwargs...)) L_inner = update_coefficients(L.L, u, p, t; kwargs...) return ScaledOperator(λ, L_inner) end diff --git a/src/scalar.jl b/src/scalar.jl index 2404e026..d8c4f9e0 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -126,6 +126,15 @@ mutable struct ScalarOperator{T <: Number, F} <: AbstractSciMLScalarOperator{T} update_func::F end +# Immutable snapshot used by out-of-place updates after ScalarOperator has evaluated its state. +struct _UpdatedScalarOperator{T <: Number, F} <: AbstractSciMLScalarOperator{T} + val::T + update_func::F +end + +_freeze_updated_scalar(α) = α +_freeze_updated_scalar(α::ScalarOperator) = _UpdatedScalarOperator(α.val, α.update_func) + """ $SIGNATURES @@ -186,6 +195,7 @@ end # constructors Base.convert(T::Type{<:Number}, α::ScalarOperator) = convert(T, α.val) +Base.convert(T::Type{<:Number}, α::_UpdatedScalarOperator) = convert(T, α.val) Base.convert(::Type{ScalarOperator}, α::Number) = ScalarOperator(α) ScalarOperator(α::AbstractSciMLScalarOperator) = α @@ -193,6 +203,7 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ) # traits Base.show(io::IO, α::ScalarOperator) = print(io, "ScalarOperator($(α.val))") +Base.show(io::IO, α::_UpdatedScalarOperator) = print(io, "ScalarOperator($(α.val))") function Base.conj(α::ScalarOperator) # TODO - test val = conj(α.val) update_func = ( @@ -208,12 +219,28 @@ function Base.conj(α::ScalarOperator) # TODO - test return ScalarOperator(val; update_func = update_func, accepted_kwargs = NoKwargFilter()) end +function Base.conj(α::_UpdatedScalarOperator) # TODO - test + val = conj(α.val) + update_func = ( + oldval, u, p, t; + kwargs..., + ) -> α.update_func( + oldval |> conj, + u, + p, + t; + kwargs... + ) |> conj + return _UpdatedScalarOperator(val, update_func) +end + Base.one(::AbstractSciMLScalarOperator{T}) where {T} = ScalarOperator(one(T)) Base.zero(::AbstractSciMLScalarOperator{T}) where {T} = ScalarOperator(zero(T)) Base.one(::Type{<:AbstractSciMLScalarOperator}) = ScalarOperator(true) Base.zero(::Type{<:AbstractSciMLScalarOperator}) = ScalarOperator(false) Base.abs(α::ScalarOperator) = abs(α.val) +Base.abs(α::_UpdatedScalarOperator) = abs(α.val) function LinearAlgebra.exp(α::AbstractSciMLScalarOperator) update_func = ( @@ -226,11 +253,16 @@ function LinearAlgebra.exp(α::AbstractSciMLScalarOperator) end Base.iszero(α::ScalarOperator) = iszero(α.val) +Base.iszero(α::_UpdatedScalarOperator) = iszero(α.val) getops(α::ScalarOperator) = (α.val,) +getops(α::_UpdatedScalarOperator) = (α.val,) isconstant(α::ScalarOperator) = update_func_isconstant(α.update_func) +isconstant(α::_UpdatedScalarOperator) = update_func_isconstant(α.update_func) has_ldiv(α::ScalarOperator) = !iszero(α.val) +has_ldiv(α::_UpdatedScalarOperator) = !iszero(α.val) has_ldiv!(α::ScalarOperator) = has_ldiv(α) +has_ldiv!(α::_UpdatedScalarOperator) = has_ldiv(α) function update_coefficients!(L::ScalarOperator, u, p, t; kwargs...) L.val = L.update_func(L.val, u, p, t; kwargs...) @@ -241,11 +273,21 @@ function SciMLOperators.update_coefficients(L::ScalarOperator, u, p, t; kwargs.. return ScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func) end +function SciMLOperators.update_coefficients( + L::_UpdatedScalarOperator, u, p, t; kwargs... + ) + return _UpdatedScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func) +end + # Copy method to avoid aliasing function Base.copy(L::ScalarOperator) return ScalarOperator(L.val, L.update_func) end +function Base.copy(L::_UpdatedScalarOperator) + return _UpdatedScalarOperator(L.val, L.update_func) +end + # Add ScalarOperator specific implementations for the new interface function (α::ScalarOperator)(v::AbstractArray, u, p, t; kwargs...) α = update_coefficients(α, u, p, t; kwargs...) diff --git a/test/ad_semantics.jl b/test/ad_semantics.jl new file mode 100644 index 00000000..450ee987 --- /dev/null +++ b/test/ad_semantics.jl @@ -0,0 +1,66 @@ +using SciMLOperators, LinearAlgebra, Test, Zygote + +using SciMLOperators: concretize + +const ad_n = 3 +const ad_u = [0.3, -0.2, 0.7] +const ad_v = [1.0, -2.0, 0.5] +const ad_t = 0.4 +const ad_pmat = [ + 0.0 2.0 -1.0 + 1.0 0.0 0.5 + -0.25 0.75 0.0 +] + +ad_scalar() = ScalarOperator(0.0, (_, _, p, _) -> p) +ad_matrix() = MatrixOperator(ad_pmat) +ad_added_operator() = MatrixOperator(Matrix{Float64}(I, ad_n, ad_n)) + ad_scalar() * ad_matrix() + +function ad_expected_scaled(p) + return p .* (ad_pmat * ad_v) +end + +function ad_expected_added(p) + return (Matrix{Float64}(I, ad_n, ad_n) + p .* ad_pmat) * ad_v +end + +@testset "AD semantic equivalence" begin + p = 1.7 + + @testset "ScalarOperator * MatrixOperator" begin + L = ad_scalar() * ad_matrix() + + concretized_loss(p) = sum(concretize(update_coefficients(L, ad_u, p, ad_t)) * ad_v) + direct_loss(p) = sum(L(ad_v, ad_u, p, ad_t)) + + @test concretize(update_coefficients(L, ad_u, p, ad_t)) ≈ p .* ad_pmat + @test L(ad_v, ad_u, p, ad_t) ≈ ad_expected_scaled(p) + + w = similar(ad_v) + L(w, ad_v, ad_u, p, ad_t) + @test w ≈ ad_expected_scaled(p) + + expected_grad = sum(ad_pmat * ad_v) + @test only(Zygote.gradient(concretized_loss, p)) ≈ expected_grad + @test only(Zygote.gradient(direct_loss, p)) ≈ expected_grad + end + + @testset "MatrixOperator + ScalarOperator * MatrixOperator" begin + L = ad_added_operator() + + concretized_loss(p) = sum(concretize(update_coefficients(L, ad_u, p, ad_t)) * ad_v) + direct_loss(p) = sum(L(ad_v, ad_u, p, ad_t)) + + @test concretize(update_coefficients(L, ad_u, p, ad_t)) ≈ + Matrix{Float64}(I, ad_n, ad_n) + p .* ad_pmat + @test L(ad_v, ad_u, p, ad_t) ≈ ad_expected_added(p) + + w = similar(ad_v) + L(w, ad_v, ad_u, p, ad_t) + @test w ≈ ad_expected_added(p) + + expected_grad = sum(ad_pmat * ad_v) + @test only(Zygote.gradient(concretized_loss, p)) ≈ expected_grad + @test only(Zygote.gradient(direct_loss, p)) ≈ expected_grad + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d27c6e77..95903e74 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,9 @@ end @time @safetestset "Zygote.jl" begin include("zygote.jl") end + @time @safetestset "AD semantics" begin + include("ad_semantics.jl") + end @time @safetestset "Copy methods" begin include("copy.jl") end From 50fa2215c081ce1bc410dcef83d34bf271d47c4c Mon Sep 17 00:00:00 2001 From: AshtonSBradley Date: Sat, 9 May 2026 08:21:01 +1200 Subject: [PATCH 3/5] Fix downgrade test resolution --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 177b75d3..526d1402 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore" [compat] Accessors = "0.1.42" Adapt = "4" -ArrayInterface = "7.19" +ArrayInterface = "7.24" DocStringExtensions = "0.9.4" LinearAlgebra = "1.10" LoopVectorization = "0.12" From 2bc53208bd94524faff12c516c651982dc92b4d2 Mon Sep 17 00:00:00 2001 From: AshtonSBradley Date: Sat, 9 May 2026 08:52:02 +1200 Subject: [PATCH 4/5] Align Adapt lower bound for downgrade tests --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 526d1402..47b06047 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,7 @@ SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore" [compat] Accessors = "0.1.42" -Adapt = "4" +Adapt = "4.5.2" ArrayInterface = "7.24" DocStringExtensions = "0.9.4" LinearAlgebra = "1.10" From c7faca86f16fd52a5a295debb4b47fd6837db2ed Mon Sep 17 00:00:00 2001 From: AshtonSBradley Date: Sat, 9 May 2026 09:23:26 +1200 Subject: [PATCH 5/5] Clarify updated scalar in-place behavior --- src/basic.jl | 28 ++++++++++++++++++++++++++++ src/scalar.jl | 2 +- test/ad_semantics.jl | 10 ++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/basic.jl b/src/basic.jl index 6d22b676..82577f00 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -345,6 +345,20 @@ function update_coefficients(L::ScaledOperator, u, p, t; kwargs...) return ScaledOperator(λ, L_inner) end +function _throw_updated_scaled_inplace() + throw( + ArgumentError( + "cannot update coefficients in-place after an out-of-place ScaledOperator update; call update_coefficients instead", + ) + ) +end + +function update_coefficients!( + L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any}, u, p, t; kwargs... + ) + _throw_updated_scaled_inplace() +end + function update_coefficients!(L::ScaledOperator, u, p, t; kwargs...) update_coefficients!(L.L, u, p, t; kwargs...) update_coefficients!(L.λ, u, p, t; kwargs...) @@ -444,6 +458,13 @@ end return L.L(w, v, u, p, t, a, false; kwargs...) end +@inline function (L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any})( + w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs... + ) + L = update_coefficients(L, u, p, t; kwargs...) + return mul!(w, L, v) +end + # In-place with scaling: w = α*(L*v) + β*w @inline function (L::ScaledOperator)( w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs... @@ -453,6 +474,13 @@ end return L.L(w, v, u, p, t, a, β; kwargs...) end +@inline function (L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any})( + w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs... + ) + L = update_coefficients(L, u, p, t; kwargs...) + return mul!(w, L, v, α, β) +end + """ Lazy operator addition diff --git a/src/scalar.jl b/src/scalar.jl index d8c4f9e0..4ce9e3d3 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -219,7 +219,7 @@ function Base.conj(α::ScalarOperator) # TODO - test return ScalarOperator(val; update_func = update_func, accepted_kwargs = NoKwargFilter()) end -function Base.conj(α::_UpdatedScalarOperator) # TODO - test +function Base.conj(α::_UpdatedScalarOperator) val = conj(α.val) update_func = ( oldval, u, p, t; diff --git a/test/ad_semantics.jl b/test/ad_semantics.jl index 450ee987..4650ef3b 100644 --- a/test/ad_semantics.jl +++ b/test/ad_semantics.jl @@ -43,6 +43,16 @@ end expected_grad = sum(ad_pmat * ad_v) @test only(Zygote.gradient(concretized_loss, p)) ≈ expected_grad @test only(Zygote.gradient(direct_loss, p)) ≈ expected_grad + + updated_L = update_coefficients(L, ad_u, p, ad_t) + @test updated_L(ad_v, ad_u, p + 1, ad_t) ≈ ad_expected_scaled(p + 1) + @test_throws ArgumentError update_coefficients!(updated_L, ad_u, p + 1, ad_t) + updated_L(w, ad_v, ad_u, p + 1, ad_t) + @test w ≈ ad_expected_scaled(p + 1) + + w .= 0.25 + updated_L(w, ad_v, ad_u, p + 1, ad_t, 2.0, 0.5) + @test w ≈ 2 .* ad_expected_scaled(p + 1) .+ 0.125 end @testset "MatrixOperator + ScalarOperator * MatrixOperator" begin