From 029306968f4675f4187d3ea63bc30dfd62b95b55 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 2 Jun 2026 06:19:45 +0200 Subject: [PATCH 1/4] nonzero, where changes from #938 --- heat/core/indexing.py | 178 +++++++++++++++++++++++++++--------- tests/core/test_indexing.py | 43 ++++++--- 2 files changed, 165 insertions(+), 56 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 916aa450df..4d42e65aaa 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -3,22 +3,22 @@ """ import torch -from typing import List, Dict, Any, TypeVar, Union, Tuple, Sequence from .communication import MPI from .dndarray import DNDarray -from . import sanitation +from . import factories from . import types +from . import manipulations __all__ = ["nonzero", "where"] -def nonzero(x: DNDarray) -> DNDarray: +def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarray: """ - Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero (using ``torch.nonzero``). - If ``x`` is split then the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray` + Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, + containing the indices of the non-zero elements in that dimension. If ``x`` is split then + the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` can be UNBALANCED as it contains the indices of the non-zero elements on each node. - Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. The values in ``x`` are always tested and returned in row-major, C-style order. The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``. @@ -26,16 +26,16 @@ def nonzero(x: DNDarray) -> DNDarray: ---------- x: DNDarray Input array + as_tuple: bool, optional + Default is True for numpy-style nonzero output. If False, the output is a torch-style single 2D ``DNDarray`` of shape `(num_nonzero, ndim)` containing the indices of the non-zero elements. Examples -------- >>> import heat as ht >>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0) >>> ht.nonzero(x) - DNDarray([[0, 0], - [1, 1], - [1, 2], - [2, 1]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None)) >>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0) >>> y > 3 DNDarray([[False, False, False], @@ -48,48 +48,108 @@ def nonzero(x: DNDarray) -> DNDarray: [2, 0], [2, 1], [2, 2]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None)) >>> y[ht.nonzero(y > 3)] DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) """ - sanitation.sanitize_in(x) - - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) - - # add offsets mapping from local indices to global indices if x is split - if x.split is not None: - _, _, slices = x.comm.chunk(x.shape, x.split) - lcl_nonzero[..., x.split] += slices[x.split].start - - if x.ndim == 1: - lcl_nonzero = lcl_nonzero.squeeze(dim=1) + try: + local_x = x.larray + except AttributeError: + raise TypeError("Input must be a DNDarray, is {}".format(type(x))) + + if not x.is_distributed(): + # nonzero indices as tuple + nonzero = torch.nonzero(input=local_x, as_tuple=as_tuple) + # bookkeeping for final DNDarray construct + if as_tuple: + nonzero = list(nonzero) + for i, nz_tensor in enumerate(nonzero): + nonzero[i] = factories.array(nz_tensor, device=x.device, comm=x.comm) + return tuple(nonzero) + # nonzero indices as single 2D DNDarray + return factories.array(nonzero, device=x.device, comm=x.comm) + + # distributed case + lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) + nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64, device=lcl_nonzero.device) + nonzero_dtype = types.canonical_heat_type(lcl_nonzero.dtype) + + # global nonzero_size + x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM) + # correct indices along split axis + _, displs = x.counts_displs() + lcl_nonzero[:, x.split] += displs[x.comm.rank] + + if x.split != 0: + # construct global 2D DNDarray of nz indices: + shape_2d = (nonzero_size.item(), x.ndim) + global_nonzero = DNDarray( + lcl_nonzero, + gshape=shape_2d, + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # vectorized sorting of nz indices along axis 0 + global_nonzero.balance_() + global_nonzero = manipulations.unique(global_nonzero, axis=0) + if not as_tuple: + # return indices as single 2D DNDarray + return global_nonzero + # return indices as tuple of 1D DNDarrays + lcl_nonzero = global_nonzero.larray.unbind(dim=1) + return tuple( + DNDarray( + nz_tensor, + gshape=(nonzero_size.item(),), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=True, + ) + for nz_tensor in lcl_nonzero + ) - # compute global shape of the index array - gout = list(lcl_nonzero.shape) - if x.split is None: - is_split = None - else: - gout[0] = x.comm.allreduce(gout[0], MPI.SUM) - is_split = 0 - - return DNDarray( - lcl_nonzero, - gshape=tuple(gout), - dtype=types.canonical_heat_type(lcl_nonzero.dtype), - split=is_split, - device=x.device, - comm=x.comm, - balanced=False, + # for split=0, the local nonzero indices are already globally ordered along the split axis + if not as_tuple: + # return indices as single 2D DNDarray + return DNDarray( + lcl_nonzero, + gshape=(nonzero_size.item(), x.ndim), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # return indices as tuple of 1D DNDarrays + lcl_nonzero = lcl_nonzero.unbind(dim=1) + return tuple( + DNDarray( + nz_tensor, + gshape=(nonzero_size.item(),), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + for nz_tensor in lcl_nonzero ) -DNDarray.nonzero = lambda self: nonzero(self) +DNDarray.nonzero = lambda self: nonzero(self, as_tuple=True) DNDarray.nonzero.__doc__ = nonzero.__doc__ def where( cond: DNDarray, - x: Union[None, int, float, DNDarray] = None, - y: Union[None, int, float, DNDarray] = None, + x: None | int | float | DNDarray = None, + y: None | int | float | DNDarray = None, ) -> DNDarray: """ Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition. @@ -128,20 +188,52 @@ def where( [ 0, 2, -1], [ 0, 3, -1]], dtype=ht.int64, device=cpu:0, split=None) """ + # ---- binary where(cond, x, y) branch ------------------------------------ if cond.split is not None and (isinstance(x, DNDarray) or isinstance(y, DNDarray)): if (isinstance(x, DNDarray) and cond.split != x.split) or ( isinstance(y, DNDarray) and cond.split != y.split ): - if len(y.shape) >= 1 and y.shape[0] > 1: + # Only raise if the "other" array has a meaningful first dimension. + if isinstance(y, DNDarray) and len(y.shape) >= 1 and y.shape[0] > 1: raise NotImplementedError("binary op not implemented for different split axes") + if isinstance(x, (DNDarray, int, float)) and isinstance(y, (DNDarray, int, float)): + # Simple elementwise selection using arithmetic: + # cond == 0 -> take y, cond == 1 -> take x for var in [x, y]: if isinstance(var, int): var = float(var) return cond.dtype(cond == 0) * y + cond * x + + # ---- where(cond) "indices only" branch ---------------------------------- elif x is None and y is None: - return nonzero(cond) + # General rule: delegate to nonzero(cond), and only wrap into a 2-D + # coordinate matrix in the special distributed case where the array + # is split along a non-zero axis. + nz = nonzero(cond) # tuple of DNDarrays, one per dimension + + # 1) Non-distributed: behave exactly like ht.nonzero(cond) + if cond.split is None: + return nz + + # 2) Distributed along axis 0: keep the legacy tuple-of-indices API. + # This is relied upon in several parts of the code base (e.g. KMeans). + if cond.split == 0: + return nz + + # 3) Distributed along a non-zero axis (split > 0) + coords = manipulations.stack(nz, axis=1) + coords = coords.astype(types.int64, copy=False) + + # Ensure indices are split along axis 0 for stable distributed behavior + if coords.split is None: + coords.resplit_(0) + + return coords + + # ---- invalid combinations ---------------------------------------------- else: raise TypeError( - f"either both or neither x and y must be given and both must be DNDarrays or numerical scalars({type(x)}, {type(y)})" + "either both or neither x and y must be given and both must be " + f"DNDarrays or numerical scalars (got {type(x)}, {type(y)})" ) diff --git a/tests/core/test_indexing.py b/tests/core/test_indexing.py index 61dda3fa4f..4750ab2de6 100644 --- a/tests/core/test_indexing.py +++ b/tests/core/test_indexing.py @@ -1,6 +1,7 @@ import heat as ht from heat.testing.basic_test import TestCase +import torch class TestIndexing(TestCase): def test_nonzero(self): @@ -9,18 +10,18 @@ def test_nonzero(self): a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) cond = a > 3 nz = ht.nonzero(cond) - self.assertEqual(nz.gshape, (5, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, None) + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 5) + self.assertEqual(nz[0].dtype, ht.int64) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 nz = cond.nonzero() - self.assertEqual(nz.gshape, (6, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, 0) - a[nz] = 10.0 + self.assertEqual(len(nz), 2) + self.assertEqual(len(nz[0]), 6) + self.assertEqual(nz[0].dtype, ht.int64) + a[nz] = 10 self.assertEqual(ht.all(a[nz] == 10), 1) # edge case: single non-zero element @@ -28,11 +29,26 @@ def test_nonzero(self): a = ht.zeros((4, 3), dtype=ht.bool, split=split) a[1, 2] = True nz = ht.indexing.nonzero(a) - a.resplit_(None) - nz.resplit_(None) - self.assertEqual(nz.gshape, (1, 2)) self.assertTrue(ht.allclose(a[nz], a[a])) + a.comm.Barrier() + # as_tuple = False (torch-style output) + a = ht.array([[1, 0, 0], [0, 4, 1], [0, 6, 0]], split=1) + nz = ht.nonzero(a, as_tuple=False) + self.assertEqual(nz.gshape, (4, 2)) + self.assertEqual(nz.dtype, ht.int64) + if a.is_distributed(): + self.assertEqual(nz.split, 0) + else: + self.assertEqual(nz.split, None) + t_a = a.resplit_(None).larray + t_nz = torch.nonzero(t_a, as_tuple=False) + self.assertTrue(ht.equal(nz, ht.array(t_nz))) + + # attribute error + a = a.numpy() + with self.assertRaises(TypeError): + ht.nonzero(a) def test_where(self): # cases to test @@ -40,9 +56,10 @@ def test_where(self): a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=None) cond = a > 3 wh = ht.where(cond) - self.assertEqual(wh.gshape, (6, 2)) - self.assertEqual(wh.dtype, ht.int64) - self.assertEqual(wh.split, None) + self.assertEqual(len(wh), 2) + self.assertEqual(wh[0].gshape[0], 6) + self.assertEqual(wh[0].dtype, ht.int64) + self.assertEqual(wh[0].split, None) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3 From 4ea407e2515a0156683604cebffda23ef2b333f5 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 2 Jun 2026 08:23:11 +0200 Subject: [PATCH 2/4] Fixed tests --- heat/core/dndarray.py | 4 ++-- tests/core/test_indexing.py | 46 ++++++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 16ce355700..012b742670 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -920,7 +920,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # TODO: remove this resplit!! key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) + key = indexing.nonzero(key, as_tuple=False) if key.ndim > 1: key = list(key.larray.split(1, dim=1)) @@ -1626,7 +1626,7 @@ def __setitem__( to be used.""" key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) + key = indexing.nonzero(key, as_tuple=False) if key.ndim > 1: key = list(key.larray.split(1, dim=1)) diff --git a/tests/core/test_indexing.py b/tests/core/test_indexing.py index 4750ab2de6..aea1e220a8 100644 --- a/tests/core/test_indexing.py +++ b/tests/core/test_indexing.py @@ -2,33 +2,43 @@ from heat.testing.basic_test import TestCase import torch +import numpy as np class TestIndexing(TestCase): def test_nonzero(self): - # cases to test: - # not split - a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) - cond = a > 3 - nz = ht.nonzero(cond) - self.assertEqual(len(nz), 2) - self.assertEqual(len(nz[0]), 5) - self.assertEqual(nz[0].dtype, ht.int64) + for split in [None, 0, 1]: + for cond_type in ['mean', 'max']: + a = ht.random.random((2*self.comm.size, 3*self.comm.size, 4*self.comm.size)) + if cond_type == 'mean': + cond = a > a.mean() / 2 + elif cond_type == 'max': + cond = a == a.max() + else: + raise NotImplementedError + + nz_as_tuple = ht.nonzero(cond, as_tuple=True) + nz_as_tuple_ref = np.nonzero(cond.numpy()) + for i in range(len(nz_as_tuple)): + self.assertEqual(nz_as_tuple[i].dtype, ht.int64) + self.assertTrue(np.allclose(nz_as_tuple[i].numpy(), nz_as_tuple_ref[i])) + + nz_no_tuple = ht.nonzero(cond, as_tuple=False) + nz_no_tuple_ref = torch.nonzero(cond.resplit(None), as_tuple=False) + self.assertEqual(nz_no_tuple.dtype, ht.int64) + self.assertTrue(np.allclose(nz_no_tuple.numpy(), nz_no_tuple_ref.numpy())) + + if cond_type == 'max': + self.assertEqual(len(cond[cond]), 1) + for me in nz_as_tuple: + self.assertEqual(me.shape, (1,)) + self.assertEqual(nz_no_tuple.shape, (1, a.ndim)) - # split - a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) - cond = a > 3 - nz = cond.nonzero() - self.assertEqual(len(nz), 2) - self.assertEqual(len(nz[0]), 6) - self.assertEqual(nz[0].dtype, ht.int64) - a[nz] = 10 - self.assertEqual(ht.all(a[nz] == 10), 1) # edge case: single non-zero element for split in [None, 0, 1]: a = ht.zeros((4, 3), dtype=ht.bool, split=split) a[1, 2] = True - nz = ht.indexing.nonzero(a) + nz = ht.indexing.nonzero(a, as_tuple=False) self.assertTrue(ht.allclose(a[nz], a[a])) a.comm.Barrier() From 40758907f3cdae53e6022a4e7a17aa0f2797906e Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 2 Jun 2026 08:34:45 +0200 Subject: [PATCH 3/4] Small refactoring --- heat/core/indexing.py | 75 +++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 4d42e65aaa..d9b640c864 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -7,6 +7,7 @@ from .communication import MPI from .dndarray import DNDarray from . import factories +from .sanitation import sanitize_in from . import types from . import manipulations @@ -17,7 +18,7 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr """ Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. If ``x`` is split then - the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` + the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray` can be UNBALANCED as it contains the indices of the non-zero elements on each node. The values in ``x`` are always tested and returned in row-major, C-style order. The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``. @@ -53,26 +54,24 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr >>> y[ht.nonzero(y > 3)] DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) """ - try: - local_x = x.larray - except AttributeError: - raise TypeError("Input must be a DNDarray, is {}".format(type(x))) + sanitize_in(x) if not x.is_distributed(): # nonzero indices as tuple - nonzero = torch.nonzero(input=local_x, as_tuple=as_tuple) + nonzero = torch.nonzero(input=x.larray, as_tuple=as_tuple) # bookkeeping for final DNDarray construct if as_tuple: nonzero = list(nonzero) for i, nz_tensor in enumerate(nonzero): nonzero[i] = factories.array(nz_tensor, device=x.device, comm=x.comm) return tuple(nonzero) - # nonzero indices as single 2D DNDarray - return factories.array(nonzero, device=x.device, comm=x.comm) + else: + # nonzero indices as single 2D DNDarray + return factories.array(nonzero, device=x.device, comm=x.comm) # distributed case - lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False) - nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64, device=lcl_nonzero.device) + lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) + nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64, device="cpu") nonzero_dtype = types.canonical_heat_type(lcl_nonzero.dtype) # global nonzero_size @@ -81,7 +80,34 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr _, displs = x.counts_displs() lcl_nonzero[:, x.split] += displs[x.comm.rank] - if x.split != 0: + if x.split == 0: + # for split=0, the local nonzero indices are already globally ordered along the split axis + if not as_tuple: + # return indices as single 2D DNDarray + return DNDarray( + lcl_nonzero, + gshape=(nonzero_size.item(), x.ndim), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # return indices as tuple of 1D DNDarrays + lcl_nonzero = lcl_nonzero.unbind(dim=1) + return tuple( + DNDarray( + nz_tensor, + gshape=(nonzero_size.item(),), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + for nz_tensor in lcl_nonzero + ) + else: # construct global 2D DNDarray of nz indices: shape_2d = (nonzero_size.item(), x.ndim) global_nonzero = DNDarray( @@ -114,33 +140,6 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr for nz_tensor in lcl_nonzero ) - # for split=0, the local nonzero indices are already globally ordered along the split axis - if not as_tuple: - # return indices as single 2D DNDarray - return DNDarray( - lcl_nonzero, - gshape=(nonzero_size.item(), x.ndim), - dtype=nonzero_dtype, - split=0, - device=x.device, - comm=x.comm, - balanced=False, - ) - # return indices as tuple of 1D DNDarrays - lcl_nonzero = lcl_nonzero.unbind(dim=1) - return tuple( - DNDarray( - nz_tensor, - gshape=(nonzero_size.item(),), - dtype=nonzero_dtype, - split=0, - device=x.device, - comm=x.comm, - balanced=False, - ) - for nz_tensor in lcl_nonzero - ) - DNDarray.nonzero = lambda self: nonzero(self, as_tuple=True) DNDarray.nonzero.__doc__ = nonzero.__doc__ From ae15914d48605f11b9da3f017edc5afdb8c3fa4f Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 2 Jun 2026 08:38:52 +0200 Subject: [PATCH 4/4] Disabling fail-fast --- .github/workflows/pr_update.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr_update.yml b/.github/workflows/pr_update.yml index 8ee728157e..98aaeddec4 100644 --- a/.github/workflows/pr_update.yml +++ b/.github/workflows/pr_update.yml @@ -12,7 +12,7 @@ jobs: quick-tests: runs-on: ubuntu-latest strategy: - fail-fast: true + fail-fast: false matrix: py-version: - '3.11' # Oldest supported