Skip to content

Commit 3089852

Browse files
committed
Make setscalar GPU-friendly
1 parent 15655b6 commit 3089852

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/sizes.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ function _getscalar(x, sizes::Sizes, k::Int)
4646
end
4747

4848
function _setscalar!(x, value, sizes::Sizes, k::Int)
49-
return x[sizes.storage_offset[k]+1] = value
49+
# Use a 1-element view + broadcast so this works on GPU storage as well as
50+
# `Vector{Float64}`. Direct `x[idx] = value` is a scalar setindex which
51+
# GPUArrays disallows by default.
52+
pos = sizes.storage_offset[k] + 1
53+
view(x, pos:pos) .= value
54+
return value
5055
end
5156

5257
function _getindex(x, sizes::Sizes, k::Int, j)

0 commit comments

Comments
 (0)