@@ -842,33 +842,34 @@ function _reverse_eval(
842842 continue
843843 end
844844 elseif op == :^
845- # Broadcasted array .^ scalar: per-j reverse for the base,
846- # and a sum-reduced reverse for the (scalar) exponent.
845+ # Broadcasted array .^ scalar: vectorize the per-element
846+ # base reverse (with 0*Inf guard preserved) and reduce
847+ # the exponent contribution as a single `sum` over GPU
848+ # arrays.
847849 @assert length (children_indices) == 2
848850 idx1 = first (children_indices)
849851 idx2 = last (children_indices)
850852 @inbounds ix1 = children_arr[idx1]
851853 @inbounds ix2 = children_arr[idx2]
852- for j in _eachindex (f. sizes, k)
853- rev_parent = @j f. reverse_storage[k]
854- partial = @j f. partials_storage[ix1]
855- val = ifelse (
856- rev_parent == 0.0 && ! isfinite (partial),
857- rev_parent,
858- rev_parent * partial,
859- )
860- @j f. reverse_storage[ix1] = val
861- end
862- rev_exp = zero (Float64)
863- for j in _eachindex (f. sizes, k)
864- rev_parent = @j f. reverse_storage[k]
865- base = @j f. forward_storage[ix1]
866- out = @j f. forward_storage[k]
867- if base > 0
868- rev_exp += rev_parent * out * log (base)
869- end
870- end
871- @s f. reverse_storage[ix2] = rev_exp
854+ rev_parent = _view_array (f. reverse_storage, f. sizes, k)
855+ rev_v1 = _view_array (f. reverse_storage, f. sizes, ix1)
856+ partial = _view_array (f. partials_storage, f. sizes, ix1)
857+ rev_v1 .= ifelse .(
858+ (rev_parent .== 0 ) .& .! isfinite .(partial),
859+ rev_parent,
860+ rev_parent .* partial,
861+ )
862+ base_view = _view_array (f. forward_storage, f. sizes, ix1)
863+ out_view = _view_array (f. forward_storage, f. sizes, k)
864+ rev_exp_total = sum (
865+ ifelse .(
866+ base_view .> 0 ,
867+ rev_parent .* out_view .* log .(abs .(base_view)),
868+ zero (Float64),
869+ ),
870+ )
871+ pos2 = _scalar_pos (f. sizes, ix2)
872+ view (f. reverse_storage, pos2: pos2) .= rev_exp_total
872873 continue
873874 end
874875 end
0 commit comments