make CUDA randn work with Zygote#2581
Draft
bgctw wants to merge 1 commit into
Draft
Conversation
maleadt
requested changes
Dec 10, 2024
Member
maleadt
left a comment
There was a problem hiding this comment.
Thanks. Needs a rebase for the CI failure.
| module ChainRulesCoreExt | ||
|
|
||
| using CUDA: CuArray | ||
| using CUDA: CuArray, CUDA |
Member
There was a problem hiding this comment.
Suggested change
| using CUDA: CuArray, CUDA | |
| using CUDA |
| @@ -0,0 +1,19 @@ | |||
| using GPUArraysCore: GPUArraysCore | |||
Member
There was a problem hiding this comment.
Suggested change
| using GPUArraysCore: GPUArraysCore | |
| using GPUArrays |
GPUArrays re-exports the Core functionality.
| function call_rand(v::AbstractVector{T}) where {T} | ||
| randn(T, 4,4) * v[1:4] | ||
| end | ||
| function call_rand(v::GPUArraysCore.AbstractGPUVector{T}) where {T} |
Member
There was a problem hiding this comment.
Suggested change
| function call_rand(v::GPUArraysCore.AbstractGPUVector{T}) where {T} | |
| function call_rand(v::AbstractGPUVector{T}) where {T} |
mcabbott
reviewed
Dec 15, 2024
| ChainRulesCore.is_inplaceable_destination(::CuArray) = true | ||
|
|
||
| # allow usage of rand with Zygote | ||
| ChainRulesCore.@non_differentiable CUDA.randn(::Any...) |
Contributor
There was a problem hiding this comment.
rand too?
Suggested change
| ChainRulesCore.@non_differentiable CUDA.randn(::Any...) | |
| ChainRulesCore.@non_differentiable CUDA.rand(::Any...) | |
| ChainRulesCore.@non_differentiable CUDA.randn(::Any...) |
Tab completion says there are a few more, but not marked public, so IDK:
julia> CUDA.rand
rand rand_logn
rand_logn! rand_poisson
rand_poisson! randexp_unlikely
randn randn_unlikely
5d585c4 to
c850163
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Currently, I get errors when using CUDA in combination with Zygote and random numbers.
mcabbott adviced to add a
@non_differentiable CUDA.randnrule forCUDA.randnto CUDAs ChainRulesCoreExt, so that all users can benefit.