@@ -411,84 +411,15 @@ function _forward_eval(
411411 - ,
412412 )
413413 elseif node. index == 3 # :* (broadcasted)
414- # Node `k` is not scalar, so we do element-wise multiply
415- # (with scalar-broadcast support: when one operand is
416- # scalar, broadcast it across the matrix output).
417- if f. sizes. ndims[k] != 0
418- @assert N == 2
419- idx1 = first (children_indices)
420- idx2 = last (children_indices)
421- @inbounds ix1 = children_arr[idx1]
422- @inbounds ix2 = children_arr[idx2]
423- out = _view_linear (f. forward_storage, f. sizes, k)
424- ndims1 = f. sizes. ndims[ix1]
425- ndims2 = f. sizes. ndims[ix2]
426- if ndims1 == 0 && ndims2 != 0
427- s = _getscalar (f. forward_storage, f. sizes, ix1)
428- v = _view_linear (f. forward_storage, f. sizes, ix2)
429- out .= s .* v
430- # Per-element partial w.r.t. the matrix child is
431- # the scalar; the scalar child's reverse is handled
432- # by the broadcasted-`:*` reverse branch below
433- # (sum of `rev_parent .* v`).
434- fill! (_view_linear (f. partials_storage, f. sizes, ix2), s)
435- elseif ndims1 != 0 && ndims2 == 0
436- v = _view_linear (f. forward_storage, f. sizes, ix1)
437- s = _getscalar (f. forward_storage, f. sizes, ix2)
438- out .= v .* s
439- fill! (_view_linear (f. partials_storage, f. sizes, ix1), s)
440- else
441- # Both children are arrays of the same shape —
442- # original element-wise path.
443- v1 = zeros (_size (f. sizes, ix1)... )
444- v2 = zeros (_size (f. sizes, ix2)... )
445- for j in _eachindex (f. sizes, ix1)
446- v1[j] = @j f. forward_storage[ix1]
447- @j f. partials_storage[ix2] = v1[j]
448- end
449- for j in _eachindex (f. sizes, ix2)
450- v2[j] = @j f. forward_storage[ix2]
451- @j f. partials_storage[ix1] = v2[j]
452- end
453- for j in _eachindex (f. sizes, k)
454- @j f. forward_storage[k] = v1[j] * v2[j]
455- end
456- end
457- # Node `k` is scalar
458- else
459- tmp_prod = one (T)
460- for c_idx in children_indices
461- @inbounds tmp_prod *=
462- f. forward_storage[children_arr[c_idx]]
463- end
464- if tmp_prod == zero (T) || N <= 2
465- # This is inefficient if there are a lot of children.
466- # 2 is chosen as a limit because (x*y)/y does not always
467- # equal x for floating-point numbers. This can produce
468- # unexpected error in partials. There's still an error when
469- # multiplying three or more terms, but users are less likely
470- # to complain about it.
471- for c_idx in children_indices
472- prod_others = one (T)
473- for c_idx2 in children_indices
474- (c_idx == c_idx2) && continue
475- ix = children_arr[c_idx2]
476- prod_others *= f. forward_storage[ix]
477- end
478- f. partials_storage[children_arr[c_idx]] =
479- prod_others
480- end
481- else
482- # Compute all-minus-one partial derivatives by dividing from
483- # the total product.
484- for c_idx in children_indices
485- ix = children_arr[c_idx]
486- f. partials_storage[ix] =
487- tmp_prod / f. forward_storage[ix]
488- end
489- end
490- @inbounds f. forward_storage[k] = tmp_prod
491- end
414+ @assert N == 2
415+ child1 = first (children_indices)
416+ _reshape_call (
417+ f. forward_storage,
418+ f. sizes,
419+ (k, children_arr[child1], children_arr[child1+ 1 ]),
420+ broadcast!,
421+ * ,
422+ )
492423 elseif node. index == 4 # :^ (broadcasted), array .^ scalar
493424 @assert N == 2
494425 idx1 = first (children_indices)
@@ -605,6 +536,35 @@ function _forward_eval(
605536 return f. forward_storage[1 ]
606537end
607538
539+ function _reverse_broadcasted_mul (dout, dlhs, drhs, lhs, rhs)
540+ # Would need `conj` once we support `Complex`
541+ Base. mapreducedim! (
542+ identity,
543+ Base. add_sum,
544+ dlhs,
545+ Broadcast. instantiate (Broadcast. broadcasted (* , dout, rhs)),
546+ )
547+ Base. mapreducedim! (
548+ identity,
549+ Base. add_sum,
550+ drhs,
551+ Broadcast. instantiate (Broadcast. broadcasted (* , lhs, dout)),
552+ )
553+ return
554+ end
555+
556+ function __reverse_broadcasted_mul (f, ilhs, irhs, dout, dlhs, drhs)
557+ return _reshape_call (
558+ f. forward_storage,
559+ f. sizes,
560+ (ilhs, irhs),
561+ _reverse_broadcasted_mul,
562+ dout,
563+ dlhs,
564+ drhs,
565+ )
566+ end
567+
608568"""
609569 _reverse_eval(f::_SubexpressionStorage)
610570
@@ -889,105 +849,15 @@ function _reverse_eval(
889849 continue
890850 end
891851 if op == :*
892- if f. sizes. ndims[k] != 0
893- idx1 = first (children_indices)
894- idx2 = last (children_indices)
895- ix1 = children_arr[idx1]
896- ix2 = children_arr[idx2]
897- rev_parent = _view_linear (f. reverse_storage, f. sizes, k)
898- ndims1 = f. sizes. ndims[ix1]
899- ndims2 = f. sizes. ndims[ix2]
900- if ndims1 == 0 && ndims2 != 0
901- v2 = _view_linear (f. forward_storage, f. sizes, ix2)
902- s1 = _getscalar (f. forward_storage, f. sizes, ix1)
903- rev_v2 =
904- _view_linear (f. reverse_storage, f. sizes, ix2)
905- rev_v2 .= rev_parent .* s1
906- _setscalar! (
907- f. reverse_storage,
908- LinearAlgebra. dot (rev_parent, v2),
909- f. sizes,
910- ix1,
911- )
912- continue
913- elseif ndims1 != 0 && ndims2 == 0
914- v1 = _view_linear (f. forward_storage, f. sizes, ix1)
915- s2 = _getscalar (f. forward_storage, f. sizes, ix2)
916- rev_v1 =
917- _view_linear (f. reverse_storage, f. sizes, ix1)
918- rev_v1 .= rev_parent .* s2
919- _setscalar! (
920- f. reverse_storage,
921- LinearAlgebra. dot (rev_parent, v1),
922- f. sizes,
923- ix2,
924- )
925- continue
926- end
927- # Both children are arrays of the same shape —
928- # original element-wise path.
929- v1 = zeros (_size (f. sizes, ix1)... )
930- v2 = zeros (_size (f. sizes, ix2)... )
931- for j in _eachindex (f. sizes, ix1)
932- v1[j] = @j f. forward_storage[ix1]
933- end
934- for j in _eachindex (f. sizes, ix2)
935- v2[j] = @j f. forward_storage[ix2]
936- end
937- rev_parent = zeros (_size (f. sizes, k)... )
938- for j in _eachindex (f. sizes, k)
939- rev_parent[j] = @j f. reverse_storage[k]
940- end
941- rev_v1 = zeros (_size (f. sizes, ix1)... )
942- rev_v2 = zeros (_size (f. sizes, ix2)... )
943- for j in _eachindex (f. sizes, ix1)
944- rev_v1[j] = rev_parent[j] * v2[j]
945- @j f. reverse_storage[ix1] = rev_v1[j]
946- end
947- for j in _eachindex (f. sizes, ix2)
948- rev_v2[j] = rev_parent[j] * v1[j]
949- @j f. reverse_storage[ix2] = rev_v2[j]
950- end
951- continue
952- end
953- elseif op == :^
954- # Broadcasted array .^ scalar: vectorize the per-element
955- # base reverse (with 0*Inf guard preserved) and reduce
956- # the exponent contribution as a single `sum` over GPU
957- # arrays.
958- @assert length (children_indices) == 2
959- idx1 = first (children_indices)
960- idx2 = last (children_indices)
961- @inbounds ix1 = children_arr[idx1]
962- @inbounds ix2 = children_arr[idx2]
963- rev_parent = _view_linear (f. reverse_storage, f. sizes, k)
964- rev_v1 = _view_linear (f. reverse_storage, f. sizes, ix1)
965- partial = _view_linear (f. partials_storage, f. sizes, ix1)
966- rev_v1 .= ifelse .(
967- (rev_parent .== 0 ) .& .! isfinite .(partial),
968- rev_parent,
969- rev_parent .* partial,
852+ _reshape_call (
853+ f. reverse_storage,
854+ f. sizes,
855+ (k, lhs, rhs),
856+ __reverse_broadcasted_mul,
857+ f,
858+ lhs,
859+ rhs,
970860 )
971- base_view = _view_linear (f. forward_storage, f. sizes, ix1)
972- out_view = _view_linear (f. forward_storage, f. sizes, k)
973- # `mapreduce(f, +, base_view, rev_parent, out_view)`
974- # would express this directly, but multi-iterable
975- # `mapreduce` materializes an intermediate today
976- # (JuliaLang/julia#53417). Wrap the inputs in `zip` so
977- # the single-iterable specialization fires and the
978- # reduction stays allocation-free. Once
979- # https://github.com/JuliaLang/julia/pull/55301 lands
980- # we can drop the `zip` and use the multi-arg form.
981- T = eltype (rev_parent)
982- rev_exp_total = mapreduce (
983- + ,
984- zip (base_view, rev_parent, out_view);
985- init = zero (T),
986- ) do (b, rp, o)
987- return b > 0 ? rp * o * log (b) : zero (T)
988- end
989- pos2 = _scalar_pos (f. sizes, ix2)
990- view (f. reverse_storage, pos2: pos2) .= rev_exp_total
991861 continue
992862 end
993863 end
0 commit comments