@@ -852,33 +852,43 @@ function _reverse_eval(
852852 continue
853853 end
854854 elseif op == :^
855- # Broadcasted array .^ scalar: per-j reverse for the base,
856- # and a sum-reduced reverse for the (scalar) exponent.
855+ # Broadcasted array .^ scalar: vectorize the per-element
856+ # base reverse (with 0*Inf guard preserved) and reduce
857+ # the exponent contribution as a single `sum` over GPU
858+ # arrays.
857859 @assert length (children_indices) == 2
858860 idx1 = first (children_indices)
859861 idx2 = last (children_indices)
860862 @inbounds ix1 = children_arr[idx1]
861863 @inbounds ix2 = children_arr[idx2]
862- for j in _eachindex (f. sizes, k)
863- rev_parent = @j f. reverse_storage[k]
864- partial = @j f. partials_storage[ix1]
865- val = ifelse (
866- rev_parent == 0.0 && ! isfinite (partial),
867- rev_parent,
868- rev_parent * partial,
869- )
870- @j f. reverse_storage[ix1] = val
871- end
872- rev_exp = zero (Float64)
873- for j in _eachindex (f. sizes, k)
874- rev_parent = @j f. reverse_storage[k]
875- base = @j f. forward_storage[ix1]
876- out = @j f. forward_storage[k]
877- if base > 0
878- rev_exp += rev_parent * out * log (base)
879- end
864+ rev_parent = _view_linear (f. reverse_storage, f. sizes, k)
865+ rev_v1 = _view_linear (f. reverse_storage, f. sizes, ix1)
866+ partial = _view_linear (f. partials_storage, f. sizes, ix1)
867+ rev_v1 .= ifelse .(
868+ (rev_parent .== 0 ) .& .! isfinite .(partial),
869+ rev_parent,
870+ rev_parent .* partial,
871+ )
872+ base_view = _view_linear (f. forward_storage, f. sizes, ix1)
873+ out_view = _view_linear (f. forward_storage, f. sizes, k)
874+ # `mapreduce(f, +, base_view, rev_parent, out_view)`
875+ # would express this directly, but multi-iterable
876+ # `mapreduce` materializes an intermediate today
877+ # (JuliaLang/julia#53417). Wrap the inputs in `zip` so
878+ # the single-iterable specialization fires and the
879+ # reduction stays allocation-free. Once
880+ # https://github.com/JuliaLang/julia/pull/55301 lands
881+ # we can drop the `zip` and use the multi-arg form.
882+ T = eltype (rev_parent)
883+ rev_exp_total = mapreduce (
884+ + ,
885+ zip (base_view, rev_parent, out_view);
886+ init = zero (T),
887+ ) do (b, rp, o)
888+ return b > 0 ? rp * o * log (b) : zero (T)
880889 end
881- @s f. reverse_storage[ix2] = rev_exp
890+ pos2 = _scalar_pos (f. sizes, ix2)
891+ view (f. reverse_storage, pos2: pos2) .= rev_exp_total
882892 continue
883893 end
884894 end
0 commit comments