Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr_update.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
quick-tests:
runs-on: ubuntu-latest
strategy:
fail-fast: true
fail-fast: false
matrix:
py-version:
- '3.11' # Oldest supported
Expand Down
4 changes: 2 additions & 2 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
# TODO: remove this resplit!!
key = manipulations.resplit(key)
if key.larray.dtype in [torch.bool, torch.uint8]:
key = indexing.nonzero(key)
key = indexing.nonzero(key, as_tuple=False)

if key.ndim > 1:
key = list(key.larray.split(1, dim=1))
Expand Down Expand Up @@ -1626,7 +1626,7 @@ def __setitem__(
to be used."""
key = manipulations.resplit(key)
if key.larray.dtype in [torch.bool, torch.uint8]:
key = indexing.nonzero(key)
key = indexing.nonzero(key, as_tuple=False)

if key.ndim > 1:
key = list(key.larray.split(1, dim=1))
Expand Down
177 changes: 134 additions & 43 deletions heat/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,40 @@
"""

import torch
from typing import List, Dict, Any, TypeVar, Union, Tuple, Sequence

from .communication import MPI
from .dndarray import DNDarray
from . import sanitation
from . import factories
Comment thread
brownbaerchen marked this conversation as resolved.
from .sanitation import sanitize_in
from . import types
from . import manipulations

__all__ = ["nonzero", "where"]


def nonzero(x: DNDarray) -> DNDarray:
def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarray:
"""
Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero (using ``torch.nonzero``).
If ``x`` is split then the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray`
Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``,
containing the indices of the non-zero elements in that dimension. If ``x`` is split then
the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray`
can be UNBALANCED as it contains the indices of the non-zero elements on each node.
Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension.
The values in ``x`` are always tested and returned in row-major, C-style order.
The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``.

Parameters
----------
x: DNDarray
Input array
as_tuple: bool, optional
Default is True for numpy-style nonzero output. If False, the output is a torch-style single 2D ``DNDarray`` of shape `(num_nonzero, ndim)` containing the indices of the non-zero elements.

Examples
--------
>>> import heat as ht
>>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0)
>>> ht.nonzero(x)
DNDarray([[0, 0],
[1, 1],
[1, 2],
[2, 1]], dtype=ht.int64, device=cpu:0, split=0)
(DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None),
DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None))
>>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0)
>>> y > 3
DNDarray([[False, False, False],
Expand All @@ -48,48 +49,106 @@ def nonzero(x: DNDarray) -> DNDarray:
[2, 0],
[2, 1],
[2, 2]], dtype=ht.int64, device=cpu:0, split=0)
(DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None),
DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None))
>>> y[ht.nonzero(y > 3)]
DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0)
"""
sanitation.sanitize_in(x)

sanitize_in(x)

if not x.is_distributed():
# nonzero indices as tuple
nonzero = torch.nonzero(input=x.larray, as_tuple=as_tuple)
# bookkeeping for final DNDarray construct
if as_tuple:
nonzero = list(nonzero)
for i, nz_tensor in enumerate(nonzero):
nonzero[i] = factories.array(nz_tensor, device=x.device, comm=x.comm)
return tuple(nonzero)
else:
# nonzero indices as single 2D DNDarray
return factories.array(nonzero, device=x.device, comm=x.comm)

# distributed case
lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False)

# add offsets mapping from local indices to global indices if x is split
if x.split is not None:
_, _, slices = x.comm.chunk(x.shape, x.split)
lcl_nonzero[..., x.split] += slices[x.split].start

if x.ndim == 1:
lcl_nonzero = lcl_nonzero.squeeze(dim=1)

# compute global shape of the index array
gout = list(lcl_nonzero.shape)
if x.split is None:
is_split = None
nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64, device="cpu")
nonzero_dtype = types.canonical_heat_type(lcl_nonzero.dtype)

# global nonzero_size
x.comm.Allreduce(MPI.IN_PLACE, nonzero_size, MPI.SUM)
# correct indices along split axis
_, displs = x.counts_displs()
lcl_nonzero[:, x.split] += displs[x.comm.rank]

