You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on May 6, 2025. It is now read-only.
I have jax 0.5.2 and jaxlib 0.5.1 (the default versions that are installed using pip). But when trying to import the library:
from neural_tangents import stax
I get the error:
AttributeError: module 'jax.random' has no attribute 'KeyArray'
It looks like this had been depreciated in newer jax version so I installed jax==0.4.23. However, then I am unable to install the corresponding jaxlib version:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.23 (from versions: 0.4.34, 0.4.35, 0.4.36, 0.4.38, 0.5.0, 0.5.1)
I have jax 0.5.2 and jaxlib 0.5.1 (the default versions that are installed using pip). But when trying to import the library:
from neural_tangents import staxI get the error:
AttributeError: module 'jax.random' has no attribute 'KeyArray'It looks like this had been depreciated in newer jax version so I installed jax==0.4.23. However, then I am unable to install the corresponding jaxlib version:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.23 (from versions: 0.4.34, 0.4.35, 0.4.36, 0.4.38, 0.5.0, 0.5.1)How can I fix this?