Skip to content

Commit 9fdb588

Browse files
committed
Simplify *
1 parent 43d9348 commit 9fdb588

1 file changed

Lines changed: 46 additions & 176 deletions

File tree

src/reverse_mode.jl

Lines changed: 46 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -411,84 +411,15 @@ function _forward_eval(
411411
-,
412412
)
413413
elseif node.index == 3 # :* (broadcasted)
414-
# Node `k` is not scalar, so we do element-wise multiply
415-
# (with scalar-broadcast support: when one operand is
416-
# scalar, broadcast it across the matrix output).
417-
if f.sizes.ndims[k] != 0
418-
@assert N == 2
419-
idx1 = first(children_indices)
420-
idx2 = last(children_indices)
421-
@inbounds ix1 = children_arr[idx1]
422-
@inbounds ix2 = children_arr[idx2]
423-
out = _view_linear(f.forward_storage, f.sizes, k)
424-
ndims1 = f.sizes.ndims[ix1]
425-
ndims2 = f.sizes.ndims[ix2]
426-
if ndims1 == 0 && ndims2 != 0
427-
s = _getscalar(f.forward_storage, f.sizes, ix1)
428-
v = _view_linear(f.forward_storage, f.sizes, ix2)
429-
out .= s .* v
430-
# Per-element partial w.r.t. the matrix child is
431-
# the scalar; the scalar child's reverse is handled
432-
# by the broadcasted-`:*` reverse branch below
433-
# (sum of `rev_parent .* v`).
434-
fill!(_view_linear(f.partials_storage, f.sizes, ix2), s)
435-
elseif ndims1 != 0 && ndims2 == 0
436-
v = _view_linear(f.forward_storage, f.sizes, ix1)
437-
s = _getscalar(f.forward_storage, f.sizes, ix2)
438-
out .= v .* s
439-
fill!(_view_linear(f.partials_storage, f.sizes, ix1), s)
440-
else
441-
# Both children are arrays of the same shape —
442-
# original element-wise path.
443-
v1 = zeros(_size(f.sizes, ix1)...)
444-
v2 = zeros(_size(f.sizes, ix2)...)
445-
for j in _eachindex(f.sizes, ix1)
446-
v1[j] = @j f.forward_storage[ix1]
447-
@j f.partials_storage[ix2] = v1[j]
448-
end
449-
for j in _eachindex(f.sizes, ix2)
450-
v2[j] = @j f.forward_storage[ix2]
451-
@j f.partials_storage[ix1] = v2[j]
452-
end
453-
for j in _eachindex(f.sizes, k)
454-
@j f.forward_storage[k] = v1[j] * v2[j]
455-
end
456-
end
457-
# Node `k` is scalar
458-
else
459-
tmp_prod = one(T)
460-
for c_idx in children_indices
461-
@inbounds tmp_prod *=
462-
f.forward_storage[children_arr[c_idx]]
463-
end
464-
if tmp_prod == zero(T) || N <= 2
465-
# This is inefficient if there are a lot of children.
466-
# 2 is chosen as a limit because (x*y)/y does not always
467-
# equal x for floating-point numbers. This can produce
468-
# unexpected error in partials. There's still an error when
469-
# multiplying three or more terms, but users are less likely
470-
# to complain about it.
471-
for c_idx in children_indices
472-
prod_others = one(T)
473-
for c_idx2 in children_indices
474-
(c_idx == c_idx2) && continue
475-
ix = children_arr[c_idx2]
476-
prod_others *= f.forward_storage[ix]
477-
end
478-
f.partials_storage[children_arr[c_idx]] =
479-
prod_others
480-
end
481-
else
482-
# Compute all-minus-one partial derivatives by dividing from
483-
# the total product.
484-
for c_idx in children_indices
485-
ix = children_arr[c_idx]
486-
f.partials_storage[ix] =
487-
tmp_prod / f.forward_storage[ix]
488-
end
489-
end
490-
@inbounds f.forward_storage[k] = tmp_prod
491-
end
414+
@assert N == 2
415+
child1 = first(children_indices)
416+
_reshape_call(
417+
f.forward_storage,
418+
f.sizes,
419+
(k, children_arr[child1], children_arr[child1+1]),
420+
broadcast!,
421+
*,
422+
)
492423
elseif node.index == 4 # :^ (broadcasted), array .^ scalar
493424
@assert N == 2
494425
idx1 = first(children_indices)
@@ -605,6 +536,35 @@ function _forward_eval(
605536
return f.forward_storage[1]
606537
end
607538

539+
function _reverse_broadcasted_mul(dout, dlhs, drhs, lhs, rhs)
540+
# Would need `conj` once we support `Complex`
541+
Base.mapreducedim!(
542+
identity,
543+
Base.add_sum,
544+
dlhs,
545+
Broadcast.instantiate(Broadcast.broadcasted(*, dout, rhs)),
546+
)
547+
Base.mapreducedim!(
548+
identity,
549+
Base.add_sum,
550+
drhs,
551+
Broadcast.instantiate(Broadcast.broadcasted(*, lhs, dout)),
552+
)
553+
return
554+
end
555+
556+
function __reverse_broadcasted_mul(f, ilhs, irhs, dout, dlhs, drhs)
557+
return _reshape_call(
558+
f.forward_storage,
559+
f.sizes,
560+
(ilhs, irhs),
561+
_reverse_broadcasted_mul,
562+
dout,
563+
dlhs,
564+
drhs,
565+
)
566+
end
567+
608568
"""
609569
_reverse_eval(f::_SubexpressionStorage)
610570
@@ -889,105 +849,15 @@ function _reverse_eval(
889849
continue
890850
end
891851
if op == :*
892-
if f.sizes.ndims[k] != 0
893-
idx1 = first(children_indices)
894-
idx2 = last(children_indices)
895-
ix1 = children_arr[idx1]
896-
ix2 = children_arr[idx2]
897-
rev_parent = _view_linear(f.reverse_storage, f.sizes, k)
898-
ndims1 = f.sizes.ndims[ix1]
899-
ndims2 = f.sizes.ndims[ix2]
900-
if ndims1 == 0 && ndims2 != 0
901-
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
902-
s1 = _getscalar(f.forward_storage, f.sizes, ix1)
903-
rev_v2 =
904-
_view_linear(f.reverse_storage, f.sizes, ix2)
905-
rev_v2 .= rev_parent .* s1
906-
_setscalar!(
907-
f.reverse_storage,
908-
LinearAlgebra.dot(rev_parent, v2),
909-
f.sizes,
910-
ix1,
911-
)
912-
continue
913-
elseif ndims1 != 0 && ndims2 == 0
914-
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
915-
s2 = _getscalar(f.forward_storage, f.sizes, ix2)
916-
rev_v1 =
917-
_view_linear(f.reverse_storage, f.sizes, ix1)
918-
rev_v1 .= rev_parent .* s2
919-
_setscalar!(
920-
f.reverse_storage,
921-
LinearAlgebra.dot(rev_parent, v1),
922-
f.sizes,
923-
ix2,
924-
)
925-
continue
926-
end
927-
# Both children are arrays of the same shape —
928-
# original element-wise path.
929-
v1 = zeros(_size(f.sizes, ix1)...)
930-
v2 = zeros(_size(f.sizes, ix2)...)
931-
for j in _eachindex(f.sizes, ix1)
932-
v1[j] = @j f.forward_storage[ix1]
933-
end
934-
for j in _eachindex(f.sizes, ix2)
935-
v2[j] = @j f.forward_storage[ix2]
936-
end
937-
rev_parent = zeros(_size(f.sizes, k)...)
938-
for j in _eachindex(f.sizes, k)
939-
rev_parent[j] = @j f.reverse_storage[k]
940-
end
941-
rev_v1 = zeros(_size(f.sizes, ix1)...)
942-
rev_v2 = zeros(_size(f.sizes, ix2)...)
943-
for j in _eachindex(f.sizes, ix1)
944-
rev_v1[j] = rev_parent[j] * v2[j]
945-
@j f.reverse_storage[ix1] = rev_v1[j]
946-
end
947-
for j in _eachindex(f.sizes, ix2)
948-
rev_v2[j] = rev_parent[j] * v1[j]
949-
@j f.reverse_storage[ix2] = rev_v2[j]
950-
end
951-
continue
952-
end
953-
elseif op == :^
954-
# Broadcasted array .^ scalar: vectorize the per-element
955-
# base reverse (with 0*Inf guard preserved) and reduce
956-
# the exponent contribution as a single `sum` over GPU
957-
# arrays.
958-
@assert length(children_indices) == 2
959-
idx1 = first(children_indices)
960-
idx2 = last(children_indices)
961-
@inbounds ix1 = children_arr[idx1]
962-
@inbounds ix2 = children_arr[idx2]
963-
rev_parent = _view_linear(f.reverse_storage, f.sizes, k)
964-
rev_v1 = _view_linear(f.reverse_storage, f.sizes, ix1)
965-
partial = _view_linear(f.partials_storage, f.sizes, ix1)
966-
rev_v1 .= ifelse.(
967-
(rev_parent .== 0) .& .!isfinite.(partial),
968-
rev_parent,
969-
rev_parent .* partial,
852+
_reshape_call(
853+
f.reverse_storage,
854+
f.sizes,
855+
(k, lhs, rhs),
856+
__reverse_broadcasted_mul,
857+
f,
858+
lhs,
859+
rhs,
970860
)
971-
base_view = _view_linear(f.forward_storage, f.sizes, ix1)
972-
out_view = _view_linear(f.forward_storage, f.sizes, k)
973-
# `mapreduce(f, +, base_view, rev_parent, out_view)`
974-
# would express this directly, but multi-iterable
975-
# `mapreduce` materializes an intermediate today
976-
# (JuliaLang/julia#53417). Wrap the inputs in `zip` so
977-
# the single-iterable specialization fires and the
978-
# reduction stays allocation-free. Once
979-
# https://github.com/JuliaLang/julia/pull/55301 lands
980-
# we can drop the `zip` and use the multi-arg form.
981-
T = eltype(rev_parent)
982-
rev_exp_total = mapreduce(
983-
+,
984-
zip(base_view, rev_parent, out_view);
985-
init = zero(T),
986-
) do (b, rp, o)
987-
return b > 0 ? rp * o * log(b) : zero(T)
988-
end
989-
pos2 = _scalar_pos(f.sizes, ix2)
990-
view(f.reverse_storage, pos2:pos2) .= rev_exp_total
991861
continue
992862
end
993863
end

0 commit comments

Comments
 (0)