@@ -171,19 +171,32 @@ function _forward_eval(
171171 end
172172 elseif node. index == 3 # :*
173173 # Node `k` is not scalar, so we do matrix multiplication
174+ # (or scalar `*` matrix scaling when one operand is scalar).
174175 if f. sizes. ndims[k] != 0
175176 @assert N == 2
176177 idx1 = first (children_indices)
177178 idx2 = last (children_indices)
178179 @inbounds ix1 = children_arr[idx1]
179180 @inbounds ix2 = children_arr[idx2]
180- v1 = _view_matrix (f. forward_storage, f. sizes, ix1)
181- v2 = _view_matrix (f. forward_storage, f. sizes, ix2)
182- out = _view_matrix (f. forward_storage, f. sizes, k)
183- LinearAlgebra. mul! (out, v1, v2)
181+ out = _view_linear (f. forward_storage, f. sizes, k)
182+ if f. sizes. ndims[ix1] == 0
183+ s = _getscalar (f. forward_storage, f. sizes, ix1)
184+ v = _view_linear (f. forward_storage, f. sizes, ix2)
185+ out .= s .* v
186+ elseif f. sizes. ndims[ix2] == 0
187+ v = _view_linear (f. forward_storage, f. sizes, ix1)
188+ s = _getscalar (f. forward_storage, f. sizes, ix2)
189+ out .= v .* s
190+ else
191+ v1 = _view_matrix (f. forward_storage, f. sizes, ix1)
192+ v2 = _view_matrix (f. forward_storage, f. sizes, ix2)
193+ out_m = _view_matrix (f. forward_storage, f. sizes, k)
194+ LinearAlgebra. mul! (out_m, v1, v2)
195+ end
184196 # We deliberately don't write v1/v2 into partials_storage
185- # here: the matmul reverse branch reads forward_storage
186- # directly, so those writes were dead.
197+ # here: the matmul (or scalar-scaling) reverse branch
198+ # reads forward_storage directly, so those writes were
199+ # dead.
187200 # Node `k` is scalar
188201 else
189202 tmp_prod = one (T)
@@ -391,46 +404,118 @@ function _forward_eval(
391404 children_indices = SparseArrays. nzrange (f. adj, k)
392405 N = length (children_indices)
393406 if node. index == 1 # :+ (broadcasted)
394- for j in _eachindex (f. sizes, k)
395- tmp_sum = zero (T)
396- for c_idx in children_indices
397- ix = children_arr[c_idx]
398- @j f. partials_storage[ix] = one (T)
399- tmp_sum += @j f. forward_storage[ix]
407+ # Broadcast-aware sum: scalar children contribute their
408+ # single value to every output slot.
409+ out = _view_linear (f. forward_storage, f. sizes, k)
410+ fill! (out, zero (T))
411+ for c_idx in children_indices
412+ ix = children_arr[c_idx]
413+ if f. sizes. ndims[ix] == 0
414+ s = _getscalar (f. forward_storage, f. sizes, ix)
415+ out .+ = s
416+ _setscalar! (
417+ f. partials_storage,
418+ one (T),
419+ f. sizes,
420+ ix,
421+ )
422+ else
423+ v = _view_linear (f. forward_storage, f. sizes, ix)
424+ out .+ = v
425+ fill! (
426+ _view_linear (f. partials_storage, f. sizes, ix),
427+ one (T),
428+ )
400429 end
401- @j f. forward_storage[k] = tmp_sum
402430 end
403431 elseif node. index == 2 # :- (broadcasted)
404432 @assert N == 2
405433 child1 = first (children_indices)
406434 @inbounds ix1 = children_arr[child1]
407435 @inbounds ix2 = children_arr[child1+ 1 ]
408436 out = _view_linear (f. forward_storage, f. sizes, k)
409- v1 = _view_linear (f. forward_storage, f. sizes, ix1)
410- v2 = _view_linear (f. forward_storage, f. sizes, ix2)
411- out .= v1 .- v2
412- fill! (_view_linear (f. partials_storage, f. sizes, ix1), one (T))
413- fill! (_view_linear (f. partials_storage, f. sizes, ix2), - one (T))
437+ ndims1 = f. sizes. ndims[ix1]
438+ ndims2 = f. sizes. ndims[ix2]
439+ if ndims1 == 0 && ndims2 != 0
440+ s1 = _getscalar (f. forward_storage, f. sizes, ix1)
441+ v2 = _view_linear (f. forward_storage, f. sizes, ix2)
442+ out .= s1 .- v2
443+ _setscalar! (f. partials_storage, one (T), f. sizes, ix1)
444+ fill! (
445+ _view_linear (f. partials_storage, f. sizes, ix2),
446+ - one (T),
447+ )
448+ elseif ndims1 != 0 && ndims2 == 0
449+ v1 = _view_linear (f. forward_storage, f. sizes, ix1)
450+ s2 = _getscalar (f. forward_storage, f. sizes, ix2)
451+ out .= v1 .- s2
452+ fill! (
453+ _view_linear (f. partials_storage, f. sizes, ix1),
454+ one (T),
455+ )
456+ _setscalar! (f. partials_storage, - one (T), f. sizes, ix2)
457+ else
458+ v1 = _view_linear (f. forward_storage, f. sizes, ix1)
459+ v2 = _view_linear (f. forward_storage, f. sizes, ix2)
460+ out .= v1 .- v2
461+ fill! (
462+ _view_linear (f. partials_storage, f. sizes, ix1),
463+ one (T),
464+ )
465+ fill! (
466+ _view_linear (f. partials_storage, f. sizes, ix2),
467+ - one (T),
468+ )
469+ end
414470 elseif node. index == 3 # :* (broadcasted)
415- # Node `k` is not scalar, so we do matrix multiplication
471+ # Node `k` is not scalar, so we do element-wise multiply
472+ # (with scalar-broadcast support: when one operand is
473+ # scalar, broadcast it across the matrix output).
416474 if f. sizes. ndims[k] != 0
417475 @assert N == 2
418476 idx1 = first (children_indices)
419477 idx2 = last (children_indices)
420478 @inbounds ix1 = children_arr[idx1]
421479 @inbounds ix2 = children_arr[idx2]
422- v1 = zeros (_size (f. sizes, ix1)... )
423- v2 = zeros (_size (f. sizes, ix2)... )
424- for j in _eachindex (f. sizes, ix1)
425- v1[j] = @j f. forward_storage[ix1]
426- @j f. partials_storage[ix2] = v1[j]
427- end
428- for j in _eachindex (f. sizes, ix2)
429- v2[j] = @j f. forward_storage[ix2]
430- @j f. partials_storage[ix1] = v2[j]
431- end
432- for j in _eachindex (f. sizes, k)
433- @j f. forward_storage[k] = v1[j] * v2[j]
480+ out = _view_linear (f. forward_storage, f. sizes, k)
481+ ndims1 = f. sizes. ndims[ix1]
482+ ndims2 = f. sizes. ndims[ix2]
483+ if ndims1 == 0 && ndims2 != 0
484+ s = _getscalar (f. forward_storage, f. sizes, ix1)
485+ v = _view_linear (f. forward_storage, f. sizes, ix2)
486+ out .= s .* v
487+ # Per-element partial w.r.t. the matrix child is
488+ # the scalar; the scalar child's reverse is handled
489+ # by the broadcasted-`:*` reverse branch below
490+ # (sum of `rev_parent .* v`).
491+ fill! (
492+ _view_linear (f. partials_storage, f. sizes, ix2),
493+ s,
494+ )
495+ elseif ndims1 != 0 && ndims2 == 0
496+ v = _view_linear (f. forward_storage, f. sizes, ix1)
497+ s = _getscalar (f. forward_storage, f. sizes, ix2)
498+ out .= v .* s
499+ fill! (
500+ _view_linear (f. partials_storage, f. sizes, ix1),
501+ s,
502+ )
503+ else
504+ # Both children are arrays of the same shape —
505+ # original element-wise path.
506+ v1 = zeros (_size (f. sizes, ix1)... )
507+ v2 = zeros (_size (f. sizes, ix2)... )
508+ for j in _eachindex (f. sizes, ix1)
509+ v1[j] = @j f. forward_storage[ix1]
510+ @j f. partials_storage[ix2] = v1[j]
511+ end
512+ for j in _eachindex (f. sizes, ix2)
513+ v2[j] = @j f. forward_storage[ix2]
514+ @j f. partials_storage[ix1] = v2[j]
515+ end
516+ for j in _eachindex (f. sizes, k)
517+ @j f. forward_storage[k] = v1[j] * v2[j]
518+ end
434519 end
435520 # Node `k` is scalar
436521 else
@@ -620,23 +705,54 @@ function _reverse_eval(
620705 op = DEFAULT_MULTIVARIATE_OPERATORS[node. index]
621706 if op == :*
622707 if f. sizes. ndims[k] != 0
623- # Matrix multiplication: rev_v1 = rev_parent * v2',
624- # rev_v2 = v1' * rev_parent. Both v1 and v2 are read
625- # straight from forward_storage ( the matmul forward
626- # branch deliberately doesn't snapshot them into
627- # partials_storage), and the reverse views are written
628- # in place .
708+ # Matmul (or `scalar * matrix` scaling): rev_v1 =
709+ # rev_parent * v2', rev_v2 = v1' * rev_parent. With
710+ # a scalar operand, the result is `s .* M`, so
711+ # rev[s] = sum(rev_parent .* M) and rev[M] =
712+ # rev_parent .* s. Both v1 and v2 are read straight
713+ # from forward_storage .
629714 idx1 = first (children_indices)
630715 idx2 = last (children_indices)
631716 ix1 = children_arr[idx1]
632717 ix2 = children_arr[idx2]
633- v1 = _view_matrix (f. forward_storage, f. sizes, ix1)
634- v2 = _view_matrix (f. forward_storage, f. sizes, ix2)
635- rev_parent = _view_matrix (f. reverse_storage, f. sizes, k)
636- rev_v1 = _view_matrix (f. reverse_storage, f. sizes, ix1)
637- rev_v2 = _view_matrix (f. reverse_storage, f. sizes, ix2)
638- LinearAlgebra. mul! (rev_v1, rev_parent, v2' )
639- LinearAlgebra. mul! (rev_v2, v1' , rev_parent)
718+ rev_parent =
719+ _view_linear (f. reverse_storage, f. sizes, k)
720+ ndims1 = f. sizes. ndims[ix1]
721+ ndims2 = f. sizes. ndims[ix2]
722+ if ndims1 == 0 && ndims2 != 0
723+ v2 = _view_linear (f. forward_storage, f. sizes, ix2)
724+ s1 = _getscalar (f. forward_storage, f. sizes, ix1)
725+ rev_v2 = _view_linear (f. reverse_storage, f. sizes, ix2)
726+ rev_v2 .= rev_parent .* s1
727+ _setscalar! (
728+ f. reverse_storage,
729+ LinearAlgebra. dot (rev_parent, v2),
730+ f. sizes,
731+ ix1,
732+ )
733+ elseif ndims1 != 0 && ndims2 == 0
734+ v1 = _view_linear (f. forward_storage, f. sizes, ix1)
735+ s2 = _getscalar (f. forward_storage, f. sizes, ix2)
736+ rev_v1 = _view_linear (f. reverse_storage, f. sizes, ix1)
737+ rev_v1 .= rev_parent .* s2
738+ _setscalar! (
739+ f. reverse_storage,
740+ LinearAlgebra. dot (rev_parent, v1),
741+ f. sizes,
742+ ix2,
743+ )
744+ else
745+ v1 = _view_matrix (f. forward_storage, f. sizes, ix1)
746+ v2 = _view_matrix (f. forward_storage, f. sizes, ix2)
747+ rev_parent_m =
748+ _view_matrix (f. reverse_storage, f. sizes, k)
749+ rev_v1 =
750+ _view_matrix (f. reverse_storage, f. sizes, ix1)
751+ rev_v2 =
752+ _view_matrix (f. reverse_storage, f. sizes, ix2)
753+ LinearAlgebra. mul! (rev_v1, rev_parent_m, v2' )
754+ LinearAlgebra. mul! (rev_v2, v1' , rev_parent_m)
755+ end
640756 continue
641757 end
642758 elseif op == :vect
@@ -832,13 +948,82 @@ function _reverse_eval(
832948 elseif node. type == NODE_CALL_MULTIVARIATE_BROADCASTED
833949 if node. index in eachindex (DEFAULT_MULTIVARIATE_OPERATORS)
834950 op = DEFAULT_MULTIVARIATE_OPERATORS[node. index]
951+ # Broadcasted +/- with at least one scalar child: the
952+ # scalar's reverse is the (signed) sum of the parent's
953+ # adjoint over the broadcast positions. Handle both scalar
954+ # and matrix children here so the generic
955+ # diagonal-partial path below doesn't trip its
956+ # `_size(k) == _size(ix)` assertion.
957+ if (op == :+ || op == :- ) && any (
958+ c -> f. sizes. ndims[children_arr[c]] == 0 ,
959+ children_indices,
960+ ) && f. sizes. ndims[k] != 0
961+ Tr = eltype (f. reverse_storage)
962+ rev_parent =
963+ _view_linear (f. reverse_storage, f. sizes, k)
964+ for c_idx in children_indices
965+ ix = children_arr[c_idx]
966+ # `:-` flips the sign for the second operand, mirroring
967+ # the partial we wrote in the forward pass.
968+ partial_sign =
969+ (op == :- && c_idx != first (children_indices)) ?
970+ - one (Tr) : one (Tr)
971+ if f. sizes. ndims[ix] == 0
972+ _setscalar! (
973+ f. reverse_storage,
974+ partial_sign * sum (rev_parent),
975+ f. sizes,
976+ ix,
977+ )
978+ else
979+ rev_child =
980+ _view_linear (f. reverse_storage, f. sizes, ix)
981+ rev_child .= partial_sign .* rev_parent
982+ end
983+ end
984+ continue
985+ end
835986 if op == :*
836987 if f. sizes. ndims[k] != 0
837- # Node `k` is not scalar, so we do matrix multiplication or broadcasted multiplication
838988 idx1 = first (children_indices)
839989 idx2 = last (children_indices)
840990 ix1 = children_arr[idx1]
841991 ix2 = children_arr[idx2]
992+ rev_parent =
993+ _view_linear (f. reverse_storage, f. sizes, k)
994+ ndims1 = f. sizes. ndims[ix1]
995+ ndims2 = f. sizes. ndims[ix2]
996+ if ndims1 == 0 && ndims2 != 0
997+ v2 =
998+ _view_linear (f. forward_storage, f. sizes, ix2)
999+ s1 = _getscalar (f. forward_storage, f. sizes, ix1)
1000+ rev_v2 =
1001+ _view_linear (f. reverse_storage, f. sizes, ix2)
1002+ rev_v2 .= rev_parent .* s1
1003+ _setscalar! (
1004+ f. reverse_storage,
1005+ LinearAlgebra. dot (rev_parent, v2),
1006+ f. sizes,
1007+ ix1,
1008+ )
1009+ continue
1010+ elseif ndims1 != 0 && ndims2 == 0
1011+ v1 =
1012+ _view_linear (f. forward_storage, f. sizes, ix1)
1013+ s2 = _getscalar (f. forward_storage, f. sizes, ix2)
1014+ rev_v1 =
1015+ _view_linear (f. reverse_storage, f. sizes, ix1)
1016+ rev_v1 .= rev_parent .* s2
1017+ _setscalar! (
1018+ f. reverse_storage,
1019+ LinearAlgebra. dot (rev_parent, v1),
1020+ f. sizes,
1021+ ix2,
1022+ )
1023+ continue
1024+ end
1025+ # Both children are arrays of the same shape —
1026+ # original element-wise path.
8421027 v1 = zeros (_size (f. sizes, ix1)... )
8431028 v2 = zeros (_size (f. sizes, ix2)... )
8441029 for j in _eachindex (f. sizes, ix1)
@@ -847,18 +1032,18 @@ function _reverse_eval(
8471032 for j in _eachindex (f. sizes, ix2)
8481033 v2[j] = @j f. forward_storage[ix2]
8491034 end
850- rev_parent = zeros (_size (f. sizes, k)... )
1035+ rev_parent_arr = zeros (_size (f. sizes, k)... )
8511036 for j in _eachindex (f. sizes, k)
852- rev_parent [j] = @j f. reverse_storage[k]
1037+ rev_parent_arr [j] = @j f. reverse_storage[k]
8531038 end
8541039 rev_v1 = zeros (_size (f. sizes, ix1)... )
8551040 rev_v2 = zeros (_size (f. sizes, ix2)... )
8561041 for j in _eachindex (f. sizes, ix1)
857- rev_v1[j] = rev_parent [j] * v2[j]
1042+ rev_v1[j] = rev_parent_arr [j] * v2[j]
8581043 @j f. reverse_storage[ix1] = rev_v1[j]
8591044 end
8601045 for j in _eachindex (f. sizes, ix2)
861- rev_v2[j] = rev_parent [j] * v1[j]
1046+ rev_v2[j] = rev_parent_arr [j] * v1[j]
8621047 @j f. reverse_storage[ix2] = rev_v2[j]
8631048 end
8641049 continue
0 commit comments