-
Notifications
You must be signed in to change notification settings - Fork 65
nonzero, where fixes from #938 #2332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,39 +3,40 @@ | |
| """ | ||
|
|
||
| import torch | ||
| from typing import List, Dict, Any, TypeVar, Union, Tuple, Sequence | ||
|
|
||
| from .communication import MPI | ||
| from .dndarray import DNDarray | ||
| from . import sanitation | ||
| from . import factories | ||
| from .sanitation import sanitize_in | ||
| from . import types | ||
| from . import manipulations | ||
|
|
||
| __all__ = ["nonzero", "where"] | ||
|
|
||
|
|
||
| def nonzero(x: DNDarray) -> DNDarray: | ||
| def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarray: | ||
| """ | ||
| Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero (using ``torch.nonzero``). | ||
| If ``x`` is split then the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray` | ||
| Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``, | ||
| containing the indices of the non-zero elements in that dimension. If ``x`` is split then | ||
| the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray` | ||
| can be UNBALANCED as it contains the indices of the non-zero elements on each node. | ||
| Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. | ||
| The values in ``x`` are always tested and returned in row-major, C-style order. | ||
| The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| x: DNDarray | ||
| Input array | ||
| as_tuple: bool, optional | ||
| Default is True for numpy-style nonzero output. If False, the output is a torch-style single 2D ``DNDarray`` of shape `(num_nonzero, ndim)` containing the indices of the non-zero elements. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import heat as ht | ||
| >>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0) | ||
| >>> ht.nonzero(x) | ||
| DNDarray([[0, 0], | ||
| [1, 1], | ||
| [1, 2], | ||
| [2, 1]], dtype=ht.int64, device=cpu:0, split=0) | ||
| (DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None), | ||
| DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None)) | ||
| >>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0) | ||
| >>> y > 3 | ||
| DNDarray([[False, False, False], | ||
|
|
@@ -48,48 +49,106 @@ def nonzero(x: DNDarray) -> DNDarray: | |
| [2, 0], | ||
| [2, 1], | ||
| [2, 2]], dtype=ht.int64, device=cpu:0, split=0) | ||
| (DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None), | ||
| DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None)) | ||
| >>> y[ht.nonzero(y > 3)] | ||
| DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0) | ||
| """ | ||
| sanitation.sanitize_in(x) | ||
|
|
||
| sanitize_in(x) | ||
|
|
||
| if not x.is_distributed(): | ||
| # nonzero indices as tuple | ||
| nonzero = torch.nonzero(input=x.larray, as_tuple=as_tuple) | ||
| # bookkeeping for final DNDarray construct | ||
| if as_tuple: | ||
| nonzero = list(nonzero) | ||
| for i, nz_tensor in enumerate(nonzero): | ||
| nonzero[i] = factories.array(nz_tensor, device=x.device, comm=x.comm) | ||
| return tuple(nonzero) | ||
| else: | ||
| # nonzero indices as single 2D DNDarray | ||
| return factories.array(nonzero, device=x.device, comm=x.comm) | ||
|
|
||
| # distributed case | ||
| lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) | ||
|
|
||
| # add offsets mapping from local indices to global indices if x is split | ||
| if x.split is not None: | ||
| _, _, slices = x.comm.chunk(x.shape, x.split) | ||
| lcl_nonzero[..., x.split] += slices[x.split].start | ||
|
|
||
| if x.ndim == 1: | ||
| lcl_nonzero = lcl_nonzero.squeeze(dim=1) | ||
|
|
||
| # compute global shape of the index array | ||
| gout = list(lcl_nonzero.shape) | ||
| if x.split is None: | ||
| is_split = None | ||
| nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64, device="cpu") | ||
| nonzero_dtype = types.canonical_heat_type(lcl_nonzero.dtype) | ||
|
|
||
| # global nonzero_size | ||
| x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM) | ||
| # correct indices along split axis | ||
| _, displs = x.counts_displs() | ||
| lcl_nonzero[:, x.split] += displs[x.comm.rank] | ||
|
|
||
| if x.split == 0: | ||
| # for split=0, the local nonzero indices are already globally ordered along the split axis | ||
| if not as_tuple: | ||
| # return indices as single 2D DNDarray | ||
| return DNDarray( | ||
| lcl_nonzero, | ||
| gshape=(nonzero_size.item(), x.ndim), | ||
| dtype=nonzero_dtype, | ||
| split=0, | ||
| device=x.device, | ||
| comm=x.comm, | ||
| balanced=False, | ||
| ) | ||
| # return indices as tuple of 1D DNDarrays | ||
| lcl_nonzero = lcl_nonzero.unbind(dim=1) | ||
| return tuple( | ||
| DNDarray( | ||
| nz_tensor, | ||
| gshape=(nonzero_size.item(),), | ||
| dtype=nonzero_dtype, | ||
| split=0, | ||
| device=x.device, | ||
| comm=x.comm, | ||
| balanced=False, | ||
| ) | ||
| for nz_tensor in lcl_nonzero | ||
| ) | ||
| else: | ||
| gout[0] = x.comm.allreduce(gout[0], MPI.SUM) | ||
| is_split = 0 | ||
|
|
||
| return DNDarray( | ||
| lcl_nonzero, | ||
| gshape=tuple(gout), | ||
| dtype=types.canonical_heat_type(lcl_nonzero.dtype), | ||
| split=is_split, | ||
| device=x.device, | ||
| comm=x.comm, | ||
| balanced=False, | ||
| ) | ||
| # construct global 2D DNDarray of nz indices: | ||
| shape_2d = (nonzero_size.item(), x.ndim) | ||
| global_nonzero = DNDarray( | ||
| lcl_nonzero, | ||
| gshape=shape_2d, | ||
| dtype=nonzero_dtype, | ||
| split=0, | ||
| device=x.device, | ||
| comm=x.comm, | ||
| balanced=False, | ||
| ) | ||
| # vectorized sorting of nz indices along axis 0 | ||
| global_nonzero.balance_() | ||
| global_nonzero = manipulations.unique(global_nonzero, axis=0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? Seems like duplicate entries would be a bug at this point. Or are there some side effects of |
||
| if not as_tuple: | ||
| # return indices as single 2D DNDarray | ||
| return global_nonzero | ||
| # return indices as tuple of 1D DNDarrays | ||
| lcl_nonzero = global_nonzero.larray.unbind(dim=1) | ||
| return tuple( | ||
| DNDarray( | ||
| nz_tensor, | ||
| gshape=(nonzero_size.item(),), | ||
| dtype=nonzero_dtype, | ||
| split=0, | ||
| device=x.device, | ||
| comm=x.comm, | ||
| balanced=True, | ||
| ) | ||
| for nz_tensor in lcl_nonzero | ||
| ) | ||
|
|
||
|
|
||
| DNDarray.nonzero = lambda self: nonzero(self) | ||
| DNDarray.nonzero = lambda self: nonzero(self, as_tuple=True) | ||
| DNDarray.nonzero.__doc__ = nonzero.__doc__ | ||
|
|
||
|
|
||
| def where( | ||
| cond: DNDarray, | ||
| x: Union[None, int, float, DNDarray] = None, | ||
| y: Union[None, int, float, DNDarray] = None, | ||
| x: None | int | float | DNDarray = None, | ||
| y: None | int | float | DNDarray = None, | ||
| ) -> DNDarray: | ||
| """ | ||
| Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition. | ||
|
|
@@ -128,20 +187,52 @@ def where( | |
| [ 0, 2, -1], | ||
| [ 0, 3, -1]], dtype=ht.int64, device=cpu:0, split=None) | ||
| """ | ||
| # ---- binary where(cond, x, y) branch ------------------------------------ | ||
| if cond.split is not None and (isinstance(x, DNDarray) or isinstance(y, DNDarray)): | ||
| if (isinstance(x, DNDarray) and cond.split != x.split) or ( | ||
| isinstance(y, DNDarray) and cond.split != y.split | ||
| ): | ||
| if len(y.shape) >= 1 and y.shape[0] > 1: | ||
| # Only raise if the "other" array has a meaningful first dimension. | ||
| if isinstance(y, DNDarray) and len(y.shape) >= 1 and y.shape[0] > 1: | ||
| raise NotImplementedError("binary op not implemented for different split axes") | ||
|
|
||
| if isinstance(x, (DNDarray, int, float)) and isinstance(y, (DNDarray, int, float)): | ||
| # Simple elementwise selection using arithmetic: | ||
| # cond == 0 -> take y, cond == 1 -> take x | ||
| for var in [x, y]: | ||
| if isinstance(var, int): | ||
| var = float(var) | ||
| return cond.dtype(cond == 0) * y + cond * x | ||
|
|
||
| # ---- where(cond) "indices only" branch ---------------------------------- | ||
| elif x is None and y is None: | ||
| return nonzero(cond) | ||
| # General rule: delegate to nonzero(cond), and only wrap into a 2-D | ||
| # coordinate matrix in the special distributed case where the array | ||
| # is split along a non-zero axis. | ||
| nz = nonzero(cond) # tuple of DNDarrays, one per dimension | ||
|
|
||
| # 1) Non-distributed: behave exactly like ht.nonzero(cond) | ||
| if cond.split is None: | ||
| return nz | ||
|
|
||
| # 2) Distributed along axis 0: keep the legacy tuple-of-indices API. | ||
| # This is relied upon in several parts of the code base (e.g. KMeans). | ||
| if cond.split == 0: | ||
| return nz | ||
|
|
||
| # 3) Distributed along a non-zero axis (split > 0) | ||
| coords = manipulations.stack(nz, axis=1) | ||
| coords = coords.astype(types.int64, copy=False) | ||
|
|
||
| # Ensure indices are split along axis 0 for stable distributed behavior | ||
| if coords.split is None: | ||
| coords.resplit_(0) | ||
|
|
||
| return coords | ||
|
|
||
| # ---- invalid combinations ---------------------------------------------- | ||
| else: | ||
| raise TypeError( | ||
| f"either both or neither x and y must be given and both must be DNDarrays or numerical scalars({type(x)}, {type(y)})" | ||
| "either both or neither x and y must be given and both must be " | ||
| f"DNDarrays or numerical scalars (got {type(x)}, {type(y)})" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,48 +1,75 @@ | ||
| import heat as ht | ||
| from heat.testing.basic_test import TestCase | ||
|
|
||
| import torch | ||
| import numpy as np | ||
|
|
||
| class TestIndexing(TestCase): | ||
| def test_nonzero(self): | ||
| # cases to test: | ||
| # not split | ||
| a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None) | ||
| cond = a > 3 | ||
| nz = ht.nonzero(cond) | ||
| self.assertEqual(nz.gshape, (5, 2)) | ||
| self.assertEqual(nz.dtype, ht.int64) | ||
| self.assertEqual(nz.split, None) | ||
| for split in [None, 0, 1]: | ||
| for cond_type in ['mean', 'max']: | ||
| a = ht.random.random((2*self.comm.size, 3*self.comm.size, 4*self.comm.size)) | ||
| if cond_type == 'mean': | ||
| cond = a > a.mean() / 2 | ||
| elif cond_type == 'max': | ||
| cond = a == a.max() | ||
| else: | ||
| raise NotImplementedError | ||
|
|
||
| nz_as_tuple = ht.nonzero(cond, as_tuple=True) | ||
| nz_as_tuple_ref = np.nonzero(cond.numpy()) | ||
| for i in range(len(nz_as_tuple)): | ||
| self.assertEqual(nz_as_tuple[i].dtype, ht.int64) | ||
| self.assertTrue(np.allclose(nz_as_tuple[i].numpy(), nz_as_tuple_ref[i])) | ||
|
|
||
| nz_no_tuple = ht.nonzero(cond, as_tuple=False) | ||
| nz_no_tuple_ref = torch.nonzero(cond.resplit(None), as_tuple=False) | ||
| self.assertEqual(nz_no_tuple.dtype, ht.int64) | ||
| self.assertTrue(np.allclose(nz_no_tuple.numpy(), nz_no_tuple_ref.numpy())) | ||
|
|
||
| if cond_type == 'max': | ||
| self.assertEqual(len(cond[cond]), 1) | ||
| for me in nz_as_tuple: | ||
| self.assertEqual(me.shape, (1,)) | ||
| self.assertEqual(nz_no_tuple.shape, (1, a.ndim)) | ||
|
|
||
| # split | ||
| a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) | ||
| cond = a > 3 | ||
| nz = cond.nonzero() | ||
| self.assertEqual(nz.gshape, (6, 2)) | ||
| self.assertEqual(nz.dtype, ht.int64) | ||
| self.assertEqual(nz.split, 0) | ||
| a[nz] = 10.0 | ||
| self.assertEqual(ht.all(a[nz] == 10), 1) | ||
|
|
||
| # edge case: single non-zero element | ||
| for split in [None, 0, 1]: | ||
| a = ht.zeros((4, 3), dtype=ht.bool, split=split) | ||
| a[1, 2] = True | ||
| nz = ht.indexing.nonzero(a) | ||
| a.resplit_(None) | ||
| nz.resplit_(None) | ||
| self.assertEqual(nz.gshape, (1, 2)) | ||
| nz = ht.indexing.nonzero(a, as_tuple=False) | ||
| self.assertTrue(ht.allclose(a[nz], a[a])) | ||
| a.comm.Barrier() | ||
|
|
||
| # as_tuple = False (torch-style output) | ||
| a = ht.array([[1, 0, 0], [0, 4, 1], [0, 6, 0]], split=1) | ||
| nz = ht.nonzero(a, as_tuple=False) | ||
| self.assertEqual(nz.gshape, (4, 2)) | ||
| self.assertEqual(nz.dtype, ht.int64) | ||
| if a.is_distributed(): | ||
| self.assertEqual(nz.split, 0) | ||
| else: | ||
| self.assertEqual(nz.split, None) | ||
| t_a = a.resplit_(None).larray | ||
| t_nz = torch.nonzero(t_a, as_tuple=False) | ||
| self.assertTrue(ht.equal(nz, ht.array(t_nz))) | ||
|
|
||
| # attribute error | ||
| a = a.numpy() | ||
| with self.assertRaises(TypeError): | ||
| ht.nonzero(a) | ||
|
|
||
| def test_where(self): | ||
| # cases to test | ||
| # no x and y | ||
| a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=None) | ||
| cond = a > 3 | ||
| wh = ht.where(cond) | ||
| self.assertEqual(wh.gshape, (6, 2)) | ||
| self.assertEqual(wh.dtype, ht.int64) | ||
| self.assertEqual(wh.split, None) | ||
| self.assertEqual(len(wh), 2) | ||
| self.assertEqual(wh[0].gshape[0], 6) | ||
| self.assertEqual(wh[0].dtype, ht.int64) | ||
| self.assertEqual(wh[0].split, None) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this could be compared to numpy and torch too? |
||
| # split | ||
| a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1) | ||
| cond = a > 3 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.