@@ -720,13 +720,46 @@ function test_transformer_stacked_residual_gradient()
720720 return _check_transformer_loss (build)
721721end
722722
723- # `sum(x; dims=N)` builds a `:sum_dims` node that reduces along the given
724- # dims, keeping the input ndims with the reduced axes collapsed to size 1.
725- # Verify both value and gradient against finite differences across
726- # `dims=1`, `dims=2`, and `dims=(1,2)` for a 2×3 matrix variable.
727- function test_sum_dims_gradient ()
728- @testset " dims=$dims " for dims in (1 , 2 , (1 , 2 ))
729- _check_transformer_loss (x -> sum (sum (x; dims = dims) .^ 2 ))
723+ # Broadcasted `./` against a JuMP matrix variable `W`, with the other
724+ # operand cycling through every shape combination ArrayDiff supports:
725+ # scalar, full matrix, column vector (length rows), row vector (1×cols).
726+ # Loss is `norm(c ./ W)` (variable always in the denominator) and a second
727+ # set of cases with W in the numerator (`W ./ c`); the analytic gradient
728+ # `dexpr_dW .* (c./W) ./ norm(c./W)` is checked against the AD-computed
729+ # gradient elementwise. W is initialized to positive values to avoid the
730+ # division-by-zero blow-up.
731+ function test_broadcast_divide_gradient ()
732+ rows, cols = 2 , 3
733+ c = 2.5
734+ model = Model ()
735+ @variable (model, W[1 : rows, 1 : cols], container = ArrayDiff. ArrayOfVariables)
736+ v = [10.0 , 20.0 ]
737+ r = [100.0 200.0 300.0 ]
738+ M = reshape (collect (11.0 : 10.0 : 160.0 )[1 : (rows* cols)], rows, cols)
739+ x = Float64 .(collect (1 : (rows* cols)))
740+ W_val = reshape (x, rows, cols)
741+ @testset " $(name) " for (name, expr, ref_mat, dexpr_dW) in [
742+ # W as denominator: ∂(c ./ W)/∂W_ij = -c_ij / W_ij^2 (c broadcast)
743+ (" scalar ./ W" , c ./ W, c ./ W_val, - c ./ W_val .^ 2 ),
744+ (" v ./ W" , v ./ W, v ./ W_val, - (v .* ones (rows, cols)) ./ W_val .^ 2 ),
745+ (" r ./ W" , r ./ W, r ./ W_val, - (ones (rows) .* r) ./ W_val .^ 2 ),
746+ (" M ./ W" , M ./ W, M ./ W_val, - M ./ W_val .^ 2 ),
747+ # W as numerator: ∂(W ./ c)/∂W_ij = 1 / c_ij (c broadcast)
748+ (" W ./ scalar" , W ./ c, W_val ./ c, fill (1 / c, rows, cols)),
749+ (" W ./ v" , W ./ v, W_val ./ v, 1 ./ (v .* ones (rows, cols))),
750+ (" W ./ r" , W ./ r, W_val ./ r, 1 ./ (ones (rows) .* r)),
751+ (" W ./ M" , W ./ M, W_val ./ M, 1 ./ M),
752+ ]
753+ sizes, val, g = _eval (model, LinearAlgebra. norm (expr), x)
754+ # Tape: norm (k=1, scalar), broadcast `./` (k=2) inheriting (rows, cols)
755+ # from the result shape, then the two children.
756+ @test sizes. ndims[1 ] == 0
757+ @test sizes. ndims[2 ] == 2
758+ b_off = sizes. size_offset[2 ]
759+ @test sizes. size[b_off+ 1 ] == rows
760+ @test sizes. size[b_off+ 2 ] == cols
761+ @test val ≈ LinearAlgebra. norm (ref_mat)
762+ @test g ≈ vec (dexpr_dW .* ref_mat) ./ LinearAlgebra. norm (ref_mat)
730763 end
731764 return
732765end
0 commit comments