File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -968,14 +968,11 @@ function _extract_reverse_pass_inner(
968968 @assert length (f. reverse_storage) >= _length (f. sizes)
969969 for (k, node) in enumerate (f. nodes)
970970 if node. type == NODE_VARIABLE_BLOCK
971- # Each block has a contiguous tape range and a contiguous `output`
972- # range: gather the adjoint, transfer to host in one memcpy, and
973- # accumulate into the matching slice of `output`.
974971 tape_range = _storage_range (f. sizes, k)
975972 len = length (tape_range)
976973 x_range = node. index: (node. index+ len- 1 )
977- cpu_buf = convert (Vector{T}, view (f . reverse_storage, tape_range))
978- view (output, x_range) .+ = scale .* cpu_buf
974+ view (output, x_range) .+ =
975+ scale .* view (f . reverse_storage, tape_range)
979976 elseif node. type == NODE_VARIABLE
980977 # Per-leaf scalar — rare, so the per-leaf `cudaMemcpy` is fine.
981978 output[node. index] += scale * @s f. reverse_storage[k]
You can’t perform that action at this time.
0 commit comments