Skip to content

Commit 78634c5

Browse files
committed
Fix alloc
1 parent bcc45c3 commit 78634c5

1 file changed

Lines changed: 2 additions & 5 deletions

File tree

src/reverse_mode.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -968,14 +968,11 @@ function _extract_reverse_pass_inner(
968968
@assert length(f.reverse_storage) >= _length(f.sizes)
969969
for (k, node) in enumerate(f.nodes)
970970
if node.type == NODE_VARIABLE_BLOCK
971-
# Each block has a contiguous tape range and a contiguous `output`
972-
# range: gather the adjoint, transfer to host in one memcpy, and
973-
# accumulate into the matching slice of `output`.
974971
tape_range = _storage_range(f.sizes, k)
975972
len = length(tape_range)
976973
x_range = node.index:(node.index+len-1)
977-
cpu_buf = convert(Vector{T}, view(f.reverse_storage, tape_range))
978-
view(output, x_range) .+= scale .* cpu_buf
974+
view(output, x_range) .+=
975+
scale .* view(f.reverse_storage, tape_range)
979976
elseif node.type == NODE_VARIABLE
980977
# Per-leaf scalar — rare, so the per-leaf `cudaMemcpy` is fine.
981978
output[node.index] += scale * @s f.reverse_storage[k]

0 commit comments

Comments
 (0)