Skip to content

Commit b66fdde

Browse files
committed
Fix broadcast division
1 parent 6cf63bb commit b66fdde

3 files changed

Lines changed: 97 additions & 15 deletions

File tree

src/reverse_mode.jl

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,16 @@ function _forward_eval(
432432
broadcast!,
433433
(*,),
434434
)
435+
elseif node.index == 5 # :/ (broadcasted)
436+
@assert N == 2
437+
child1 = first(children_indices)
438+
_reshape_call(
439+
f.forward_storage,
440+
f.sizes,
441+
(k, children_arr[child1], children_arr[child1+1]),
442+
broadcast!,
443+
(/,),
444+
)
435445
elseif node.index == 4 # :^ (broadcasted), array .^ scalar
436446
@assert N == 2
437447
idx1 = first(children_indices)
@@ -578,15 +588,46 @@ function __reverse_broadcasted_mul(f, ilhs, irhs, dout, dlhs, drhs)
578588
)
579589
end
580590

581-
# Reverse for `sum_dims`: broadcast the parent's gradient back to the
582-
# input's shape. Parent has size 1 in the reduced dimensions, child has the
583-
# original input shape.
584-
# Good news: `dchild .= dparent` does the expansion via Julia broadcasting.
585-
function _reverse_sum_dims!(dchild, dparent)
586-
dchild .= dparent
591+
# Reverse for broadcasted `:/`. `z = x ./ y`:
592+
# ∂z/∂x = 1 ./ y → dx += dout ./ y
593+
# ∂z/∂y = -x ./ y .^ 2 → dy += -dout .* x ./ y .^ 2
594+
function _reverse_broadcasted_div(dout, dlhs, drhs, lhs, rhs)
595+
# Why `fill!` ? See comment in `_reverse_broadcasted_mul`
596+
fill!(dlhs, zero(eltype(dlhs)))
597+
Base.mapreducedim!(
598+
identity,
599+
Base.add_sum,
600+
dlhs,
601+
Broadcast.instantiate(Broadcast.broadcasted(/, dout, rhs)),
602+
)
603+
fill!(drhs, zero(eltype(drhs)))
604+
# dy += -dout * lhs / rhs^2, written lazily so no temporary materializes.
605+
Base.mapreducedim!(
606+
identity,
607+
Base.add_sum,
608+
drhs,
609+
Broadcast.instantiate(
610+
Broadcast.broadcasted(
611+
(do_, l, r) -> -do_ * l / (r * r),
612+
dout,
613+
lhs,
614+
rhs,
615+
),
616+
),
617+
)
587618
return
588619
end
589620

