Skip to content

Commit db9db16

Browse files
authored
Vectorized power (#50)
* Vectorized power * Fix * Fix * Fix
1 parent 59f0d97 commit db9db16

2 files changed

Lines changed: 34 additions & 21 deletions

File tree

src/reverse_mode.jl

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -852,33 +852,43 @@ function _reverse_eval(
852852
continue
853853
end
854854
elseif op == :^
855-
# Broadcasted array .^ scalar: per-j reverse for the base,
856-
# and a sum-reduced reverse for the (scalar) exponent.
855+
# Broadcasted array .^ scalar: vectorize the per-element
856+
# base reverse (with 0*Inf guard preserved) and reduce
857+
# the exponent contribution as a single `sum` over GPU
858+
# arrays.
857859
@assert length(children_indices) == 2
858860
idx1 = first(children_indices)
859861
idx2 = last(children_indices)
860862
@inbounds ix1 = children_arr[idx1]
861863
@inbounds ix2 = children_arr[idx2]
862-
for j in _eachindex(f.sizes, k)
863-
rev_parent = @j f.reverse_storage[k]
864-
partial = @j f.partials_storage[ix1]
865-
val = ifelse(
866-
rev_parent == 0.0 && !isfinite(partial),
867-
rev_parent,
868-
rev_parent * partial,
869-
)
870-
@j f.reverse_storage[ix1] = val
871-
end
872-
rev_exp = zero(Float64)
873-
for j in _eachindex(f.sizes, k)
874-
rev_parent = @j f.reverse_storage[k]
875-
base = @j f.forward_storage[ix1]
876-
out = @j f.forward_storage[k]
877-
if base > 0
878-
rev_exp += rev_parent * out * log(base)
879-
end
864+
rev_parent = _view_linear(f.reverse_storage, f.sizes, k)
865+
rev_v1 = _view_linear(f.reverse_storage, f.sizes, ix1)
866+
partial = _view_linear(f.partials_storage, f.sizes, ix1)
867+
rev_v1 .= ifelse.(
868+
(rev_parent .== 0) .& .!isfinite.(partial),
869+
rev_parent,
870+
rev_parent .* partial,
871+
)
872+
base_view = _view_linear(f.forward_storage, f.sizes, ix1)
873+
out_view = _view_linear(f.forward_storage, f.sizes, k)
874+
# `mapreduce(f, +, base_view, rev_parent, out_view)`
875+
# would express this directly, but multi-iterable
876+
# `mapreduce` materializes an intermediate today
877+
# (JuliaLang/julia#53417). Wrap the inputs in `zip` so
878+
# the single-iterable specialization fires and the
879+
# reduction stays allocation-free. Once
880+
# https://github.com/JuliaLang/julia/pull/55301 lands
881+
# we can drop the `zip` and use the multi-arg form.
882+
T = eltype(rev_parent)
883+
rev_exp_total = mapreduce(
884+
+,
885+
zip(base_view, rev_parent, out_view);
886+
init = zero(T),
887+
) do (b, rp, o)
888+
return b > 0 ? rp * o * log(b) : zero(T)
880889
end
881-
@s f.reverse_storage[ix2] = rev_exp
890+
pos2 = _scalar_pos(f.sizes, ix2)
891+
view(f.reverse_storage, pos2:pos2) .= rev_exp_total
882892
continue
883893
end
884894
end

test/JuMP.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ end
314314
# is the path that actually re-runs forward+reverse, not the
315315
# `last_x == x` short-circuit).
316316
function test_neural_allocations()
317+
if VERSION < v"1.12"
318+
return
319+
end
317320
n = 2
318321
X = [1.0 0.5; 0.3 0.8]
319322
target = [0.5 0.2; 0.1 0.7]

0 commit comments

Comments
 (0)