@@ -165,9 +165,9 @@ function _forward_eval(
165165 idx2 = last (children_indices)
166166 @inbounds ix1 = children_arr[idx1]
167167 @inbounds ix2 = children_arr[idx2]
168- v1 = _view_array (f. forward_storage, f. sizes, ix1)
169- v2 = _view_array (f. forward_storage, f. sizes, ix2)
170- out = _view_array (f. forward_storage, f. sizes, k)
168+ v1 = _view_matrix (f. forward_storage, f. sizes, ix1)
169+ v2 = _view_matrix (f. forward_storage, f. sizes, ix2)
170+ out = _view_matrix (f. forward_storage, f. sizes, k)
171171 LinearAlgebra. mul! (out, v1, v2)
172172 # We deliberately don't write v1/v2 into partials_storage
173173 # here: the matmul reverse branch reads forward_storage
@@ -343,8 +343,8 @@ function _forward_eval(
343343 elseif node. index == 15 # sum
344344 @assert N == 1
345345 ix = children_arr[first (children_indices)]
346- inp = _view_array (f. forward_storage, f. sizes, ix)
347- fill! (_view_array (f. partials_storage, f. sizes, ix), one (T))
346+ inp = _view_linear (f. forward_storage, f. sizes, ix)
347+ fill! (_view_linear (f. partials_storage, f. sizes, ix), one (T))
348348 @s f. forward_storage[k] = sum (inp)
349349 elseif node. index == 16 # row
350350 for j in _eachindex (f. sizes, k)
@@ -393,12 +393,12 @@ function _forward_eval(
393393 child1 = first (children_indices)
394394 @inbounds ix1 = children_arr[child1]
395395 @inbounds ix2 = children_arr[child1+ 1 ]
396- out = _view_array (f. forward_storage, f. sizes, k)
397- v1 = _view_array (f. forward_storage, f. sizes, ix1)
398- v2 = _view_array (f. forward_storage, f. sizes, ix2)
396+ out = _view_linear (f. forward_storage, f. sizes, k)
397+ v1 = _view_linear (f. forward_storage, f. sizes, ix1)
398+ v2 = _view_linear (f. forward_storage, f. sizes, ix2)
399399 out .= v1 .- v2
400- fill! (_view_array (f. partials_storage, f. sizes, ix1), one (T))
401- fill! (_view_array (f. partials_storage, f. sizes, ix2), - one (T))
400+ fill! (_view_linear (f. partials_storage, f. sizes, ix1), one (T))
401+ fill! (_view_linear (f. partials_storage, f. sizes, ix2), - one (T))
402402 elseif node. index == 3 # :* (broadcasted)
403403 # Node `k` is not scalar, so we do matrix multiplication
404404 if f. sizes. ndims[k] != 0
@@ -466,9 +466,9 @@ function _forward_eval(
466466 f. forward_storage,
467467 f. sizes. storage_offset[ix2]+ 1 ,
468468 )
469- out = _view_array (f. forward_storage, f. sizes, k)
470- inp = _view_array (f. forward_storage, f. sizes, ix1)
471- partials = _view_array (f. partials_storage, f. sizes, ix1)
469+ out = _view_linear (f. forward_storage, f. sizes, k)
470+ inp = _view_linear (f. forward_storage, f. sizes, ix1)
471+ partials = _view_linear (f. partials_storage, f. sizes, ix1)
472472 if exponent == 2
473473 out .= inp .* inp
474474 partials .= 2 .* inp
@@ -518,9 +518,9 @@ function _forward_eval(
518518 @j f. forward_storage[k] = - val
519519 end
520520 elseif operators. univariate_operators[node. index] === :tanh
521- out = _view_array (f. forward_storage, f. sizes, k)
522- inp = _view_array (f. forward_storage, f. sizes, child_idx)
523- partials = _view_array (f. partials_storage, f. sizes, child_idx)
521+ out = _view_linear (f. forward_storage, f. sizes, k)
522+ inp = _view_linear (f. forward_storage, f. sizes, child_idx)
523+ partials = _view_linear (f. partials_storage, f. sizes, child_idx)
524524 out .= tanh .(inp)
525525 partials .= one (T) .- out .* out
526526 else
@@ -618,11 +618,11 @@ function _reverse_eval(
618618 idx2 = last (children_indices)
619619 ix1 = children_arr[idx1]
620620 ix2 = children_arr[idx2]
621- v1 = _view_array (f. forward_storage, f. sizes, ix1)
622- v2 = _view_array (f. forward_storage, f. sizes, ix2)
623- rev_parent = _view_array (f. reverse_storage, f. sizes, k)
624- rev_v1 = _view_array (f. reverse_storage, f. sizes, ix1)
625- rev_v2 = _view_array (f. reverse_storage, f. sizes, ix2)
621+ v1 = _view_matrix (f. forward_storage, f. sizes, ix1)
622+ v2 = _view_matrix (f. forward_storage, f. sizes, ix2)
623+ rev_parent = _view_matrix (f. reverse_storage, f. sizes, k)
624+ rev_v1 = _view_matrix (f. reverse_storage, f. sizes, ix1)
625+ rev_v2 = _view_matrix (f. reverse_storage, f. sizes, ix2)
626626 LinearAlgebra. mul! (rev_v1, rev_parent, v2' )
627627 LinearAlgebra. mul! (rev_v2, v1' , rev_parent)
628628 continue
@@ -881,12 +881,12 @@ function _reverse_eval(
881881 # diagonal entries are stored in `f.partials_storage`. We broadcast
882882 # `rev_child .= rev_parent .* partial` over the whole array (with the
883883 # 0 * Inf guard preserved).
884- rev_parent = _view_array (f. reverse_storage, f. sizes, k)
884+ rev_parent = _view_linear (f. reverse_storage, f. sizes, k)
885885 for child_idx in children_indices
886886 ix = children_arr[child_idx]
887887 @assert _size (f. sizes, k) == _size (f. sizes, ix)
888- rev_child = _view_array (f. reverse_storage, f. sizes, ix)
889- partial = _view_array (f. partials_storage, f. sizes, ix)
888+ rev_child = _view_linear (f. reverse_storage, f. sizes, ix)
889+ partial = _view_linear (f. partials_storage, f. sizes, ix)
890890 rev_child .= ifelse .(
891891 (rev_parent .== 0 ) .& .! isfinite .(partial),
892892 rev_parent,
0 commit comments