if x.split == 0:
# for split=0, the local nonzero indices are already globally ordered along the split axis
if not as_tuple:
# return indices as single 2D DNDarray
return DNDarray(
lcl_nonzero,
gshape=(nonzero_size.item(), x.ndim),
dtype=nonzero_dtype,
split=0,
device=x.device,
comm=x.comm,
balanced=False,
)
# return indices as tuple of 1D DNDarrays
lcl_nonzero = lcl_nonzero.unbind(dim=1)
return tuple(
DNDarray(
nz_tensor,
gshape=(nonzero_size.item(),),
dtype=nonzero_dtype,
split=0,
device=x.device,
comm=x.comm,
balanced=False,
)
for nz_tensor in lcl_nonzero
)
else:
gout[0] = x.comm.allreduce(gout[0], MPI.SUM)
is_split = 0

return DNDarray(
lcl_nonzero,
gshape=tuple(gout),
dtype=types.canonical_heat_type(lcl_nonzero.dtype),
split=is_split,
device=x.device,
comm=x.comm,
balanced=False,
)
# construct global 2D DNDarray of nz indices:
shape_2d = (nonzero_size.item(), x.ndim)
global_nonzero = DNDarray(
lcl_nonzero,
gshape=shape_2d,
dtype=nonzero_dtype,
split=0,
device=x.device,
comm=x.comm,
balanced=False,
)
# vectorized sorting of nz indices along axis 0
global_nonzero.balance_()
global_nonzero = manipulations.unique(global_nonzero, axis=0)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Seems like duplicate entries would be a bug at this point. Or are there some side effects of unique here?

if not as_tuple:
# return indices as single 2D DNDarray
return global_nonzero
# return indices as tuple of 1D DNDarrays
lcl_nonzero = global_nonzero.larray.unbind(dim=1)
return tuple(
DNDarray(
nz_tensor,
gshape=(nonzero_size.item(),),
dtype=nonzero_dtype,
split=0,
device=x.device,
comm=x.comm,
balanced=True,
)
for nz_tensor in lcl_nonzero
)


DNDarray.nonzero = lambda self: nonzero(self)
DNDarray.nonzero = lambda self: nonzero(self, as_tuple=True)
DNDarray.nonzero.__doc__ = nonzero.__doc__


def where(
cond: DNDarray,
x: Union[None, int, float, DNDarray] = None,
y: Union[None, int, float, DNDarray] = None,
x: None | int | float | DNDarray = None,
y: None | int | float | DNDarray = None,
) -> DNDarray:
"""
Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition.
Expand Down Expand Up @@ -128,20 +187,52 @@ def where(
[ 0, 2, -1],
[ 0, 3, -1]], dtype=ht.int64, device=cpu:0, split=None)
"""
# ---- binary where(cond, x, y) branch ------------------------------------
if cond.split is not None and (isinstance(x, DNDarray) or isinstance(y, DNDarray)):
if (isinstance(x, DNDarray) and cond.split != x.split) or (
isinstance(y, DNDarray) and cond.split != y.split
):
if len(y.shape) >= 1 and y.shape[0] > 1:
# Only raise if the "other" array has a meaningful first dimension.
if isinstance(y, DNDarray) and len(y.shape) >= 1 and y.shape[0] > 1:
raise NotImplementedError("binary op not implemented for different split axes")

if isinstance(x, (DNDarray, int, float)) and isinstance(y, (DNDarray, int, float)):
# Simple elementwise selection using arithmetic:
# cond == 0 -> take y, cond == 1 -> take x
for var in [x, y]:
if isinstance(var, int):
var = float(var)
return cond.dtype(cond == 0) * y + cond * x

# ---- where(cond) "indices only" branch ----------------------------------
elif x is None and y is None:
return nonzero(cond)
# General rule: delegate to nonzero(cond), and only wrap into a 2-D
# coordinate matrix in the special distributed case where the array
# is split along a non-zero axis.
nz = nonzero(cond) # tuple of DNDarrays, one per dimension

# 1) Non-distributed: behave exactly like ht.nonzero(cond)
if cond.split is None:
return nz

# 2) Distributed along axis 0: keep the legacy tuple-of-indices API.
# This is relied upon in several parts of the code base (e.g. KMeans).
if cond.split == 0:
return nz

# 3) Distributed along a non-zero axis (split > 0)
coords = manipulations.stack(nz, axis=1)
coords = coords.astype(types.int64, copy=False)

