Skip to content

Commit ae7bf66

Browse files
committed
fix
1 parent 0a109ec commit ae7bf66

3 files changed

Lines changed: 286 additions & 104 deletions

File tree

src/reverse_mode.jl

Lines changed: 234 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -171,19 +171,32 @@ function _forward_eval(
171171
end
172172
elseif node.index == 3 # :*
173173
# Node `k` is not scalar, so we do matrix multiplication
174+
# (or scalar `*` matrix scaling when one operand is scalar).
174175
if f.sizes.ndims[k] != 0
175176
@assert N == 2
176177
idx1 = first(children_indices)
177178
idx2 = last(children_indices)
178179
@inbounds ix1 = children_arr[idx1]
179180
@inbounds ix2 = children_arr[idx2]
180-
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
181-
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
182-
out = _view_matrix(f.forward_storage, f.sizes, k)
183-
LinearAlgebra.mul!(out, v1, v2)
181+
out = _view_linear(f.forward_storage, f.sizes, k)
182+
if f.sizes.ndims[ix1] == 0
183+
s = _getscalar(f.forward_storage, f.sizes, ix1)
184+
v = _view_linear(f.forward_storage, f.sizes, ix2)
185+
out .= s .* v
186+
elseif f.sizes.ndims[ix2] == 0
187+
v = _view_linear(f.forward_storage, f.sizes, ix1)
188+
s = _getscalar(f.forward_storage, f.sizes, ix2)
189+
out .= v .* s
190+
else
191+
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
192+
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
193+
out_m = _view_matrix(f.forward_storage, f.sizes, k)
194+
LinearAlgebra.mul!(out_m, v1, v2)
195+
end
184196
# We deliberately don't write v1/v2 into partials_storage
185-
# here: the matmul reverse branch reads forward_storage
186-
# directly, so those writes were dead.
197+
# here: the matmul (or scalar-scaling) reverse branch
198+
# reads forward_storage directly, so those writes were
199+
# dead.
187200
# Node `k` is scalar
188201
else
189202
tmp_prod = one(T)
@@ -391,46 +404,118 @@ function _forward_eval(
391404
children_indices = SparseArrays.nzrange(f.adj, k)
392405
N = length(children_indices)
393406
if node.index == 1 # :+ (broadcasted)
394-
for j in _eachindex(f.sizes, k)
395-
tmp_sum = zero(T)
396-
for c_idx in children_indices
397-
ix = children_arr[c_idx]
398-
@j f.partials_storage[ix] = one(T)
399-
tmp_sum += @j f.forward_storage[ix]
407+
# Broadcast-aware sum: scalar children contribute their
408+
# single value to every output slot.
409+
out = _view_linear(f.forward_storage, f.sizes, k)
410+
fill!(out, zero(T))
411+
for c_idx in children_indices
412+
ix = children_arr[c_idx]
413+
if f.sizes.ndims[ix] == 0
414+
s = _getscalar(f.forward_storage, f.sizes, ix)
415+
out .+= s
416+
_setscalar!(
417+
f.partials_storage,
418+
one(T),
419+
f.sizes,
420+
ix,
421+
)
422+
else
423+
v = _view_linear(f.forward_storage, f.sizes, ix)
424+
out .+= v
425+
fill!(
426+
_view_linear(f.partials_storage, f.sizes, ix),
427+
one(T),
428+
)
400429
end
401-
@j f.forward_storage[k] = tmp_sum
402430
end
403431
elseif node.index == 2 # :- (broadcasted)
404432
@assert N == 2
405433
child1 = first(children_indices)
406434
@inbounds ix1 = children_arr[child1]
407435
@inbounds ix2 = children_arr[child1+1]
408436
out = _view_linear(f.forward_storage, f.sizes, k)
409-
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
410-
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
411-
out .= v1 .- v2
412-
fill!(_view_linear(f.partials_storage, f.sizes, ix1), one(T))
413-
fill!(_view_linear(f.partials_storage, f.sizes, ix2), -one(T))
437+
ndims1 = f.sizes.ndims[ix1]
438+
ndims2 = f.sizes.ndims[ix2]
439+
if ndims1 == 0 && ndims2 != 0
440+
s1 = _getscalar(f.forward_storage, f.sizes, ix1)
441+
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
442+
out .= s1 .- v2
443+
_setscalar!(f.partials_storage, one(T), f.sizes, ix1)
444+
fill!(
445+
_view_linear(f.partials_storage, f.sizes, ix2),
446+
-one(T),
447+
)
448+
elseif ndims1 != 0 && ndims2 == 0
449+
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
450+
s2 = _getscalar(f.forward_storage, f.sizes, ix2)
451+
out .= v1 .- s2
452+
fill!(
453+
_view_linear(f.partials_storage, f.sizes, ix1),
454+
one(T),
455+
)
456+
_setscalar!(f.partials_storage, -one(T), f.sizes, ix2)
457+
else
458+
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
459+
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
460+
out .= v1 .- v2
461+
fill!(
462+
_view_linear(f.partials_storage, f.sizes, ix1),
463+
one(T),
464+
)
465+
fill!(
466+
_view_linear(f.partials_storage, f.sizes, ix2),
467+
-one(T),
468+
)
469+
end
414470
elseif node.index == 3 # :* (broadcasted)
415-
# Node `k` is not scalar, so we do matrix multiplication
471+
# Node `k` is not scalar, so we do element-wise multiply
472+
# (with scalar-broadcast support: when one operand is
473+
# scalar, broadcast it across the matrix output).
416474
if f.sizes.ndims[k] != 0
417475
@assert N == 2
418476
idx1 = first(children_indices)
419477
idx2 = last(children_indices)
420478
@inbounds ix1 = children_arr[idx1]
421479
@inbounds ix2 = children_arr[idx2]
422-
v1 = zeros(_size(f.sizes, ix1)...)
423-
v2 = zeros(_size(f.sizes, ix2)...)
424-
for j in _eachindex(f.sizes, ix1)
425-
v1[j] = @j f.forward_storage[ix1]
426-
@j f.partials_storage[ix2] = v1[j]
427-
end
428-
for j in _eachindex(f.sizes, ix2)
429-
v2[j] = @j f.forward_storage[ix2]
430-
@j f.partials_storage[ix1] = v2[j]
431-
end
432-
for j in _eachindex(f.sizes, k)
433-
@j f.forward_storage[k] = v1[j] * v2[j]
480+
out = _view_linear(f.forward_storage, f.sizes, k)
481+
ndims1 = f.sizes.ndims[ix1]
482+
ndims2 = f.sizes.ndims[ix2]
483+
if ndims1 == 0 && ndims2 != 0
484+
s = _getscalar(f.forward_storage, f.sizes, ix1)
485+
v = _view_linear(f.forward_storage, f.sizes, ix2)
486+
out .= s .* v
487+
# Per-element partial w.r.t. the matrix child is
488+
# the scalar; the scalar child's reverse is handled
489+
# by the broadcasted-`:*` reverse branch below
490+
# (sum of `rev_parent .* v`).
491+
fill!(
492+
_view_linear(f.partials_storage, f.sizes, ix2),
493+
s,
494+
)
495+
elseif ndims1 != 0 && ndims2 == 0
496+
v = _view_linear(f.forward_storage, f.sizes, ix1)
497+
s = _getscalar(f.forward_storage, f.sizes, ix2)
498+
out .= v .* s
499+
fill!(
500+
_view_linear(f.partials_storage, f.sizes, ix1),
501+
s,
502+
)
503+
else
504+
# Both children are arrays of the same shape —
505+
# original element-wise path.
506+
v1 = zeros(_size(f.sizes, ix1)...)
507+
v2 = zeros(_size(f.sizes, ix2)...)
508+
for j in _eachindex(f.sizes, ix1)
509+
v1[j] = @j f.forward_storage[ix1]
510+
@j f.partials_storage[ix2] = v1[j]
511+
end
512+
for j in _eachindex(f.sizes, ix2)
513+
v2[j] = @j f.forward_storage[ix2]
514+
@j f.partials_storage[ix1] = v2[j]
515+
end
516+
for j in _eachindex(f.sizes, k)
517+
@j f.forward_storage[k] = v1[j] * v2[j]
518+
end
434519
end
435520
# Node `k` is scalar
436521
else
@@ -620,23 +705,54 @@ function _reverse_eval(
620705
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
621706
if op == :*
622707
if f.sizes.ndims[k] != 0
623-
# Matrix multiplication: rev_v1 = rev_parent * v2',
624-
# rev_v2 = v1' * rev_parent. Both v1 and v2 are read
625-
# straight from forward_storage (the matmul forward
626-
# branch deliberately doesn't snapshot them into
627-
# partials_storage), and the reverse views are written
628-
# in place.
708+
# Matmul (or `scalar * matrix` scaling): rev_v1 =
709+
# rev_parent * v2', rev_v2 = v1' * rev_parent. With
710+
# a scalar operand, the result is `s .* M`, so
711+
# rev[s] = sum(rev_parent .* M) and rev[M] =
712+
# rev_parent .* s. Both v1 and v2 are read straight
713+
# from forward_storage.
629714
idx1 = first(children_indices)
630715
idx2 = last(children_indices)
631716
ix1 = children_arr[idx1]
632717
ix2 = children_arr[idx2]
633-
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
634-
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
635-
rev_parent = _view_matrix(f.reverse_storage, f.sizes, k)
636-
rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1)
637-
rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2)
638-
LinearAlgebra.mul!(rev_v1, rev_parent, v2')
639-
LinearAlgebra.mul!(rev_v2, v1', rev_parent)
718+
rev_parent =
719+
_view_linear(f.reverse_storage, f.sizes, k)
720+
ndims1 = f.sizes.ndims[ix1]
721+
ndims2 = f.sizes.ndims[ix2]
722+
if ndims1 == 0 && ndims2 != 0
723+
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
724+
s1 = _getscalar(f.forward_storage, f.sizes, ix1)
725+
rev_v2 = _view_linear(f.reverse_storage, f.sizes, ix2)
726+
rev_v2 .= rev_parent .* s1
727+
_setscalar!(
728+
f.reverse_storage,
729+
LinearAlgebra.dot(rev_parent, v2),
730+
f.sizes,
731+
ix1,
732+
)
733+
elseif ndims1 != 0 && ndims2 == 0
734+
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
735+
s2 = _getscalar(f.forward_storage, f.sizes, ix2)
736+
rev_v1 = _view_linear(f.reverse_storage, f.sizes, ix1)
737+
rev_v1 .= rev_parent .* s2
738+
_setscalar!(
739+
f.reverse_storage,
740+
LinearAlgebra.dot(rev_parent, v1),
741+
f.sizes,
742+
ix2,
743+
)
744+
else
745+
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
746+
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
747+
rev_parent_m =
748+
_view_matrix(f.reverse_storage, f.sizes, k)
749+
rev_v1 =
750+
_view_matrix(f.reverse_storage, f.sizes, ix1)
751+
rev_v2 =
752+
_view_matrix(f.reverse_storage, f.sizes, ix2)
753+
LinearAlgebra.mul!(rev_v1, rev_parent_m, v2')
754+
LinearAlgebra.mul!(rev_v2, v1', rev_parent_m)
755+
end
640756
continue
641757
end
642758
elseif op == :vect
@@ -832,13 +948,82 @@ function _reverse_eval(
832948
elseif node.type == NODE_CALL_MULTIVARIATE_BROADCASTED
833949
if node.index in eachindex(DEFAULT_MULTIVARIATE_OPERATORS)
834950
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
951+
# Broadcasted +/- with at least one scalar child: the
952+
# scalar's reverse is the (signed) sum of the parent's
953+
# adjoint over the broadcast positions. Handle both scalar
954+
# and matrix children here so the generic
955+
# diagonal-partial path below doesn't trip its
956+
# `_size(k) == _size(ix)` assertion.
957+
if (op == :+ || op == :-) && any(
958+
c -> f.sizes.ndims[children_arr[c]] == 0,
959+
children_indices,
960+
) && f.sizes.ndims[k] != 0
961+
Tr = eltype(f.reverse_storage)
962+
rev_parent =
963+
_view_linear(f.reverse_storage, f.sizes, k)
964+
for c_idx in children_indices
965+
ix = children_arr[c_idx]
966+
# `:-` flips the sign for the second operand, mirroring
967+
# the partial we wrote in the forward pass.
968+
partial_sign =
969+
(op == :- && c_idx != first(children_indices)) ?
970+
-one(Tr) : one(Tr)
971+
if f.sizes.ndims[ix] == 0
972+
_setscalar!(
973+
f.reverse_storage,
974+
partial_sign * sum(rev_parent),
975+
f.sizes,
976+
ix,
977+
)
978+
else
979+
rev_child =
980+
_view_linear(f.reverse_storage, f.sizes, ix)
981+
rev_child .= partial_sign .* rev_parent
982+
end
983+
end
984+
continue
985+
end
835986
if op == :*
836987
if f.sizes.ndims[k] != 0
837-
# Node `k` is not scalar, so we do matrix multiplication or broadcasted multiplication
838988
idx1 = first(children_indices)
839989
idx2 = last(children_indices)
840990
ix1 = children_arr[idx1]
841991
ix2 = children_arr[idx2]
992+
rev_parent =
993+
_view_linear(f.reverse_storage, f.sizes, k)
994+
ndims1 = f.sizes.ndims[ix1]
995+
ndims2 = f.sizes.ndims[ix2]
996+
if ndims1 == 0 && ndims2 != 0
997+
v2 =
998+
_view_linear(f.forward_storage, f.sizes, ix2)
999+
s1 = _getscalar(f.forward_storage, f.sizes, ix1)
1000+
rev_v2 =
1001+
_view_linear(f.reverse_storage, f.sizes, ix2)
1002+
rev_v2 .= rev_parent .* s1
1003+
_setscalar!(
1004+
f.reverse_storage,
1005+
LinearAlgebra.dot(rev_parent, v2),
1006+
f.sizes,
1007+
ix1,
1008+
)
1009+
continue
1010+
elseif ndims1 != 0 && ndims2 == 0
1011+
v1 =
1012+
_view_linear(f.forward_storage, f.sizes, ix1)
1013+
s2 = _getscalar(f.forward_storage, f.sizes, ix2)
1014+
rev_v1 =
1015+
_view_linear(f.reverse_storage, f.sizes, ix1)
1016+
rev_v1 .= rev_parent .* s2
1017+
_setscalar!(
1018+
f.reverse_storage,
1019+
LinearAlgebra.dot(rev_parent, v1),
1020+
f.sizes,
1021+
ix2,
1022+
)
1023+
continue
1024+
end
1025+
# Both children are arrays of the same shape —
1026+
# original element-wise path.
8421027
v1 = zeros(_size(f.sizes, ix1)...)
8431028
v2 = zeros(_size(f.sizes, ix2)...)
8441029
for j in _eachindex(f.sizes, ix1)
@@ -847,18 +1032,18 @@ function _reverse_eval(
8471032
for j in _eachindex(f.sizes, ix2)
8481033
v2[j] = @j f.forward_storage[ix2]
8491034
end
850-
rev_parent = zeros(_size(f.sizes, k)...)
1035+
rev_parent_arr = zeros(_size(f.sizes, k)...)
8511036
for j in _eachindex(f.sizes, k)
852-
rev_parent[j] = @j f.reverse_storage[k]
1037+
rev_parent_arr[j] = @j f.reverse_storage[k]
8531038
end
8541039
rev_v1 = zeros(_size(f.sizes, ix1)...)
8551040
rev_v2 = zeros(_size(f.sizes, ix2)...)
8561041
for j in _eachindex(f.sizes, ix1)
857-
rev_v1[j] = rev_parent[j] * v2[j]
1042+
rev_v1[j] = rev_parent_arr[j] * v2[j]
8581043
@j f.reverse_storage[ix1] = rev_v1[j]
8591044
end
8601045
for j in _eachindex(f.sizes, ix2)
861-
rev_v2[j] = rev_parent[j] * v1[j]
1046+
rev_v2[j] = rev_parent_arr[j] * v1[j]
8621047
@j f.reverse_storage[ix2] = rev_v2[j]
8631048
end
8641049
continue

0 commit comments

Comments
 (0)