Skip to content

Commit 09f1f95

Browse files
committed
Vectorized power
1 parent b554b3b commit 09f1f95

1 file changed

Lines changed: 23 additions & 22 deletions

File tree

src/reverse_mode.jl

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -842,33 +842,34 @@ function _reverse_eval(
842842
continue
843843
end
844844
elseif op == :^
845-
# Broadcasted array .^ scalar: per-j reverse for the base,
846-
# and a sum-reduced reverse for the (scalar) exponent.
845+
# Broadcasted array .^ scalar: vectorize the per-element
846+
# base reverse (with 0*Inf guard preserved) and reduce
847+
# the exponent contribution as a single `sum` over GPU
848+
# arrays.
847849
@assert length(children_indices) == 2
848850
idx1 = first(children_indices)
849851
idx2 = last(children_indices)
850852
@inbounds ix1 = children_arr[idx1]
851853
@inbounds ix2 = children_arr[idx2]
852-
for j in _eachindex(f.sizes, k)
853-
rev_parent = @j f.reverse_storage[k]
854-
partial = @j f.partials_storage[ix1]
855-
val = ifelse(
856-
rev_parent == 0.0 && !isfinite(partial),
857-
rev_parent,
858-
rev_parent * partial,
859-
)
860-
@j f.reverse_storage[ix1] = val
861-
end
862-
rev_exp = zero(Float64)
863-
for j in _eachindex(f.sizes, k)
864-
rev_parent = @j f.reverse_storage[k]
865-
base = @j f.forward_storage[ix1]
866-
out = @j f.forward_storage[k]
867-
if base > 0
868-
rev_exp += rev_parent * out * log(base)
869-
end
870-
end
871-
@s f.reverse_storage[ix2] = rev_exp
854+
rev_parent = _view_array(f.reverse_storage, f.sizes, k)
855+
rev_v1 = _view_array(f.reverse_storage, f.sizes, ix1)
856+
partial = _view_array(f.partials_storage, f.sizes, ix1)
857+
rev_v1 .= ifelse.(
858+
(rev_parent .== 0) .& .!isfinite.(partial),
859+
rev_parent,
860+
rev_parent .* partial,
861+
)
862+
base_view = _view_array(f.forward_storage, f.sizes, ix1)
863+
out_view = _view_array(f.forward_storage, f.sizes, k)
864+
rev_exp_total = sum(
865+
ifelse.(
866+
base_view .> 0,
867+
rev_parent .* out_view .* log.(abs.(base_view)),
868+
zero(Float64),
869+
),
870+
)
871+
pos2 = _scalar_pos(f.sizes, ix2)
872+
view(f.reverse_storage, pos2:pos2) .= rev_exp_total
872873
continue
873874
end
874875
end

0 commit comments

Comments
 (0)