diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bd8162561..5144eca844 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,10 +10,11 @@ - [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set. - [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed. - [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element. +- [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN) ## Feature Additions -### Linear Algebra -- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot` +- [#867](https://github.com/helmholtz-analytics/heat/pull/867) Support torch 1.9.0 +- [#884](https://github.com/helmholtz-analytics/heat/pull/884) Support PyTorch 1.10.0, this is now the recommended version to use. ## Feature additions ### Communication @@ -22,9 +23,12 @@ - [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__` - [#885](https://github.com/helmholtz-analytics/heat/pull/885) New `DNDarray` method `conj` +### Factories +- [#749](https://github.com/helmholtz-analytics/heat/pull/749) `ht.array(copy=False)` behaviour now more in line with `np.array(copy=False)`, reduced memory footprint # Feature additions ### Linear Algebra - [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()` +- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot` - [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm` - [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross` - [#877](https://github.com/helmholtz-analytics/heat/pull/877) New feature `det` @@ -32,6 +36,7 @@ ### Logical - [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit` ### Manipulations +- [#749](https://github.com/helmholtz-analytics/heat/pull/749) Distributed sorted `ht.unique` - [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll` - [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes` - [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis` @@ -43,6 +48,7 @@ ### Rounding - [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn` + # v1.1.1 - [#864](https://github.com/helmholtz-analytics/heat/pull/864) Dependencies: constrain `torchvision` version range to match supported `pytorch` version range. @@ -104,6 +110,9 @@ Example on 2 processes: ### Linear Algebra - [#718](https://github.com/helmholtz-analytics/heat/pull/718) New feature: `trace()` - [#768](https://github.com/helmholtz-analytics/heat/pull/768) New feature: unary positive and negative operations + +### Manipulations +- [#820](https://github.com/helmholtz-analytics/heat/pull/820) `dot` can handle matrix vector operation now - [#820](https://github.com/helmholtz-analytics/heat/pull/820) `dot` can handle matrix-vector operation now ### Manipulations @@ -199,6 +208,7 @@ Example on 2 processes: ### Manipulations - [#690](https://github.com/helmholtz-analytics/heat/pull/690) Enhancement: reshape accepts shape arguments with one unknown dimension. - [#706](https://github.com/helmholtz-analytics/heat/pull/706) Bug fix: prevent `__setitem__`, `__getitem__` from modifying key in place +- [#744](https://github.com/helmholtz-analytics/heat/pull/744) Fix split semantics for reduction operations ### Unit testing / CI - [#717](https://github.com/helmholtz-analytics/heat/pull/717) Switch CPU CI over to Jenkins and pre-commit to GitHub action. - [#720](https://github.com/helmholtz-analytics/heat/pull/720) Ignore test files in codecov report and allow drops in code coverage. diff --git a/heat/core/communication.py b/heat/core/communication.py index 388949bcd8..8297ccc851 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -170,21 +170,21 @@ def chunk( Parameters ---------- shape : Tuple[int,...] - The global shape of the data to be split + The global shape of the data to be split. split : int - The axis along which to chunk the data + The axis along which to chunk the data. Must be within the range of ``shape``. rank : int, optional Process for which the chunking is calculated for, defaults to ``self.rank``. - Intended for creating chunk maps without communication + Intended for creating chunk maps without communication. w_size : int, optional The MPI world size, defaults to ``self.size``. - Intended for creating chunk maps without communication - + Intended for creating chunk maps without communication. """ - # ensure the split axis is valid, we actually do not need it - split = sanitize_axis(shape, split) if split is None: return 0, shape, tuple(slice(0, end) for end in shape) + if split < 0: + split = len(shape) + split + rank = self.rank if rank is None else rank w_size = self.size if w_size is None else w_size if not isinstance(rank, int) or not isinstance(w_size, int): @@ -212,7 +212,7 @@ def counts_displs_shape( self, shape: Tuple[int], axis: int ) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: """ - Calculates the item counts, displacements and output shape for a variable sized all-to-all MPI-call (e.g. + Calculates the item counts, displacements and output shape for a variable-sized all-to-all MPI-call (e.g. ``MPI_Alltoallv``). The passed shape is regularly chunk along the given axis and for all nodes. Parameters diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 2580d09090..e438677894 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -368,10 +368,6 @@ def get_halo(self, halo_size) -> torch.Tensor: halo_size : int Size of the halo. """ - if not self.is_balanced(): - raise RuntimeError( - "halo cannot be created for unbalanced tensors, running the .balance_() function is recommended" - ) if not isinstance(halo_size, int): raise TypeError( "halo_size needs to be of Python type integer, {} given".format(type(halo_size)) @@ -381,30 +377,43 @@ def get_halo(self, halo_size) -> torch.Tensor: "halo_size needs to be a positive Python integer, {} given".format(type(halo_size)) ) - if self.comm.is_distributed() and self.split is not None: + if self.is_distributed(): # gather lshapes lshape_map = self.create_lshape_map() rank = self.comm.rank size = self.comm.size + + first_rank = 0 next_rank = rank + 1 prev_rank = rank - 1 last_rank = size - 1 - # if local shape is zero and it's the last process + if not self.balanced: + populated_ranks = torch.nonzero(lshape_map[:, 0]).squeeze().tolist() + if rank in populated_ranks: + first_rank = populated_ranks[0] + last_rank = populated_ranks[-1] + next_rank = rank + 1 + prev_rank = rank - 1 + if rank != last_rank: + next_rank = populated_ranks[populated_ranks.index(rank) + 1] + if rank != first_rank: + prev_rank = populated_ranks[populated_ranks.index(rank) - 1] + + # if local shape is zero if self.lshape[self.split] == 0: return # if process has no data we ignore it if halo_size > self.lshape[self.split]: # if on at least one process the halo_size is larger than the local size throw ValueError raise ValueError( - "halo_size {} needs to be smaller than chunck-size {} )".format( + "halo_size {} needs to be smaller than chunk-size {} )".format( halo_size, self.lshape[self.split] ) ) a_prev = self.__prephalo(0, halo_size) a_next = self.__prephalo(-halo_size, None) - res_prev = None res_next = None @@ -418,7 +427,7 @@ def get_halo(self, halo_size) -> torch.Tensor: ) req_list.append(self.comm.Irecv(res_prev, source=next_rank)) - if rank != 0: + if rank != first_rank: self.comm.Isend(a_prev, prev_rank) res_next = torch.zeros( a_next.size(), dtype=a_next.dtype, device=self.device.torch_device diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index 6445930800..c81cf74e6e 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -80,7 +80,6 @@ def test_qr(self): self.assertTrue( ht.allclose(ht.eye(m, dtype=ht.double), qr2.Q @ qr2.Q.T, rtol=1e-5, atol=1e-5) ) - # test if calc R alone works a2_0 = ht.array(st2, split=0) a2_1 = ht.array(st2, split=1) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 37666abbf5..a5520257fc 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -22,6 +22,7 @@ from . import types from . import _operations + __all__ = [ "balance", "column_stack", @@ -300,7 +301,10 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: # no splits, local concat if s0 is None and s1 is None: return factories.array( - torch.cat((arr0.larray, arr1.larray), dim=axis), device=arr0.device, comm=arr0.comm + torch.cat((arr0.larray, arr1.larray), dim=axis), + device=arr0.device, + comm=arr0.comm, + copy=False, ) # non-matching splits when both arrays are split @@ -319,6 +323,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: is_split=s1 if s1 is not None else s0, device=arr1.device, comm=arr0.comm, + copy=False, ) return out @@ -334,6 +339,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: is_split=s0, device=arr0.device, comm=arr0.comm, + copy=False, ) return out @@ -504,6 +510,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: dtype=out_dtype, device=arr0.device, comm=arr0.comm, + copy=False, ) return out @@ -581,7 +588,9 @@ def diag(a: DNDarray, offset: int = 0) -> DNDarray: local = torch.zeros(lshape, dtype=a.dtype.torch_type(), device=a.device.torch_device) local[indices_x, indices_y] = a.larray[indices_x] - return factories.array(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) + return factories.array( + local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm, copy=False + ) def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDarray: @@ -655,7 +664,9 @@ def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDa vz = 1 if a.split == dim1 else -1 off, _, _ = a.comm.chunk(a.shape, a.split) result = torch.diagonal(a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2) - return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) + return factories.array( + result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm, copy=True + ) def dsplit(x: Sequence[DNDarray, ...], indices_or_sections: Iterable) -> List[DNDarray, ...]: @@ -807,14 +818,24 @@ def flatten(a: DNDarray) -> DNDarray: if a.split is None: return factories.array( - torch.flatten(a.larray), dtype=a.dtype, is_split=None, device=a.device, comm=a.comm + torch.flatten(a.larray), + dtype=a.dtype, + is_split=None, + device=a.device, + comm=a.comm, + copy=False, ) if a.split > 0: a = resplit(a, 0) a = factories.array( - torch.flatten(a.larray), dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm + torch.flatten(a.larray), + dtype=a.dtype, + is_split=a.split, + device=a.device, + comm=a.comm, + copy=False, ) a.balance_() @@ -865,7 +886,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: if a.split not in axis: return factories.array( - flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm + flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm, copy=False ) # Need to redistribute tensors on split axis @@ -1445,6 +1466,7 @@ def pad( is_split=array.split, device=array.device, comm=array.comm, + copy=False, ) padded_tensor.balance_() @@ -1604,9 +1626,9 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr # sanitation `a` if not isinstance(a, DNDarray): if isinstance(a, (int, float)): - a = factories.array([a]) + a = factories.array([a], copy=False) elif isinstance(a, (tuple, list, np.ndarray)): - a = factories.array(a) + a = factories.array(a, copy=False) else: raise TypeError( "`a` must be a ht.DNDarray, np.ndarray, list, tuple, integer, or float, currently: {}".format( @@ -2260,103 +2282,112 @@ def shape(a: DNDarray) -> Tuple[int, ...]: return a.gshape -def sort(a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DNDarray] = None): +def __pivot_sorting( + a: DNDarray, sort_op: Callable, axis: Optional[int] = None, descending: bool = False, **kwargs +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ - Sorts the elements of `a` along the given dimension (by default in ascending order) by their value. - The sorting is not stable which means that equal elements in the result may have a different ordering than in the - original array. - Sorting where `axis==a.split` needs a lot of communication between the processes of MPI. - Returns a tuple `(values, indices)` with the sorted local results and the indices of the elements in the original data + Parallel sorting function for :func:`sort` and :func:`unique`, based on [1]. Parameters ---------- a : DNDarray - Input array to be sorted. - axis : int, optional - The dimension to sort along. - Default is the last axis. - descending : bool, optional - If set to `True`, values are sorted in descending order. - out : DNDarray, optional - A location in which to store the results. If provided, it must have a broadcastable shape. If not provided - or set to `None`, a fresh array is allocated. - - Raises - ------ - ValueError - If `axis` is not consistent with the available dimensions. + Distributed input data + axis : int or None + Axis along which the operation will be performed. + sort_op : torch operation + torch.sort or torch.unique + descending : bool + Whether :func:`sort` will return elements sorted in descending order. Default: `False`. - Examples - -------- - >>> x = ht.array([[4, 1], [2, 3]], split=0) - >>> x.shape - (1, 2) - (1, 2) - >>> y = ht.sort(x, axis=0) - >>> y - (array([[2, 1]], array([[1, 0]])) - (array([[4, 3]], array([[0, 1]])) - >>> ht.sort(x, descending=True) - (array([[4, 1]], array([[0, 1]])) - (array([[3, 2]], array([[1, 0]])) + References + ---------- + [1] Li et al., 1993, "On the versatility of parallel sorting by regular sampling", Parallel Computing, Volume 19, Issue 10, pages 1079-1103 """ - stride_tricks.sanitize_axis(a.shape, axis) - - if a.split is None or axis != a.split: - # sorting is not affected by split -> we can just sort along the axis - final_result, final_indices = torch.sort(a.larray, dim=axis, descending=descending) - - else: - # sorting is affected by split, processes need to communicate results - # transpose so we can work along the 0 axis - transposed = a.larray.transpose(axis, 0) - local_sorted, local_indices = torch.sort(transposed, dim=0, descending=descending) - - size = a.comm.Get_size() - rank = a.comm.Get_rank() + size = a.comm.Get_size() + rank = a.comm.Get_rank() + transposed = a.larray.transpose(axis, 0) + if sort_op is torch.sort: counts, disp, _ = a.comm.counts_displs_shape(a.gshape, axis=axis) - + local_sorted, local_indices = sort_op(transposed, dim=0, descending=descending) actual_indices = local_indices.to(dtype=local_sorted.dtype) + disp[rank] - - length = local_sorted.size()[0] - - # Separate the sorted tensor into size + 1 equal length partitions - partitions = [x * length // (size + 1) for x in range(1, size + 1)] - local_pivots = ( - local_sorted[partitions] - if counts[rank] - else torch.empty((0,) + local_sorted.size()[1:], dtype=local_sorted.dtype) + elif sort_op is torch.unique: + local_sorted = sort_op(transposed, dim=0, **kwargs)[0] + if 0 in local_sorted.shape: + local_shape = list(transposed.shape) + local_shape[0] = 0 + local_sorted = local_sorted.reshape(local_shape) + g_local_sorted = factories.array(local_sorted, is_split=0, device=a.device, copy=False) + counts, _ = g_local_sorted.counts_displs() + local_sorted = g_local_sorted.larray + + unique_along_axis = True if sort_op is torch.unique and axis is not None else False + + length = local_sorted.shape[0] + + # Separate the sorted tensor into size + 1 equal length partitions + partitions = [x * length // (size + 1) for x in range(1, size + 1)] + local_pivots = ( + local_sorted[partitions] + if counts[rank] + else torch.empty( + (0,) + local_sorted.shape[1:], dtype=local_sorted.dtype, device=local_sorted.device ) + ) - # Only processes with elements should share their pivots - gather_counts = [int(x > 0) * size for x in counts] - gather_displs = (0,) + tuple(np.cumsum(gather_counts[:-1])) + # Only processes with elements should share their pivots + gather_counts = [int(x > 0) * size for x in counts] + gather_displs = (0,) + tuple(torch.cumsum(torch.tensor(gather_counts[:-1]), dim=0).tolist()) + pivot_dim = list(transposed.shape) + pivot_dim[0] = size * sum([1 for x in counts if x > 0]) - pivot_dim = list(transposed.size()) - pivot_dim[0] = size * sum([1 for x in counts if x > 0]) + # share the local pivots with root process + pivot_buffer = torch.empty(pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device) + a.comm.Gatherv(local_pivots, (pivot_buffer, gather_counts, gather_displs), root=0) - # share the local pivots with root process - pivot_buffer = torch.empty( - pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device - ) - a.comm.Gatherv(local_pivots, (pivot_buffer, gather_counts, gather_displs), root=0) + pivot_dim[0] = size - 1 + global_pivots = torch.empty(pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device) - pivot_dim[0] = size - 1 - global_pivots = torch.empty( - pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device + # root process creates new pivots and shares them with other processes + if rank == 0: + if sort_op is torch.sort: + sorted_pivots, _ = sort_op(pivot_buffer, dim=0, descending=descending) + else: + sorted_pivots = sort_op(pivot_buffer, dim=0, **kwargs)[0] + length = sorted_pivots.shape[0] + global_partitions = [x * length // size for x in range(1, size)] + global_pivots = sorted_pivots[global_partitions] + + a.comm.Bcast(global_pivots, root=0) + # special case: unique along axis + if unique_along_axis: + # find position of global pivots in local sorted uniques + local_sorted, local_inverse_ind = torch.cat((local_sorted, global_pivots), dim=0).unique( + dim=0, + sorted=kwargs.get("sorted") if kwargs.get("sorted") else True, + return_inverse=True, ) + # Use the inverse indices of the global pivots to work out the local partition slices + local_slices = torch.zeros(size + 1, dtype=torch.int64, device=local_sorted.device) + local_slices[1:-1] = local_inverse_ind[-global_pivots.shape[0] :] + 1 + local_slices[-1] = torch.tensor([local_sorted.shape[0]]) + # how many rows will be sent and received where + send_matrix = torch.tensor( + [local_slices[i] - local_slices[i - 1] for i in range(1, size + 1)] + ) + recv_matrix = torch.zeros(size, dtype=torch.int64, device=local_sorted.device) + a.comm.Alltoall(send_matrix, recv_matrix) + # reshape send/recv_matrix into column to match sort() alltoall scheme + for matrix in [send_matrix, recv_matrix]: + matrix = matrix.reshape(1, matrix.numel()) - # root process creates new pivots and shares them with other processes - if rank == 0: - sorted_pivots, _ = torch.sort(pivot_buffer, descending=descending, dim=0) - length = sorted_pivots.size()[0] - global_partitions = [x * length // size for x in range(1, size)] - global_pivots = sorted_pivots[global_partitions] - - a.comm.Bcast(global_pivots, root=0) - - lt_partitions = torch.empty((size,) + local_sorted.shape, dtype=torch.int64) - last = torch.zeros_like(local_sorted, dtype=torch.int64) + scounts = send_matrix + rcounts = recv_matrix + shape = (recv_matrix.sum(dim=0),) + local_sorted.shape[1:] + else: + lt_partitions = torch.empty( + (size,) + local_sorted.shape, dtype=torch.int64, device=local_sorted.device + ) + last = torch.zeros_like(local_sorted, dtype=torch.int64, device=local_sorted.device) comp_op = torch.gt if descending else torch.lt # Iterate over all pivots and store which pivot is the first greater than the elements value for idx, p in enumerate(global_pivots): @@ -2366,21 +2397,26 @@ def sort(a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DN else: lt_partitions[idx] = lt last = lt - lt_partitions[size - 1] = torch.ones_like(local_sorted, dtype=last.dtype) - last + lt_partitions[size - 1] = ( + torch.ones_like(local_sorted, dtype=last.dtype, device=local_sorted.device) - last + ) # Matrix holding information how many values will be sent where local_partitions = torch.sum(lt_partitions, dim=1) - - partition_matrix = torch.empty_like(local_partitions) + partition_matrix = torch.empty_like(local_partitions, device=local_partitions.device) a.comm.Allreduce(local_partitions, partition_matrix, op=MPI.SUM) # Matrix that holds information which value will be shipped where - index_matrix = torch.empty_like(local_sorted, dtype=torch.int64) + index_matrix = torch.empty_like(local_sorted, dtype=torch.int64, device=local_sorted.device) # Matrix holding information which process get how many values from where - shape = (size,) + transposed.size()[1:] - send_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) - recv_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) + shape = (size,) + transposed.shape[1:] + send_matrix = torch.zeros( + shape, dtype=partition_matrix.dtype, device=partition_matrix.device + ) + recv_matrix = torch.zeros( + shape, dtype=partition_matrix.dtype, device=partition_matrix.device + ) for i, x in enumerate(lt_partitions): index_matrix[x > 0] = i @@ -2391,125 +2427,207 @@ def sort(a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DN scounts = local_partitions rcounts = recv_matrix - shape = (partition_matrix[rank].max(),) + transposed.size()[1:] - first_result = torch.empty(shape, dtype=local_sorted.dtype) - first_indices = torch.empty_like(first_result) + shape = (partition_matrix[rank].max(),) + transposed.shape[1:] - # Iterate through one layer and send values with alltoallv - for idx in np.ndindex(local_sorted.shape[1:]): - idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] + first_result = torch.empty(shape, dtype=local_sorted.dtype, device=local_sorted.device) + if sort_op is torch.sort: + first_indices = torch.empty_like(first_result, device=first_result.device) - send_count = scounts[idx_slice].reshape(-1).tolist() - send_disp = [0] + list(np.cumsum(send_count[:-1])) - s_val = local_sorted[idx_slice].clone() - s_ind = actual_indices[idx_slice].clone().to(dtype=local_sorted.dtype) + # Iterate through one layer and send values with alltoallv + if unique_along_axis: + iterator = range(1) + else: + iterator = np.ndindex(local_sorted.shape[1:]) + + for idx in iterator: + if unique_along_axis: + idx_slice = [slice(None)] + else: + idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] - recv_count = rcounts[idx_slice].reshape(-1).tolist() - recv_disp = [0] + list(np.cumsum(recv_count[:-1])) - rcv_length = rcounts[idx_slice].sum().item() - r_val = torch.empty((rcv_length,) + s_val.shape[1:], dtype=local_sorted.dtype) - r_ind = torch.empty_like(r_val) + send_count = scounts[idx_slice].reshape(-1) + send_disp = [0] + torch.cumsum(send_count[:-1], dim=0).tolist() + send_count = send_count.tolist() + s_val = local_sorted[idx_slice].clone() + + recv_count = rcounts[idx_slice].reshape(-1) + recv_disp = [0] + torch.cumsum(recv_count[:-1], dim=0).tolist() + recv_count = recv_count.tolist() + rcv_length = rcounts[idx_slice].sum().item() + r_val = torch.empty( + (rcv_length,) + s_val.shape[1:], dtype=local_sorted.dtype, device=local_sorted.device + ) + a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) + first_result[idx_slice][:rcv_length] = r_val - a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) + if sort_op is torch.sort: + s_ind = actual_indices[idx_slice].clone().to(dtype=local_sorted.dtype) + r_ind = torch.empty_like(r_val, device=r_val.device) a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) - first_result[idx_slice][:rcv_length] = r_val first_indices[idx_slice][:rcv_length] = r_ind - # The process might not have the correct number of values therefore the tensors need to be rebalanced - send_vec = torch.zeros(local_sorted.shape[1:] + (size, size), dtype=torch.int64) - target_cumsum = np.cumsum(counts) - for idx in np.ndindex(local_sorted.shape[1:]): - idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] - current_counts = partition_matrix[idx_slice].reshape(-1).tolist() - current_cumsum = list(np.cumsum(current_counts)) - for proc in range(size): - if current_cumsum[proc] > target_cumsum[proc]: - # process has to many values which will be sent to higher ranks - first = next(i for i in range(size) if send_vec[idx][:, i].sum() < counts[i]) - last = next( - i - for i in range(size + 1) - if i == size or current_cumsum[proc] < target_cumsum[i] - ) - sent = 0 - for i, x in enumerate(counts[first:last]): - # Each following process gets as many elements as it needs - amount = int(x - send_vec[idx][:, first + i].sum()) - send_vec[idx][proc][first + i] = amount - current_counts[first + i] += amount - sent = send_vec[idx][proc][: first + i + 1].sum().item() - if last < size: - # Send all left over values to the highest last process - amount = partition_matrix[proc][idx] - send_vec[idx][proc][last] = int(amount - sent) - current_counts[last] += int(amount - sent) - elif current_cumsum[proc] < target_cumsum[proc]: - # process needs values from higher rank - first = ( - 0 - if proc == 0 - else next( - i for i, x in enumerate(current_cumsum) if target_cumsum[proc - 1] < x - ) - ) - last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) - for i, x in enumerate(partition_matrix[idx_slice][first:last]): - # Taking as many elements as possible from each following process - send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum()) - current_counts[first + i] = 0 - # Taking just enough elements from the last element to fill the current processes tensor - send_vec[idx][last][proc] = int(target_cumsum[proc] - current_cumsum[last - 1]) - current_counts[last] -= int(target_cumsum[proc] - current_cumsum[last - 1]) - else: - # process doesn't need more values - send_vec[idx][proc][proc] = ( - partition_matrix[proc][idx] - send_vec[idx][proc].sum() + if sort_op is torch.unique: + # early out for unique + return first_result + + # The process might not have the correct number of values therefore the tensors need to be rebalanced + send_vec = torch.zeros( + local_sorted.shape[1:] + (size, size), dtype=torch.int64, device=local_sorted.device + ) + target_cumsum = torch.cumsum(torch.tensor(counts), dim=0) + for idx in np.ndindex(local_sorted.shape[1:]): + idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] + current_counts = partition_matrix[idx_slice].reshape(-1) + current_cumsum = torch.cumsum(current_counts, dim=0).tolist() + current_counts = current_counts.tolist() + for proc in range(size): + # process has to many values which will be sent to higher ranks + if current_cumsum[proc] > target_cumsum[proc]: + first = next(i for i in range(size) if send_vec[idx][:, i].sum() < counts[i]) + last = next( + i + for i in range(size + 1) + if i == size or current_cumsum[proc] < target_cumsum[i] + ) + sent = 0 + for i, x in enumerate(counts[first:last]): + # Each following process gets as many elements as it needs + amount = int(x - send_vec[idx][:, first + i].sum()) + send_vec[idx][proc][first + i] = amount + current_counts[first + i] += amount + sent = send_vec[idx][proc][: first + i + 1].sum().item() + if last < size: + # Send all left over values to the highest last process + amount = partition_matrix[proc][idx] + send_vec[idx][proc][last] = int(amount - sent) + current_counts[last] += int(amount - sent) + elif current_cumsum[proc] < target_cumsum[proc]: + # process needs values from higher rank + first = ( + 0 + if proc == 0 + else next( + i for i, x in enumerate(current_cumsum) if target_cumsum[proc - 1] < x ) - current_counts[proc] = counts[proc] - current_cumsum = list(np.cumsum(current_counts)) + ) + last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) + for i, x in enumerate(partition_matrix[idx_slice][first:last]): + # Taking as many elements as possible from each following process + send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum()) + current_counts[first + i] = 0 + # Taking just enough elements from the last element to fill the current processes tensor + send_vec[idx][last][proc] = int(target_cumsum[proc] - current_cumsum[last - 1]) + current_counts[last] -= int(target_cumsum[proc] - current_cumsum[last - 1]) + else: + # process doesn't need more values + send_vec[idx][proc][proc] = partition_matrix[proc][idx] - send_vec[idx][proc].sum() + current_counts[proc] = counts[proc] + current_cumsum = torch.cumsum(torch.tensor(current_counts), dim=0).tolist() + + # Iterate through one layer again to create the final balanced local tensors + second_result = torch.empty_like(local_sorted, device=local_sorted.device) + second_indices = torch.empty_like(second_result, device=second_result.device) + for idx in np.ndindex(local_sorted.shape[1:]): + idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] + + send_count = send_vec[idx][rank] + send_disp = [0] + torch.cumsum(send_count[:-1], dim=0).tolist() + + recv_count = send_vec[idx][:, rank] + recv_disp = [0] + torch.cumsum(recv_count[:-1], dim=0).tolist() + + end = partition_matrix[rank][idx] + s_val, indices = first_result[0:end][idx_slice].sort(descending=descending, dim=0) + r_val = torch.empty( + (counts[rank],) + s_val.shape[1:], dtype=local_sorted.dtype, device=local_sorted.device + ) + a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) + second_result[idx_slice] = r_val + + s_ind = first_indices[0:end][idx_slice][indices].reshape_as(s_val) + r_ind = torch.empty_like(r_val) + a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) + second_indices[idx_slice] = r_ind + + second_result, tmp_indices = sort_op(second_result, dim=0, descending=descending) + final_result = second_result.transpose(0, axis) + final_indices = torch.empty_like(second_indices, device=second_indices.device) + # Update the indices in case the ordering changed during the last sort + for idx in np.ndindex(tmp_indices.shape): + val = tmp_indices[idx] + final_indices[idx] = second_indices[val.item()][idx[1:]] + final_indices = final_indices.transpose(0, axis) + return final_result, final_indices + + +def sort( + a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DNDarray] = None +) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: + """ + Sorts the elements of `a` along the given dimension (by default in ascending order) by their value. + Returns a tuple `(values, indices)` with the sorted local results and the indices of the elements in the original data. + The sorting is not stable, which means that equal elements in the result may have a different ordering than in the + original array. + Distributed sorting where `axis==a.split` is based on [1]. - # Iterate through one layer again to create the final balanced local tensors - second_result = torch.empty_like(local_sorted) - second_indices = torch.empty_like(second_result) - for idx in np.ndindex(local_sorted.shape[1:]): - idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] + References + ---------- + [1] Li et al., 1993, "On the versatility of parallel sorting by regular sampling", Parallel Computing, Volume 19, Issue 10, pages 1079-1103 - send_count = send_vec[idx][rank] - send_disp = [0] + list(np.cumsum(send_count[:-1])) + Parameters + ---------- + a : DNDarray + Input array to be sorted. + axis : int, optional + The dimension to sort along. + Default is the last axis. + descending : bool, optional + If set to `True`, values are sorted in descending order. + out : DNDarray, optional + A location in which to store the results. If provided, it must have a broadcastable shape. If not provided + or set to `None`, a fresh array is allocated. - recv_count = send_vec[idx][:, rank] - recv_disp = [0] + list(np.cumsum(recv_count[:-1])) + Raises + ------ + ValueError + If `axis` is not consistent with the available dimensions. - end = partition_matrix[rank][idx] - s_val, indices = first_result[0:end][idx_slice].sort(descending=descending, dim=0) - s_ind = first_indices[0:end][idx_slice][indices].reshape_as(s_val) + Examples + -------- + >>> x = ht.array([[4, 1], [2, 3]], split=0) + >>> x.shape + (1, 2) + (1, 2) + >>> y = ht.sort(x, axis=0) + >>> y + (array([[2, 1]], array([[1, 0]])) + (array([[4, 3]], array([[0, 1]])) + >>> ht.sort(x, descending=True) + (array([[4, 1]], array([[0, 1]])) + (array([[3, 2]], array([[1, 0]])) + """ + # default: using last axis + if axis is None: + axis = len(a.shape) - 1 - r_val = torch.empty((counts[rank],) + s_val.shape[1:], dtype=local_sorted.dtype) - r_ind = torch.empty_like(r_val) + stride_tricks.sanitize_axis(a.shape, axis) - a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) - a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) + if a.split is None or axis != a.split: + # sorting is not affected by split -> we can just sort along the axis + final_result, final_indices = torch.sort(a.larray, dim=axis, descending=descending) + + else: + final_result, final_indices = __pivot_sorting(a, torch.sort, axis, descending=descending) - second_result[idx_slice] = r_val - second_indices[idx_slice] = r_ind - - second_result, tmp_indices = second_result.sort(dim=0, descending=descending) - final_result = second_result.transpose(0, axis) - final_indices = torch.empty_like(second_indices) - # Update the indices in case the ordering changed during the last sort - for idx in np.ndindex(tmp_indices.shape): - val = tmp_indices[idx] - final_indices[idx] = second_indices[val.item()][idx[1:]] - final_indices = final_indices.transpose(0, axis) return_indices = factories.array( - final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm + final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm, copy=False ) if out is not None: out.larray = final_result return return_indices else: tensor = factories.array( - final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm + final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm, copy=False ) return tensor, return_indices @@ -3080,63 +3198,111 @@ def swapaxes(x: DNDarray, axis1: int, axis2: int) -> DNDarray: def unique( - a: DNDarray, sorted: bool = False, return_inverse: bool = False, axis: int = None -) -> Tuple[DNDarray, torch.tensor]: + a: DNDarray, return_inverse: bool = False, axis: Optional[int] = None +) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: """ - Finds and returns the unique elements of a `DNDarray`. - If return_inverse is `True`, the second tensor will hold the list of inverse indices - If distributed, it is most efficient if `axis!=a.split`. + Returns the sorted unique elements of an array. + + If `a` is distributed, and unique elements along a specific `axis` are required, + then `a` must be distributed along `axis`. + + The distributed implementation is based on [1]. Parameters ---------- a : DNDarray Input array. - sorted : bool, optional - Whether the found elements should be sorted before returning as output. - Warning: sorted is not working if `axis!=None and axis!=a.split` - return_inverse : bool, optional - Whether to also return the indices for where elements in the original input ended up in the returned - unique list. axis : int, optional - Axis along which unique elements should be found. Default to `None`, which will return a one dimensional list of - unique values. + The axis to operate on. If None, `a` will be flattened. + return_inverse : bool, optional + Return the indices of the unique array (for the specified `axis`, if provided) + that can be used to reconstruct `a`. + Default: False + + Returns + ------- + unique : DNDarray + The sorted unique elements of `a`. Whether `unique` is distributed depends + on the size of the unique elements with respect to the (process-local) data. + See Notes below. + inverse_indices : DNDarray + The global indices to reconstruct the original (possibly distributed) array + from `unique`. `inverse_indices` is distributed like `a`. See Notes below + on reconstructing the original array from a distributed `unique` array. + + References + ---------- + [1] Li et al., 1993, "On the versatility of parallel sorting by regular sampling", Parallel Computing, Volume 19, Issue 10, pages 1079-1103 + + Notes + ----- + The resulting `unique` will not be distributed (`unique.split=None`) if the collective + size of the unique values, in bytes, does not exceed a certain threshold + (arbitrarily defined as the size of the process-local input `a`). Otherwise, + `unique` will be distributed along 0, if `axis` is specified, or along `a.split`, + if `axis` is None. + + Warnings + -------- + `inverse_indices` will always be distributed like the original data + (if `axis is None`) or along 0, and contains the GLOBAL indices to recreate the + LOCAL portion of `a`. Before reconstructing an array based on `unique[inverse_indices]`, + make sure that `unique` is local (with `unique.resplit_(axis=None)`, see :func:`resplit`). Examples -------- >>> x = ht.array([[3, 2], [1, 3]]) - >>> ht.unique(x, sorted=True) + >>> ht.unique(x) array([1, 2, 3]) - >>> ht.unique(x, sorted=True, axis=0) + + >>> ht.unique(x, axis=0) array([[1, 3], [2, 3]]) - >>> ht.unique(x, sorted=True, axis=1) + + >>> ht.unique(x, axis=1) array([[2, 3], [3, 1]]) """ - if a.split is None: - torch_output = torch.unique( - a.larray, sorted=sorted, return_inverse=return_inverse, dim=axis - ) + if not a.is_distributed(): + torch_output = torch.unique(a.larray, sorted=True, return_inverse=return_inverse, dim=axis) if isinstance(torch_output, tuple): heat_output = tuple( - factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output + factories.array( + i, + dtype=types.canonical_heat_type(i.dtype), + split=None, + device=a.device, + copy=False, + ) + for i in torch_output ) else: - heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device) + heat_output = factories.array( + torch_output, dtype=a.dtype, split=None, device=a.device, copy=False + ) return heat_output + rank = a.comm.rank + size = a.comm.size + local_data = a.larray + inv_shape = local_data.shape if axis is None else (local_data.shape[axis],) unique_axis = None - inverse_indices = None - if axis is not None: - # transpose so we can work along the 0 axis - local_data = local_data.transpose(0, axis) + if axis != a.split: + raise NotImplementedError( + "Not implemented yet: Operation axis differs from distribution axis: axis is {}, array.split is {}".format( + axis, a.split + ) + ) + if axis != 0: + # transpose so we can work along the 0 axis + local_data = local_data.transpose(0, axis) unique_axis = 0 - # Calculate the unique on the local values + # Calculate local uniques if a.lshape[a.split] == 0: - # Passing an empty vector to torch throws exception + # address empty local tensor if axis is None: res_shape = [0] inv_shape = list(a.gshape) @@ -3146,146 +3312,93 @@ def unique( res_shape[0] = 0 inv_shape = [0] lres = torch.empty(res_shape, dtype=a.dtype.torch_type()) - inverse_pos = torch.empty(inv_shape, dtype=torch.int64) - else: - lres, inverse_pos = torch.unique( - local_data, sorted=sorted, return_inverse=True, dim=unique_axis - ) - - # Share and gather the results with the other processes - uniques = torch.tensor([lres.shape[0]]).to(torch.int32) - uniques_buf = torch.empty((a.comm.Get_size(),), dtype=torch.int32) - a.comm.Allgather(uniques, uniques_buf) - - if axis is None or axis == a.split: - is_split = None - split = a.split - - output_dim = list(lres.shape) - output_dim[0] = uniques_buf.sum().item() - - # Gather all unique vectors - counts = list(uniques_buf.tolist()) - displs = list([0] + uniques_buf.cumsum(0).tolist()[:-1]) - gres_buf = torch.empty(output_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device) - a.comm.Allgatherv(lres, (gres_buf, counts, displs), recv_axis=0) - - if return_inverse: - # Prepare some information to generated the inverse indices list - avg_len = a.gshape[a.split] // a.comm.Get_size() - rem = a.gshape[a.split] % a.comm.Get_size() - - # Share the local reverse indices with other processes - counts = [avg_len] * a.comm.Get_size() - add_vec = [1] * rem + [0] * (a.comm.Get_size() - rem) - inverse_counts = [sum(x) for x in zip(counts, add_vec)] - inverse_displs = [0] + list(np.cumsum(inverse_counts[:-1])) - inverse_dim = list(inverse_pos.shape) - inverse_dim[a.split] = a.gshape[a.split] - inverse_buf = torch.empty(inverse_dim, dtype=inverse_pos.dtype) - - # Transpose data and buffer so we can use Allgatherv along axis=0 (axis=1 does not work properly yet) - inverse_pos = inverse_pos.transpose(0, a.split) - inverse_buf = inverse_buf.transpose(0, a.split) - a.comm.Allgatherv( - inverse_pos, (inverse_buf, inverse_counts, inverse_displs), recv_axis=0 - ) - inverse_buf = inverse_buf.transpose(0, a.split) - - # Run unique a second time - gres = torch.unique(gres_buf, sorted=sorted, return_inverse=return_inverse, dim=unique_axis) - if return_inverse: - # Use the previously gathered information to generate global inverse_indices - g_inverse = gres[1] - gres = gres[0] - if axis is None: - # Calculate how many elements we have in each layer along the split axis - elements_per_layer = 1 - for num, val in enumerate(a.gshape): - if not num == a.split: - elements_per_layer *= val - - # Create the displacements for the flattened inverse indices array - local_elements = [displ * elements_per_layer for displ in inverse_displs][1:] + [ - float("inf") - ] - - # Flatten the inverse indices array every element can be updated to represent a global index - transposed = inverse_buf.transpose(0, a.split) - transposed_shape = transposed.shape - flatten_inverse = transposed.flatten() - - # Update the index elements iteratively - cur_displ = 0 - inverse_indices = [0] * len(flatten_inverse) - for num in range(len(inverse_indices)): - if num >= local_elements[cur_displ]: - cur_displ += 1 - index = flatten_inverse[num] + displs[cur_displ] - inverse_indices[num] = g_inverse[index].tolist() - - # Convert the flattened array back to the correct global shape of a - inverse_indices = torch.tensor(inverse_indices).reshape(transposed_shape) - inverse_indices = inverse_indices.transpose(0, a.split) - - else: - inverse_indices = torch.zeros_like(inverse_buf) - steps = displs + [None] - - # Algorithm that creates the correct list for the reverse_indices - for i in range(len(steps) - 1): - begin = steps[i] - end = steps[i + 1] - for num, x in enumerate(inverse_buf[begin:end]): - inverse_indices[begin + num] = g_inverse[begin + x] - + lres = torch.unique(local_data, sorted=True, return_inverse=False, dim=unique_axis) + gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device, copy=False) + + # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally + data_max_lbytes = torch.prod(a.lshape_map[0]) * a.larray.element_size() + if gres.nbytes <= data_max_lbytes: + # gather local uniques + gres.resplit_(None) + # final round of torch.unique + lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) + lres_split = None + gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device, copy=False) else: - # Tensor is already split and does not need to be redistributed afterward - split = None - is_split = a.split - max_uniques, max_pos = uniques_buf.max(0) - # find indices of vectors - if a.comm.Get_rank() == max_pos.item(): - # Get indices of the unique vectors to share with all over processes - indices = inverse_pos.reshape(-1).unique() - else: - indices = torch.empty((max_uniques.item(),), dtype=inverse_pos.dtype) - - a.comm.Bcast(indices, root=max_pos) - - gres = local_data[indices.tolist()] + # global sorted unique + lres = __pivot_sorting(gres, torch.unique, 0, sorted=True, return_inverse=True) + # second local unique + if 0 not in lres.shape: + lres = torch.unique(lres, sorted=True, dim=unique_axis) + lres_split = 0 - inverse_indices = indices - if sorted: - raise ValueError( - "Sorting with axis != split is not supported yet. " - "See https://github.com/helmholtz-analytics/heat/issues/363" - ) + gres = factories.array(lres, dtype=a.dtype, is_split=lres_split, device=a.device, copy=False) + gres.balance_() - if axis is not None: - # transpose matrix back - gres = gres.transpose(0, axis) + if return_inverse: + # inverse indices + # allocate local tensors and global DNDarray + inverse = torch.empty(inv_shape, dtype=torch.int64, device=local_data.device) + if a.is_distributed(): + inv_split = 0 if inverse.ndim == 1 else a.split + else: + inv_split = None + global_inverse = factories.array( + inverse, is_split=inv_split, device=gres.device, copy=False + ) - split = split if a.split < len(gres.shape) else None - result = factories.array( - gres, dtype=a.dtype, device=a.device, comm=a.comm, split=split, is_split=is_split - ) - if split is not None: - result.resplit_(a.split) + unique_ranks = size if gres.is_distributed() else 1 + gres_map = gres.lshape_map + if unique_ranks > 1: + _, gres_offsets = gres.counts_displs() + gres_offsets = torch.tensor(gres_offsets, device=gres_map.device) + else: + gres_offsets = torch.tensor([0], device=gres_map.device) + lres = gres.larray + for p in range(unique_ranks): + if unique_ranks == 1: + incoming_offset = 0 + else: + origin = (rank - p) % size + incoming_offset = gres_offsets[origin] + tmp = torch.empty( + gres_map[0].tolist(), dtype=local_data.dtype, device=local_data.device + ) + # loop through unique elements, find matching position in data + for i, el in enumerate(lres): + counts = torch.zeros_like(local_data, dtype=torch.int32, device=local_data.device) + counts[torch.where(local_data == el)] = 1 + if lres.ndim > 1: + counts = torch.sum(counts, dim=tuple(range(lres.ndim))[1:]) + cond = torch.where(counts == el.numel()) + global_inverse.larray[cond] = i + incoming_offset + # if necessary, prepare to send lres to rank+1 and receive from rank-1 + if unique_ranks > 1: + dest_rank = (rank + 1) % unique_ranks + tmp[slice(None, lres.shape[0])] = lres + queue = gres.comm.Isend(tmp, dest_rank, tag=rank) + recv_from_rank = (rank - 1) % unique_ranks + next_origin = (origin - 1) % unique_ranks + incoming_size = gres_map[next_origin].tolist()[0] + queue.Wait() + gres.comm.Recv(tmp, recv_from_rank, tag=recv_from_rank) + lres = tmp[slice(None, incoming_size)] + gres.larray = lres + + if axis is not None and axis != 0: + # transpose back to original + gres = linalg.basics.transpose(gres, (axis, 0)) - return_value = result if return_inverse: - return_value = [return_value, inverse_indices.to(a.device.torch_device)] + return gres, global_inverse - return return_value + return gres DNDarray.unique: Callable[ - [DNDarray, bool, bool, int], Tuple[DNDarray, torch.tensor] -] = lambda self, sorted=False, return_inverse=False, axis=None: unique( - self, sorted, return_inverse, axis -) + [DNDarray, bool, int], Union[DNDarray, Tuple[DNDarray, DNDarray]] +] = lambda self, return_inverse=False, axis=None: unique(self, return_inverse, axis) DNDarray.unique.__doc__ = unique.__doc__ diff --git a/heat/core/memory.py b/heat/core/memory.py index 99554c88ed..07b864e9de 100644 --- a/heat/core/memory.py +++ b/heat/core/memory.py @@ -58,30 +58,24 @@ def sanitize_memory_layout(x: torch.Tensor, order: str = "C") -> torch.Tensor: if x.ndim < 2 or x.numel() == 0: # do nothing return x - dims = list(range(x.ndim)) stride = torch.tensor(x.stride()) # since strides can get a bit wonky with operations like transpose # we should assume that the tensors are row major or are distributed the default way - sdiff = stride[1:] - stride[:-1] - column_major = all(sdiff >= 0) - row_major = True if not column_major else False - if (order == "C" and row_major) or (order == "F" and column_major): + column_major = (stride[1:] - stride[:-1] >= 0).all() + if (order == "C" and not column_major) or (order == "F" and column_major): # do nothing return x - elif (order == "C" and column_major) or (order == "F" and row_major): - dims = tuple(reversed(dims)) - y = torch.empty_like(x) - permutation = x.permute(dims).contiguous() - y = y.set_( - permutation.storage(), - x.storage_offset(), - x.shape, - tuple(reversed(permutation.stride())), - ) - return y - else: - raise ValueError( - "combination of order and layout not permitted, order: {} column major: {} row major: {}".format( - order, column_major, row_major - ) + if (order == "C" and column_major) or (order == "F" and not column_major): + dims = tuple(range(x.ndim - 1, -1, -1)) + storage_offset = x.storage_offset() + shape = x.shape + x = x.permute(dims).contiguous() + reversed_stride = tuple(reversed(x.stride())) + x.set_(x.storage(), storage_offset, shape, reversed_stride) + return x + + raise ValueError( + "combination of order and layout not permitted, order: {} column major: {} row major: {}".format( + order, column_major, not column_major ) + ) diff --git a/heat/core/stride_tricks.py b/heat/core/stride_tricks.py index f07af2418a..7e3c5db571 100644 --- a/heat/core/stride_tricks.py +++ b/heat/core/stride_tricks.py @@ -97,14 +97,14 @@ def sanitize_axis( axis = None if axis is not None: - if not isinstance(axis, int) and not isinstance(axis, tuple): + if isinstance(axis, tuple): + axis = tuple(dim + len(shape) if dim < 0 else dim for dim in axis) + for dim in axis: + if dim < 0 or dim >= len(shape): + raise ValueError("axis {} is out of bounds for shape {}".format(axis, shape)) + return axis + if not isinstance(axis, int): raise TypeError("axis must be None or int or tuple, but was {}".format(type(axis))) - if isinstance(axis, tuple): - axis = tuple(dim + len(shape) if dim < 0 else dim for dim in axis) - for dim in axis: - if dim < 0 or dim >= len(shape): - raise ValueError("axis {} is out of bounds for shape {}".format(axis, shape)) - return axis if axis is None or 0 <= axis < len(shape): return axis @@ -113,7 +113,6 @@ def sanitize_axis( if axis < 0 or axis >= len(shape): raise ValueError("axis {} is out of bounds for shape {}".format(axis, shape)) - return axis diff --git a/heat/core/tests/test_arithmetics.py b/heat/core/tests/test_arithmetics.py index 1203b57512..43cf205008 100644 --- a/heat/core/tests/test_arithmetics.py +++ b/heat/core/tests/test_arithmetics.py @@ -515,6 +515,15 @@ def test_prod(self): self.assertEqual(no_axis_prod.split, None) self.assertEqual(no_axis_prod.larray, 134217728) + # check split semantics + shape_noaxis_split_axis = ht.ones((3, 3, 3), split=2) + split_axis_sum = shape_noaxis_split_axis.sum(axis=1) + self.assertIsInstance(split_axis_sum, ht.DNDarray) + self.assertEqual(split_axis_sum.shape, (3, 3)) + self.assertEqual(split_axis_sum.dtype, ht.float32) + self.assertEqual(split_axis_sum._DNDarray__array.dtype, torch.float32) + self.assertEqual(split_axis_sum.split, 1) + out_noaxis = ht.zeros((1,)) ht.prod(shape_noaxis, out=out_noaxis) self.assertEqual(out_noaxis.larray, 134217728) diff --git a/heat/core/tests/test_communication.py b/heat/core/tests/test_communication.py index 1410eaf9cc..504fd4161a 100644 --- a/heat/core/tests/test_communication.py +++ b/heat/core/tests/test_communication.py @@ -23,10 +23,6 @@ def setUpClass(cls): def test_self_communicator(self): comm = ht.core.communication.MPI_SELF - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=2) - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=-3) with self.assertRaises(TypeError): comm.chunk(self.data.shape, split=0, rank="dicndjh") @@ -47,11 +43,6 @@ def test_mpi_communicator(self): comm = ht.core.communication.MPI_WORLD self.assertLess(comm.rank, comm.size) - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=2) - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=-3) - offset, lshape, chunks = comm.chunk(self.data.shape, split=0) self.assertIsInstance(offset, int) self.assertGreaterEqual(offset, 0) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 0bfb1dbfdb..712867a17a 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -123,12 +123,6 @@ def test_gethalo(self): # exception for too large halos with self.assertRaises(ValueError): data.get_halo(4) - # exception on non balanced tensor - with self.assertRaises(RuntimeError): - data_nobalance = ht.array( - torch.empty(((data.comm.rank + 1) * 2, 3, 4)), is_split=0, device=data.device - ) - data_nobalance.get_halo(1) # test no data on process data_np = np.arange(2 * 12).reshape(2, 12) data = ht.array(data_np, split=0) @@ -167,6 +161,42 @@ def test_gethalo(self): self.assertTrue(data.halo_next is None) self.assertEqual(data_with_halos.shape, (12, 0)) + # test halo of imbalanced dndarray + if data.comm.size > 2: + t_data = torch.arange( + 5 * data.comm.rank, dtype=torch.float64, device=data.larray.device + ).reshape(data.comm.rank, 5) + if data.comm.rank > 0: + prev_data = torch.arange( + 5 * (data.comm.rank - 1), dtype=torch.float64, device=data.larray.device + ).reshape(data.comm.rank - 1, 5) + if data.comm.rank < data.comm.size - 1: + next_data = torch.arange( + 5 * (data.comm.rank + 1), dtype=torch.float64, device=data.larray.device + ).reshape(data.comm.rank + 1, 5) + data = ht.array(t_data, is_split=0) + data.get_halo(1) + data_with_halos = data.array_with_halos + if data.comm.rank == 0: + prev_halo = None + next_halo = None + new_split_size = 0 + elif data.comm.rank == 1: + prev_halo = None + next_halo = next_data[0] + new_split_size = data.larray.shape[0] + 1 + elif data.comm.rank == data.comm.size - 1: + prev_halo = prev_data[-1] + next_halo = None + new_split_size = data.larray.shape[0] + 1 + else: + prev_halo = prev_data[-1] + next_halo = next_data[0] + new_split_size = data.larray.shape[0] + 2 + self.assertEqual(data_with_halos.shape, (new_split_size, 5)) + self.assertTrue(data.halo_prev is prev_halo or (data.halo_prev == prev_halo).all()) + self.assertTrue(data.halo_next is next_halo or (data.halo_next == next_halo).all()) + def test_larray(self): # undistributed case x = ht.arange(6 * 7 * 8).reshape((6, 7, 8)) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index ef74b39a4d..5b6dc21a99 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2618,126 +2618,181 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size - rank = ht.MPI_WORLD.rank - tensor = ( - torch.arange(size, device=self.device.torch_device).repeat(size).reshape(size, size) - ) - + torch.manual_seed(42) + tensor_3d = torch.randint(0, 10 * size, (size, size, size), device=self.device.torch_device) + tensor = tensor_3d[0] + # sort along axis 0, split None data = ht.array(tensor, split=None) result, result_indices = ht.sort(data, axis=0, descending=True) - expected, exp_indices = torch.sort(tensor, dim=0, descending=True) - self.assertTrue(torch.equal(result.larray, expected)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - + expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) + self.assertTrue(torch.equal(result.larray, expected_dim0)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices_dim0).numel() == exp_indices_dim0.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) + # sort along axis 1, split None result, result_indices = ht.sort(data, axis=1, descending=True) - expected, exp_indices = torch.sort(tensor, dim=1, descending=True) - self.assertTrue(torch.equal(result.larray, expected)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - + expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) + self.assertTrue(torch.equal(result.larray, expected_dim1)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices_dim1).numel() == exp_indices_dim1.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) + # sort along axis 0, split 0 data = ht.array(tensor, split=0) - - exp_axis_zero = torch.arange(size, device=self.device.torch_device).reshape(1, size) - exp_indices = torch.tensor([[rank] * size], device=self.device.torch_device) result, result_indices = ht.sort(data, descending=True, axis=0) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - - exp_axis_one, exp_indices = ( - torch.arange(size, device=self.device.torch_device) - .reshape(1, size) - .sort(dim=1, descending=True) - ) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # # sort along axis 1, split 0 result, result_indices = ht.sort(data, descending=True, axis=1) + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] self.assertTrue(torch.equal(result.larray, exp_axis_one)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - - result1 = ht.sort(data, axis=1, descending=True) - result2 = ht.sort(data, descending=True) - self.assertTrue(ht.equal(result1[0], result2[0])) - self.assertTrue(ht.equal(result1[1], result2[1])) - + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + # sort along axis 0, split 1 data = ht.array(tensor, split=1) - - exp_axis_zero = ( - torch.tensor(rank, device=self.device.torch_device).repeat(size).reshape(size, 1) - ) - indices_axis_zero = torch.arange( - size, dtype=torch.int64, device=self.device.torch_device - ).reshape(size, 1) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] result, result_indices = ht.sort(data, axis=0, descending=True) self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - # comparison value is only true on CPU - if result_indices.larray.is_cuda is False: - self.assertTrue(torch.equal(result_indices.larray, indices_axis_zero.int())) - - exp_axis_one = ( - torch.tensor(size - rank - 1, device=self.device.torch_device) - .repeat(size) - .reshape(size, 1) - ) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 1, split 1 + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] result, result_indices = ht.sort(data, descending=True, axis=1) self.assertTrue(torch.equal(result.larray, exp_axis_one)) - self.assertTrue(torch.equal(result_indices.larray, exp_axis_one.int())) - - tensor = torch.tensor( - [ - [[2, 8, 5], [7, 2, 3]], - [[6, 5, 2], [1, 8, 7]], - [[9, 3, 0], [1, 2, 4]], - [[8, 4, 7], [0, 8, 9]], - ], - dtype=torch.int32, - device=self.device.torch_device, - ) - + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # # 3D array + tensor = tensor_3d + expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) + expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) + expected_dim2, exp_indices_dim2 = torch.sort(tensor, dim=2, descending=True) + # sort along axis 0, split 0 data = ht.array(tensor, split=0) - exp_axis_zero = torch.tensor( - [[2, 3, 0], [0, 2, 3]], dtype=torch.int32, device=self.device.torch_device - ) - if torch.cuda.is_available() and data.device == ht.gpu and size < 4: - indices_axis_zero = torch.tensor( - [[0, 2, 2], [3, 2, 0]], dtype=torch.int32, device=self.device.torch_device - ) - else: - indices_axis_zero = torch.tensor( - [[0, 2, 2], [3, 0, 0]], dtype=torch.int32, device=self.device.torch_device - ) - result, result_indices = ht.sort(data, axis=0) - first = result[0].larray - first_indices = result_indices[0].larray - if rank == 0: - self.assertTrue(torch.equal(first, exp_axis_zero)) - self.assertTrue(torch.equal(first_indices, indices_axis_zero)) - + result, result_indices = ht.sort(data, descending=True, axis=0) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 1, split 0 + result, result_indices = ht.sort(data, descending=True, axis=1) + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 0, split 1 data = ht.array(tensor, split=1) - exp_axis_one = torch.tensor([[2, 2, 3]], dtype=torch.int32, device=self.device.torch_device) - indices_axis_one = torch.tensor( - [[0, 1, 1]], dtype=torch.int32, device=self.device.torch_device - ) - result, result_indices = ht.sort(data, axis=1) - first = result[0].larray[:1] - first_indices = result_indices[0].larray[:1] - if rank == 0: - self.assertTrue(torch.equal(first, exp_axis_one)) - self.assertTrue(torch.equal(first_indices, indices_axis_one)) - + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + result, result_indices = ht.sort(data, axis=0, descending=True) + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 1, split 1 + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] + result, result_indices = ht.sort(data, descending=True, axis=1) + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + # sort along axis 0, split 2 data = ht.array(tensor, split=2) - exp_axis_two = torch.tensor([[2], [2]], dtype=torch.int32, device=self.device.torch_device) - indices_axis_two = torch.tensor( - [[0], [1]], dtype=torch.int32, device=self.device.torch_device - ) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=2) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=2) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + result, result_indices = ht.sort(data, axis=0, descending=True) + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 2, split 2 + _, _, local_slice = data.comm.chunk(expected_dim2.shape, split=2) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim2.shape, split=2) + exp_axis_one = expected_dim2[local_slice] + exp_indices = exp_indices_dim2[local_slice_ind] + result, result_indices = ht.sort(data, descending=True, axis=2) + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + # test out, descending=False result, result_indices = ht.sort(data, axis=2) - first = result[0].larray[:, :1] - first_indices = result_indices[0].larray[:, :1] - if rank == 0: - self.assertTrue(torch.equal(first, exp_axis_two)) - self.assertTrue(torch.equal(first_indices, indices_axis_two)) - # out = ht.empty_like(data) indices = ht.sort(data, axis=2, out=out) self.assertTrue(ht.equal(out, result)) self.assertTrue(ht.equal(indices, result_indices)) + # test exceptions with self.assertRaises(ValueError): ht.sort(data, axis=3) with self.assertRaises(TypeError): @@ -3428,76 +3483,75 @@ def test_topk(self): def test_unique(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank - torch_array = torch.arange(size, dtype=torch.int32, device=self.device.torch_device).expand( - size, size - ) - split_zero = ht.array(torch_array, split=0) - - exp_axis_none = ht.array([rank], dtype=ht.int32) - res = split_zero.unique(sorted=True) - self.assertTrue((res.larray == exp_axis_none.larray).all()) - - exp_axis_zero = ht.arange(size, dtype=ht.int32).expand_dims(0) - res = ht.unique(split_zero, sorted=True, axis=0) - self.assertTrue((res.larray == exp_axis_zero.larray).all()) - - exp_axis_one = ht.array([rank], dtype=ht.int32).expand_dims(0) - split_zero_transposed = ht.array(torch_array.transpose(0, 1), split=0) - res = ht.unique(split_zero_transposed, sorted=False, axis=1) - self.assertTrue((res.larray == exp_axis_one.larray).all()) - - split_one = ht.array(torch_array, dtype=ht.int32, split=1) - - exp_axis_none = ht.arange(size, dtype=ht.int32) - res = ht.unique(split_one, sorted=True) - self.assertTrue((res.larray == exp_axis_none.larray).all()) - - exp_axis_zero = ht.array([rank], dtype=ht.int32).expand_dims(0) - res = ht.unique(split_one, sorted=False, axis=0) - self.assertTrue((res.larray == exp_axis_zero.larray).all()) - - exp_axis_one = ht.array([rank] * size, dtype=ht.int32).expand_dims(1) - res = ht.unique(split_one, sorted=True, axis=1) - self.assertTrue((res.larray == exp_axis_one.larray).all()) - - torch_array = torch.tensor( - [[1, 2], [2, 3], [1, 2], [2, 3], [1, 2]], - dtype=torch.int32, - device=self.device.torch_device, - ) - data = ht.array(torch_array, split=0) - - res, inv = ht.unique(data, return_inverse=True, axis=0) - _, exp_inv = torch_array.unique(dim=0, return_inverse=True, sorted=True) - self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + # "sparse" data + sparse_data = ht.array( + torch.zeros(10, 4, dtype=torch.int32, device=self.device.torch_device), is_split=0 + ) + random_ranks = torch.randint(size, size=(size // 2 + 1,)).tolist() + if rank in random_ranks: + random_row = torch.randint(10, size=(10,)) + random_col = torch.randint(4, size=(10,)) + sparse_data.larray[random_row, random_col] = 1 + t_sparse = ht.resplit(sparse_data, axis=None).larray + + # "dense" data + dense_data = ht.random.randint(0, 25, (50, 3), dtype=ht.int64, split=0) + t_dense = ht.resplit(dense_data, axis=None).larray + + datasets = [sparse_data, dense_data] + comps = [t_sparse, t_dense] + + for data, comp in zip(datasets, comps): + _, _, local_slice = data.comm.chunk(data.gshape, data.split) + # axis is None + unique, inverse = ht.unique(data, return_inverse=True) + unique.resplit_(None) + t_unique, t_inverse = torch.unique(comp, sorted=True, return_inverse=True) + self.assertTrue((unique.larray == t_unique).all()) + self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) + if data.is_distributed(): + self.assertTrue(unique.split is None or unique.split == 0) + else: + self.assertTrue(unique.split is None) + # test inverse indices on "gathered" unique + self.assertTrue((unique[inverse.larray].larray == data.larray).all()) + + # axis not None + axis = 0 + unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) + unique0.resplit_(None) + t_unique0, t_inverse0 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) + self.assertTrue((unique0.larray == t_unique0).all()) + self.assertTrue((inverse0.larray == t_inverse0[local_slice[axis]]).all()) + if data.is_distributed(): + self.assertTrue(unique0.split is None or unique0.split == axis) + else: + self.assertTrue(unique0.split is None) + # test inverse indices on "gathered" unique + self.assertTrue((unique0[inverse0.larray].larray == data.larray).all()) + + # axis == split != 0 + data = ht.array(comp, split=1) + _, _, local_slice = data.comm.chunk(data.gshape, data.split) + axis = 1 + unique1, inverse1 = ht.unique(data, return_inverse=True, axis=axis) + unique1.resplit_(None) + t_unique1, t_inverse1 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) + self.assertTrue((unique1.larray == t_unique1).all()) + self.assertTrue((inverse1.larray == t_inverse1[local_slice[axis]]).all()) + if data.is_distributed(): + self.assertTrue(unique1.split is None or unique1.split == axis) + else: + self.assertTrue(unique1.split is None) + # test inverse indices on "gathered" unique + self.assertTrue((unique1[:, inverse1.larray].larray == data.larray).all()) - res, inv = ht.unique(data, return_inverse=True, axis=1) - _, exp_inv = torch_array.unique(dim=1, return_inverse=True, sorted=True) - self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + # test unique on sorted data - torch_array = torch.tensor( - [[1, 1, 2], [1, 2, 2], [2, 1, 2], [1, 3, 2], [0, 1, 2]], - dtype=torch.int32, - device=self.device.torch_device, - ) - exp_res, exp_inv = torch_array.unique(return_inverse=True, sorted=True) - - data_split_none = ht.array(torch_array) - res = ht.unique(data_split_none, sorted=True) - self.assertIsInstance(res, ht.DNDarray) - self.assertEqual(res.split, None) - self.assertEqual(res.dtype, data_split_none.dtype) - self.assertEqual(res.device, data_split_none.device) - res, inv = ht.unique(data_split_none, return_inverse=True, sorted=True) - self.assertIsInstance(inv, ht.DNDarray) - self.assertEqual(inv.split, None) - self.assertEqual(inv.dtype, data_split_none.dtype) - self.assertEqual(inv.device, data_split_none.device) - self.assertTrue(torch.equal(inv.larray, exp_inv.int())) - - data_split_zero = ht.array(torch_array, split=0) - res, inv = ht.unique(data_split_zero, return_inverse=True, sorted=True) - self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + # test exceptions + if dense_data.is_distributed(): + with self.assertRaises(NotImplementedError): + ht.unique(dense_data, axis=1) def test_vsplit(self): # for further testing, see test_split diff --git a/heat/naive_bayes/gaussianNB.py b/heat/naive_bayes/gaussianNB.py index 0cc18f78bb..836e906604 100644 --- a/heat/naive_bayes/gaussianNB.py +++ b/heat/naive_bayes/gaussianNB.py @@ -57,7 +57,7 @@ class GaussianNB(ht.ClassificationMixin, ht.BaseEstimator): >>> print(clf.predict(ht.array([[-0.8, -1]]))) tensor([1]) >>> clf_pf = GaussianNB() - >>> clf_pf.partial_fit(X, Y, ht.unique(Y, sorted=True)) + >>> clf_pf.partial_fit(X, Y, ht.unique(Y)) >>> print(clf_pf.predict(ht.array([[-0.8, -1]]))) tensor([1]) @@ -95,7 +95,7 @@ def fit(self, x: DNDarray, y: DNDarray, sample_weight: Optional[DNDarray] = None type(sample_weight) ) ) - classes = ht.unique(y, sorted=True) + classes = ht.unique(y) if classes.split is not None: classes = ht.resplit(classes, axis=None) @@ -335,7 +335,9 @@ def __partial_fit( classes = self.classes_ - unique_y = ht.unique(y, sorted=True).resplit_(None) + unique_y = ht.unique(y) + if unique_y.split is not None: + unique_y = ht.resplit(unique_y, axis=None) unique_y_in_classes = ht.eq(unique_y, classes) if not ht.all(unique_y_in_classes):