Skip to content

make CUDA randn work with Zygote#2581

Draft
bgctw wants to merge 1 commit into
JuliaGPU:mainfrom
bgctw:chainrules_randn
Draft

make CUDA randn work with Zygote#2581
bgctw wants to merge 1 commit into
JuliaGPU:mainfrom
bgctw:chainrules_randn

Conversation

@bgctw
Copy link
Copy Markdown

@bgctw bgctw commented Dec 9, 2024

Currently, I get errors when using CUDA in combination with Zygote and random numbers.
mcabbott adviced to add a @non_differentiable CUDA.randn rule for CUDA.randn to CUDAs ChainRulesCoreExt, so that all users can benefit.

Copy link
Copy Markdown
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Needs a rebase for the CI failure.

Comment thread ext/ChainRulesCoreExt.jl
module ChainRulesCoreExt

using CUDA: CuArray
using CUDA: CuArray, CUDA
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using CUDA: CuArray, CUDA
using CUDA

Comment thread test/extensions/zygote.jl
@@ -0,0 +1,19 @@
using GPUArraysCore: GPUArraysCore
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using GPUArraysCore: GPUArraysCore
using GPUArrays

GPUArrays re-exports the Core functionality.

Comment thread test/extensions/zygote.jl
function call_rand(v::AbstractVector{T}) where {T}
randn(T, 4,4) * v[1:4]
end
function call_rand(v::GPUArraysCore.AbstractGPUVector{T}) where {T}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function call_rand(v::GPUArraysCore.AbstractGPUVector{T}) where {T}
function call_rand(v::AbstractGPUVector{T}) where {T}

Comment thread ext/ChainRulesCoreExt.jl
ChainRulesCore.is_inplaceable_destination(::CuArray) = true

# allow usage of rand with Zygote
ChainRulesCore.@non_differentiable CUDA.randn(::Any...)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rand too?

Suggested change
ChainRulesCore.@non_differentiable CUDA.randn(::Any...)
ChainRulesCore.@non_differentiable CUDA.rand(::Any...)
ChainRulesCore.@non_differentiable CUDA.randn(::Any...)

Tab completion says there are a few more, but not marked public, so IDK:

julia> CUDA.rand
rand              rand_logn
rand_logn!        rand_poisson
rand_poisson!     randexp_unlikely
randn             randn_unlikely

@maleadt maleadt force-pushed the master branch 15 times, most recently from 5d585c4 to c850163 Compare December 20, 2024 08:18
@maleadt maleadt marked this pull request as draft January 8, 2025 10:06
@maleadt maleadt added enhancement New feature or request needs changes Changes are needed. labels Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request needs changes Changes are needed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants