You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add JAX array support to pted and pted_coverage_test (#14)
* Initial plan
* Add JAX compatibility to pted and pted_coverage_test
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
* Add jax as a dev dependency
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
* Update README for JAX support and add utils unit tests with mocking
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
* Optimize _jax_cdist: squared-norm identity for L2, vmap for general p
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
* Update pted() signatures/docstrings for jax.Array; add jax optional extra in pyproject.toml
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
* Add cross-backend consistency tests for numpy/torch/JAX energy distance utils
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
***x***(Union[np.ndarray, Tensor])*: first set of samples. Shape (N, *D)
277
-
***y***(Union[np.ndarray, Tensor])*: second set of samples. Shape (M, *D)
282
+
***x***(Union[np.ndarray, Tensor, jax.Array])*: first set of samples. Shape (N, *D)
283
+
***y***(Union[np.ndarray, Tensor, jax.Array])*: second set of samples. Shape (M, *D)
278
284
***permutations***(int)*: number of permutations to run. This determines how accurately the p-value is computed.
279
-
***metric***(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf.
285
+
***metric***(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float.
280
286
***return_all***(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False)
281
287
***chunk_size***(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset.
282
288
***chunk_iter***(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size.
***permutations***(int)*: number of permutations to run. This determines how accurately the p-value is computed.
307
-
***metric***(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf.
313
+
***metric***(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float.
308
314
***return_all***(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False)
309
315
***chunk_size***(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset.
310
316
***chunk_iter***(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size.
@@ -315,9 +321,9 @@ def pted_coverage_test(
315
321
## GPU Compatibility
316
322
317
323
PTED works on both CPU and GPU. All that is needed is to pass the `x` and `y` as
318
-
PyTorch Tensors on the appropriate device.
324
+
PyTorch Tensors or JAX Arrays on the appropriate device.
0 commit comments