# Ensure indices are split along axis 0 for stable distributed behavior
if coords.split is None:
coords.resplit_(0)

return coords

# ---- invalid combinations ----------------------------------------------
else:
raise TypeError(
f"either both or neither x and y must be given and both must be DNDarrays or numerical scalars({type(x)}, {type(y)})"
"either both or neither x and y must be given and both must be "
f"DNDarrays or numerical scalars (got {type(x)}, {type(y)})"
)
75 changes: 51 additions & 24 deletions tests/core/test_indexing.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,75 @@
import heat as ht
from heat.testing.basic_test import TestCase

import torch
import numpy as np

class TestIndexing(TestCase):
def test_nonzero(self):
# cases to test:
# not split
a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None)
cond = a > 3
nz = ht.nonzero(cond)
self.assertEqual(nz.gshape, (5, 2))
self.assertEqual(nz.dtype, ht.int64)
self.assertEqual(nz.split, None)
for split in [None, 0, 1]:
for cond_type in ['mean', 'max']:
a = ht.random.random((2*self.comm.size, 3*self.comm.size, 4*self.comm.size))
if cond_type == 'mean':
cond = a > a.mean() / 2
elif cond_type == 'max':
cond = a == a.max()
else:
raise NotImplementedError

nz_as_tuple = ht.nonzero(cond, as_tuple=True)
nz_as_tuple_ref = np.nonzero(cond.numpy())
for i in range(len(nz_as_tuple)):
self.assertEqual(nz_as_tuple[i].dtype, ht.int64)
self.assertTrue(np.allclose(nz_as_tuple[i].numpy(), nz_as_tuple_ref[i]))

nz_no_tuple = ht.nonzero(cond, as_tuple=False)
nz_no_tuple_ref = torch.nonzero(cond.resplit(None), as_tuple=False)
self.assertEqual(nz_no_tuple.dtype, ht.int64)
self.assertTrue(np.allclose(nz_no_tuple.numpy(), nz_no_tuple_ref.numpy()))

if cond_type == 'max':
self.assertEqual(len(cond[cond]), 1)
for me in nz_as_tuple:
self.assertEqual(me.shape, (1,))
self.assertEqual(nz_no_tuple.shape, (1, a.ndim))

# split
a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1)
cond = a > 3
nz = cond.nonzero()
self.assertEqual(nz.gshape, (6, 2))
self.assertEqual(nz.dtype, ht.int64)
self.assertEqual(nz.split, 0)
a[nz] = 10.0
self.assertEqual(ht.all(a[nz] == 10), 1)

# edge case: single non-zero element
for split in [None, 0, 1]:
a = ht.zeros((4, 3), dtype=ht.bool, split=split)
a[1, 2] = True
nz = ht.indexing.nonzero(a)
a.resplit_(None)
nz.resplit_(None)
self.assertEqual(nz.gshape, (1, 2))
nz = ht.indexing.nonzero(a, as_tuple=False)
self.assertTrue(ht.allclose(a[nz], a[a]))
a.comm.Barrier()

# as_tuple = False (torch-style output)
a = ht.array([[1, 0, 0], [0, 4, 1], [0, 6, 0]], split=1)
nz = ht.nonzero(a, as_tuple=False)
self.assertEqual(nz.gshape, (4, 2))
self.assertEqual(nz.dtype, ht.int64)
if a.is_distributed():
self.assertEqual(nz.split, 0)
else:
self.assertEqual(nz.split, None)
t_a = a.resplit_(None).larray
t_nz = torch.nonzero(t_a, as_tuple=False)
self.assertTrue(ht.equal(nz, ht.array(t_nz)))

# attribute error
a = a.numpy()
with self.assertRaises(TypeError):
ht.nonzero(a)

def test_where(self):
# cases to test
# no x and y
a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=None)
cond = a > 3
wh = ht.where(cond)
self.assertEqual(wh.gshape, (6, 2))
self.assertEqual(wh.dtype, ht.int64)
self.assertEqual(wh.split, None)
self.assertEqual(len(wh), 2)
self.assertEqual(wh[0].gshape[0], 6)
self.assertEqual(wh[0].dtype, ht.int64)
self.assertEqual(wh[0].split, None)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this could be compared to numpy and torch too?

# split
a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1)
cond = a > 3
Expand Down
Loading