diff --git a/.github/workflows/pr_update.yml b/.github/workflows/pr_update.yml index 8ee728157e..98aaeddec4 100644 --- a/.github/workflows/pr_update.yml +++ b/.github/workflows/pr_update.yml @@ -12,7 +12,7 @@ jobs: quick-tests: runs-on: ubuntu-latest strategy: - fail-fast: true + fail-fast: false matrix: py-version: - '3.11' # Oldest supported diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 16ce355700..012b742670 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -920,7 +920,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar # TODO: remove this resplit!! key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) + key = indexing.nonzero(key, as_tuple=False) if key.ndim > 1: key = list(key.larray.split(1, dim=1)) @@ -1626,7 +1626,7 @@ def __setitem__( to be used.""" key = manipulations.resplit(key) if key.larray.dtype in [torch.bool, torch.uint8]: - key = indexing.nonzero(key) + key = indexing.nonzero(key, as_tuple=False) if key.ndim > 1: key = list(key.larray.split(1, dim=1)) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 916aa450df..d9b640c864 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -3,22 +3,23 @@ """ import torch -from typing import List, Dict, Any, TypeVar, Union, Tuple, Sequence from .communication import MPI from .dndarray import DNDarray -from . import sanitation +from . import factories +from .sanitation import sanitize_in from . import types +from . import manipulations __all__ = ["nonzero", "where"] -def nonzero(x: DNDarray) -> DNDarray: +def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarray: """ - Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero (using ``torch.nonzero``). - If ``x`` is split then the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray` + Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, + containing the indices of the non-zero elements in that dimension. If ``x`` is split then + the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray` can be UNBALANCED as it contains the indices of the non-zero elements on each node. - Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. The values in ``x`` are always tested and returned in row-major, C-style order. The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``. @@ -26,16 +27,16 @@ def nonzero(x: DNDarray) -> DNDarray: ---------- x: DNDarray Input array + as_tuple: bool, optional + Default is True for numpy-style nonzero output. If False, the output is a torch-style single 2D ``DNDarray`` of shape `(num_nonzero, ndim)` containing the indices of the non-zero elements. Examples -------- >>> import heat as ht >>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0) >>> ht.nonzero(x) - DNDarray([[0, 0], - [1, 1], - [1, 2], - [2, 1]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None)) >>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0) >>> y > 3 DNDarray([[False, False, False], @@ -48,48 +49,106 @@ def nonzero(x: DNDarray) -> DNDarray: [2, 0], [2, 1], [2, 2]], dtype=ht.int64, device=cpu:0, split=0) + (DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None), + DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None)) >>> y[ht.nonzero(y > 3)] DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) """ - sanitation.sanitize_in(x) - + sanitize_in(x) + + if not x.is_distributed(): + # nonzero indices as tuple + nonzero = torch.nonzero(input=x.larray, as_tuple=as_tuple) + # bookkeeping for final DNDarray construct + if as_tuple: + nonzero = list(nonzero) + for i, nz_tensor in enumerate(nonzero): + nonzero[i] = factories.array(nz_tensor, device=x.device, comm=x.comm) + return tuple(nonzero) + else: + # nonzero indices as single 2D DNDarray + return factories.array(nonzero, device=x.device, comm=x.comm) + + # distributed case lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) - - # add offsets mapping from local indices to global indices if x is split - if x.split is not None: - _, _, slices = x.comm.chunk(x.shape, x.split) - lcl_nonzero[..., x.split] += slices[x.split].start - - if x.ndim == 1: - lcl_nonzero = lcl_nonzero.squeeze(dim=1) - - # compute global shape of the index array - gout = list(lcl_nonzero.shape) - if x.split is None: - is_split = None + nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64, device="cpu") + nonzero_dtype = types.canonical_heat_type(lcl_nonzero.dtype) + + # global nonzero_size + x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM) + # correct indices along split axis + _, displs = x.counts_displs() + lcl_nonzero[:, x.split] += displs[x.comm.rank] + + if x.split == 0: + # for split=0, the local nonzero indices are already globally ordered along the split axis + if not as_tuple: + # return indices as single 2D DNDarray + return DNDarray( + lcl_nonzero, + gshape=(nonzero_size.item(), x.ndim), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # return indices as tuple of 1D DNDarrays + lcl_nonzero = lcl_nonzero.unbind(dim=1) + return tuple( + DNDarray( + nz_tensor, + gshape=(nonzero_size.item(),), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + for nz_tensor in lcl_nonzero + ) else: - gout[0] = x.comm.allreduce(gout[0], MPI.SUM) - is_split = 0 - - return DNDarray( - lcl_nonzero, - gshape=tuple(gout), - dtype=types.canonical_heat_type(lcl_nonzero.dtype), - split=is_split, - device=x.device, - comm=x.comm, - balanced=False, - ) + # construct global 2D DNDarray of nz indices: + shape_2d = (nonzero_size.item(), x.ndim) + global_nonzero = DNDarray( + lcl_nonzero, + gshape=shape_2d, + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=False, + ) + # vectorized sorting of nz indices along axis 0 + global_nonzero.balance_() + global_nonzero = manipulations.unique(global_nonzero, axis=0) + if not as_tuple: + # return indices as single 2D DNDarray + return global_nonzero + # return indices as tuple of 1D DNDarrays + lcl_nonzero = global_nonzero.larray.unbind(dim=1) + return tuple( + DNDarray( + nz_tensor, + gshape=(nonzero_size.item(),), + dtype=nonzero_dtype, + split=0, + device=x.device, + comm=x.comm, + balanced=True, + ) + for nz_tensor in lcl_nonzero + ) -DNDarray.nonzero = lambda self: nonzero(self) +DNDarray.nonzero = lambda self: nonzero(self, as_tuple=True) DNDarray.nonzero.__doc__ = nonzero.__doc__ def where( cond: DNDarray, - x: Union[None, int, float, DNDarray] = None, - y: Union[None, int, float, DNDarray] = None, + x: None | int | float | DNDarray = None, + y: None | int | float | DNDarray = None, ) -> DNDarray: """ Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition. @@ -128,20 +187,52 @@ def where( [ 0, 2, -1], [ 0, 3, -1]], dtype=ht.int64, device=cpu:0, split=None) """ + # ---- binary where(cond, x, y) branch ------------------------------------ if cond.split is not None and (isinstance(x, DNDarray) or isinstance(y, DNDarray)): if (isinstance(x, DNDarray) and cond.split != x.split) or ( isinstance(y, DNDarray) and cond.split != y.split ): - if len(y.shape) >= 1 and y.shape[0] > 1: + # Only raise if the "other" array has a meaningful first dimension. + if isinstance(y, DNDarray) and len(y.shape) >= 1 and y.shape[0] > 1: raise NotImplementedError("binary op not implemented for different split axes") + if isinstance(x, (DNDarray, int, float)) and isinstance(y, (DNDarray, int, float)): + # Simple elementwise selection using arithmetic: + # cond == 0 -> take y, cond == 1 -> take x for var in [x, y]: if isinstance(var, int): var = float(var) return cond.dtype(cond == 0) * y + cond * x + + # ---- where(cond) "indices only" branch ---------------------------------- elif x is None and y is None: - return nonzero(cond) + # General rule: delegate to nonzero(cond), and only wrap into a 2-D + # coordinate matrix in the special distributed case where the array + # is split along a non-zero axis. + nz = nonzero(cond) # tuple of DNDarrays, one per dimension + + # 1) Non-distributed: behave exactly like ht.nonzero(cond) + if cond.split is None: + return nz + + # 2) Distributed along axis 0: keep the legacy tuple-of-indices API. + # This is relied upon in several parts of the code base (e.g. KMeans). + if cond.split == 0: + return nz + + # 3) Distributed along a non-zero axis (split > 0) + coords = manipulations.stack(nz, axis=1) + coords = coords.astype(types.int64, copy=False) + + # Ensure indices are split along axis 0 for stable distributed behavior + if coords.split is None: + coords.resplit_(0) + + return coords + + # ---- invalid combinations ---------------------------------------------- else: raise TypeError( - f"either both or neither x and y must be given and both must be DNDarrays or numerical scalars({type(x)}, {type(y)})" + "either both or neither x and y must be given and both must be " + f"DNDarrays or numerical scalars (got {type(x)}, {type(y)})" ) diff --git a/tests/core/test_indexing.py b/tests/core/test_indexing.py index 61dda3fa4f..aea1e220a8 100644 --- a/tests/core/test_indexing.py +++ b/tests/core/test_indexing.py @@ -1,38 +1,64 @@ import heat as ht from heat.testing.basic_test import TestCase +import torch +import numpy as np class TestIndexing(TestCase): def test_nonzero(self): - # cases to test: - # not split - a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) - cond = a > 3 - nz = ht.nonzero(cond) - self.assertEqual(nz.gshape, (5, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, None) + for split in [None, 0, 1]: + for cond_type in ['mean', 'max']: + a = ht.random.random((2*self.comm.size, 3*self.comm.size, 4*self.comm.size)) + if cond_type == 'mean': + cond = a > a.mean() / 2 + elif cond_type == 'max': + cond = a == a.max() + else: + raise NotImplementedError + + nz_as_tuple = ht.nonzero(cond, as_tuple=True) + nz_as_tuple_ref = np.nonzero(cond.numpy()) + for i in range(len(nz_as_tuple)): + self.assertEqual(nz_as_tuple[i].dtype, ht.int64) + self.assertTrue(np.allclose(nz_as_tuple[i].numpy(), nz_as_tuple_ref[i])) + + nz_no_tuple = ht.nonzero(cond, as_tuple=False) + nz_no_tuple_ref = torch.nonzero(cond.resplit(None), as_tuple=False) + self.assertEqual(nz_no_tuple.dtype, ht.int64) + self.assertTrue(np.allclose(nz_no_tuple.numpy(), nz_no_tuple_ref.numpy())) + + if cond_type == 'max': + self.assertEqual(len(cond[cond]), 1) + for me in nz_as_tuple: + self.assertEqual(me.shape, (1,)) + self.assertEqual(nz_no_tuple.shape, (1, a.ndim)) - # split - a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) - cond = a > 3 - nz = cond.nonzero() - self.assertEqual(nz.gshape, (6, 2)) - self.assertEqual(nz.dtype, ht.int64) - self.assertEqual(nz.split, 0) - a[nz] = 10.0 - self.assertEqual(ht.all(a[nz] == 10), 1) # edge case: single non-zero element for split in [None, 0, 1]: a = ht.zeros((4, 3), dtype=ht.bool, split=split) a[1, 2] = True - nz = ht.indexing.nonzero(a) - a.resplit_(None) - nz.resplit_(None) - self.assertEqual(nz.gshape, (1, 2)) + nz = ht.indexing.nonzero(a, as_tuple=False) self.assertTrue(ht.allclose(a[nz], a[a])) + a.comm.Barrier() + # as_tuple = False (torch-style output) + a = ht.array([[1, 0, 0], [0, 4, 1], [0, 6, 0]], split=1) + nz = ht.nonzero(a, as_tuple=False) + self.assertEqual(nz.gshape, (4, 2)) + self.assertEqual(nz.dtype, ht.int64) + if a.is_distributed(): + self.assertEqual(nz.split, 0) + else: + self.assertEqual(nz.split, None) + t_a = a.resplit_(None).larray + t_nz = torch.nonzero(t_a, as_tuple=False) + self.assertTrue(ht.equal(nz, ht.array(t_nz))) + + # attribute error + a = a.numpy() + with self.assertRaises(TypeError): + ht.nonzero(a) def test_where(self): # cases to test @@ -40,9 +66,10 @@ def test_where(self): a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=None) cond = a > 3 wh = ht.where(cond) - self.assertEqual(wh.gshape, (6, 2)) - self.assertEqual(wh.dtype, ht.int64) - self.assertEqual(wh.split, None) + self.assertEqual(len(wh), 2) + self.assertEqual(wh[0].gshape[0], 6) + self.assertEqual(wh[0].dtype, ht.int64) + self.assertEqual(wh[0].split, None) # split a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) cond = a > 3