The following tests currently fail on jax, all due to random seed-related issues:
tests/test_augmentations.py
tests/test_flow_matching.py
tests/test_generative.py
This was originally fixed in #373, but that made the PR too big to review, so I'm splitting off those changes into a separate PR.
The following tests currently fail on jax, all due to random seed-related issues:
This was originally fixed in #373, but that made the PR too big to review, so I'm splitting off those changes into a separate PR.