621+
function __reverse_broadcasted_div(f, ilhs, irhs, dout, dlhs, drhs)
622+
return _reshape_call(
623+
f.forward_storage,
624+
f.sizes,
625+
(ilhs, irhs),
626+
_reverse_broadcasted_div,
627+
(dout, dlhs, drhs),
628+
)
629+
end
630+
590631
"""
591632
_reverse_eval(f::_SubexpressionStorage)
592633
@@ -855,7 +896,7 @@ function _reverse_eval(
855896
# and matrix children here so the generic
856897
# diagonal-partial path below doesn't trip its
857898
# `_size(k) == _size(ix)` assertion.
858-
if op == :+ || op == :- || op == :* || op == :^
899+
if op == :+ || op == :- || op == :* || op == :^ || op == :/
859900
@assert length(children_indices) == 2
860901
child1 = first(children_indices)
861902
lhs = children_arr[child1]
@@ -868,6 +909,14 @@ function _reverse_eval(
868909
__reverse_broadcasted_mul,
869910
(f, lhs, rhs),
870911
)
912+
elseif op == :/
913+
_reshape_call(
914+
f.reverse_storage,
915+
f.sizes,
916+
(k, lhs, rhs),
917+
__reverse_broadcasted_div,
918+
(f, lhs, rhs),
919+
)
871920
elseif op == :^
872921
# We start with just .^2 to simplify
873922
@assert f.sizes.ndims[rhs] == 0 "Broadcasted ^ requires scalar exponent"

src/sizes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ function _infer_sizes(
495495
continue
496496
end
497497
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
498-
if op == :+ || op == :- || op == :*
498+
if op == :+ || op == :- || op == :* || op == :/
499499
sizes.ndims[k] = maximum(children_indices, init = 0) do i
500500
return sizes.ndims[children_arr[i]]
501501
end

test/JuMP.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -720,13 +720,46 @@ function test_transformer_stacked_residual_gradient()
720720
return _check_transformer_loss(build)
721721
end
722722

723-
# `sum(x; dims=N)` builds a `:sum_dims` node that reduces along the given
724-
# dims, keeping the input ndims with the reduced axes collapsed to size 1.
725-
# Verify both value and gradient against finite differences across
726-
# `dims=1`, `dims=2`, and `dims=(1,2)` for a 2×3 matrix variable.
727-
function test_sum_dims_gradient()
728-
@testset "dims=$dims" for dims in (1, 2, (1, 2))
729-
_check_transformer_loss(x -> sum(sum(x; dims = dims) .^ 2))
723+
# Broadcasted `./` against a JuMP matrix variable `W`, with the other
724+
# operand cycling through every shape combination ArrayDiff supports:
725+
# scalar, full matrix, column vector (length rows), row vector (1×cols).
726+
# Loss is `norm(c ./ W)` (variable always in the denominator) and a second
727+
# set of cases with W in the numerator (`W ./ c`); the analytic gradient
728+
# `dexpr_dW .* (c./W) ./ norm(c./W)` is checked against the AD-computed
729+
# gradient elementwise. W is initialized to positive values to avoid the
730+
# division-by-zero blow-up.
731+
function test_broadcast_divide_gradient()
732+
rows, cols = 2, 3
733+
c = 2.5
734+
model = Model()
735+
@variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables)
736+
v = [10.0, 20.0]
737+
r = [100.0 200.0 300.0]
738+
M = reshape(collect(11.0:10.0:160.0)[1:(rows*cols)], rows, cols)
739+
x = Float64.(collect(1:(rows*cols)))
740+
W_val = reshape(x, rows, cols)
741+
@testset "$(name)" for (name, expr, ref_mat, dexpr_dW) in [
742+
# W as denominator: ∂(c ./ W)/∂W_ij = -c_ij / W_ij^2 (c broadcast)
743+
("scalar ./ W", c ./ W, c ./ W_val, -c ./ W_val .^ 2),
744+
("v ./ W", v ./ W, v ./ W_val, -(v .* ones(rows, cols)) ./ W_val .^ 2),
745+
("r ./ W", r ./ W, r ./ W_val, -(ones(rows) .* r) ./ W_val .^ 2),
746+
("M ./ W", M ./ W, M ./ W_val, -M ./ W_val .^ 2),
747+
# W as numerator: ∂(W ./ c)/∂W_ij = 1 / c_ij (c broadcast)
748+
("W ./ scalar", W ./ c, W_val ./ c, fill(1 / c, rows, cols)),
749+
("W ./ v", W ./ v, W_val ./ v, 1 ./ (v .* ones(rows, cols))),
750+
("W ./ r", W ./ r, W_val ./ r, 1 ./ (ones(rows) .* r)),
751+
("W ./ M", W ./ M, W_val ./ M, 1 ./ M),
752+
]
753+
sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x)
754+
# Tape: norm (k=1, scalar), broadcast `./` (k=2) inheriting (rows, cols)
755+
# from the result shape, then the two children.
756+
@test sizes.ndims[1] == 0
757+
@test sizes.ndims[2] == 2
758+
b_off = sizes.size_offset[2]
759+
@test sizes.size[b_off+1] == rows
760+
@test sizes.size[b_off+2] == cols
761+
@test val LinearAlgebra.norm(ref_mat)
762+
@test g vec(dexpr_dW .* ref_mat) ./ LinearAlgebra.norm(ref_mat)
730763
end
731764
return
732765
end

0 commit comments

Comments
 (0)