Skip to content

Commit 728bd53

Browse files
Add JAX array support to pted and pted_coverage_test (#14)
* Initial plan * Add JAX compatibility to pted and pted_coverage_test Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Add jax as a dev dependency Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Update README for JAX support and add utils unit tests with mocking Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Optimize _jax_cdist: squared-norm identity for L2, vmap for general p Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Update pted() signatures/docstrings for jax.Array; add jax optional extra in pyproject.toml Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Add cross-backend consistency tests for numpy/torch/JAX energy distance utils Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
1 parent a5a45a8 commit 728bd53

5 files changed

Lines changed: 423 additions & 28 deletions

File tree

README.md

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ If you want to run PTED on GPUs using PyTorch, then also install torch:
2929
pip install torch
3030
```
3131

32+
If you want to use JAX arrays as inputs, then also install jax:
33+
34+
```bash
35+
pip install jax
36+
```
37+
3238
The two functions are ``pted.pted`` and ``pted.pted_coverage_test``. For
3339
information about each argument, just use ``help(pted.pted)`` or
3440
``help(pted.pted_coverage_test)``.
@@ -261,8 +267,8 @@ results you are getting!
261267

262268
```python
263269
def pted(
264-
x: Union[np.ndarray, "Tensor"],
265-
y: Union[np.ndarray, "Tensor"],
270+
x: Union[np.ndarray, "Tensor", "jax.Array"],
271+
y: Union[np.ndarray, "Tensor", "jax.Array"],
266272
permutations: int = 1000,
267273
metric: Union[str, float] = "euclidean",
268274
return_all: bool = False,
@@ -273,10 +279,10 @@ def pted(
273279
) -> Union[float, tuple[float, np.ndarray, float]]:
274280
```
275281

276-
* **x** *(Union[np.ndarray, Tensor])*: first set of samples. Shape (N, *D)
277-
* **y** *(Union[np.ndarray, Tensor])*: second set of samples. Shape (M, *D)
282+
* **x** *(Union[np.ndarray, Tensor, jax.Array])*: first set of samples. Shape (N, *D)
283+
* **y** *(Union[np.ndarray, Tensor, jax.Array])*: second set of samples. Shape (M, *D)
278284
* **permutations** *(int)*: number of permutations to run. This determines how accurately the p-value is computed.
279-
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf.
285+
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float.
280286
* **return_all** *(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False)
281287
* **chunk_size** *(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset.
282288
* **chunk_iter** *(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size.
@@ -287,8 +293,8 @@ def pted(
287293

288294
```python
289295
def pted_coverage_test(
290-
g: Union[np.ndarray, "Tensor"],
291-
s: Union[np.ndarray, "Tensor"],
296+
g: Union[np.ndarray, "Tensor", "jax.Array"],
297+
s: Union[np.ndarray, "Tensor", "jax.Array"],
292298
permutations: int = 1000,
293299
metric: Union[str, float] = "euclidean",
294300
warn_confidence: Optional[float] = 1e-3,
@@ -301,10 +307,10 @@ def pted_coverage_test(
301307
) -> Union[float, tuple[np.ndarray, np.ndarray, float]]:
302308
```
303309

304-
* **g** *(Union[np.ndarray, Tensor])*: Ground truth samples. Shape (n_sims, *D)
305-
* **s** *(Union[np.ndarray, Tensor])*: Posterior samples. Shape (n_samples, n_sims, *D)
310+
* **g** *(Union[np.ndarray, Tensor, jax.Array])*: Ground truth samples. Shape (n_sims, *D)
311+
* **s** *(Union[np.ndarray, Tensor, jax.Array])*: Posterior samples. Shape (n_samples, n_sims, *D)
306312
* **permutations** *(int)*: number of permutations to run. This determines how accurately the p-value is computed.
307-
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf.
313+
* **metric** *(Union[str, float])*: distance metric to use. See scipy.spatial.distance.cdist for the list of available metrics with numpy. See torch.cdist when using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. When using JAX arrays, the metric is passed as the "ord" for jnp.linalg.norm and therefore is also a float.
308314
* **return_all** *(bool)*: if True, return the test statistic and the permuted statistics with the p-value. If False, just return the p-value. bool (default: False)
309315
* **chunk_size** *(Optional[int])*: if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full dataset.
310316
* **chunk_iter** *(Optional[int])*: The chunk iter is the number of iterations to use with the given chunk size.
@@ -315,9 +321,9 @@ def pted_coverage_test(
315321
## GPU Compatibility
316322

317323
PTED works on both CPU and GPU. All that is needed is to pass the `x` and `y` as
318-
PyTorch Tensors on the appropriate device.
324+
PyTorch Tensors or JAX Arrays on the appropriate device.
319325

320-
Example:
326+
Example with PyTorch:
321327
```python
322328
from pted import pted
323329
import numpy as np
@@ -330,6 +336,19 @@ p_value = pted(torch.tensor(x), torch.tensor(y))
330336
print(f"p-value: {p_value:.3f}") # expect uniform random from 0-1
331337
```
332338

339+
Example with JAX:
340+
```python
341+
from pted import pted
342+
import numpy as np
343+
import jax.numpy as jnp
344+
345+
x = np.random.normal(size = (500, 10)) # (n_samples_x, n_dimensions)
346+
y = np.random.normal(size = (400, 10)) # (n_samples_y, n_dimensions)
347+
348+
p_value = pted(jnp.array(x), jnp.array(y))
349+
print(f"p-value: {p_value:.3f}") # expect uniform random from 0-1
350+
```
351+
333352
## Memory and Compute limitations
334353

335354
If a GPU isn't enough to get PTED running fast enough for you, or if you are

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ dev = [
4141
"pytest-cov>=4.1,<5",
4242
"pytest-mock>=3.12,<4",
4343
"torch>=2.0,<3",
44+
"jax>=0.4,<1",
4445
"matplotlib",
4546
]
4647
torch = [
4748
"torch>=2.0,<3",
4849
]
50+
jax = [
51+
"jax>=0.4,<1",
52+
]
4953

5054
[tool.hatch.metadata.hooks.requirements_txt]
5155
files = ["requirements.txt"]

src/pted/pted.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44

55
from .utils import (
66
is_torch_tensor,
7+
is_jax_array,
78
pted_torch,
89
pted_numpy,
910
pted_chunk_torch,
1011
pted_chunk_numpy,
12+
pted_jax,
13+
pted_chunk_jax,
1114
two_tailed_p,
1215
confidence_alert,
1316
simulation_based_calibration_histogram,
@@ -17,8 +20,8 @@
1720

1821

1922
def pted(
20-
x: Union[np.ndarray, "Tensor"],
21-
y: Union[np.ndarray, "Tensor"],
23+
x: Union[np.ndarray, "Tensor", "jax.Array"],
24+
y: Union[np.ndarray, "Tensor", "jax.Array"],
2225
permutations: int = 1000,
2326
metric: Union[str, float] = "euclidean",
2427
return_all: bool = False,
@@ -72,14 +75,17 @@ def pted(
7275
7376
Parameters
7477
----------
75-
x (Union[np.ndarray, Tensor]): first set of samples. Shape (N, *D)
76-
y (Union[np.ndarray, Tensor]): second set of samples. Shape (M, *D)
78+
x (Union[np.ndarray, Tensor, jax.Array]): first set of samples. Shape (N, *D)
79+
y (Union[np.ndarray, Tensor, jax.Array]): second set of samples. Shape (M, *D)
7780
permutations (int): number of permutations to run. This determines how
7881
accurately the p-value is computed.
79-
metric (Union[str, float]): distance metric to use. See scipy.spatial.distance.cdist
80-
for the list of available metrics with numpy. See torch.cdist when
81-
using PyTorch, note that the metric is passed as the "p" for
82-
torch.cdist and therefore is a float from 0 to inf.
82+
metric (Union[str, float]): distance metric to use. For NumPy inputs,
83+
see scipy.spatial.distance.cdist for available metrics. For PyTorch
84+
inputs, the metric is passed as the "p" argument to torch.cdist and
85+
therefore is a float from 0 to inf. For JAX inputs, "euclidean" uses
86+
the squared-norm identity (p=2), and any float p uses
87+
jnp.linalg.norm with ord=p; string metrics other than "euclidean"
88+
are not supported for JAX.
8389
return_all (bool): if True, return the test statistic and the permuted
8490
statistics with the p-value. If False, just return the p-value.
8591
bool (default: False)
@@ -140,6 +146,18 @@ def pted(
140146
)
141147
elif is_torch_tensor(x):
142148
test, permute = pted_torch(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)
149+
elif is_jax_array(x) and chunk_size is not None:
150+
test, permute = pted_chunk_jax(
151+
x,
152+
y,
153+
permutations=permutations,
154+
metric=metric,
155+
chunk_size=int(chunk_size),
156+
chunk_iter=int(chunk_iter),
157+
prog_bar=prog_bar,
158+
)
159+
elif is_jax_array(x):
160+
test, permute = pted_jax(x, y, permutations=permutations, metric=metric, prog_bar=prog_bar)
143161
elif chunk_size is not None:
144162
test, permute = pted_chunk_numpy(
145163
x,
@@ -170,8 +188,8 @@ def pted(
170188

171189

172190
def pted_coverage_test(
173-
g: Union[np.ndarray, "Tensor"],
174-
s: Union[np.ndarray, "Tensor"],
191+
g: Union[np.ndarray, "Tensor", "jax.Array"],
192+
s: Union[np.ndarray, "Tensor", "jax.Array"],
175193
permutations: int = 1000,
176194
metric: Union[str, float] = "euclidean",
177195
warn_confidence: Optional[float] = 1e-3,
@@ -228,14 +246,17 @@ def pted_coverage_test(
228246
229247
Parameters
230248
----------
231-
g (Union[np.ndarray, Tensor]): Ground truth samples. Shape (n_sims, *D)
232-
s (Union[np.ndarray, Tensor]): Posterior samples. Shape (n_samples, n_sims, *D)
249+
g (Union[np.ndarray, Tensor, jax.Array]): Ground truth samples. Shape (n_sims, *D)
250+
s (Union[np.ndarray, Tensor, jax.Array]): Posterior samples. Shape (n_samples, n_sims, *D)
233251
permutations (int): number of permutations to run. This determines how
234252
accurately the p-value is computed.
235-
metric (Union[str, float]): distance metric to use. See scipy.spatial.distance.cdist
236-
for the list of available metrics with numpy. See torch.cdist when using
237-
PyTorch, note that the metric is passed as the "p" for torch.cdist and
238-
therefore is a float from 0 to inf.
253+
metric (Union[str, float]): distance metric to use. For NumPy inputs,
254+
see scipy.spatial.distance.cdist for available metrics. For PyTorch
255+
inputs, the metric is passed as the "p" argument to torch.cdist and
256+
therefore is a float from 0 to inf. For JAX inputs, "euclidean" uses
257+
the squared-norm identity (p=2), and any float p uses
258+
jnp.linalg.norm with ord=p; string metrics other than "euclidean"
259+
are not supported for JAX.
239260
return_all (bool): if True, return the test statistic and the permuted
240261
statistics with the p-value. If False, just return the p-value. bool
241262
(default: False)

src/pted/utils.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,23 @@ class torch:
1616
Tensor = np.ndarray
1717

1818

19+
try:
20+
import jax
21+
import jax.numpy as jnp
22+
except ImportError:
23+
jax = None
24+
jnp = None
25+
26+
1927
__all__ = (
2028
"is_torch_tensor",
29+
"is_jax_array",
2130
"pted_numpy",
2231
"pted_chunk_numpy",
2332
"pted_torch",
2433
"pted_chunk_torch",
34+
"pted_jax",
35+
"pted_chunk_jax",
2536
"two_tailed_p",
2637
"confidence_alert",
2738
"simulation_based_calibration_histogram",
@@ -39,6 +50,12 @@ def is_torch_tensor(o):
3950
)
4051

4152

53+
def is_jax_array(o):
54+
if jax is None:
55+
return False
56+
return isinstance(o, jax.Array)
57+
58+
4259
def _energy_distance_precompute(
4360
D: Union[np.ndarray, torch.Tensor], nx: int, ny: int
4461
) -> Union[float, torch.Tensor]:
@@ -110,6 +127,49 @@ def _energy_distance_estimate_torch(
110127
return np.mean(E_est)
111128

112129

130+
def _jax_cdist(x, y, p: float = 2.0):
131+
if p == 2.0:
132+
# Squared-norm identity avoids materializing the (nx, ny, d) diff tensor.
133+
# ||x_i - y_j||^2 = ||x_i||^2 + ||y_j||^2 - 2 * x_i . y_j
134+
x_sq = jnp.sum(x ** 2, axis=-1) # (nx,)
135+
y_sq = jnp.sum(y ** 2, axis=-1) # (ny,)
136+
sq_dist = x_sq[:, None] + y_sq[None, :] - 2.0 * (x @ y.T)
137+
return jnp.sqrt(jnp.maximum(sq_dist, 0.0))
138+
# For general p-norms use vmap to avoid the (nx, ny, d) intermediate.
139+
return jax.vmap(lambda xi: jnp.linalg.norm(xi - y, ord=p, axis=-1))(x)
140+
141+
142+
def _energy_distance_jax(x, y, metric: Union[str, float] = "euclidean") -> float:
143+
nx = len(x)
144+
ny = len(y)
145+
z = jnp.concatenate([x, y], axis=0)
146+
if metric == "euclidean":
147+
metric = 2.0
148+
D = _jax_cdist(z, z, p=metric)
149+
return float(_energy_distance_precompute(D, nx, ny))
150+
151+
152+
def _energy_distance_estimate_jax(
153+
x,
154+
y,
155+
chunk_size: int,
156+
chunk_iter: int,
157+
metric: Union[str, float] = "euclidean",
158+
) -> float:
159+
160+
E_est = []
161+
for _ in range(chunk_iter):
162+
# Randomly sample a chunk of data
163+
idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False)
164+
x_chunk = x[idx]
165+
idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False)
166+
y_chunk = y[idy]
167+
168+
# Compute the energy distance
169+
E_est.append(_energy_distance_jax(x_chunk, y_chunk, metric=metric))
170+
return np.mean(E_est)
171+
172+
113173
def pted_chunk_numpy(
114174
x: np.ndarray,
115175
y: np.ndarray,
@@ -210,6 +270,59 @@ def pted_torch(
210270
return test_stat, permute_stats
211271

212272

273+
def pted_jax(
274+
x,
275+
y,
276+
permutations: int = 100,
277+
metric: Union[str, float] = "euclidean",
278+
prog_bar: bool = False,
279+
) -> tuple[float, list[float]]:
280+
assert jax is not None, "JAX is not installed! try: `pip install jax`"
281+
z = jnp.concatenate([x, y], axis=0)
282+
assert jnp.all(jnp.isfinite(z)), "Input contains NaN or Inf!"
283+
if metric == "euclidean":
284+
metric = 2.0
285+
dmatrix = _jax_cdist(z, z, p=metric)
286+
assert jnp.all(
287+
jnp.isfinite(dmatrix)
288+
), "Distance matrix contains NaN or Inf! Consider using a different metric or normalizing values to be more stable (i.e. z-score norm)."
289+
nx = len(x)
290+
ny = len(y)
291+
292+
test_stat = float(_energy_distance_precompute(dmatrix, nx, ny))
293+
permute_stats = []
294+
for _ in trange(permutations, disable=not prog_bar):
295+
I = np.random.permutation(len(z))
296+
dmatrix = dmatrix[I][:, I]
297+
permute_stats.append(float(_energy_distance_precompute(dmatrix, nx, ny)))
298+
return test_stat, permute_stats
299+
300+
301+
def pted_chunk_jax(
302+
x,
303+
y,
304+
permutations: int = 100,
305+
metric: Union[str, float] = "euclidean",
306+
chunk_size: int = 100,
307+
chunk_iter: int = 10,
308+
prog_bar: bool = False,
309+
) -> tuple[float, list[float]]:
310+
assert jax is not None, "JAX is not installed! try: `pip install jax`"
311+
assert jnp.all(jnp.isfinite(x)) and jnp.all(jnp.isfinite(y)), "Input contains NaN or Inf!"
312+
nx = len(x)
313+
314+
test_stat = _energy_distance_estimate_jax(x, y, chunk_size, chunk_iter, metric=metric)
315+
permute_stats = []
316+
for _ in trange(permutations, disable=not prog_bar):
317+
z = jnp.concatenate([x, y], axis=0)
318+
z = z[np.random.permutation(len(z))]
319+
x, y = z[:nx], z[nx:]
320+
permute_stats.append(
321+
_energy_distance_estimate_jax(x, y, chunk_size, chunk_iter, metric=metric)
322+
)
323+
return test_stat, permute_stats
324+
325+
213326
def two_tailed_p(chi2, df):
214327
assert df > 2, "Degrees of freedom must be greater than 2 for two-tailed p-value calculation."
215328
alpha = chi2_dist.pdf(chi2, df)

0 commit comments

Comments
 (0)