diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 5a90faa..47202ed 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -391,14 +391,24 @@ function _forward_eval( children_indices = SparseArrays.nzrange(f.adj, k) N = length(children_indices) if node.index == 1 # :+ (broadcasted) - for j in _eachindex(f.sizes, k) - tmp_sum = zero(T) - for c_idx in children_indices - ix = children_arr[c_idx] - @j f.partials_storage[ix] = one(T) - tmp_sum += @j f.forward_storage[ix] + # Broadcast-aware sum: scalar children contribute their + # single value to every output slot. + out = _view_linear(f.forward_storage, f.sizes, k) + fill!(out, zero(T)) + for c_idx in children_indices + ix = children_arr[c_idx] + if f.sizes.ndims[ix] == 0 + s = _getscalar(f.forward_storage, f.sizes, ix) + out .+= s + _setscalar!(f.partials_storage, one(T), f.sizes, ix) + else + v = _view_linear(f.forward_storage, f.sizes, ix) + out .+= v + fill!( + _view_linear(f.partials_storage, f.sizes, ix), + one(T), + ) end - @j f.forward_storage[k] = tmp_sum end elseif node.index == 2 # :- (broadcasted) @assert N == 2 @@ -406,31 +416,82 @@ function _forward_eval( @inbounds ix1 = children_arr[child1] @inbounds ix2 = children_arr[child1+1] out = _view_linear(f.forward_storage, f.sizes, k) - v1 = _view_linear(f.forward_storage, f.sizes, ix1) - v2 = _view_linear(f.forward_storage, f.sizes, ix2) - out .= v1 .- v2 - fill!(_view_linear(f.partials_storage, f.sizes, ix1), one(T)) - fill!(_view_linear(f.partials_storage, f.sizes, ix2), -one(T)) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + s1 = _getscalar(f.forward_storage, f.sizes, ix1) + v2 = _view_linear(f.forward_storage, f.sizes, ix2) + out .= s1 .- v2 + _setscalar!(f.partials_storage, one(T), f.sizes, ix1) + fill!( + _view_linear(f.partials_storage, f.sizes, ix2), + -one(T), + ) + elseif ndims1 != 0 && ndims2 == 0 + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + s2 = _getscalar(f.forward_storage, f.sizes, ix2) + out .= v1 .- s2 + fill!( + _view_linear(f.partials_storage, f.sizes, ix1), + one(T), + ) + _setscalar!(f.partials_storage, -one(T), f.sizes, ix2) + else + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + v2 = _view_linear(f.forward_storage, f.sizes, ix2) + out .= v1 .- v2 + fill!( + _view_linear(f.partials_storage, f.sizes, ix1), + one(T), + ) + fill!( + _view_linear(f.partials_storage, f.sizes, ix2), + -one(T), + ) + end elseif node.index == 3 # :* (broadcasted) - # Node `k` is not scalar, so we do matrix multiplication + # Node `k` is not scalar, so we do element-wise multiply + # (with scalar-broadcast support: when one operand is + # scalar, broadcast it across the matrix output). if f.sizes.ndims[k] != 0 @assert N == 2 idx1 = first(children_indices) idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - v1 = zeros(_size(f.sizes, ix1)...) - v2 = zeros(_size(f.sizes, ix2)...) - for j in _eachindex(f.sizes, ix1) - v1[j] = @j f.forward_storage[ix1] - @j f.partials_storage[ix2] = v1[j] - end - for j in _eachindex(f.sizes, ix2) - v2[j] = @j f.forward_storage[ix2] - @j f.partials_storage[ix1] = v2[j] - end - for j in _eachindex(f.sizes, k) - @j f.forward_storage[k] = v1[j] * v2[j] + out = _view_linear(f.forward_storage, f.sizes, k) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + s = _getscalar(f.forward_storage, f.sizes, ix1) + v = _view_linear(f.forward_storage, f.sizes, ix2) + out .= s .* v + # Per-element partial w.r.t. the matrix child is + # the scalar; the scalar child's reverse is handled + # by the broadcasted-`:*` reverse branch below + # (sum of `rev_parent .* v`). + fill!(_view_linear(f.partials_storage, f.sizes, ix2), s) + elseif ndims1 != 0 && ndims2 == 0 + v = _view_linear(f.forward_storage, f.sizes, ix1) + s = _getscalar(f.forward_storage, f.sizes, ix2) + out .= v .* s + fill!(_view_linear(f.partials_storage, f.sizes, ix1), s) + else + # Both children are arrays of the same shape — + # original element-wise path. + v1 = zeros(_size(f.sizes, ix1)...) + v2 = zeros(_size(f.sizes, ix2)...) + for j in _eachindex(f.sizes, ix1) + v1[j] = @j f.forward_storage[ix1] + @j f.partials_storage[ix2] = v1[j] + end + for j in _eachindex(f.sizes, ix2) + v2[j] = @j f.forward_storage[ix2] + @j f.partials_storage[ix1] = v2[j] + end + for j in _eachindex(f.sizes, k) + @j f.forward_storage[k] = v1[j] * v2[j] + end end # Node `k` is scalar else @@ -832,13 +893,80 @@ function _reverse_eval( elseif node.type == NODE_CALL_MULTIVARIATE_BROADCASTED if node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS) op = DEFAULT_MULTIVARIATE_OPERATORS[node.index] + # Broadcasted +/- with at least one scalar child: the + # scalar's reverse is the (signed) sum of the parent's + # adjoint over the broadcast positions. Handle both scalar + # and matrix children here so the generic + # diagonal-partial path below doesn't trip its + # `_size(k) == _size(ix)` assertion. + if (op == :+ || op == :-) && + any( + c -> f.sizes.ndims[children_arr[c]] == 0, + children_indices, + ) && + f.sizes.ndims[k] != 0 + Tr = eltype(f.reverse_storage) + rev_parent = _view_linear(f.reverse_storage, f.sizes, k) + for c_idx in children_indices + ix = children_arr[c_idx] + # `:-` flips the sign for the second operand, mirroring + # the partial we wrote in the forward pass. + partial_sign = + (op == :- && c_idx != first(children_indices)) ? + -one(Tr) : one(Tr) + if f.sizes.ndims[ix] == 0 + _setscalar!( + f.reverse_storage, + partial_sign * sum(rev_parent), + f.sizes, + ix, + ) + else + rev_child = + _view_linear(f.reverse_storage, f.sizes, ix) + rev_child .= partial_sign .* rev_parent + end + end + continue + end if op == :* if f.sizes.ndims[k] != 0 - # Node `k` is not scalar, so we do matrix multiplication or broadcasted multiplication idx1 = first(children_indices) idx2 = last(children_indices) ix1 = children_arr[idx1] ix2 = children_arr[idx2] + rev_parent = _view_linear(f.reverse_storage, f.sizes, k) + ndims1 = f.sizes.ndims[ix1] + ndims2 = f.sizes.ndims[ix2] + if ndims1 == 0 && ndims2 != 0 + v2 = _view_linear(f.forward_storage, f.sizes, ix2) + s1 = _getscalar(f.forward_storage, f.sizes, ix1) + rev_v2 = + _view_linear(f.reverse_storage, f.sizes, ix2) + rev_v2 .= rev_parent .* s1 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v2), + f.sizes, + ix1, + ) + continue + elseif ndims1 != 0 && ndims2 == 0 + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + s2 = _getscalar(f.forward_storage, f.sizes, ix2) + rev_v1 = + _view_linear(f.reverse_storage, f.sizes, ix1) + rev_v1 .= rev_parent .* s2 + _setscalar!( + f.reverse_storage, + LinearAlgebra.dot(rev_parent, v1), + f.sizes, + ix2, + ) + continue + end + # Both children are arrays of the same shape — + # original element-wise path. v1 = zeros(_size(f.sizes, ix1)...) v2 = zeros(_size(f.sizes, ix2)...) for j in _eachindex(f.sizes, ix1) diff --git a/test/JuMP.jl b/test/JuMP.jl index 533e631..f397220 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -457,35 +457,40 @@ function test_broadcast_nonsquare_matrix() return end -function test_broadcast_scalar_matrix_size_inference() +# Cover every `Number op MatrixVar` / `MatrixVar op Number` broadcast +# pattern that JuMP's `Base.broadcasted` produces — both the size inference +# (broadcast node inherits the matrix child's shape, not the old `(1, 1)` +# stub) and the eval/reverse paths (`out .= s op v`, `rev_s = +# ±sum(rev_parent)` or `dot(rev_parent, v)`). Loss is `norm(c op W)` so the +# analytic gradient is `dexpr_dW .* (c op W) ./ norm(c op W)`. +function test_broadcast_scalar_matrix_gradient() + c = 2.5 + rows, cols = 2, 3 model = Model() - @variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables) - mode = ArrayDiff.Mode() - @testset "$(name)" for (name, expr) in [ - ("scalar .* M", LinearAlgebra.norm(2.5 .* W)), - ("M .* scalar", LinearAlgebra.norm(W .* 2.5)), - ("scalar .+ M", LinearAlgebra.norm(2.5 .+ W)), - ("M .+ scalar", LinearAlgebra.norm(W .+ 2.5)), - ("scalar .- M", LinearAlgebra.norm(2.5 .- W)), - ("M .- scalar", LinearAlgebra.norm(W .- 2.5)), + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + x = Float64.(collect(1:(rows*cols))) + W_val = reshape(x, rows, cols) + @testset "$(name)" for (name, expr, ref_mat, dexpr_dW) in [ + ("scalar .+ M", c .+ W, c .+ W_val, fill(1.0, rows, cols)), + ("M .+ scalar", W .+ c, W_val .+ c, fill(1.0, rows, cols)), + ("scalar .- M", c .- W, c .- W_val, fill(-1.0, rows, cols)), + ("M .- scalar", W .- c, W_val .- c, fill(1.0, rows, cols)), + ("scalar .* M", c .* W, c .* W_val, fill(c, rows, cols)), + ("M .* scalar", W .* c, W_val .* c, fill(c, rows, cols)), ] - ad = ArrayDiff.model(mode) - MOI.Nonlinear.set_objective(ad, JuMP.moi_function(expr)) - evaluator = MOI.Nonlinear.Evaluator( - ad, - mode, - JuMP.index.(JuMP.all_variables(model)), - ) - MOI.initialize(evaluator, [:Grad]) - sizes = evaluator.backend.objective.expr.sizes - # Broadcast node is at index 2; it should inherit the matrix child's - # (2, 3) shape, not the old `(1, 1)` stub. + sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x) + # Outer norm scalar (k=1), then the broadcast (k=2) which must + # inherit the matrix child's (rows, cols) shape — not the old + # `(1, 1)` stub — then the two children (one scalar leaf, one + # matrix leaf) in some order. + @test sizes.ndims[1] == 0 @test sizes.ndims[2] == 2 - broadcast_size_off = sizes.size_offset[2] - @test sizes.size[broadcast_size_off+1] == 2 - @test sizes.size[broadcast_size_off+2] == 3 - # And the scalar leaf among the children stays ndims=0. + b_off = sizes.size_offset[2] + @test sizes.size[b_off+1] == rows + @test sizes.size[b_off+2] == cols @test 0 in sizes.ndims[3:4] + @test val ≈ LinearAlgebra.norm(ref_mat) + @test g ≈ vec(dexpr_dW .* ref_mat) ./ LinearAlgebra.norm(ref_mat) end return end