From 976bec2f2201a6db6beb9c915e548509efe47008 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 19 Feb 2021 12:43:28 +0100 Subject: [PATCH 01/87] Modify get_halo to work with non-balanced DNDarray --- heat/core/dndarray.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 785bf21ed0..6a3c968196 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -413,10 +413,6 @@ def get_halo(self, halo_size): halo_size : int Size of the halo. """ - if not self.is_balanced(): - raise RuntimeError( - "halo cannot be created for unbalanced tensors, running the .balance_() function is recommended" - ) if not isinstance(halo_size, int): raise TypeError( "halo_size needs to be of Python type integer, {} given".format(type(halo_size)) @@ -426,23 +422,30 @@ def get_halo(self, halo_size): "halo_size needs to be a positive Python integer, {} given".format(type(halo_size)) ) - if self.comm.is_distributed() and self.split is not None: + if self.is_distributed(): # gather lshapes lshape_map = self.create_lshape_map() rank = self.comm.rank size = self.comm.size - next_rank = rank + 1 - prev_rank = rank - 1 - last_rank = size - 1 - # if local shape is zero and it's the last process + if not self.balanced: + populated_ranks = torch.nonzero(lshape_map[:, 0]).squeeze().tolist() + next_rank = populated_ranks.index(rank) + 1 + prev_rank = populated_ranks.index(rank) - 1 + last_rank = populated_ranks[-1] + else: + next_rank = rank + 1 + prev_rank = rank - 1 + last_rank = size - 1 + + # if local shape is zero if self.lshape[self.split] == 0: return # if process has no data we ignore it if halo_size > self.lshape[self.split]: # if on at least one process the halo_size is larger than the local size throw ValueError raise ValueError( - "halo_size {} needs to be smaller than chunck-size {} )".format( + "halo_size {} needs to be smaller than chunk-size {} )".format( halo_size, self.lshape[self.split] ) ) From 30b3ec3fa63f728d8d4ff1004885df7fbd62c5b4 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 5 Mar 2021 06:26:10 +0100 Subject: [PATCH 02/87] Create lshape_map without communication if DNDarray is balanced. --- heat/core/dndarray.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 6a3c968196..23b8df3a0b 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1142,11 +1142,22 @@ def create_lshape_map(self): lshape_map : torch.Tensor Units -> (process rank, lshape) """ + if not self.is_distributed: + return torch.tensor(self.gshape) + lshape_map = torch.zeros( (self.comm.size, len(self.gshape)), dtype=torch.int, device=self.device.torch_device ) - lshape_map[self.comm.rank, :] = torch.tensor(self.lshape, device=self.device.torch_device) - self.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) + if self.is_balanced(): + for i in range(self.comm.size): + _, lshape, _ = self.comm.chunk(self.gshape, self.split, rank=i) + lshape_map[i, :] = torch.tensor(lshape, device=self.device.torch_device) + else: + lshape_map[self.comm.rank, :] = torch.tensor( + self.lshape, device=self.device.torch_device + ) + self.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) + return lshape_map def __eq__(self, other): From db4ebed91d49de3f338aa110ed2b5b9e1dd0aaa5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 5 Mar 2021 17:08:02 +0100 Subject: [PATCH 03/87] in-place resplit to work in imbalanced DNDarrays as well --- heat/core/dndarray.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 23b8df3a0b..9d21335fe7 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1143,7 +1143,7 @@ def create_lshape_map(self): Units -> (process rank, lshape) """ if not self.is_distributed: - return torch.tensor(self.gshape) + return torch.tensor(self.gshape).reshape(1, self.ndim) lshape_map = torch.zeros( (self.comm.size, len(self.gshape)), dtype=torch.int, device=self.device.torch_device @@ -2874,7 +2874,13 @@ def resplit_(self, axis=None): gathered = torch.empty( self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device ) - counts, displs, _ = self.comm.counts_displs_shape(self.shape, self.split) + if self.is_balanced(): + counts, displs, _ = self.comm.counts_displs_shape(self.shape, self.split) + else: + counts = self.create_lshape_map()[self.split] + displs = torch.cumsum( + torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 + ) self.comm.Allgatherv(self.__array, (gathered, counts, displs), recv_axis=self.split) self.__array = gathered self.__split = axis From 060b48a9d9cbd9990cdec586fbaa2eb893f560ed Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 5 Mar 2021 17:54:54 +0100 Subject: [PATCH 04/87] Implement distributed unique, return inverse indices --- heat/core/manipulations.py | 234 +++++++++++++++---------------------- 1 file changed, 96 insertions(+), 138 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 96848c277f..f4362f4357 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2682,9 +2682,9 @@ def stack(arrays, axis=0, out=None): return stacked -def unique(a, sorted=False, return_inverse=False, axis=None): +def unique(a, return_inverse=False, axis=None): """ - Finds and returns the unique elements of an array. + Finds and returns the sorted unique elements of an array. Works most effective if axis != a.split. @@ -2725,10 +2725,8 @@ def unique(a, sorted=False, return_inverse=False, axis=None): array([[2, 3], [3, 1]]) """ - if a.split is None: - torch_output = torch.unique( - a.larray, sorted=sorted, return_inverse=return_inverse, dim=axis - ) + if not a.is_distributed: + torch_output = torch.unique(a.larray, sorted=True, return_inverse=return_inverse, dim=axis) if isinstance(torch_output, tuple): heat_output = tuple( factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output @@ -2737,18 +2735,24 @@ def unique(a, sorted=False, return_inverse=False, axis=None): heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device) return heat_output + rank = a.comm.rank + size = a.comm.size + local_data = a.larray + unique_axis = None - inverse_indices = None if axis is not None: - # transpose so we can work along the 0 axis - local_data = local_data.transpose(0, axis) - unique_axis = 0 + if axis != 0: + # transpose so we can work along the 0 axis + local_data = local_data.transpose(0, axis) + unique_axis = 0 + else: + unique_axis = axis - # Calculate the unique on the local values + # Calculate local uniques if a.lshape[a.split] == 0: - # Passing an empty vector to torch throws exception + # address empty local tensor if axis is None: res_shape = [0] inv_shape = list(a.gshape) @@ -2758,139 +2762,87 @@ def unique(a, sorted=False, return_inverse=False, axis=None): res_shape[0] = 0 inv_shape = [0] lres = torch.empty(res_shape, dtype=a.dtype.torch_type()) - inverse_pos = torch.empty(inv_shape, dtype=torch.int64) - else: - lres, inverse_pos = torch.unique( - local_data, sorted=sorted, return_inverse=True, dim=unique_axis - ) - - # Share and gather the results with the other processes - uniques = torch.tensor([lres.shape[0]]).to(torch.int32) - uniques_buf = torch.empty((a.comm.Get_size(),), dtype=torch.int32) - a.comm.Allgather(uniques, uniques_buf) - - if axis is None or axis == a.split: - is_split = None - split = a.split - - output_dim = list(lres.shape) - output_dim[0] = uniques_buf.sum().item() - - # Gather all unique vectors - counts = list(uniques_buf.tolist()) - displs = list([0] + uniques_buf.cumsum(0).tolist()[:-1]) - gres_buf = torch.empty(output_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device) - a.comm.Allgatherv(lres, (gres_buf, counts, displs), recv_axis=0) - - if return_inverse: - # Prepare some information to generated the inverse indices list - avg_len = a.gshape[a.split] // a.comm.Get_size() - rem = a.gshape[a.split] % a.comm.Get_size() - - # Share the local reverse indices with other processes - counts = [avg_len] * a.comm.Get_size() - add_vec = [1] * rem + [0] * (a.comm.Get_size() - rem) - inverse_counts = [sum(x) for x in zip(counts, add_vec)] - inverse_displs = [0] + list(np.cumsum(inverse_counts[:-1])) - inverse_dim = list(inverse_pos.shape) - inverse_dim[a.split] = a.gshape[a.split] - inverse_buf = torch.empty(inverse_dim, dtype=inverse_pos.dtype) - - # Transpose data and buffer so we can use Allgatherv along axis=0 (axis=1 does not work properly yet) - inverse_pos = inverse_pos.transpose(0, a.split) - inverse_buf = inverse_buf.transpose(0, a.split) - a.comm.Allgatherv( - inverse_pos, (inverse_buf, inverse_counts, inverse_displs), recv_axis=0 - ) - inverse_buf = inverse_buf.transpose(0, a.split) - - # Run unique a second time - gres = torch.unique(gres_buf, sorted=sorted, return_inverse=return_inverse, dim=unique_axis) - if return_inverse: - # Use the previously gathered information to generate global inverse_indices - g_inverse = gres[1] - gres = gres[0] - if axis is None: - # Calculate how many elements we have in each layer along the split axis - elements_per_layer = 1 - for num, val in enumerate(a.gshape): - if not num == a.split: - elements_per_layer *= val - - # Create the displacements for the flattened inverse indices array - local_elements = [displ * elements_per_layer for displ in inverse_displs][1:] + [ - float("inf") - ] - - # Flatten the inverse indices array every element can be updated to represent a global index - transposed = inverse_buf.transpose(0, a.split) - transposed_shape = transposed.shape - flatten_inverse = transposed.flatten() - - # Update the index elements iteratively - cur_displ = 0 - inverse_indices = [0] * len(flatten_inverse) - for num in range(len(inverse_indices)): - if num >= local_elements[cur_displ]: - cur_displ += 1 - index = flatten_inverse[num] + displs[cur_displ] - inverse_indices[num] = g_inverse[index].tolist() - - # Convert the flattened array back to the correct global shape of a - inverse_indices = torch.tensor(inverse_indices).reshape(transposed_shape) - inverse_indices = inverse_indices.transpose(0, a.split) - - else: - inverse_indices = torch.zeros_like(inverse_buf) - steps = displs + [None] - - # Algorithm that creates the correct list for the reverse_indices - for i in range(len(steps) - 1): - begin = steps[i] - end = steps[i + 1] - for num, x in enumerate(inverse_buf[begin:end]): - inverse_indices[begin + num] = g_inverse[begin + x] - + lres = torch.unique(local_data, sorted=True, return_inverse=False, dim=unique_axis) + inv_shape = local_data.shape if axis is None else (local_data.shape[unique_axis],) + gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) + + # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally + _, data_max_lshape, _ = a.comm.chunk(a.gshape, a.split, rank=0) + data_max_lbytes = torch.prod(torch.tensor(data_max_lshape)) * a.larray.element_size() + if gres.nbytes <= data_max_lbytes: + print("RUNNING SPARSE UNIQUE") + # gather local uniques + gres.resplit_(None) + # final round of torch.unique + lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) + gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) else: - # Tensor is already split and does not need to be redistributed afterward - split = None - is_split = a.split - max_uniques, max_pos = uniques_buf.max(0) - # find indices of vectors - if a.comm.Get_rank() == max_pos.item(): - # Get indices of the unique vectors to share with all over processes - indices = inverse_pos.reshape(-1).unique() + print("RUNNING DENSE UNIQUE") + # balance gres if needed + gres.balance_() + # global sort + gres, sorted_gindices = sort(gres, axis=unique_axis) + # second local unique + lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) + gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) + # get rid of doubles at the edges + gres.get_halo(1) + if gres.halo_prev is not None and (gres.halo_prev == lres[0]).all(): + lres = lres[1:] + gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) + gres.balance_() + lres = gres.larray + + # inverse indices + if return_inverse: + # allocate local tensors + inverse_pos = torch.empty(inv_shape, dtype=torch.int64, device=local_data.device) + unique_ranks = size if gres.is_distributed() else 1 + if unique_ranks > 1: + gres_map = gres.create_lshape_map() + gres_offsets = torch.cat( + (torch.tensor([0], device=gres_map.device), gres_map[:-1, gres.split]) + ).cumsum(dim=0) else: - indices = torch.empty((max_uniques.item(),), dtype=inverse_pos.dtype) - - a.comm.Bcast(indices, root=max_pos) - - gres = local_data[indices.tolist()] - - inverse_indices = indices - if sorted: - raise ValueError( - "Sorting with axis != split is not supported yet. " - "See https://github.com/helmholtz-analytics/heat/issues/363" - ) + gres_map = torch.tensor(gres.gshape, device=inverse_pos.device) + gres_offsets = torch.tensor([0], device=gres_map.device) + lres = gres.larray + for p in range(unique_ranks): + origin = rank + p if rank + p < unique_ranks else rank + p - unique_ranks + incoming_offset = gres_offsets[origin] + # loop through unique elements, find matching position in data + for i, el in enumerate(lres.split(1, dim=0)): + counts = torch.zeros_like(local_data, dtype=torch.int8, device=local_data.device) + counts[torch.where(local_data == el)] = 1 + if lres.ndim > 1: + counts = torch.sum(counts, dim=tuple(range(lres.ndim))[1:]) + cond = torch.where(counts == el.numel()) + inverse_pos[cond] = i + incoming_offset + # if necessary, prepare to send lres to rank-1 and receive from rank+1 + if unique_ranks > 1: + dest_rank = rank - 1 if rank != 0 else size - 1 + gres.comm.Send(lres, dest_rank) + next_origin = origin + 1 if origin + 1 < unique_ranks else origin + 1 - unique_ranks + incoming_shape = gres_map[next_origin].tolist() + if incoming_shape != lres.shape: + lres = torch.empty( + incoming_shape, dtype=local_data.dtype, device=local_data.device + ) + recv_from_rank = rank + 1 if rank != size - 1 else 0 + gres.comm.Recv(lres, recv_from_rank) + inverse = factories.array(inverse_pos, is_split=0, device=gres.device) - if axis is not None: - # transpose matrix back + if axis is not None and axis != 0: + # transpose back to original dimensions gres = gres.transpose(0, axis) + if return_inverse: + inverse = inverse.transpose(0, axis) - split = split if a.split < len(gres.shape) else None - result = factories.array( - gres, dtype=a.dtype, device=a.device, comm=a.comm, split=split, is_split=is_split - ) - if split is not None: - result.resplit_(a.split) - - return_value = result if return_inverse: - return_value = [return_value, inverse_indices.to(a.device.torch_device)] + return (gres, inverse) - return return_value + return gres def vsplit(ary, indices_or_sections): @@ -3019,7 +2971,13 @@ def resplit(arr, axis=None): gathered = torch.empty( arr.shape, dtype=arr.dtype.torch_type(), device=arr.device.torch_device ) - counts, displs, _ = arr.comm.counts_displs_shape(arr.shape, arr.split) + if arr.is_balanced(): + counts, displs, _ = arr.comm.counts_displs_shape(arr.shape, arr.split) + else: + counts = arr.create_lshape_map()[arr.split] + displs = torch.cumsum( + torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 + ) arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split) new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype) return new_arr From 1d9f71fb1c5057d3274ffe537aa27f2ab9389934 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 10 Mar 2021 05:49:23 +0100 Subject: [PATCH 05/87] Debugging unique --- heat/core/manipulations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index f4362f4357..701c658ce5 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2772,6 +2772,7 @@ def unique(a, return_inverse=False, axis=None): data_max_lbytes = torch.prod(torch.tensor(data_max_lshape)) * a.larray.element_size() if gres.nbytes <= data_max_lbytes: print("RUNNING SPARSE UNIQUE") + print("DEBUGGING: gres.lshape = ", gres.lshape) # gather local uniques gres.resplit_(None) # final round of torch.unique From 77fb0d467f36bc24a24b49c6677a992aa6fafed6 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 10 Mar 2021 06:14:44 +0100 Subject: [PATCH 06/87] Fix error in counts, displs for unbalanced resplit_(None) --- heat/core/communication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/communication.py b/heat/core/communication.py index 88a837d98e..b852461ade 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -139,7 +139,7 @@ def chunk(self, shape, split, rank=None, w_size=None): def counts_displs_shape(self, shape, axis): """ - Calculates the item counts, displacements and output shape for a variable sized all-to-all MPI-call (e.g. + Calculates the item counts, displacements and output shape for a variable-sized all-to-all MPI-call (e.g. MPI_Alltoallv). The passed shape is regularly chunk along the given axis and for all nodes. Parameters From 63c96c42168b4608e7a32f7239466e6c91a77255 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 10 Mar 2021 06:18:13 +0100 Subject: [PATCH 07/87] Fix error in counts, displs for unbalanced resplit_(None) --- heat/core/dndarray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 9d21335fe7..32be10491b 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -2877,10 +2877,11 @@ def resplit_(self, axis=None): if self.is_balanced(): counts, displs, _ = self.comm.counts_displs_shape(self.shape, self.split) else: - counts = self.create_lshape_map()[self.split] + counts = self.create_lshape_map()[:, self.split] displs = torch.cumsum( torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 ) + counts, displs = tuple(counts.tolist()), tuple(displs.tolist()) self.comm.Allgatherv(self.__array, (gathered, counts, displs), recv_axis=self.split) self.__array = gathered self.__split = axis From a6620da16161bd98bc5ab8907b8015b16107a762 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 10 Mar 2021 06:21:01 +0100 Subject: [PATCH 08/87] Fix imbalanced resplit() --- heat/core/manipulations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 701c658ce5..e38d3ab8b6 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2772,7 +2772,6 @@ def unique(a, return_inverse=False, axis=None): data_max_lbytes = torch.prod(torch.tensor(data_max_lshape)) * a.larray.element_size() if gres.nbytes <= data_max_lbytes: print("RUNNING SPARSE UNIQUE") - print("DEBUGGING: gres.lshape = ", gres.lshape) # gather local uniques gres.resplit_(None) # final round of torch.unique @@ -2975,10 +2974,11 @@ def resplit(arr, axis=None): if arr.is_balanced(): counts, displs, _ = arr.comm.counts_displs_shape(arr.shape, arr.split) else: - counts = arr.create_lshape_map()[arr.split] + counts = arr.create_lshape_map()[:, arr.split] displs = torch.cumsum( torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 ) + counts, displs = tuple(counts.tolist()), tuple(displs.tolist()) arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split) new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype) return new_arr From 1fe8428b9d35af59aee113a9c3afc089c1c93322 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 10 Mar 2021 06:35:46 +0100 Subject: [PATCH 09/87] Debugging unique --- heat/core/manipulations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index e38d3ab8b6..21c5908024 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2797,6 +2797,7 @@ def unique(a, return_inverse=False, axis=None): # inverse indices if return_inverse: # allocate local tensors + print("DEBUGGING: gres is distributed: ", gres.is_distributed()) inverse_pos = torch.empty(inv_shape, dtype=torch.int64, device=local_data.device) unique_ranks = size if gres.is_distributed() else 1 if unique_ranks > 1: From 8dd0d82f82fd5cd6a28ccee43d81f1c17e0fc266 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 10 Mar 2021 06:45:24 +0100 Subject: [PATCH 10/87] Fix incoming_offset error in sparse unique --- heat/core/manipulations.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 21c5908024..13af4d1f01 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2810,8 +2810,11 @@ def unique(a, return_inverse=False, axis=None): gres_offsets = torch.tensor([0], device=gres_map.device) lres = gres.larray for p in range(unique_ranks): - origin = rank + p if rank + p < unique_ranks else rank + p - unique_ranks - incoming_offset = gres_offsets[origin] + if unique_ranks == 1: + incoming_offset = 0 + else: + origin = rank + p if rank + p < unique_ranks else rank + p - unique_ranks + incoming_offset = gres_offsets[origin] # loop through unique elements, find matching position in data for i, el in enumerate(lres.split(1, dim=0)): counts = torch.zeros_like(local_data, dtype=torch.int8, device=local_data.device) From c35c72a093f737a3dd3c534d2a52bdf782b6d953 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 10 Mar 2021 13:30:02 +0100 Subject: [PATCH 11/87] Updated documentation, fixed some split errors. --- heat/core/manipulations.py | 68 ++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 13af4d1f01..85a93592b0 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2682,34 +2682,48 @@ def stack(arrays, axis=0, out=None): return stacked -def unique(a, return_inverse=False, axis=None): +def unique(a, sort=True, return_inverse=False, axis=None): """ - Finds and returns the sorted unique elements of an array. + Returns the sorted unique elements of an array. - Works most effective if axis != a.split. + If `a` is distributed, and unique elements along a specific `axis` are required, + then `a` must be distributed along `axis`. Parameters ---------- a : ht.DNDarray - Input array where unique elements should be found. - sorted : bool, optional - Whether the found elements should be sorted before returning as output. - Warning: sorted is not working if 'axis != None and axis != a.split' - Default: False + Input array. + axis : int, optional + The axis to operate on. If None, `a` will be flattened. + sort : bool, optional + Sort the array in ascending order before finding the unique elements. + Set `sorted=False` only if `a` is already sorted (in whichever order). + Default: True return_inverse : bool, optional - Whether to also return the indices for where elements in the original input ended up in the returned - unique list. + Return the indices of the unique array (for the specified `axis`, if provided) + that can be used to reconstruct `a`. Default: False - axis : int, optional - Axis along which unique elements should be found. Default to None, which will return a one dimensional list of - unique values. Returns ------- - res : ht.DNDarray - Output array. The unique elements. Elements are distributed the same way as the input tensor. - inverse_indices : torch.tensor (optional) - If return_inverse is True, this tensor will hold the list of inverse indices + unique : ht.DNDarray + The sorted unique elements of `a`. Whether `unique` is distributed depends on the + size of the unique elements with respect to the (process-local) data. See Notes below. + inverse_indices : ht.DNDarray + The global indices to reconstruct the original (possibly distributed) array from `unique`. + `inverse_indices` is distributed like `a`. See Notes below on reconstructing the original array + from a distributed `unique` array. + + Notes + ----- + Distributed sorting is a communication-intensive operation. `ht.unique()` performs the unique/sort operation + locally if the collective size of the unique values, in bytes, does not exceed the size of + the process-local input `a`. In this case, the resulting `unique` will not be distributed (`unique.split=None`). + Otherwise, `unique` will be distributed along 0, if `axis` is specified, or along `a.split`, if `axis` is None. + + WARNING: `inverse_indices` will always be distributed like the original data, and contains the GLOBAL indices + to recreate the LOCAL portion of `a`. Before you reconstruct an array based on `unique[inverse_indices]`, make sure + that `unique` is local (for example with `unique.resplit_(axis=None)`, see `ht.resplit`). Examples -------- @@ -2771,21 +2785,20 @@ def unique(a, return_inverse=False, axis=None): _, data_max_lshape, _ = a.comm.chunk(a.gshape, a.split, rank=0) data_max_lbytes = torch.prod(torch.tensor(data_max_lshape)) * a.larray.element_size() if gres.nbytes <= data_max_lbytes: - print("RUNNING SPARSE UNIQUE") # gather local uniques gres.resplit_(None) # final round of torch.unique lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) else: - print("RUNNING DENSE UNIQUE") - # balance gres if needed - gres.balance_() - # global sort - gres, sorted_gindices = sort(gres, axis=unique_axis) - # second local unique - lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) - gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) + if sort: + # balance gres if needed + gres.balance_() + # global sort + gres, sorted_gindices = sort(gres, axis=unique_axis) + # second local unique + lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) + gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) # get rid of doubles at the edges gres.get_halo(1) if gres.halo_prev is not None and (gres.halo_prev == lres[0]).all(): @@ -2797,7 +2810,6 @@ def unique(a, return_inverse=False, axis=None): # inverse indices if return_inverse: # allocate local tensors - print("DEBUGGING: gres is distributed: ", gres.is_distributed()) inverse_pos = torch.empty(inv_shape, dtype=torch.int64, device=local_data.device) unique_ranks = size if gres.is_distributed() else 1 if unique_ranks > 1: @@ -2835,7 +2847,7 @@ def unique(a, return_inverse=False, axis=None): ) recv_from_rank = rank + 1 if rank != size - 1 else 0 gres.comm.Recv(lres, recv_from_rank) - inverse = factories.array(inverse_pos, is_split=0, device=gres.device) + inverse = factories.array(inverse_pos, is_split=a.split, device=gres.device) if axis is not None and axis != 0: # transpose back to original dimensions From ae56b86d031d0657c8e776b22a9b76e056769bb0 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 12 Mar 2021 14:14:25 +0100 Subject: [PATCH 12/87] Skip non-populated ranks in imbalanced gethalo --- heat/core/dndarray.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 32be10491b..9e71e8c7de 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -430,9 +430,10 @@ def get_halo(self, halo_size): if not self.balanced: populated_ranks = torch.nonzero(lshape_map[:, 0]).squeeze().tolist() - next_rank = populated_ranks.index(rank) + 1 - prev_rank = populated_ranks.index(rank) - 1 - last_rank = populated_ranks[-1] + if rank in populated_ranks: + next_rank = populated_ranks.index(rank) + 1 + prev_rank = populated_ranks.index(rank) - 1 + last_rank = populated_ranks[-1] else: next_rank = rank + 1 prev_rank = rank - 1 From fa597a815a77f709bdcb61747214fc6c9fe3024e Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 17 Mar 2021 09:00:25 +0100 Subject: [PATCH 13/87] Merge changes to reduce_op --- CHANGELOG.md | 1 + heat/core/_operations.py | 13 ++++++++----- heat/core/tests/test_arithmetics.py | 9 +++++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0219e17184..083652b7d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ # v0.5.2 - [#706](https://github.com/helmholtz-analytics/heat/pull/706) Bug fix: prevent `__setitem__`, `__getitem__` from modifying key in place +- [#744](https://github.com/helmholtz-analytics/heat/pull/744) Fix split semantics for reduction operations # v0.5.1 diff --git a/heat/core/_operations.py b/heat/core/_operations.py index 8932ede239..8a2fb60900 100644 --- a/heat/core/_operations.py +++ b/heat/core/_operations.py @@ -416,11 +416,14 @@ def __reduce_op(x, partial_op, reduction_op, neutral=None, **kwargs): if len(lshape_losedim) > 0: partial = partial.reshape(lshape_losedim) # perform a reduction operation in case the tensor is distributed across the reduction axis - if x.split is not None and (axis is None or (x.split in axis)): - split = None - balanced = True - if x.comm.is_distributed(): - x.comm.Allreduce(MPI.IN_PLACE, partial, reduction_op) + if x.split is not None: + if axis is None or (x.split in axis): + split = None + if x.comm.is_distributed(): + x.comm.Allreduce(MPI.IN_PLACE, partial, reduction_op) + elif axis is not None: + down_dims = len(tuple(dim for dim in axis if dim < x.split)) + split -= down_dims ARG_OPS = [statistics.MPI_ARGMAX, statistics.MPI_ARGMIN] arg_op = False diff --git a/heat/core/tests/test_arithmetics.py b/heat/core/tests/test_arithmetics.py index 54942553b5..bbeb21abad 100644 --- a/heat/core/tests/test_arithmetics.py +++ b/heat/core/tests/test_arithmetics.py @@ -478,6 +478,15 @@ def test_prod(self): self.assertEqual(no_axis_prod.split, None) self.assertEqual(no_axis_prod.larray, 134217728) + # check split semantics + shape_noaxis_split_axis = ht.ones((3, 3, 3), split=2) + split_axis_sum = shape_noaxis_split_axis.sum(axis=1) + self.assertIsInstance(split_axis_sum, ht.DNDarray) + self.assertEqual(split_axis_sum.shape, (3, 3)) + self.assertEqual(split_axis_sum.dtype, ht.float32) + self.assertEqual(split_axis_sum._DNDarray__array.dtype, torch.float32) + self.assertEqual(split_axis_sum.split, 1) + out_noaxis = ht.zeros((1,)) ht.prod(shape_noaxis, out=out_noaxis) self.assertEqual(out_noaxis.larray, 134217728) From 39050f0281202d742be98c64f314c575f68afe54 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 20 Mar 2021 14:03:26 +0100 Subject: [PATCH 14/87] Modify tests for imbalanced gethalo() --- heat/core/tests/test_dndarray.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 7174647237..fe03daf2a4 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -123,12 +123,6 @@ def test_gethalo(self): # exception for too large halos with self.assertRaises(ValueError): data.get_halo(4) - # exception on non balanced tensor - with self.assertRaises(RuntimeError): - data_nobalance = ht.array( - torch.empty(((data.comm.rank + 1) * 2, 3, 4)), is_split=0, device=data.device - ) - data_nobalance.get_halo(1) # test no data on process data_np = np.arange(2 * 12).reshape(2, 12) data = ht.array(data_np, split=0) From e57235af4a5372f2ec564d6ef44e3e867850ffec Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 20 Mar 2021 14:06:12 +0100 Subject: [PATCH 15/87] Generalize sort() implementation into helper function _pivot_sorting for both sort() and unique() --- heat/core/manipulations.py | 547 ++++++++++++++++++++++--------------- 1 file changed, 321 insertions(+), 226 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 85a93592b0..a478f0596a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1890,116 +1890,102 @@ def shape(a): return a.gshape -def sort(a, axis=None, descending=False, out=None): - """ - Sorts the elements of the DNDarray a along the given dimension (by default in ascending order) by their value. - - The sorting is not stable which means that equal elements in the result may have a different ordering than in the - original array. - - Sorting where `axis == a.split` needs a lot of communication between the processes of MPI. - - Parameters - ---------- - a : ht.DNDarray - Input array to be sorted. - axis : int, optional - The dimension to sort along. - Default is the last axis. - descending : bool, optional - If set to true values are sorted in descending order - Default is false - out : ht.DNDarray or None, optional - A location in which to store the results. If provided, it must have a broadcastable shape. If not provided - or set to None, a fresh tensor is allocated. - - Returns - ------- - values : ht.DNDarray - The sorted local results. - indices - The indices of the elements in the original data - - Raises - ------ - ValueError - If the axis is not in range of the axes. - - Examples - -------- - >>> x = ht.array([[4, 1], [2, 3]], split=0) - >>> x.shape - (1, 2) - (1, 2) - - >>> y = ht.sort(x, axis=0) - >>> y - (array([[2, 1]], array([[1, 0]])) - (array([[4, 3]], array([[0, 1]])) - - >>> ht.sort(x, descending=True) - (array([[4, 1]], array([[0, 1]])) - (array([[3, 2]], array([[1, 0]])) - """ - # default: using last axis - if axis is None: - axis = len(a.shape) - 1 - - stride_tricks.sanitize_axis(a.shape, axis) - - if a.split is None or axis != a.split: - # sorting is not affected by split -> we can just sort along the axis - final_result, final_indices = torch.sort(a.larray, dim=axis, descending=descending) - - else: - # sorting is affected by split, processes need to communicate results - # transpose so we can work along the 0 axis - transposed = a.larray.transpose(axis, 0) - local_sorted, local_indices = torch.sort(transposed, dim=0, descending=descending) - - size = a.comm.Get_size() - rank = a.comm.Get_rank() +def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): + size = a.comm.Get_size() + rank = a.comm.Get_rank() + transposed = a.larray.transpose(axis, 0) + if sort_op is torch.sort: counts, disp, _ = a.comm.counts_displs_shape(a.gshape, axis=axis) - + local_sorted, local_indices = sort_op(transposed, dim=0, descending=descending) actual_indices = local_indices.to(dtype=local_sorted.dtype) + disp[rank] - - length = local_sorted.size()[0] - - # Separate the sorted tensor into size + 1 equal length partitions - partitions = [x * length // (size + 1) for x in range(1, size + 1)] - local_pivots = ( - local_sorted[partitions] - if counts[rank] - else torch.empty((0,) + local_sorted.size()[1:], dtype=local_sorted.dtype) + elif sort_op is torch.unique: + local_sorted = sort_op(transposed, dim=0, **kwargs)[0] + lshape_map = torch.empty( + (size, local_sorted.ndim), dtype=torch.int64, device=local_sorted.device ) + a.comm.Allgather(torch.tensor(local_sorted.shape), lshape_map) + counts = lshape_map[:, 0] + displs = torch.cumsum( + torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 + ) + counts, displs = tuple(counts.tolist()), tuple(displs.tolist()) + else: + raise ValueError( + "sorting operation can be torch.sort or torch.unique, was {}".format(sort_op) + ) + unique_along_axis = True if sort_op is torch.unique and axis is not None else False - # Only processes with elements should share their pivots - gather_counts = [int(x > 0) * size for x in counts] - gather_displs = (0,) + tuple(np.cumsum(gather_counts[:-1])) - - pivot_dim = list(transposed.size()) - pivot_dim[0] = size * sum([1 for x in counts if x > 0]) + length = local_sorted.size()[0] - # share the local pivots with root process - pivot_buffer = torch.empty( - pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device + # Separate the sorted tensor into size + 1 equal length partitions + partitions = [x * length // (size + 1) for x in range(1, size + 1)] + print("DEBUGGING: partitions = ", partitions) + local_pivots = ( + local_sorted[partitions] + if counts[rank] + else torch.empty((0,) + local_sorted.size()[1:], dtype=local_sorted.dtype) + ) + print("DEBUGGING: local_pivots = ", local_pivots) + + # Only processes with elements should share their pivots + gather_counts = [int(x > 0) * size for x in counts] + gather_displs = (0,) + tuple(np.cumsum(gather_counts[:-1])) + # print("DEBUGGING: gather_counts, gather_displs = ", gather_counts, gather_displs) + pivot_dim = list(transposed.size()) + pivot_dim[0] = size * sum([1 for x in counts if x > 0]) + + # share the local pivots with root process + pivot_buffer = torch.empty(pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device) + # print("DEBUGGING: local_pivots.shape, local_pivots.dtype = ", local_pivots.shape, local_pivots.dtype) + # print("DEBUGGING: pivot_buffer.shape, pivot_buffer.dtype = ", pivot_buffer.shape, pivot_buffer.dtype) + a.comm.Gatherv(local_pivots, (pivot_buffer, gather_counts, gather_displs), root=0) + + pivot_dim[0] = size - 1 + global_pivots = torch.empty(pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device) + + # root process creates new pivots and shares them with other processes + if rank == 0: + print("DEBUGGING: pivot_buffer = ", pivot_buffer) + if sort_op is torch.sort: + sorted_pivots, _ = sort_op(pivot_buffer, dim=0, descending=descending) + else: + sorted_pivots = sort_op(pivot_buffer, dim=0, **kwargs)[0] + print("DEBUGGING: sorted_pivots = ", sorted_pivots) + length = sorted_pivots.size()[0] + global_partitions = [x * length // size for x in range(1, size)] + global_pivots = sorted_pivots[global_partitions] + + a.comm.Bcast(global_pivots, root=0) + print("DEBUGGING: global_pivots = ", global_pivots) + print("DEBUGGING: local_sorted = ", local_sorted) + # special case: unique with axis not None + if unique_along_axis: + # find position of global pivots in local sorted uniques + local_sorted, local_inv = torch.cat((local_sorted, global_pivots), dim=0).unique( + dim=0, + sorted=kwargs.get("sorted") if kwargs.get("sorted") else True, + return_inverse=True, ) - a.comm.Gatherv(local_pivots, (pivot_buffer, gather_counts, gather_displs), root=0) - - pivot_dim[0] = size - 1 - global_pivots = torch.empty( - pivot_dim, dtype=a.dtype.torch_type(), device=a.device.torch_device + # Use the inverse indices of the global pivots to work out the local partition slices + local_slices = torch.zeros(size + 1, dtype=torch.int64, device=local_sorted.device) + local_slices[1:-1] = local_inv[-global_pivots.shape[0] :] + 1 + local_slices[-1] = torch.tensor([local_sorted.shape[0]]) + # how many rows will be sent and received where + send_matrix = torch.tensor( + [local_slices[i] - local_slices[i - 1] for i in range(1, size + 1)] ) + recv_matrix = torch.zeros(size, dtype=torch.int64, device=local_sorted.device) + a.comm.Alltoall(send_matrix, recv_matrix) + for matrix in [send_matrix, recv_matrix]: + matrix = matrix.reshape(1, matrix.numel()) - # root process creates new pivots and shares them with other processes - if rank == 0: - sorted_pivots, _ = torch.sort(pivot_buffer, descending=descending, dim=0) - length = sorted_pivots.size()[0] - global_partitions = [x * length // size for x in range(1, size)] - global_pivots = sorted_pivots[global_partitions] - - a.comm.Bcast(global_pivots, root=0) + print("DEBUGGING: after alltoall: send_matrix = ", send_matrix) + print("DEBUGGING: after alltoall: recv_matrix = ", recv_matrix) + scounts = send_matrix + rcounts = recv_matrix + shape = (recv_matrix.sum(dim=0),) + local_sorted.shape[1:] + else: lt_partitions = torch.empty((size,) + local_sorted.shape, dtype=torch.int64) last = torch.zeros_like(local_sorted, dtype=torch.int64) comp_op = torch.gt if descending else torch.lt @@ -2012,10 +1998,8 @@ def sort(a, axis=None, descending=False, out=None): lt_partitions[idx] = lt last = lt lt_partitions[size - 1] = torch.ones_like(local_sorted, dtype=last.dtype) - last - # Matrix holding information how many values will be sent where local_partitions = torch.sum(lt_partitions, dim=1) - partition_matrix = torch.empty_like(local_partitions) a.comm.Allreduce(local_partitions, partition_matrix, op=MPI.SUM) @@ -2033,112 +2017,129 @@ def sort(a, axis=None, descending=False, out=None): a.comm.Alltoall(send_matrix, recv_matrix) + print("DEBUGGING: after alltoall: send_matrix = ", send_matrix) + print("DEBUGGING: after alltoall: recv_matrix = ", recv_matrix) scounts = local_partitions rcounts = recv_matrix shape = (partition_matrix[rank].max(),) + transposed.size()[1:] - first_result = torch.empty(shape, dtype=local_sorted.dtype) + + first_result = torch.empty(shape, dtype=local_sorted.dtype) + if sort_op is torch.sort: first_indices = torch.empty_like(first_result) - # Iterate through one layer and send values with alltoallv - for idx in np.ndindex(local_sorted.shape[1:]): + # Iterate through one layer and send values with alltoallv + if unique_along_axis: + iterator = range(1) + else: + iterator = np.ndindex(local_sorted.shape[1:]) + + for idx in iterator: + if unique_along_axis: + idx_slice = [slice(None)] # + [slice(ind, ind + 1) for ind in range(idx)] + else: idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] - send_count = scounts[idx_slice].reshape(-1).tolist() - send_disp = [0] + list(np.cumsum(send_count[:-1])) - s_val = local_sorted[idx_slice].clone() - s_ind = actual_indices[idx_slice].clone().to(dtype=local_sorted.dtype) + send_count = scounts[idx_slice].reshape(-1).tolist() + send_disp = [0] + list(np.cumsum(send_count[:-1])) + s_val = local_sorted[idx_slice].clone() - recv_count = rcounts[idx_slice].reshape(-1).tolist() - recv_disp = [0] + list(np.cumsum(recv_count[:-1])) - rcv_length = rcounts[idx_slice].sum().item() - r_val = torch.empty((rcv_length,) + s_val.shape[1:], dtype=local_sorted.dtype) - r_ind = torch.empty_like(r_val) + recv_count = rcounts[idx_slice].reshape(-1).tolist() + recv_disp = [0] + list(np.cumsum(recv_count[:-1])) + rcv_length = rcounts[idx_slice].sum().item() + r_val = torch.empty((rcv_length,) + s_val.shape[1:], dtype=local_sorted.dtype) + a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) + first_result[idx_slice][:rcv_length] = r_val - a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) + if sort_op is torch.sort: + s_ind = actual_indices[idx_slice].clone().to(dtype=local_sorted.dtype) + r_ind = torch.empty_like(r_val) a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) - first_result[idx_slice][:rcv_length] = r_val first_indices[idx_slice][:rcv_length] = r_ind - # The process might not have the correct number of values therefore the tensors need to be rebalanced - send_vec = torch.zeros(local_sorted.shape[1:] + (size, size), dtype=torch.int64) - target_cumsum = np.cumsum(counts) - for idx in np.ndindex(local_sorted.shape[1:]): - idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] - current_counts = partition_matrix[idx_slice].reshape(-1).tolist() - current_cumsum = list(np.cumsum(current_counts)) - for proc in range(size): - if current_cumsum[proc] > target_cumsum[proc]: - # process has to many values which will be sent to higher ranks - first = next(i for i in range(size) if send_vec[idx][:, i].sum() < counts[i]) - last = next( - i - for i in range(size + 1) - if i == size or current_cumsum[proc] < target_cumsum[i] - ) - sent = 0 - for i, x in enumerate(counts[first:last]): - # Each following process gets as many elements as it needs - amount = int(x - send_vec[idx][:, first + i].sum()) - send_vec[idx][proc][first + i] = amount - current_counts[first + i] += amount - sent += amount - if last < size: - # Send all left over values to the highest last process - amount = partition_matrix[proc][idx] - send_vec[idx][proc][last] = int(amount - sent) - current_counts[last] += int(amount - sent) - elif current_cumsum[proc] < target_cumsum[proc]: - # process needs values from higher rank - first = ( - 0 - if proc == 0 - else next( - i for i, x in enumerate(current_cumsum) if target_cumsum[proc - 1] < x - ) - ) - last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) - for i, x in enumerate(partition_matrix[idx_slice][first:last]): - # Taking as many elements as possible from each following process - send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum()) - current_counts[first + i] = 0 - # Taking just enough elements from the last element to fill the current processes tensor - send_vec[idx][last][proc] = int(target_cumsum[proc] - current_cumsum[last - 1]) - current_counts[last] -= int(target_cumsum[proc] - current_cumsum[last - 1]) - else: - # process doesn't need more values - send_vec[idx][proc][proc] = ( - partition_matrix[proc][idx] - send_vec[idx][proc].sum() + if sort_op is torch.unique: + # early out for unique + return first_result + + # The process might not have the correct number of values therefore the tensors need to be rebalanced + send_vec = torch.zeros(local_sorted.shape[1:] + (size, size), dtype=torch.int64) + target_cumsum = np.cumsum(counts) + for idx in np.ndindex(local_sorted.shape[1:]): + idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] + current_counts = partition_matrix[idx_slice].reshape(-1).tolist() + current_cumsum = list(np.cumsum(current_counts)) + for proc in range(size): + # process has to many values which will be sent to higher ranks + if current_cumsum[proc] > target_cumsum[proc]: + first = next(i for i in range(size) if send_vec[idx][:, i].sum() < counts[i]) + last = next( + i + for i in range(size + 1) + if i == size or current_cumsum[proc] < target_cumsum[i] + ) + sent = 0 + for i, x in enumerate(counts[first:last]): + # Each following process gets as many elements as it needs + amount = int(x - send_vec[idx][:, first + i].sum()) + send_vec[idx][proc][first + i] = amount + current_counts[first + i] += amount + sent += amount + if last < size: + # Send all left over values to the highest last process + amount = partition_matrix[proc][idx] + send_vec[idx][proc][last] = int(amount - sent) + current_counts[last] += int(amount - sent) + elif current_cumsum[proc] < target_cumsum[proc]: + # process needs values from higher rank + first = ( + 0 + if proc == 0 + else next( + i for i, x in enumerate(current_cumsum) if target_cumsum[proc - 1] < x ) - current_counts[proc] = counts[proc] - current_cumsum = list(np.cumsum(current_counts)) + ) + # print("DEBUGGING: current_cumsum, target_cumsum[proc] = ", current_cumsum, target_cumsum[proc]) + last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) + for i, x in enumerate(partition_matrix[idx_slice][first:last]): + # Taking as many elements as possible from each following process + send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum()) + current_counts[first + i] = 0 + # Taking just enough elements from the last element to fill the current processes tensor + send_vec[idx][last][proc] = int(target_cumsum[proc] - current_cumsum[last - 1]) + current_counts[last] -= int(target_cumsum[proc] - current_cumsum[last - 1]) + else: + # process doesn't need more values + send_vec[idx][proc][proc] = partition_matrix[proc][idx] - send_vec[idx][proc].sum() + current_counts[proc] = counts[proc] + current_cumsum = list(np.cumsum(current_counts)) - # Iterate through one layer again to create the final balanced local tensors - second_result = torch.empty_like(local_sorted) + # Iterate through one layer again to create the final balanced local tensors + second_result = torch.empty_like(local_sorted) + if sort_op is torch.sort: second_indices = torch.empty_like(second_result) - for idx in np.ndindex(local_sorted.shape[1:]): - idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] + for idx in np.ndindex(local_sorted.shape[1:]): + idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] - send_count = send_vec[idx][rank] - send_disp = [0] + list(np.cumsum(send_count[:-1])) + send_count = send_vec[idx][rank] + send_disp = [0] + list(np.cumsum(send_count[:-1])) - recv_count = send_vec[idx][:, rank] - recv_disp = [0] + list(np.cumsum(recv_count[:-1])) + recv_count = send_vec[idx][:, rank] + recv_disp = [0] + list(np.cumsum(recv_count[:-1])) - end = partition_matrix[rank][idx] - s_val, indices = first_result[0:end][idx_slice].sort(descending=descending, dim=0) - s_ind = first_indices[0:end][idx_slice][indices].reshape_as(s_val) + end = partition_matrix[rank][idx] + s_val, indices = first_result[0:end][idx_slice].sort(descending=descending, dim=0) + r_val = torch.empty((counts[rank],) + s_val.shape[1:], dtype=local_sorted.dtype) + a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) + second_result[idx_slice] = r_val - r_val = torch.empty((counts[rank],) + s_val.shape[1:], dtype=local_sorted.dtype) + if sort_op is torch.sort: + s_ind = first_indices[0:end][idx_slice][indices].reshape_as(s_val) r_ind = torch.empty_like(r_val) - - a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) - - second_result[idx_slice] = r_val second_indices[idx_slice] = r_ind - second_result, tmp_indices = second_result.sort(dim=0, descending=descending) + if sort_op is torch.sort: + second_result, tmp_indices = sort_op(second_result, dim=0, descending=descending) final_result = second_result.transpose(0, axis) final_indices = torch.empty_like(second_indices) # Update the indices in case the ordering changed during the last sort @@ -2146,6 +2147,77 @@ def sort(a, axis=None, descending=False, out=None): val = tmp_indices[idx] final_indices[idx] = second_indices[val.item()][idx[1:]] final_indices = final_indices.transpose(0, axis) + return final_result, final_indices + else: + second_result = sort_op(second_result, dim=0, **kwargs)[0] + final_result = second_result.transpose(0, axis) + return final_result + + +def sort(a, axis=None, descending=False, out=None): + """ + Sorts the elements of the DNDarray a along the given dimension (by default in ascending order) by their value. + + The sorting is not stable which means that equal elements in the result may have a different ordering than in the + original array. + + Sorting where `axis == a.split` needs a lot of communication between the processes of MPI. + + Parameters + ---------- + a : ht.DNDarray + Input array to be sorted. + axis : int, optional + The dimension to sort along. + Default is the last axis. + descending : bool, optional + If set to true values are sorted in descending order + Default is false + out : ht.DNDarray or None, optional + A location in which to store the results. If provided, it must have a broadcastable shape. If not provided + or set to None, a fresh tensor is allocated. + + Returns + ------- + values : ht.DNDarray + The sorted local results. + indices + The indices of the elements in the original data + + Raises + ------ + ValueError + If the axis is not in range of the axes. + + Examples + -------- + >>> x = ht.array([[4, 1], [2, 3]], split=0) + >>> x.shape + (1, 2) + (1, 2) + + >>> y = ht.sort(x, axis=0) + >>> y + (array([[2, 1]], array([[1, 0]])) + (array([[4, 3]], array([[0, 1]])) + + >>> ht.sort(x, descending=True) + (array([[4, 1]], array([[0, 1]])) + (array([[3, 2]], array([[1, 0]])) + """ + # default: using last axis + if axis is None: + axis = len(a.shape) - 1 + + stride_tricks.sanitize_axis(a.shape, axis) + + if a.split is None or axis != a.split: + # sorting is not affected by split -> we can just sort along the axis + final_result, final_indices = torch.sort(a.larray, dim=axis, descending=descending) + + else: + final_result, final_indices = _pivot_sorting(a, axis, torch.sort, descending=descending) + return_indices = factories.array( final_indices, dtype=dndarray.types.int32, is_split=a.split, device=a.device, comm=a.comm ) @@ -2682,7 +2754,7 @@ def stack(arrays, axis=0, out=None): return stacked -def unique(a, sort=True, return_inverse=False, axis=None): +def unique(a, sorted=True, return_inverse=False, axis=None): """ Returns the sorted unique elements of an array. @@ -2695,7 +2767,7 @@ def unique(a, sort=True, return_inverse=False, axis=None): Input array. axis : int, optional The axis to operate on. If None, `a` will be flattened. - sort : bool, optional + sorted : bool, optional Sort the array in ascending order before finding the unique elements. Set `sorted=False` only if `a` is already sorted (in whichever order). Default: True @@ -2707,23 +2779,26 @@ def unique(a, sort=True, return_inverse=False, axis=None): Returns ------- unique : ht.DNDarray - The sorted unique elements of `a`. Whether `unique` is distributed depends on the - size of the unique elements with respect to the (process-local) data. See Notes below. + The sorted unique elements of `a`. Whether `unique` is distributed depends + on the size of the unique elements with respect to the (process-local) data. + See Notes below. inverse_indices : ht.DNDarray - The global indices to reconstruct the original (possibly distributed) array from `unique`. - `inverse_indices` is distributed like `a`. See Notes below on reconstructing the original array - from a distributed `unique` array. + The global indices to reconstruct the original (possibly distributed) array + from `unique`. `inverse_indices` is distributed like `a`. See Notes below + on reconstructing the original array from a distributed `unique` array. Notes ----- - Distributed sorting is a communication-intensive operation. `ht.unique()` performs the unique/sort operation - locally if the collective size of the unique values, in bytes, does not exceed the size of - the process-local input `a`. In this case, the resulting `unique` will not be distributed (`unique.split=None`). - Otherwise, `unique` will be distributed along 0, if `axis` is specified, or along `a.split`, if `axis` is None. + The resulting `unique` will not be distributed (`unique.split=None`) if the collective + size of the unique values, in bytes, does not exceed a certain threshold + (arbitrarily defined as the size of the process-local input `a`). Otherwise, + `unique` will be distributed along 0, if `axis` is specified, or along `a.split`, + if `axis` is None. - WARNING: `inverse_indices` will always be distributed like the original data, and contains the GLOBAL indices - to recreate the LOCAL portion of `a`. Before you reconstruct an array based on `unique[inverse_indices]`, make sure - that `unique` is local (for example with `unique.resplit_(axis=None)`, see `ht.resplit`). + WARNING: `inverse_indices` will always be distributed like the original data + (if `axis is None`) or along 0, and contains the GLOBAL indices to recreate the + LOCAL portion of `a`. Before reconstructing an array based on `unique[inverse_indices]`, + make sure that `unique` is local (with `unique.resplit_(axis=None)`, see `ht.resplit`). Examples -------- @@ -2739,8 +2814,10 @@ def unique(a, sort=True, return_inverse=False, axis=None): array([[2, 3], [3, 1]]) """ - if not a.is_distributed: - torch_output = torch.unique(a.larray, sorted=True, return_inverse=return_inverse, dim=axis) + if not a.is_distributed(): + torch_output = torch.unique( + a.larray, sorted=sorted, return_inverse=return_inverse, dim=axis + ) if isinstance(torch_output, tuple): heat_output = tuple( factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output @@ -2757,6 +2834,12 @@ def unique(a, sort=True, return_inverse=False, axis=None): unique_axis = None if axis is not None: + if axis != a.split: + raise ValueError( + "Operation axis must match distribution axis of the array: axis is {}, array.split is {}".format( + axis, a.split + ) + ) if axis != 0: # transpose so we can work along the 0 axis local_data = local_data.transpose(0, axis) @@ -2777,7 +2860,7 @@ def unique(a, sort=True, return_inverse=False, axis=None): inv_shape = [0] lres = torch.empty(res_shape, dtype=a.dtype.torch_type()) else: - lres = torch.unique(local_data, sorted=True, return_inverse=False, dim=unique_axis) + lres = torch.unique(local_data, sorted=sorted, return_inverse=False, dim=unique_axis) inv_shape = local_data.shape if axis is None else (local_data.shape[unique_axis],) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) @@ -2788,29 +2871,31 @@ def unique(a, sort=True, return_inverse=False, axis=None): # gather local uniques gres.resplit_(None) # final round of torch.unique - lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) + lres = torch.unique(gres.larray, sorted=sorted, dim=unique_axis) gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) else: - if sort: - # balance gres if needed - gres.balance_() - # global sort - gres, sorted_gindices = sort(gres, axis=unique_axis) - # second local unique - lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) - gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) - # get rid of doubles at the edges - gres.get_halo(1) - if gres.halo_prev is not None and (gres.halo_prev == lres[0]).all(): - lres = lres[1:] + # print("DEBUGGING: before sorting: gres.larray = ", gres.larray ) + # balance gres if needed + gres.balance_() + # global sort + # print("DEBUGGING: before sorting after balancing: gres.larray = ", gres.larray ) + # gres, sorted_gindices = sort(gres, axis=unique_axis) + lres = _pivot_sorting(gres, 0, torch.unique, sorted=sorted, return_inverse=True) + print("DEBUGGING: after pivot_sorting: lres = ", lres) + # print("DEBUGGING: after sorting before unique: gres.larray = ", gres.larray ) + # second local unique + lres = torch.unique(lres, sorted=sorted, dim=unique_axis) + print("DEBUGGING: after second unique: lres = ", lres) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) + # print("DEBUGGING: after pivot_sorting unique: gres.larray = ", gres.larray ) gres.balance_() - lres = gres.larray # inverse indices if return_inverse: - # allocate local tensors - inverse_pos = torch.empty(inv_shape, dtype=torch.int64, device=local_data.device) + # allocate local tensors and global DNDarray + inverse = torch.empty(inv_shape, dtype=torch.int64, device=local_data.device) + global_inverse = factories.array(inverse, is_split=a.split, device=gres.device) + unique_ranks = size if gres.is_distributed() else 1 if unique_ranks > 1: gres_map = gres.create_lshape_map() @@ -2818,23 +2903,30 @@ def unique(a, sort=True, return_inverse=False, axis=None): (torch.tensor([0], device=gres_map.device), gres_map[:-1, gres.split]) ).cumsum(dim=0) else: - gres_map = torch.tensor(gres.gshape, device=inverse_pos.device) + gres_map = torch.tensor(gres.gshape, device=inverse.device) gres_offsets = torch.tensor([0], device=gres_map.device) lres = gres.larray + print("DEBUGGING: after balancing: lres = ", lres) + + gres_recv = None for p in range(unique_ranks): if unique_ranks == 1: incoming_offset = 0 else: origin = rank + p if rank + p < unique_ranks else rank + p - unique_ranks incoming_offset = gres_offsets[origin] + if gres_recv: + gres_recv.Wait() # loop through unique elements, find matching position in data - for i, el in enumerate(lres.split(1, dim=0)): + for i, el in enumerate(lres): counts = torch.zeros_like(local_data, dtype=torch.int8, device=local_data.device) + print("DEBUGGING: el = ", el) + # print("DEBUGGING: local_data = ", local_data) counts[torch.where(local_data == el)] = 1 if lres.ndim > 1: counts = torch.sum(counts, dim=tuple(range(lres.ndim))[1:]) cond = torch.where(counts == el.numel()) - inverse_pos[cond] = i + incoming_offset + global_inverse.larray[cond] = i + incoming_offset # if necessary, prepare to send lres to rank-1 and receive from rank+1 if unique_ranks > 1: dest_rank = rank - 1 if rank != 0 else size - 1 @@ -2842,21 +2934,24 @@ def unique(a, sort=True, return_inverse=False, axis=None): next_origin = origin + 1 if origin + 1 < unique_ranks else origin + 1 - unique_ranks incoming_shape = gres_map[next_origin].tolist() if incoming_shape != lres.shape: + print( + "DEBUGGING: CHANGING LRES SHAPE!! incoming_shape, lres.shape = ", + incoming_shape, + lres.shape, + ) lres = torch.empty( incoming_shape, dtype=local_data.dtype, device=local_data.device ) recv_from_rank = rank + 1 if rank != size - 1 else 0 - gres.comm.Recv(lres, recv_from_rank) - inverse = factories.array(inverse_pos, is_split=a.split, device=gres.device) + gres_recv = gres.comm.Recv(lres, recv_from_rank) + print("DEBUGGING: rank, p, lres = ", rank, p, lres) if axis is not None and axis != 0: # transpose back to original dimensions gres = gres.transpose(0, axis) - if return_inverse: - inverse = inverse.transpose(0, axis) if return_inverse: - return (gres, inverse) + return (gres, global_inverse) return gres From 20a99301c0911619418489187817b765ec3604b3 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 20 Mar 2021 14:10:14 +0100 Subject: [PATCH 16/87] Update test_unique based on new distributed implementation --- heat/core/tests/test_manipulations.py | 170 +++++++++++++++----------- 1 file changed, 100 insertions(+), 70 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index ed7882b6d5..76546434a7 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2907,76 +2907,106 @@ def test_topk(self): def test_unique(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank - torch_array = torch.arange(size, dtype=torch.int32, device=self.device.torch_device).expand( - size, size - ) - split_zero = ht.array(torch_array, split=0) - - exp_axis_none = ht.array([rank], dtype=ht.int32) - res = split_zero.unique(sorted=True) - self.assertTrue((res.larray == exp_axis_none.larray).all()) - - exp_axis_zero = ht.arange(size, dtype=ht.int32).expand_dims(0) - res = ht.unique(split_zero, sorted=True, axis=0) - self.assertTrue((res.larray == exp_axis_zero.larray).all()) - - exp_axis_one = ht.array([rank], dtype=ht.int32).expand_dims(0) - split_zero_transposed = ht.array(torch_array.transpose(0, 1), split=0) - res = ht.unique(split_zero_transposed, sorted=False, axis=1) - self.assertTrue((res.larray == exp_axis_one.larray).all()) - - split_one = ht.array(torch_array, dtype=ht.int32, split=1) - - exp_axis_none = ht.arange(size, dtype=ht.int32) - res = ht.unique(split_one, sorted=True) - self.assertTrue((res.larray == exp_axis_none.larray).all()) - - exp_axis_zero = ht.array([rank], dtype=ht.int32).expand_dims(0) - res = ht.unique(split_one, sorted=False, axis=0) - self.assertTrue((res.larray == exp_axis_zero.larray).all()) - - exp_axis_one = ht.array([rank] * size, dtype=ht.int32).expand_dims(1) - res = ht.unique(split_one, sorted=True, axis=1) - self.assertTrue((res.larray == exp_axis_one.larray).all()) - - torch_array = torch.tensor( - [[1, 2], [2, 3], [1, 2], [2, 3], [1, 2]], - dtype=torch.int32, - device=self.device.torch_device, - ) - data = ht.array(torch_array, split=0) - - res, inv = ht.unique(data, return_inverse=True, axis=0) - _, exp_inv = torch_array.unique(dim=0, return_inverse=True, sorted=True) - self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) - - res, inv = ht.unique(data, return_inverse=True, axis=1) - _, exp_inv = torch_array.unique(dim=1, return_inverse=True, sorted=True) - self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) - - torch_array = torch.tensor( - [[1, 1, 2], [1, 2, 2], [2, 1, 2], [1, 3, 2], [0, 1, 2]], - dtype=torch.int32, - device=self.device.torch_device, - ) - exp_res, exp_inv = torch_array.unique(return_inverse=True, sorted=True) - - data_split_none = ht.array(torch_array) - res = ht.unique(data_split_none, sorted=True) - self.assertIsInstance(res, ht.DNDarray) - self.assertEqual(res.split, None) - self.assertEqual(res.dtype, data_split_none.dtype) - self.assertEqual(res.device, data_split_none.device) - res, inv = ht.unique(data_split_none, return_inverse=True, sorted=True) - self.assertIsInstance(inv, ht.DNDarray) - self.assertEqual(inv.split, None) - self.assertEqual(inv.dtype, data_split_none.dtype) - self.assertEqual(inv.device, data_split_none.device) - self.assertTrue(torch.equal(inv.larray, exp_inv.int())) - - data_split_zero = ht.array(torch_array, split=0) - res, inv = ht.unique(data_split_zero, return_inverse=True, sorted=True) - self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + # test "sparse" unique + data = ht.array( + torch.zeros(10, 4, dtype=torch.int32, device=self.device.torch_device), is_split=0 + ) + _, _, local_slice = data.comm.chunk(data.gshape, data.split) + random_ranks = torch.randint(size, size=(size // 2 + 1,)).tolist() + if rank in random_ranks: + random_row = torch.randint(10, size=(10,)) + random_col = torch.randint(4, size=(10,)) + data.larray[random_row, random_col] = 1 + t_comp = ht.resplit(data, axis=None).larray + # axis is None + unique, inverse = ht.unique(data, return_inverse=True) + unique.resplit_(None) + t_unique, t_inverse = torch.unique(t_comp, sorted=True, return_inverse=True) + self.assertTrue((unique.larray == t_unique).all()) + self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) + # axis not None + axis = 0 + unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) + unique0.resplit_(None) + t_unique0, t_inverse0 = torch.unique(t_comp, sorted=True, return_inverse=True, dim=axis) + # print("DEBUGGING: unique0, inverse0 = ", inverse0.larray) + # print("DEBUGGING: t_unique0, t_inverse0 = ", t_unique0, t_inverse0, t_inverse0.shape) + self.assertTrue((unique0.larray == t_unique0).all()) + self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) + + # test "sparse" unique, distributed, axis + + # torch_array = torch.arange(size, dtype=torch.int32, device=self.device.torch_device).expand( + # size, size + # ) + # split_zero = ht.array(torch_array, split=0) + + # exp_axis_none = ht.array([rank], dtype=ht.int32) + # res = split_zero.unique()[rank] + # print("res, exp_axis_none = ", res.larray, exp_axis_none.larray) + # self.assertTrue((res.larray == exp_axis_none.larray).all()) + + # exp_axis_zero = ht.arange(size, dtype=ht.int32).expand_dims(0) + # res = ht.unique(split_zero, axis=0) + # self.assertTrue((res.larray == exp_axis_zero.larray).all()) + + # exp_axis_one = ht.array([rank], dtype=ht.int32).expand_dims(0) + # split_zero_transposed = ht.array(torch_array.transpose(0, 1), split=0) + # res = ht.unique(split_zero_transposed, sorted=False, axis=1) + # self.assertTrue((res.larray == exp_axis_one.larray).all()) + + # split_one = ht.array(torch_array, dtype=ht.int32, split=1) + + # exp_axis_none = ht.arange(size, dtype=ht.int32) + # res = ht.unique(split_one, sorted=True) + # self.assertTrue((res.larray == exp_axis_none.larray).all()) + + # exp_axis_zero = ht.array([rank], dtype=ht.int32).expand_dims(0) + # res = ht.unique(split_one, sorted=False, axis=0) + # self.assertTrue((res.larray == exp_axis_zero.larray).all()) + + # exp_axis_one = ht.array([rank] * size, dtype=ht.int32).expand_dims(1) + # res = ht.unique(split_one, sorted=True, axis=1) + # self.assertTrue((res.larray == exp_axis_one.larray).all()) + + # torch_array = torch.tensor( + # [[1, 2], [2, 3], [1, 2], [2, 3], [1, 2]], + # dtype=torch.int32, + # device=self.device.torch_device, + # ) + # data = ht.array(torch_array, split=0) + + # res, inv = ht.unique(data, return_inverse=True, axis=0) + # _, exp_inv = torch_array.unique(dim=0, return_inverse=True, sorted=True) + # self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + + # res, inv = ht.unique(data, return_inverse=True, axis=1) + # _, exp_inv = torch_array.unique(dim=1, return_inverse=True, sorted=True) + # self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + + # torch_array = torch.tensor( + # [[1, 1, 2], [1, 2, 2], [2, 1, 2], [1, 3, 2], [0, 1, 2]], + # dtype=torch.int32, + # device=self.device.torch_device, + # ) + # exp_res, exp_inv = torch_array.unique(return_inverse=True, sorted=True) + + # data_split_none = ht.array(torch_array) + # res = ht.unique(data_split_none, sorted=True) + # self.assertIsInstance(res, ht.DNDarray) + # self.assertEqual(res.split, None) + # self.assertEqual(res.dtype, data_split_none.dtype) + # self.assertEqual(res.device, data_split_none.device) + # res, inv = ht.unique(data_split_none, return_inverse=True, sorted=True) + # self.assertIsInstance(inv, ht.DNDarray) + # self.assertEqual(inv.split, None) + # self.assertEqual(inv.dtype, data_split_none.dtype) + # self.assertEqual(inv.device, data_split_none.device) + # self.assertTrue(torch.equal(inv.larray, exp_inv.int())) + + # data_split_zero = ht.array(torch_array, split=0) + # res, inv = ht.unique(data_split_zero, return_inverse=True, sorted=True) + # self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) def test_vsplit(self): # for further testing, see test_split From 9290c40352e29ea80c142f47373f4ae11a9f5079 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Mar 2021 17:41:34 +0100 Subject: [PATCH 17/87] Fix write-out bug in MPI ring --- heat/core/manipulations.py | 101 ++++++++++++------------------------- 1 file changed, 33 insertions(+), 68 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index a478f0596a..73dd5ddc4e 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1919,13 +1919,11 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): # Separate the sorted tensor into size + 1 equal length partitions partitions = [x * length // (size + 1) for x in range(1, size + 1)] - print("DEBUGGING: partitions = ", partitions) local_pivots = ( local_sorted[partitions] if counts[rank] else torch.empty((0,) + local_sorted.size()[1:], dtype=local_sorted.dtype) ) - print("DEBUGGING: local_pivots = ", local_pivots) # Only processes with elements should share their pivots gather_counts = [int(x > 0) * size for x in counts] @@ -1945,20 +1943,16 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): # root process creates new pivots and shares them with other processes if rank == 0: - print("DEBUGGING: pivot_buffer = ", pivot_buffer) if sort_op is torch.sort: sorted_pivots, _ = sort_op(pivot_buffer, dim=0, descending=descending) else: sorted_pivots = sort_op(pivot_buffer, dim=0, **kwargs)[0] - print("DEBUGGING: sorted_pivots = ", sorted_pivots) length = sorted_pivots.size()[0] global_partitions = [x * length // size for x in range(1, size)] global_pivots = sorted_pivots[global_partitions] a.comm.Bcast(global_pivots, root=0) - print("DEBUGGING: global_pivots = ", global_pivots) - print("DEBUGGING: local_sorted = ", local_sorted) - # special case: unique with axis not None + # special case: unique along axis if unique_along_axis: # find position of global pivots in local sorted uniques local_sorted, local_inv = torch.cat((local_sorted, global_pivots), dim=0).unique( @@ -1976,12 +1970,10 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): ) recv_matrix = torch.zeros(size, dtype=torch.int64, device=local_sorted.device) a.comm.Alltoall(send_matrix, recv_matrix) + # reshape send/recv_matrix into column to match sort() alltoall scheme for matrix in [send_matrix, recv_matrix]: matrix = matrix.reshape(1, matrix.numel()) - print("DEBUGGING: after alltoall: send_matrix = ", send_matrix) - print("DEBUGGING: after alltoall: recv_matrix = ", recv_matrix) - scounts = send_matrix rcounts = recv_matrix shape = (recv_matrix.sum(dim=0),) + local_sorted.shape[1:] @@ -1998,6 +1990,7 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): lt_partitions[idx] = lt last = lt lt_partitions[size - 1] = torch.ones_like(local_sorted, dtype=last.dtype) - last + # Matrix holding information how many values will be sent where local_partitions = torch.sum(lt_partitions, dim=1) partition_matrix = torch.empty_like(local_partitions) @@ -2017,8 +2010,6 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): a.comm.Alltoall(send_matrix, recv_matrix) - print("DEBUGGING: after alltoall: send_matrix = ", send_matrix) - print("DEBUGGING: after alltoall: recv_matrix = ", recv_matrix) scounts = local_partitions rcounts = recv_matrix @@ -2036,7 +2027,7 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): for idx in iterator: if unique_along_axis: - idx_slice = [slice(None)] # + [slice(ind, ind + 1) for ind in range(idx)] + idx_slice = [slice(None)] else: idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] @@ -2098,7 +2089,6 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): i for i, x in enumerate(current_cumsum) if target_cumsum[proc - 1] < x ) ) - # print("DEBUGGING: current_cumsum, target_cumsum[proc] = ", current_cumsum, target_cumsum[proc]) last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) for i, x in enumerate(partition_matrix[idx_slice][first:last]): # Taking as many elements as possible from each following process @@ -2115,8 +2105,7 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): # Iterate through one layer again to create the final balanced local tensors second_result = torch.empty_like(local_sorted) - if sort_op is torch.sort: - second_indices = torch.empty_like(second_result) + second_indices = torch.empty_like(second_result) for idx in np.ndindex(local_sorted.shape[1:]): idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] @@ -2132,26 +2121,20 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) second_result[idx_slice] = r_val - if sort_op is torch.sort: - s_ind = first_indices[0:end][idx_slice][indices].reshape_as(s_val) - r_ind = torch.empty_like(r_val) - a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) - second_indices[idx_slice] = r_ind + s_ind = first_indices[0:end][idx_slice][indices].reshape_as(s_val) + r_ind = torch.empty_like(r_val) + a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) + second_indices[idx_slice] = r_ind - if sort_op is torch.sort: - second_result, tmp_indices = sort_op(second_result, dim=0, descending=descending) - final_result = second_result.transpose(0, axis) - final_indices = torch.empty_like(second_indices) - # Update the indices in case the ordering changed during the last sort - for idx in np.ndindex(tmp_indices.shape): - val = tmp_indices[idx] - final_indices[idx] = second_indices[val.item()][idx[1:]] - final_indices = final_indices.transpose(0, axis) - return final_result, final_indices - else: - second_result = sort_op(second_result, dim=0, **kwargs)[0] - final_result = second_result.transpose(0, axis) - return final_result + second_result, tmp_indices = sort_op(second_result, dim=0, descending=descending) + final_result = second_result.transpose(0, axis) + final_indices = torch.empty_like(second_indices) + # Update the indices in case the ordering changed during the last sort + for idx in np.ndindex(tmp_indices.shape): + val = tmp_indices[idx] + final_indices[idx] = second_indices[val.item()][idx[1:]] + final_indices = final_indices.transpose(0, axis) + return final_result, final_indices def sort(a, axis=None, descending=False, out=None): @@ -2874,20 +2857,13 @@ def unique(a, sorted=True, return_inverse=False, axis=None): lres = torch.unique(gres.larray, sorted=sorted, dim=unique_axis) gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) else: - # print("DEBUGGING: before sorting: gres.larray = ", gres.larray ) # balance gres if needed gres.balance_() - # global sort - # print("DEBUGGING: before sorting after balancing: gres.larray = ", gres.larray ) - # gres, sorted_gindices = sort(gres, axis=unique_axis) + # global sorted unique lres = _pivot_sorting(gres, 0, torch.unique, sorted=sorted, return_inverse=True) - print("DEBUGGING: after pivot_sorting: lres = ", lres) - # print("DEBUGGING: after sorting before unique: gres.larray = ", gres.larray ) # second local unique lres = torch.unique(lres, sorted=sorted, dim=unique_axis) - print("DEBUGGING: after second unique: lres = ", lres) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) - # print("DEBUGGING: after pivot_sorting unique: gres.larray = ", gres.larray ) gres.balance_() # inverse indices @@ -2906,45 +2882,34 @@ def unique(a, sorted=True, return_inverse=False, axis=None): gres_map = torch.tensor(gres.gshape, device=inverse.device) gres_offsets = torch.tensor([0], device=gres_map.device) lres = gres.larray - print("DEBUGGING: after balancing: lres = ", lres) - - gres_recv = None for p in range(unique_ranks): if unique_ranks == 1: incoming_offset = 0 else: - origin = rank + p if rank + p < unique_ranks else rank + p - unique_ranks + origin = (rank - p) % size incoming_offset = gres_offsets[origin] - if gres_recv: - gres_recv.Wait() + tmp = torch.empty( + gres_map[0].tolist(), dtype=local_data.dtype, device=local_data.device + ) # loop through unique elements, find matching position in data for i, el in enumerate(lres): counts = torch.zeros_like(local_data, dtype=torch.int8, device=local_data.device) - print("DEBUGGING: el = ", el) - # print("DEBUGGING: local_data = ", local_data) counts[torch.where(local_data == el)] = 1 if lres.ndim > 1: counts = torch.sum(counts, dim=tuple(range(lres.ndim))[1:]) cond = torch.where(counts == el.numel()) global_inverse.larray[cond] = i + incoming_offset - # if necessary, prepare to send lres to rank-1 and receive from rank+1 + # if necessary, prepare to send lres to rank+1 and receive from rank-1 if unique_ranks > 1: - dest_rank = rank - 1 if rank != 0 else size - 1 - gres.comm.Send(lres, dest_rank) - next_origin = origin + 1 if origin + 1 < unique_ranks else origin + 1 - unique_ranks - incoming_shape = gres_map[next_origin].tolist() - if incoming_shape != lres.shape: - print( - "DEBUGGING: CHANGING LRES SHAPE!! incoming_shape, lres.shape = ", - incoming_shape, - lres.shape, - ) - lres = torch.empty( - incoming_shape, dtype=local_data.dtype, device=local_data.device - ) - recv_from_rank = rank + 1 if rank != size - 1 else 0 - gres_recv = gres.comm.Recv(lres, recv_from_rank) - print("DEBUGGING: rank, p, lres = ", rank, p, lres) + dest_rank = (rank + 1) % unique_ranks + tmp[slice(None, lres.shape[0])] = lres + queue = gres.comm.Isend(tmp, dest_rank, tag=rank) + recv_from_rank = (rank - 1) % unique_ranks + next_origin = (origin - 1) % unique_ranks + incoming_size = gres_map[next_origin].tolist()[0] + queue.Wait() + gres.comm.Recv(tmp, recv_from_rank, tag=recv_from_rank) + lres = tmp[slice(None, incoming_size)] if axis is not None and axis != 0: # transpose back to original dimensions From 79e6219574ee352867629dc04e890b41028b4215 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Mar 2021 17:42:34 +0100 Subject: [PATCH 18/87] Expand "dense unique" tests --- heat/core/tests/test_manipulations.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 76546434a7..fe04860021 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2929,8 +2929,23 @@ def test_unique(self): unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) unique0.resplit_(None) t_unique0, t_inverse0 = torch.unique(t_comp, sorted=True, return_inverse=True, dim=axis) - # print("DEBUGGING: unique0, inverse0 = ", inverse0.larray) - # print("DEBUGGING: t_unique0, t_inverse0 = ", t_unique0, t_inverse0, t_inverse0.shape) + self.assertTrue((unique0.larray == t_unique0).all()) + self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) + # test "dense" unique + torch.random_seed(3) + t_comp = torch.randint(0, 7, (50, 3), dtype=torch.int64, device=self.device.torch_device) + data = ht.array(t_comp, split=0) + # axis is None + unique, inverse = ht.unique(data, return_inverse=True) + unique.resplit_(None) + t_unique, t_inverse = torch.unique(t_comp, sorted=True, return_inverse=True) + self.assertTrue((unique.larray == t_unique).all()) + self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) + # axis not None + axis = 0 + unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) + unique0.resplit_(None) + t_unique0, t_inverse0 = torch.unique(t_comp, sorted=True, return_inverse=True, dim=axis) self.assertTrue((unique0.larray == t_unique0).all()) self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) From 8123b6ada2106d54d5ed00d89c58fae27ec26a01 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 23 Mar 2021 17:05:06 +0100 Subject: [PATCH 19/87] Expand test_unique --- heat/core/tests/test_manipulations.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index fe04860021..9a08b312df 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2932,13 +2932,15 @@ def test_unique(self): self.assertTrue((unique0.larray == t_unique0).all()) self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) # test "dense" unique - torch.random_seed(3) - t_comp = torch.randint(0, 7, (50, 3), dtype=torch.int64, device=self.device.torch_device) - data = ht.array(t_comp, split=0) + data = ht.random.randint(0, 25, (50, 3), dtype=ht.int64, split=0) + _, _, local_slice = data.comm.chunk(data.gshape, data.split) + t_comp = ht.resplit(data, axis=None).larray # axis is None unique, inverse = ht.unique(data, return_inverse=True) unique.resplit_(None) t_unique, t_inverse = torch.unique(t_comp, sorted=True, return_inverse=True) + print("DEBUGGING: unique.larray = ", unique.larray) + print("DEBUGGING: t_unique = ", t_unique) self.assertTrue((unique.larray == t_unique).all()) self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) # axis not None From 290f11d66a6a34d6db679fee1edd541b2632fd41 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Mar 2021 11:31:34 +0100 Subject: [PATCH 20/87] minimize boiler-plate code in test_unique --- heat/core/tests/test_manipulations.py | 138 +++++--------------------- 1 file changed, 27 insertions(+), 111 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index f11a7ba946..b0727577ec 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2924,123 +2924,39 @@ def test_topk(self): def test_unique(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank - # test "sparse" unique - data = ht.array( + ## test "sparse" unique + sparse_data = ht.array( torch.zeros(10, 4, dtype=torch.int32, device=self.device.torch_device), is_split=0 ) - _, _, local_slice = data.comm.chunk(data.gshape, data.split) random_ranks = torch.randint(size, size=(size // 2 + 1,)).tolist() if rank in random_ranks: random_row = torch.randint(10, size=(10,)) random_col = torch.randint(4, size=(10,)) - data.larray[random_row, random_col] = 1 - t_comp = ht.resplit(data, axis=None).larray - # axis is None - unique, inverse = ht.unique(data, return_inverse=True) - unique.resplit_(None) - t_unique, t_inverse = torch.unique(t_comp, sorted=True, return_inverse=True) - self.assertTrue((unique.larray == t_unique).all()) - self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) - # axis not None - axis = 0 - unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) - unique0.resplit_(None) - t_unique0, t_inverse0 = torch.unique(t_comp, sorted=True, return_inverse=True, dim=axis) - self.assertTrue((unique0.larray == t_unique0).all()) - self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) - # test "dense" unique - data = ht.random.randint(0, 25, (50, 3), dtype=ht.int64, split=0) - _, _, local_slice = data.comm.chunk(data.gshape, data.split) - t_comp = ht.resplit(data, axis=None).larray - # axis is None - unique, inverse = ht.unique(data, return_inverse=True) - unique.resplit_(None) - t_unique, t_inverse = torch.unique(t_comp, sorted=True, return_inverse=True) - print("DEBUGGING: unique.larray = ", unique.larray) - print("DEBUGGING: t_unique = ", t_unique) - self.assertTrue((unique.larray == t_unique).all()) - self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) - # axis not None - axis = 0 - unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) - unique0.resplit_(None) - t_unique0, t_inverse0 = torch.unique(t_comp, sorted=True, return_inverse=True, dim=axis) - self.assertTrue((unique0.larray == t_unique0).all()) - self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) - - # test "sparse" unique, distributed, axis - - # torch_array = torch.arange(size, dtype=torch.int32, device=self.device.torch_device).expand( - # size, size - # ) - # split_zero = ht.array(torch_array, split=0) - - # exp_axis_none = ht.array([rank], dtype=ht.int32) - # res = split_zero.unique()[rank] - # print("res, exp_axis_none = ", res.larray, exp_axis_none.larray) - # self.assertTrue((res.larray == exp_axis_none.larray).all()) - - # exp_axis_zero = ht.arange(size, dtype=ht.int32).expand_dims(0) - # res = ht.unique(split_zero, axis=0) - # self.assertTrue((res.larray == exp_axis_zero.larray).all()) - - # exp_axis_one = ht.array([rank], dtype=ht.int32).expand_dims(0) - # split_zero_transposed = ht.array(torch_array.transpose(0, 1), split=0) - # res = ht.unique(split_zero_transposed, sorted=False, axis=1) - # self.assertTrue((res.larray == exp_axis_one.larray).all()) - - # split_one = ht.array(torch_array, dtype=ht.int32, split=1) - - # exp_axis_none = ht.arange(size, dtype=ht.int32) - # res = ht.unique(split_one, sorted=True) - # self.assertTrue((res.larray == exp_axis_none.larray).all()) - - # exp_axis_zero = ht.array([rank], dtype=ht.int32).expand_dims(0) - # res = ht.unique(split_one, sorted=False, axis=0) - # self.assertTrue((res.larray == exp_axis_zero.larray).all()) - - # exp_axis_one = ht.array([rank] * size, dtype=ht.int32).expand_dims(1) - # res = ht.unique(split_one, sorted=True, axis=1) - # self.assertTrue((res.larray == exp_axis_one.larray).all()) - - # torch_array = torch.tensor( - # [[1, 2], [2, 3], [1, 2], [2, 3], [1, 2]], - # dtype=torch.int32, - # device=self.device.torch_device, - # ) - # data = ht.array(torch_array, split=0) - - # res, inv = ht.unique(data, return_inverse=True, axis=0) - # _, exp_inv = torch_array.unique(dim=0, return_inverse=True, sorted=True) - # self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) - - # res, inv = ht.unique(data, return_inverse=True, axis=1) - # _, exp_inv = torch_array.unique(dim=1, return_inverse=True, sorted=True) - # self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) - - # torch_array = torch.tensor( - # [[1, 1, 2], [1, 2, 2], [2, 1, 2], [1, 3, 2], [0, 1, 2]], - # dtype=torch.int32, - # device=self.device.torch_device, - # ) - # exp_res, exp_inv = torch_array.unique(return_inverse=True, sorted=True) - - # data_split_none = ht.array(torch_array) - # res = ht.unique(data_split_none, sorted=True) - # self.assertIsInstance(res, ht.DNDarray) - # self.assertEqual(res.split, None) - # self.assertEqual(res.dtype, data_split_none.dtype) - # self.assertEqual(res.device, data_split_none.device) - # res, inv = ht.unique(data_split_none, return_inverse=True, sorted=True) - # self.assertIsInstance(inv, ht.DNDarray) - # self.assertEqual(inv.split, None) - # self.assertEqual(inv.dtype, data_split_none.dtype) - # self.assertEqual(inv.device, data_split_none.device) - # self.assertTrue(torch.equal(inv.larray, exp_inv.int())) - - # data_split_zero = ht.array(torch_array, split=0) - # res, inv = ht.unique(data_split_zero, return_inverse=True, sorted=True) - # self.assertTrue(torch.equal(inv, exp_inv.to(dtype=inv.dtype))) + sparse_data.larray[random_row, random_col] = 1 + t_sparse = ht.resplit(sparse_data, axis=None).larray + + ## test "dense" unique + dense_data = ht.random.randint(0, 25, (50, 3), dtype=ht.int64, split=0) + t_dense = ht.resplit(dense_data, axis=None).larray + + datasets = [sparse_data, dense_data] + comps = [t_sparse, t_dense] + + for data, comp in zip(datasets, comps): + _, _, local_slice = data.comm.chunk(data.gshape, data.split) + # axis is None + unique, inverse = ht.unique(data, return_inverse=True) + unique.resplit_(None) + t_unique, t_inverse = torch.unique(comp, sorted=True, return_inverse=True) + self.assertTrue((unique.larray == t_unique).all()) + self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) + # axis not None + axis = 0 + unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) + unique0.resplit_(None) + t_unique0, t_inverse0 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) + self.assertTrue((unique0.larray == t_unique0).all()) + self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) def test_vsplit(self): # for further testing, see test_split From 4bed18a67a4122eb4e1ead6fdb71d9a0057c3697 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Mar 2021 11:41:03 +0100 Subject: [PATCH 21/87] remove excess ### --- heat/core/tests/test_manipulations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index a3dbb217fe..1239ff5a55 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3071,7 +3071,7 @@ def test_topk(self): def test_unique(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank - ## test "sparse" unique + # test "sparse" unique sparse_data = ht.array( torch.zeros(10, 4, dtype=torch.int32, device=self.device.torch_device), is_split=0 ) @@ -3082,7 +3082,7 @@ def test_unique(self): sparse_data.larray[random_row, random_col] = 1 t_sparse = ht.resplit(sparse_data, axis=None).larray - ## test "dense" unique + # test "dense" unique dense_data = ht.random.randint(0, 25, (50, 3), dtype=ht.int64, split=0) t_dense = ht.resplit(dense_data, axis=None).larray From 2c2aa6e1cc9d5aa347335f683dd3000a813d8f25 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Mar 2021 12:03:12 +0100 Subject: [PATCH 22/87] Debugging --- heat/core/tests/test_manipulations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 1239ff5a55..cc1d6f8823 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3102,6 +3102,8 @@ def test_unique(self): unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) unique0.resplit_(None) t_unique0, t_inverse0 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) + print("DEBUGGING: unique0.larray = ", unique0.larray) + print("DEBUGGING: t_unique0 = ", t_unique0) self.assertTrue((unique0.larray == t_unique0).all()) self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) From 5a2e592bea1be7599973fa23693b4f30f25e87e0 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Mar 2021 12:09:28 +0100 Subject: [PATCH 23/87] Debugging --- heat/core/tests/test_manipulations.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index cc1d6f8823..76817cdff3 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3102,8 +3102,13 @@ def test_unique(self): unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) unique0.resplit_(None) t_unique0, t_inverse0 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) - print("DEBUGGING: unique0.larray = ", unique0.larray) - print("DEBUGGING: t_unique0 = ", t_unique0) + print( + "DEBUGGING: unique0.larray = ", + unique0.larray, + unique0.larray.dtype, + unique0.larray.device, + ) + print("DEBUGGING: t_unique0 = ", t_unique0, t_unique0.dtype, t_unique0.device) self.assertTrue((unique0.larray == t_unique0).all()) self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) From 3c068e70e9ef2fa767c23b06ae6efc3ba389e393 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Mar 2021 12:19:50 +0100 Subject: [PATCH 24/87] Debugging --- heat/core/tests/test_manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 76817cdff3..ce0bf8e042 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3109,7 +3109,7 @@ def test_unique(self): unique0.larray.device, ) print("DEBUGGING: t_unique0 = ", t_unique0, t_unique0.dtype, t_unique0.device) - self.assertTrue((unique0.larray == t_unique0).all()) + self.assertTrue((unique0.larray == t_unique0).all().item()) self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) def test_vsplit(self): From 3d6a02b744554283de2b92d226bef2b5d4ade294 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Mar 2021 12:26:54 +0100 Subject: [PATCH 25/87] Debugging --- heat/core/tests/test_manipulations.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index ce0bf8e042..359df83c14 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3095,6 +3095,13 @@ def test_unique(self): unique, inverse = ht.unique(data, return_inverse=True) unique.resplit_(None) t_unique, t_inverse = torch.unique(comp, sorted=True, return_inverse=True) + print( + "DEBUGGING: unique.larray = ", + unique.larray, + unique.larray.dtype, + unique.larray.device, + ) + print("DEBUGGING: t_unique = ", t_unique, t_unique.dtype, t_unique.device) self.assertTrue((unique.larray == t_unique).all()) self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) # axis not None @@ -3102,14 +3109,7 @@ def test_unique(self): unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) unique0.resplit_(None) t_unique0, t_inverse0 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) - print( - "DEBUGGING: unique0.larray = ", - unique0.larray, - unique0.larray.dtype, - unique0.larray.device, - ) - print("DEBUGGING: t_unique0 = ", t_unique0, t_unique0.dtype, t_unique0.device) - self.assertTrue((unique0.larray == t_unique0).all().item()) + self.assertTrue((unique0.larray == t_unique0).all()) self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) def test_vsplit(self): From b4b47636cf8ba81bdf4f065e4d70e6652df8e868 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Mar 2021 13:08:23 +0100 Subject: [PATCH 26/87] Fix empty Allgather problem --- heat/core/manipulations.py | 1 + heat/core/tests/test_manipulations.py | 7 ------- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 91d9e35fe3..4bc074c219 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3000,6 +3000,7 @@ def unique(a, sorted=True, return_inverse=False, axis=None): queue.Wait() gres.comm.Recv(tmp, recv_from_rank, tag=recv_from_rank) lres = tmp[slice(None, incoming_size)] + gres.larray = lres if axis is not None and axis != 0: # transpose back to original dimensions diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 359df83c14..1239ff5a55 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3095,13 +3095,6 @@ def test_unique(self): unique, inverse = ht.unique(data, return_inverse=True) unique.resplit_(None) t_unique, t_inverse = torch.unique(comp, sorted=True, return_inverse=True) - print( - "DEBUGGING: unique.larray = ", - unique.larray, - unique.larray.dtype, - unique.larray.device, - ) - print("DEBUGGING: t_unique = ", t_unique, t_unique.dtype, t_unique.device) self.assertTrue((unique.larray == t_unique).all()) self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) # axis not None From d8f73ead28e763a2e193bd877d3cb99ffd4078a2 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 25 Mar 2021 05:50:19 +0100 Subject: [PATCH 27/87] Debugging --- heat/core/manipulations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 4bc074c219..d3dfc7cb8e 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2953,6 +2953,7 @@ def unique(a, sorted=True, return_inverse=False, axis=None): lres = _pivot_sorting(gres, 0, torch.unique, sorted=sorted, return_inverse=True) # second local unique lres = torch.unique(lres, sorted=sorted, dim=unique_axis) + print("DEBUGGING: lres.shape = ", lres.shape) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) gres.balance_() From f4b7f03306000f0d7845c933ea075a1d3c10f71d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 25 Mar 2021 10:10:11 +0100 Subject: [PATCH 28/87] Skip second local torch.unique if local tensor is empty --- heat/core/manipulations.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index d3dfc7cb8e..614c67fddf 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2893,10 +2893,13 @@ def unique(a, sorted=True, return_inverse=False, axis=None): ) if isinstance(torch_output, tuple): heat_output = tuple( - factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output + factories.array(i, dtype=a.dtype, split=a.split, device=a.device) + for i in torch_output ) else: - heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device) + heat_output = factories.array( + torch_output, dtype=a.dtype, split=a.split, device=a.device + ) return heat_output rank = a.comm.rank @@ -2905,7 +2908,6 @@ def unique(a, sorted=True, return_inverse=False, axis=None): local_data = a.larray unique_axis = None - if axis is not None: if axis != a.split: raise ValueError( @@ -2952,8 +2954,8 @@ def unique(a, sorted=True, return_inverse=False, axis=None): # global sorted unique lres = _pivot_sorting(gres, 0, torch.unique, sorted=sorted, return_inverse=True) # second local unique - lres = torch.unique(lres, sorted=sorted, dim=unique_axis) - print("DEBUGGING: lres.shape = ", lres.shape) + if 0 not in lres.shape: + lres = torch.unique(lres, sorted=sorted, dim=unique_axis) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) gres.balance_() From eb57b5e147b5b8ee91176847a723fc700bf625d4 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 28 Mar 2021 08:00:05 +0200 Subject: [PATCH 29/87] Expand tests, fix spit inconsistencies --- heat/core/manipulations.py | 42 +++++++++++++-------------- heat/core/tests/test_manipulations.py | 35 +++++++++++++++++++++- 2 files changed, 54 insertions(+), 23 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 614c67fddf..9d6a0a0065 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1999,10 +1999,7 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 ) counts, displs = tuple(counts.tolist()), tuple(displs.tolist()) - else: - raise ValueError( - "sorting operation can be torch.sort or torch.unique, was {}".format(sort_op) - ) + unique_along_axis = True if sort_op is torch.unique and axis is not None else False length = local_sorted.size()[0] @@ -2893,34 +2890,29 @@ def unique(a, sorted=True, return_inverse=False, axis=None): ) if isinstance(torch_output, tuple): heat_output = tuple( - factories.array(i, dtype=a.dtype, split=a.split, device=a.device) - for i in torch_output + factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output ) else: - heat_output = factories.array( - torch_output, dtype=a.dtype, split=a.split, device=a.device - ) + heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device) return heat_output rank = a.comm.rank size = a.comm.size local_data = a.larray - + inv_shape = local_data.shape if axis is None else (local_data.shape[axis],) unique_axis = None if axis is not None: if axis != a.split: - raise ValueError( - "Operation axis must match distribution axis of the array: axis is {}, array.split is {}".format( + raise NotImplementedError( + "Not implemented yet: Operation axis differs from distribution axis: axis is {}, array.split is {}".format( axis, a.split ) ) if axis != 0: # transpose so we can work along the 0 axis local_data = local_data.transpose(0, axis) - unique_axis = 0 - else: - unique_axis = axis + unique_axis = 0 # Calculate local uniques if a.lshape[a.split] == 0: @@ -2936,7 +2928,6 @@ def unique(a, sorted=True, return_inverse=False, axis=None): lres = torch.empty(res_shape, dtype=a.dtype.torch_type()) else: lres = torch.unique(local_data, sorted=sorted, return_inverse=False, dim=unique_axis) - inv_shape = local_data.shape if axis is None else (local_data.shape[unique_axis],) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally @@ -2947,6 +2938,7 @@ def unique(a, sorted=True, return_inverse=False, axis=None): gres.resplit_(None) # final round of torch.unique lres = torch.unique(gres.larray, sorted=sorted, dim=unique_axis) + lres_split = None gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) else: # balance gres if needed @@ -2956,14 +2948,20 @@ def unique(a, sorted=True, return_inverse=False, axis=None): # second local unique if 0 not in lres.shape: lres = torch.unique(lres, sorted=sorted, dim=unique_axis) - gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) - gres.balance_() + lres_split = 0 + + gres = factories.array(lres, dtype=a.dtype, is_split=lres_split, device=a.device) + gres.balance_() - # inverse indices if return_inverse: + # inverse indices # allocate local tensors and global DNDarray inverse = torch.empty(inv_shape, dtype=torch.int64, device=local_data.device) - global_inverse = factories.array(inverse, is_split=a.split, device=gres.device) + if a.is_distributed(): + inv_split = 0 if inverse.ndim == 1 else a.split + else: + inv_split = None + global_inverse = factories.array(inverse, is_split=inv_split, device=gres.device) unique_ranks = size if gres.is_distributed() else 1 if unique_ranks > 1: @@ -3006,8 +3004,8 @@ def unique(a, sorted=True, return_inverse=False, axis=None): gres.larray = lres if axis is not None and axis != 0: - # transpose back to original dimensions - gres = gres.transpose(0, axis) + # transpose back to original + gres = linalg.basics.transpose(gres, (axis, 0)) if return_inverse: return (gres, global_inverse) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 1239ff5a55..0f4d373743 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3097,13 +3097,46 @@ def test_unique(self): t_unique, t_inverse = torch.unique(comp, sorted=True, return_inverse=True) self.assertTrue((unique.larray == t_unique).all()) self.assertTrue((inverse.larray == t_inverse[local_slice]).all()) + if data.is_distributed(): + self.assertTrue(unique.split is None or unique.split == 0) + else: + self.assertTrue(unique.split is None) + # test inverse indices on "gathered" unique + self.assertTrue((unique[inverse.larray].larray == data.larray).all()) + # axis not None axis = 0 unique0, inverse0 = ht.unique(data, return_inverse=True, axis=axis) unique0.resplit_(None) t_unique0, t_inverse0 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) self.assertTrue((unique0.larray == t_unique0).all()) - self.assertTrue((inverse0.larray == t_inverse0[local_slice[0]]).all()) + self.assertTrue((inverse0.larray == t_inverse0[local_slice[axis]]).all()) + if data.is_distributed(): + self.assertTrue(unique0.split is None or unique0.split == axis) + else: + self.assertTrue(unique0.split is None) + # test inverse indices on "gathered" unique + self.assertTrue((unique0[inverse0.larray].larray == data.larray).all()) + + # axis == split != 0 + data = ht.array(comp, split=1) + _, _, local_slice = data.comm.chunk(data.gshape, data.split) + axis = 1 + unique1, inverse1 = ht.unique(data, return_inverse=True, axis=axis) + unique1.resplit_(None) + t_unique1, t_inverse1 = torch.unique(comp, sorted=True, return_inverse=True, dim=axis) + self.assertTrue((unique1.larray == t_unique1).all()) + self.assertTrue((inverse1.larray == t_inverse1[local_slice[axis]]).all()) + if data.is_distributed(): + self.assertTrue(unique1.split is None or unique1.split == axis) + else: + self.assertTrue(unique1.split is None) + # test inverse indices on "gathered" unique + self.assertTrue((unique1[:, inverse1.larray].larray == data.larray).all()) + + # test exceptions + with self.assertRaises(NotImplementedError): + ht.unique(dense_data, axis=1) def test_vsplit(self): # for further testing, see test_split From 32b88578d5a951250f7770ad6d0d1df544296f08 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 29 Mar 2021 12:49:42 +0200 Subject: [PATCH 30/87] Fix inverse indices dtype in non-distributed case --- heat/core/manipulations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 9d6a0a0065..72263fc76b 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2890,7 +2890,10 @@ def unique(a, sorted=True, return_inverse=False, axis=None): ) if isinstance(torch_output, tuple): heat_output = tuple( - factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output + factories.array( + i, dtype=types.canonical_heat_type(i.dtype), split=None, device=a.device + ) + for i in torch_output ) else: heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device) From a5752239af6e516113d8458d0e656575a0a79bdc Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 29 Mar 2021 12:50:19 +0200 Subject: [PATCH 31/87] Test NonImplementedError exception in distributed case only --- heat/core/tests/test_manipulations.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 0f4d373743..598502314a 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3135,8 +3135,9 @@ def test_unique(self): self.assertTrue((unique1[:, inverse1.larray].larray == data.larray).all()) # test exceptions - with self.assertRaises(NotImplementedError): - ht.unique(dense_data, axis=1) + if dense_data.is_distributed(): + with self.assertRaises(NotImplementedError): + ht.unique(dense_data, axis=1) def test_vsplit(self): # for further testing, see test_split From 2cba6cc0447eb2615f1a4f3da22396602a8d4ccb Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 29 Mar 2021 13:30:28 +0200 Subject: [PATCH 32/87] Fix lshape_map of local sorted uniques when nodes are empty --- heat/core/manipulations.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 72263fc76b..b3bf63ff5a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1990,10 +1990,13 @@ def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): actual_indices = local_indices.to(dtype=local_sorted.dtype) + disp[rank] elif sort_op is torch.unique: local_sorted = sort_op(transposed, dim=0, **kwargs)[0] + local_shape = local_sorted.shape + if 0 in local_shape: + local_shape = transposed.shape lshape_map = torch.empty( - (size, local_sorted.ndim), dtype=torch.int64, device=local_sorted.device + (size, transposed.ndim), dtype=torch.int64, device=transposed.device ) - a.comm.Allgather(torch.tensor(local_sorted.shape), lshape_map) + a.comm.Allgather(torch.tensor(local_shape), lshape_map) counts = lshape_map[:, 0] displs = torch.cumsum( torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 From 68eb57cbd12fedef88f3bed80a83bf01a8e99da5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 7 Apr 2021 06:54:01 +0200 Subject: [PATCH 33/87] Set dndarray.__balanced to `balanced`, not None --- heat/core/dndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index d1358736ed..8146520735 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -84,7 +84,7 @@ def __init__(self, array, gshape, dtype, split, device, comm, balanced): self.__split = split self.__device = device self.__comm = comm - self.__balanced = None + self.__balanced = balanced self.__ishalo = False self.__halo_next = None self.__halo_prev = None From 519c0209263331346ae57c7593ec1c5764fd1cb6 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 8 Apr 2021 07:13:16 +0200 Subject: [PATCH 34/87] Remove `sorted` option from ht.unique() --- heat/core/manipulations.py | 24 +++++++++--------------- heat/core/tests/test_manipulations.py | 6 ++++-- heat/naive_bayes/gaussianNB.py | 6 +++--- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index b3bf63ff5a..0cbc46cdbf 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2827,7 +2827,7 @@ def stack(arrays, axis=0, out=None): return stacked -def unique(a, sorted=True, return_inverse=False, axis=None): +def unique(a, return_inverse=False, axis=None): """ Returns the sorted unique elements of an array. @@ -2840,10 +2840,6 @@ def unique(a, sorted=True, return_inverse=False, axis=None): Input array. axis : int, optional The axis to operate on. If None, `a` will be flattened. - sorted : bool, optional - Sort the array in ascending order before finding the unique elements. - Set `sorted=False` only if `a` is already sorted (in whichever order). - Default: True return_inverse : bool, optional Return the indices of the unique array (for the specified `axis`, if provided) that can be used to reconstruct `a`. @@ -2876,21 +2872,19 @@ def unique(a, sorted=True, return_inverse=False, axis=None): Examples -------- >>> x = ht.array([[3, 2], [1, 3]]) - >>> ht.unique(x, sorted=True) + >>> ht.unique(x) array([1, 2, 3]) - >>> ht.unique(x, sorted=True, axis=0) + >>> ht.unique(x, axis=0) array([[1, 3], [2, 3]]) - >>> ht.unique(x, sorted=True, axis=1) + >>> ht.unique(x, axis=1) array([[2, 3], [3, 1]]) """ if not a.is_distributed(): - torch_output = torch.unique( - a.larray, sorted=sorted, return_inverse=return_inverse, dim=axis - ) + torch_output = torch.unique(a.larray, sorted=True, return_inverse=return_inverse, dim=axis) if isinstance(torch_output, tuple): heat_output = tuple( factories.array( @@ -2933,7 +2927,7 @@ def unique(a, sorted=True, return_inverse=False, axis=None): inv_shape = [0] lres = torch.empty(res_shape, dtype=a.dtype.torch_type()) else: - lres = torch.unique(local_data, sorted=sorted, return_inverse=False, dim=unique_axis) + lres = torch.unique(local_data, sorted=True, return_inverse=False, dim=unique_axis) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally @@ -2943,17 +2937,17 @@ def unique(a, sorted=True, return_inverse=False, axis=None): # gather local uniques gres.resplit_(None) # final round of torch.unique - lres = torch.unique(gres.larray, sorted=sorted, dim=unique_axis) + lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) lres_split = None gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) else: # balance gres if needed gres.balance_() # global sorted unique - lres = _pivot_sorting(gres, 0, torch.unique, sorted=sorted, return_inverse=True) + lres = _pivot_sorting(gres, 0, torch.unique, sorted=True, return_inverse=True) # second local unique if 0 not in lres.shape: - lres = torch.unique(lres, sorted=sorted, dim=unique_axis) + lres = torch.unique(lres, sorted=True, dim=unique_axis) lres_split = 0 gres = factories.array(lres, dtype=a.dtype, is_split=lres_split, device=a.device) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 598502314a..ff2847d634 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3071,7 +3071,7 @@ def test_topk(self): def test_unique(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank - # test "sparse" unique + # "sparse" data sparse_data = ht.array( torch.zeros(10, 4, dtype=torch.int32, device=self.device.torch_device), is_split=0 ) @@ -3082,7 +3082,7 @@ def test_unique(self): sparse_data.larray[random_row, random_col] = 1 t_sparse = ht.resplit(sparse_data, axis=None).larray - # test "dense" unique + # "dense" data dense_data = ht.random.randint(0, 25, (50, 3), dtype=ht.int64, split=0) t_dense = ht.resplit(dense_data, axis=None).larray @@ -3134,6 +3134,8 @@ def test_unique(self): # test inverse indices on "gathered" unique self.assertTrue((unique1[:, inverse1.larray].larray == data.larray).all()) + # test unique on sorted data + # test exceptions if dense_data.is_distributed(): with self.assertRaises(NotImplementedError): diff --git a/heat/naive_bayes/gaussianNB.py b/heat/naive_bayes/gaussianNB.py index a44e4b6352..fb731b72bd 100644 --- a/heat/naive_bayes/gaussianNB.py +++ b/heat/naive_bayes/gaussianNB.py @@ -51,7 +51,7 @@ class labels known to the classifier >>> print(clf.predict(ht.array([[-0.8, -1]]))) tensor([1]) >>> clf_pf = GaussianNB() - >>> clf_pf.partial_fit(X, Y, ht.unique(Y, sorted=True)) + >>> clf_pf.partial_fit(X, Y, ht.unique(Y)) >>> print(clf_pf.predict(ht.array([[-0.8, -1]]))) tensor([1]) @@ -93,7 +93,7 @@ def fit(self, X, y, sample_weight=None): type(sample_weight) ) ) - classes = ht.unique(y, sorted=True) + classes = ht.unique(y) if classes.split is not None: classes = ht.resplit(classes, axis=None) @@ -334,7 +334,7 @@ def __partial_fit(self, X, y, classes=None, _refit=False, sample_weight=None): classes = self.classes_ - unique_y = ht.unique(y, sorted=True) + unique_y = ht.unique(y) if unique_y.split is not None: unique_y = ht.resplit(unique_y, axis=None) unique_y_in_classes = ht.eq(unique_y, classes) From c5de73f1cb0df9d0b3c25872f291a39d991d6d2d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 11 Apr 2021 08:47:26 +0200 Subject: [PATCH 35/87] Fix race condition in test_qr --- heat/core/linalg/tests/test_qr.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/heat/core/linalg/tests/test_qr.py b/heat/core/linalg/tests/test_qr.py index 9c3d4f987c..c81cf74e6e 100644 --- a/heat/core/linalg/tests/test_qr.py +++ b/heat/core/linalg/tests/test_qr.py @@ -80,9 +80,13 @@ def test_qr(self): self.assertTrue( ht.allclose(ht.eye(m, dtype=ht.double), qr2.Q @ qr2.Q.T, rtol=1e-5, atol=1e-5) ) - # test if calc R alone works - qr = ht.qr(a2, calc_q=False, overwrite_a=True) - self.assertTrue(qr.Q is None) + # test if calc R alone works + a2_0 = ht.array(st2, split=0) + a2_1 = ht.array(st2, split=1) + qr_0 = ht.qr(a2_0, calc_q=False, overwrite_a=True) + self.assertTrue(qr_0.Q is None) + qr_1 = ht.qr(a2_1, calc_q=False, overwrite_a=True) + self.assertTrue(qr_1.Q is None) m, n = 40, 20 st = torch.randn(m, n, device=self.device.torch_device) From 1964cc85bf5e553f711032010ac0fd5314dec04b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 11 Apr 2021 11:21:51 +0200 Subject: [PATCH 36/87] Fix prev_rank/next_rank indices for imbalanced gethalo, expand tests --- heat/core/dndarray.py | 21 ++++++++++++------- heat/core/tests/test_dndarray.py | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8146520735..d421f0dacd 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -428,16 +428,22 @@ def get_halo(self, halo_size): rank = self.comm.rank size = self.comm.size + first_rank = 0 + next_rank = rank + 1 + prev_rank = rank - 1 + last_rank = size - 1 + if not self.balanced: populated_ranks = torch.nonzero(lshape_map[:, 0]).squeeze().tolist() if rank in populated_ranks: - next_rank = populated_ranks.index(rank) + 1 - prev_rank = populated_ranks.index(rank) - 1 + first_rank = populated_ranks[0] last_rank = populated_ranks[-1] - else: - next_rank = rank + 1 - prev_rank = rank - 1 - last_rank = size - 1 + next_rank = rank + 1 + prev_rank = rank - 1 + if rank != last_rank: + next_rank = populated_ranks[populated_ranks.index(rank) + 1] + if rank != first_rank: + prev_rank = populated_ranks[populated_ranks.index(rank) - 1] # if local shape is zero if self.lshape[self.split] == 0: @@ -453,7 +459,6 @@ def get_halo(self, halo_size): a_prev = self.__prephalo(0, halo_size) a_next = self.__prephalo(-halo_size, None) - res_prev = None res_next = None @@ -467,7 +472,7 @@ def get_halo(self, halo_size): ) req_list.append(self.comm.Irecv(res_prev, source=next_rank)) - if rank != 0: + if rank != first_rank: self.comm.Isend(a_prev, prev_rank) res_next = torch.zeros( a_next.size(), dtype=a_next.dtype, device=self.device.torch_device diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index dac7837cd6..ca98b6f425 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -161,6 +161,42 @@ def test_gethalo(self): self.assertTrue(data.halo_next is None) self.assertEqual(data_with_halos.shape, (12, 0)) + # test halo of imbalanced dndarray + if data.comm.size > 2: + t_data = torch.arange( + 5 * data.comm.rank, dtype=torch.float64, device=data.larray.device + ).reshape(data.comm.rank, 5) + if data.comm.rank > 0: + prev_data = torch.arange( + 5 * (data.comm.rank - 1), dtype=torch.float64, device=data.larray.device + ).reshape(data.comm.rank - 1, 5) + if data.comm.rank < data.comm.size - 1: + next_data = torch.arange( + 5 * (data.comm.rank + 1), dtype=torch.float64, device=data.larray.device + ).reshape(data.comm.rank + 1, 5) + data = ht.array(t_data, is_split=0) + data.get_halo(1) + data_with_halos = data.array_with_halos + if data.comm.rank == 0: + prev_halo = None + next_halo = None + new_split_size = 0 + elif data.comm.rank == 1: + prev_halo = None + next_halo = next_data[0] + new_split_size = data.larray.shape[0] + 1 + elif data.comm.rank == data.comm.size - 1: + prev_halo = prev_data[-1] + next_halo = None + new_split_size = data.larray.shape[0] + 1 + else: + prev_halo = prev_data[-1] + next_halo = next_data[0] + new_split_size = data.larray.shape[0] + 2 + self.assertEqual(data_with_halos.shape, (new_split_size, 5)) + self.assertTrue(data.halo_prev is prev_halo or (data.halo_prev == prev_halo).all()) + self.assertTrue(data.halo_next is next_halo or (data.halo_next == next_halo).all()) + def test_larray(self): # undistributed case x = ht.arange(6 * 7 * 8).reshape((6, 7, 8)) From dd9b797d9a1700a2f29e7d5de9501da67b9d0016 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sun, 11 Apr 2021 11:52:29 +0200 Subject: [PATCH 37/87] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c728f5ec8c..8581b8b26e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,7 @@ - [#690](https://github.com/helmholtz-analytics/heat/pull/690) Enhancement: reshape accepts shape arguments with one unknown dimension. - [#706](https://github.com/helmholtz-analytics/heat/pull/706) Bug fix: prevent `__setitem__`, `__getitem__` from modifying key in place - [#744](https://github.com/helmholtz-analytics/heat/pull/744) Fix split semantics for reduction operations +- [#749](https://github.com/helmholtz-analytics/heat/pull/749) Distributed sorted `ht.unique` # v0.5.1 From 50876a57f26c117322bdefef7c356ebcc87d96c8 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 12 Apr 2021 05:19:04 +0200 Subject: [PATCH 38/87] Documentation update --- heat/core/manipulations.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 2eb149ccdc..df8dd052cc 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -1980,7 +1980,26 @@ def shape(a): return a.gshape -def _pivot_sorting(a, axis, sort_op, descending=False, **kwargs): +def __pivot_sorting(a, axis, sort_op, descending=False, **kwargs): + """ + Parallel sorting function for :func:`sort` and :func:`unique`, based on [1]. + + Parameters + ---------- + + a : DNDarray + Distributed input data + axis : int or None + Axis along which the operation will be performed. + sort_op : torch operation + torch.sort or torch.unique + descending : bool + Whether :func:`sort` will return elements sorted in descending order. Default: `False`. + + References + ---------- + [1] Li et al., 1993, "On the versatility of parallel sorting by regular sampling", Parallel Computing, Volume 19, Issue 10, pages 1079-1103 + """ size = a.comm.Get_size() rank = a.comm.Get_rank() transposed = a.larray.transpose(axis, 0) @@ -2289,7 +2308,7 @@ def sort(a, axis=None, descending=False, out=None): final_result, final_indices = torch.sort(a.larray, dim=axis, descending=descending) else: - final_result, final_indices = _pivot_sorting(a, axis, torch.sort, descending=descending) + final_result, final_indices = __pivot_sorting(a, axis, torch.sort, descending=descending) return_indices = factories.array( final_indices, dtype=dndarray.types.int32, is_split=a.split, device=a.device, comm=a.comm @@ -2864,7 +2883,9 @@ def unique(a, return_inverse=False, axis=None): `unique` will be distributed along 0, if `axis` is specified, or along `a.split`, if `axis` is None. - WARNING: `inverse_indices` will always be distributed like the original data + Warnings + -------- + `inverse_indices` will always be distributed like the original data (if `axis is None`) or along 0, and contains the GLOBAL indices to recreate the LOCAL portion of `a`. Before reconstructing an array based on `unique[inverse_indices]`, make sure that `unique` is local (with `unique.resplit_(axis=None)`, see `ht.resplit`). @@ -2944,7 +2965,7 @@ def unique(a, return_inverse=False, axis=None): # balance gres if needed gres.balance_() # global sorted unique - lres = _pivot_sorting(gres, 0, torch.unique, sorted=True, return_inverse=True) + lres = __pivot_sorting(gres, 0, torch.unique, sorted=True, return_inverse=True) # second local unique if 0 not in lres.shape: lres = torch.unique(lres, sorted=True, dim=unique_axis) From 906c80417a09bd8a9de1981f93dd9cd56e0b7dc0 Mon Sep 17 00:00:00 2001 From: ClaudiaComito Date: Mon, 3 May 2021 05:58:38 +0200 Subject: [PATCH 39/87] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9618eafc63..9d9b66b5eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# v1.0.1 + +## Bug fixes + # v1.0.0 ## New features / Highlights From 1ed470827a3ce25b28e5f586b5b5c1c0668065a7 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 23 Aug 2021 18:28:52 +0200 Subject: [PATCH 40/87] Replace explicit `counts, displs` calculation with `dndarray.counts_displs()` --- heat/core/manipulations.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 90ec668f64..d88ea71f7f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2285,18 +2285,9 @@ def __pivot_sorting( actual_indices = local_indices.to(dtype=local_sorted.dtype) + disp[rank] elif sort_op is torch.unique: local_sorted = sort_op(transposed, dim=0, **kwargs)[0] - local_shape = local_sorted.shape - if 0 in local_shape: - local_shape = transposed.shape - lshape_map = torch.empty( - (size, transposed.ndim), dtype=torch.int64, device=transposed.device - ) - a.comm.Allgather(torch.tensor(local_shape), lshape_map) - counts = lshape_map[:, 0] - displs = torch.cumsum( - torch.cat((torch.tensor([0], device=counts.device), counts[:-1])), dim=0 - ) - counts, displs = tuple(counts.tolist()), tuple(displs.tolist()) + local_sorted = factories.array(local_sorted, is_split=0, device=a.device) + counts, _ = local_sorted.counts_displs() + local_sorted = local_sorted.larray unique_along_axis = True if sort_op is torch.unique and axis is not None else False From 723b37a20b73df2f9613ecb454cabc9044199f15 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 10 Sep 2021 06:21:18 +0200 Subject: [PATCH 41/87] Address review, part I --- heat/core/manipulations.py | 47 +++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index d88ea71f7f..b98db47870 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2291,20 +2291,22 @@ def __pivot_sorting( unique_along_axis = True if sort_op is torch.unique and axis is not None else False - length = local_sorted.size()[0] + length = local_sorted.shape[0] # Separate the sorted tensor into size + 1 equal length partitions partitions = [x * length // (size + 1) for x in range(1, size + 1)] local_pivots = ( local_sorted[partitions] if counts[rank] - else torch.empty((0,) + local_sorted.size()[1:], dtype=local_sorted.dtype) + else torch.empty( + (0,) + local_sorted.shape[1:], dtype=local_sorted.dtype, device=local_sorted.device + ) ) # Only processes with elements should share their pivots gather_counts = [int(x > 0) * size for x in counts] - gather_displs = (0,) + tuple(np.cumsum(gather_counts[:-1])) - pivot_dim = list(transposed.size()) + gather_displs = (0,) + tuple(torch.cumsum(torch.tensor(gather_counts[:-1]), dim=0).tolist()) + pivot_dim = list(transposed.shape) pivot_dim[0] = size * sum([1 for x in counts if x > 0]) # share the local pivots with root process @@ -2320,7 +2322,7 @@ def __pivot_sorting( sorted_pivots, _ = sort_op(pivot_buffer, dim=0, descending=descending) else: sorted_pivots = sort_op(pivot_buffer, dim=0, **kwargs)[0] - length = sorted_pivots.size()[0] + length = sorted_pivots.shape[0] global_partitions = [x * length // size for x in range(1, size)] global_pivots = sorted_pivots[global_partitions] @@ -2328,14 +2330,14 @@ def __pivot_sorting( # special case: unique along axis if unique_along_axis: # find position of global pivots in local sorted uniques - local_sorted, local_inv = torch.cat((local_sorted, global_pivots), dim=0).unique( + local_sorted, local_inverse_ind = torch.cat((local_sorted, global_pivots), dim=0).unique( dim=0, sorted=kwargs.get("sorted") if kwargs.get("sorted") else True, return_inverse=True, ) # Use the inverse indices of the global pivots to work out the local partition slices local_slices = torch.zeros(size + 1, dtype=torch.int64, device=local_sorted.device) - local_slices[1:-1] = local_inv[-global_pivots.shape[0] :] + 1 + local_slices[1:-1] = local_inverse_ind[-global_pivots.shape[0] :] + 1 local_slices[-1] = torch.tensor([local_sorted.shape[0]]) # how many rows will be sent and received where send_matrix = torch.tensor( @@ -2373,7 +2375,7 @@ def __pivot_sorting( index_matrix = torch.empty_like(local_sorted, dtype=torch.int64) # Matrix holding information which process get how many values from where - shape = (size,) + transposed.size()[1:] + shape = (size,) + transposed.shape[1:] send_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) recv_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) @@ -2386,9 +2388,9 @@ def __pivot_sorting( scounts = local_partitions rcounts = recv_matrix - shape = (partition_matrix[rank].max(),) + transposed.size()[1:] + shape = (partition_matrix[rank].max(),) + transposed.shape[1:] - first_result = torch.empty(shape, dtype=local_sorted.dtype) + first_result = torch.empty(shape, dtype=local_sorted.dtype, device=local_sorted.device) if sort_op is torch.sort: first_indices = torch.empty_like(first_result) @@ -2404,14 +2406,18 @@ def __pivot_sorting( else: idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] - send_count = scounts[idx_slice].reshape(-1).tolist() - send_disp = [0] + list(np.cumsum(send_count[:-1])) + send_count = scounts[idx_slice].reshape(-1) + send_disp = [0] + torch.cumsum(send_count[:-1], dim=0).tolist() + send_count = send_count.tolist() s_val = local_sorted[idx_slice].clone() - recv_count = rcounts[idx_slice].reshape(-1).tolist() - recv_disp = [0] + list(np.cumsum(recv_count[:-1])) + recv_count = rcounts[idx_slice].reshape(-1) + recv_disp = [0] + torch.cumsum(recv_count[:-1], dim=0).tolist() + recv_count = recv_count.tolist() rcv_length = rcounts[idx_slice].sum().item() - r_val = torch.empty((rcv_length,) + s_val.shape[1:], dtype=local_sorted.dtype) + r_val = torch.empty( + (rcv_length,) + s_val.shape[1:], dtype=local_sorted.dtype, device=local_sorted.device + ) a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) first_result[idx_slice][:rcv_length] = r_val @@ -2426,12 +2432,15 @@ def __pivot_sorting( return first_result # The process might not have the correct number of values therefore the tensors need to be rebalanced - send_vec = torch.zeros(local_sorted.shape[1:] + (size, size), dtype=torch.int64) - target_cumsum = np.cumsum(counts) + send_vec = torch.zeros( + local_sorted.shape[1:] + (size, size), dtype=torch.int64, device=local_sorted.device + ) + target_cumsum = torch.cumsum(torch.tensor(counts), dim=0) for idx in np.ndindex(local_sorted.shape[1:]): idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] - current_counts = partition_matrix[idx_slice].reshape(-1).tolist() - current_cumsum = list(np.cumsum(current_counts)) + current_counts = partition_matrix[idx_slice].reshape(-1) + current_cumsum = torch.cumsum(current_counts, dim=0).tolist() + current_counts = current_counts.tolist() for proc in range(size): # process has to many values which will be sent to higher ranks if current_cumsum[proc] > target_cumsum[proc]: From f5e360c9d5441e355c180380b55a5b7c3a75f4f0 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 10 Sep 2021 10:00:59 +0200 Subject: [PATCH 42/87] Address review Part II of II --- heat/core/manipulations.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index b98db47870..9b8f59c81d 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2483,7 +2483,7 @@ def __pivot_sorting( # process doesn't need more values send_vec[idx][proc][proc] = partition_matrix[proc][idx] - send_vec[idx][proc].sum() current_counts[proc] = counts[proc] - current_cumsum = list(np.cumsum(current_counts)) + current_cumsum = torch.cumsum(torch.tensor(current_counts), dim=0).tolist() # Iterate through one layer again to create the final balanced local tensors second_result = torch.empty_like(local_sorted) @@ -2499,7 +2499,9 @@ def __pivot_sorting( end = partition_matrix[rank][idx] s_val, indices = first_result[0:end][idx_slice].sort(descending=descending, dim=0) - r_val = torch.empty((counts[rank],) + s_val.shape[1:], dtype=local_sorted.dtype) + r_val = torch.empty( + (counts[rank],) + s_val.shape[1:], dtype=local_sorted.dtype, device=local_sorted.device + ) a.comm.Alltoallv((s_val, send_count, send_disp), (r_val, recv_count, recv_disp)) second_result[idx_slice] = r_val @@ -3270,8 +3272,7 @@ def unique( gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally - _, data_max_lshape, _ = a.comm.chunk(a.gshape, a.split, rank=0) - data_max_lbytes = torch.prod(torch.tensor(data_max_lshape)) * a.larray.element_size() + data_max_lbytes = torch.prod(a.lshape_map[0]) * a.larray.element_size() if gres.nbytes <= data_max_lbytes: # gather local uniques gres.resplit_(None) @@ -3323,7 +3324,7 @@ def unique( ) # loop through unique elements, find matching position in data for i, el in enumerate(lres): - counts = torch.zeros_like(local_data, dtype=torch.int8, device=local_data.device) + counts = torch.zeros_like(local_data, dtype=torch.int32, device=local_data.device) counts[torch.where(local_data == el)] = 1 if lres.ndim > 1: counts = torch.sum(counts, dim=tuple(range(lres.ndim))[1:]) @@ -3347,7 +3348,7 @@ def unique( gres = linalg.basics.transpose(gres, (axis, 0)) if return_inverse: - return (gres, global_inverse) + return gres, global_inverse return gres From 2753bfe3b71a4d376e22cca9ba3f8e4d7888b919 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 10 Sep 2021 10:57:49 +0200 Subject: [PATCH 43/87] Reshape empty `local_sorted` to match global shape --- heat/core/manipulations.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 9b8f59c81d..26aa514262 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2285,6 +2285,10 @@ def __pivot_sorting( actual_indices = local_indices.to(dtype=local_sorted.dtype) + disp[rank] elif sort_op is torch.unique: local_sorted = sort_op(transposed, dim=0, **kwargs)[0] + if 0 in local_sorted.shape: + local_shape = list(transposed.shape) + local_shape[0] = 0 + local_sorted = local_sorted.reshape(local_shape) local_sorted = factories.array(local_sorted, is_split=0, device=a.device) counts, _ = local_sorted.counts_displs() local_sorted = local_sorted.larray From bfa58ba76da21c99b53dae608277746eabfe3da3 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 12:15:04 +0200 Subject: [PATCH 44/87] GPU Debugging --- heat/core/manipulations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 26aa514262..ae1d3ce4b9 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2289,9 +2289,9 @@ def __pivot_sorting( local_shape = list(transposed.shape) local_shape[0] = 0 local_sorted = local_sorted.reshape(local_shape) - local_sorted = factories.array(local_sorted, is_split=0, device=a.device) - counts, _ = local_sorted.counts_displs() - local_sorted = local_sorted.larray + g_local_sorted = factories.array(local_sorted, is_split=0, device=a.device, copy=False) + counts, _ = g_local_sorted.counts_displs() + local_sorted = g_local_sorted.larray unique_along_axis = True if sort_op is torch.unique and axis is not None else False From 68d437d36640e38c7ed430b885294fc674776567 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 12:58:16 +0200 Subject: [PATCH 45/87] Debug devices --- heat/core/manipulations.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ae1d3ce4b9..8ac423e7ab 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2355,6 +2355,12 @@ def __pivot_sorting( scounts = send_matrix rcounts = recv_matrix + print( + "DEBUGGING: DEVICES: ", + recv_matrix.sum(dim=0).device, + local_sorted.device, + local_sorted.shape[1:].device, + ) shape = (recv_matrix.sum(dim=0),) + local_sorted.shape[1:] else: lt_partitions = torch.empty((size,) + local_sorted.shape, dtype=torch.int64) From 1cfd565434fd15323096a3f0d40cf0259db0666e Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 15:25:55 +0200 Subject: [PATCH 46/87] Debug tensor devices --- heat/core/manipulations.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 8ac423e7ab..27091fc6c3 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2355,12 +2355,6 @@ def __pivot_sorting( scounts = send_matrix rcounts = recv_matrix - print( - "DEBUGGING: DEVICES: ", - recv_matrix.sum(dim=0).device, - local_sorted.device, - local_sorted.shape[1:].device, - ) shape = (recv_matrix.sum(dim=0),) + local_sorted.shape[1:] else: lt_partitions = torch.empty((size,) + local_sorted.shape, dtype=torch.int64) @@ -2482,8 +2476,10 @@ def __pivot_sorting( ) ) last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) + print("DEBUGGING: devices: partition_matrix", partition_matrix.device) for i, x in enumerate(partition_matrix[idx_slice][first:last]): # Taking as many elements as possible from each following process + print("DEBUGGING: devices: x, send_vec ", x.device, send_vec.device) send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum()) current_counts[first + i] = 0 # Taking just enough elements from the last element to fill the current processes tensor From 1f6cec2817fe2db48ff4ae274ea7d39c59520d59 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 15:48:36 +0200 Subject: [PATCH 47/87] Debug tensor devices --- heat/core/manipulations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 27091fc6c3..8e53a4b6b2 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2292,6 +2292,7 @@ def __pivot_sorting( g_local_sorted = factories.array(local_sorted, is_split=0, device=a.device, copy=False) counts, _ = g_local_sorted.counts_displs() local_sorted = g_local_sorted.larray + print("DEBUGGING: DEVICES: a, local_sorted = ", a.device, local_sorted.device) unique_along_axis = True if sort_op is torch.unique and axis is not None else False From f923834f7b7d96dbbc192ea424c942bb89e0c25d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 16:19:12 +0200 Subject: [PATCH 48/87] Specify device for torch `_like` factories --- heat/core/manipulations.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 8e53a4b6b2..dfc3769f4d 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2377,7 +2377,7 @@ def __pivot_sorting( a.comm.Allreduce(local_partitions, partition_matrix, op=MPI.SUM) # Matrix that holds information which value will be shipped where - index_matrix = torch.empty_like(local_sorted, dtype=torch.int64) + index_matrix = torch.empty_like(local_sorted, dtype=torch.int64, device=local_sorted.device) # Matrix holding information which process get how many values from where shape = (size,) + transposed.shape[1:] @@ -2397,7 +2397,7 @@ def __pivot_sorting( first_result = torch.empty(shape, dtype=local_sorted.dtype, device=local_sorted.device) if sort_op is torch.sort: - first_indices = torch.empty_like(first_result) + first_indices = torch.empty_like(first_result, device=first_result.device) # Iterate through one layer and send values with alltoallv if unique_along_axis: @@ -2428,7 +2428,7 @@ def __pivot_sorting( if sort_op is torch.sort: s_ind = actual_indices[idx_slice].clone().to(dtype=local_sorted.dtype) - r_ind = torch.empty_like(r_val) + r_ind = torch.empty_like(r_val, device=r_val.device) a.comm.Alltoallv((s_ind, send_count, send_disp), (r_ind, recv_count, recv_disp)) first_indices[idx_slice][:rcv_length] = r_ind @@ -2493,8 +2493,8 @@ def __pivot_sorting( current_cumsum = torch.cumsum(torch.tensor(current_counts), dim=0).tolist() # Iterate through one layer again to create the final balanced local tensors - second_result = torch.empty_like(local_sorted) - second_indices = torch.empty_like(second_result) + second_result = torch.empty_like(local_sorted, device=local_sorted.device) + second_indices = torch.empty_like(second_result, device=second_result.device) for idx in np.ndindex(local_sorted.shape[1:]): idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] @@ -2519,7 +2519,7 @@ def __pivot_sorting( second_result, tmp_indices = sort_op(second_result, dim=0, descending=descending) final_result = second_result.transpose(0, axis) - final_indices = torch.empty_like(second_indices) + final_indices = torch.empty_like(second_indices, device=second_indices.device) # Update the indices in case the ordering changed during the last sort for idx in np.ndindex(tmp_indices.shape): val = tmp_indices[idx] From 34667c85cc040f699228e4ba7849a944d2c56c78 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 16:36:12 +0200 Subject: [PATCH 49/87] Specify device for all torch _like factories --- heat/core/manipulations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index dfc3769f4d..6c219f1f2b 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2369,11 +2369,13 @@ def __pivot_sorting( else: lt_partitions[idx] = lt last = lt - lt_partitions[size - 1] = torch.ones_like(local_sorted, dtype=last.dtype) - last + lt_partitions[size - 1] = ( + torch.ones_like(local_sorted, dtype=last.dtype, device=local_sorted.device) - last + ) # Matrix holding information how many values will be sent where local_partitions = torch.sum(lt_partitions, dim=1) - partition_matrix = torch.empty_like(local_partitions) + partition_matrix = torch.empty_like(local_partitions, device=local_partitions.device) a.comm.Allreduce(local_partitions, partition_matrix, op=MPI.SUM) # Matrix that holds information which value will be shipped where From e1c28c545e4dd739808acac86a658384d7f9a5ab Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 16:39:02 +0200 Subject: [PATCH 50/87] Devices --- heat/core/manipulations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 6c219f1f2b..45f9423ed3 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2358,8 +2358,10 @@ def __pivot_sorting( rcounts = recv_matrix shape = (recv_matrix.sum(dim=0),) + local_sorted.shape[1:] else: - lt_partitions = torch.empty((size,) + local_sorted.shape, dtype=torch.int64) - last = torch.zeros_like(local_sorted, dtype=torch.int64) + lt_partitions = torch.empty( + (size,) + local_sorted.shape, dtype=torch.int64, device=local_sorted.device + ) + last = torch.zeros_like(local_sorted, dtype=torch.int64, device=local_sorted.device) comp_op = torch.gt if descending else torch.lt # Iterate over all pivots and store which pivot is the first greater than the elements value for idx, p in enumerate(global_pivots): From 1ca6cebbc42314d1ce6c9a2b2d590d27e2ed4f9b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 14 Sep 2021 17:01:47 +0200 Subject: [PATCH 51/87] Add more missing devices --- heat/core/manipulations.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 45f9423ed3..c5f9e720bf 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2385,8 +2385,12 @@ def __pivot_sorting( # Matrix holding information which process get how many values from where shape = (size,) + transposed.shape[1:] - send_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) - recv_matrix = torch.zeros(shape, dtype=partition_matrix.dtype) + send_matrix = torch.zeros( + shape, dtype=partition_matrix.dtype, device=partition_matrix.device + ) + recv_matrix = torch.zeros( + shape, dtype=partition_matrix.dtype, device=partition_matrix.device + ) for i, x in enumerate(lt_partitions): index_matrix[x > 0] = i From 209d5d3cc6b517e8fc0af6587bc2661a538b36ab Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 15 Sep 2021 06:02:52 +0200 Subject: [PATCH 52/87] Replace np.cumsum calls with torch.cumsum --- heat/core/manipulations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index c5f9e720bf..2a2a358a39 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2507,10 +2507,10 @@ def __pivot_sorting( idx_slice = [slice(None)] + [slice(ind, ind + 1) for ind in idx] send_count = send_vec[idx][rank] - send_disp = [0] + list(np.cumsum(send_count[:-1])) + send_disp = [0] + torch.cumsum(send_count[:-1], dim=0).tolist() recv_count = send_vec[idx][:, rank] - recv_disp = [0] + list(np.cumsum(recv_count[:-1])) + recv_disp = [0] + torch.cumsum(recv_count[:-1], dim=0).tolist() end = partition_matrix[rank][idx] s_val, indices = first_result[0:end][idx_slice].sort(descending=descending, dim=0) From 01281ccf2250b3f953ffb1403678d64ff92bed87 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 15 Sep 2021 06:26:38 +0200 Subject: [PATCH 53/87] Debug test_sort on GPU --- heat/core/tests/test_manipulations.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 72bb6ad36f..8b70136d19 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2640,6 +2640,11 @@ def test_sort(self): exp_indices = torch.tensor([[rank] * size], device=self.device.torch_device) result, result_indices = ht.sort(data, descending=True, axis=0) self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + print( + "DEBUGGING: result_indices.larray, exp_indices = ", + result_indices.larray, + exp_indices.int(), + ) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) exp_axis_one, exp_indices = ( From badd041d3e7c1cc92f8ec1c62888e50ff42d426f Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 15 Sep 2021 06:27:25 +0200 Subject: [PATCH 54/87] Remove print (debugging) statements --- heat/core/manipulations.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 2a2a358a39..def34445b1 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2292,7 +2292,6 @@ def __pivot_sorting( g_local_sorted = factories.array(local_sorted, is_split=0, device=a.device, copy=False) counts, _ = g_local_sorted.counts_displs() local_sorted = g_local_sorted.larray - print("DEBUGGING: DEVICES: a, local_sorted = ", a.device, local_sorted.device) unique_along_axis = True if sort_op is torch.unique and axis is not None else False @@ -2485,10 +2484,8 @@ def __pivot_sorting( ) ) last = next(i for i, x in enumerate(current_cumsum) if target_cumsum[proc] <= x) - print("DEBUGGING: devices: partition_matrix", partition_matrix.device) for i, x in enumerate(partition_matrix[idx_slice][first:last]): # Taking as many elements as possible from each following process - print("DEBUGGING: devices: x, send_vec ", x.device, send_vec.device) send_vec[idx][first + i][proc] = int(x - send_vec[idx][first + i].sum()) current_counts[first + i] = 0 # Taking just enough elements from the last element to fill the current processes tensor From 369203e90edad598e31a972b3b8a6e708a96ca21 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 29 Sep 2021 11:36:22 +0200 Subject: [PATCH 55/87] Set up memory profiling --- heat/core/manipulations.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index def34445b1..281bbb43bc 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -6,6 +6,7 @@ import numpy as np import torch import warnings +from memory_profiler import profile from typing import Iterable, Type, List, Callable, Union, Tuple, Sequence, Optional @@ -22,6 +23,7 @@ from . import types from . import _operations + __all__ = [ "balance", "column_stack", @@ -2255,6 +2257,7 @@ def shape(a: DNDarray) -> Tuple[int, ...]: return a.gshape +@profile def __pivot_sorting( a: DNDarray, sort_op: Callable, axis: Optional[int] = None, descending: bool = False, **kwargs ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -2533,6 +2536,7 @@ def __pivot_sorting( return final_result, final_indices +@profile def sort( a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DNDarray] = None ) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: @@ -3170,6 +3174,7 @@ def swapaxes(x: DNDarray, axis1: int, axis2: int) -> DNDarray: DNDarray.swapaxes.__doc__ = swapaxes.__doc__ +@profile def unique( a: DNDarray, return_inverse: bool = False, axis: Optional[int] = None ) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: From b46924d3549cd485184b4dc0f419714157f9c084 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 17 Nov 2021 11:24:05 +0100 Subject: [PATCH 56/87] Improve efficiency, adopt `dndarray.lshape_map` and `dndarray.counts_displs()` where possible --- heat/core/manipulations.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 1a0f490fa4..7599ef5241 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2262,7 +2262,7 @@ def shape(a: DNDarray) -> Tuple[int, ...]: return a.gshape -@profile +# @profile def __pivot_sorting( a: DNDarray, sort_op: Callable, axis: Optional[int] = None, descending: bool = False, **kwargs ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -2541,7 +2541,7 @@ def __pivot_sorting( return final_result, final_indices -@profile +# @profile def sort( a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DNDarray] = None ) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: @@ -3179,7 +3179,7 @@ def swapaxes(x: DNDarray, axis1: int, axis2: int) -> DNDarray: DNDarray.swapaxes.__doc__ = swapaxes.__doc__ -@profile +# @profile def unique( a: DNDarray, return_inverse: bool = False, axis: Optional[int] = None ) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: @@ -3326,13 +3326,11 @@ def unique( global_inverse = factories.array(inverse, is_split=inv_split, device=gres.device) unique_ranks = size if gres.is_distributed() else 1 + gres_map = gres.lshape_map if unique_ranks > 1: - gres_map = gres.create_lshape_map() - gres_offsets = torch.cat( - (torch.tensor([0], device=gres_map.device), gres_map[:-1, gres.split]) - ).cumsum(dim=0) + _, gres_offsets = gres.counts_displs() + gres_offsets = torch.tensor(gres_offsets, device=gres_map.device) else: - gres_map = torch.tensor(gres.gshape, device=inverse.device) gres_offsets = torch.tensor([0], device=gres_map.device) lres = gres.larray for p in range(unique_ranks): From e5a713c21190916f6a841c6b01c17162d7bc10cd Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 17 Nov 2021 11:46:12 +0100 Subject: [PATCH 57/87] Comment out memory_profiler import --- heat/core/manipulations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7599ef5241..1d63bc0efd 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -6,7 +6,8 @@ import numpy as np import torch import warnings -from memory_profiler import profile + +# from memory_profiler import profile from typing import Iterable, Type, List, Callable, Union, Tuple, Sequence, Optional From 690f9e091f6e3dfb460cf809c10cba8c3d7636c0 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 11:08:47 +0100 Subject: [PATCH 58/87] Remove redundant split sanitation from self.comm.chunk --- heat/core/communication.py | 15 +++---- heat/core/manipulations.py | 56 +++++++++++++++------------ heat/core/tests/test_communication.py | 9 ----- 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/heat/core/communication.py b/heat/core/communication.py index abc29cf564..c446ae958f 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -7,6 +7,7 @@ import os import subprocess import torch +import tracemalloc from mpi4py import MPI from typing import Any, Callable, Optional, List, Tuple, Union @@ -170,21 +171,21 @@ def chunk( Parameters ---------- shape : Tuple[int,...] - The global shape of the data to be split + The global shape of the data to be split. split : int - The axis along which to chunk the data + The axis along which to chunk the data. Must be within the range of ``shape``. rank : int, optional Process for which the chunking is calculated for, defaults to ``self.rank``. - Intended for creating chunk maps without communication + Intended for creating chunk maps without communication. w_size : int, optional The MPI world size, defaults to ``self.size``. - Intended for creating chunk maps without communication - + Intended for creating chunk maps without communication. """ - # ensure the split axis is valid, we actually do not need it - split = sanitize_axis(shape, split) if split is None: return 0, shape, tuple(slice(0, end) for end in shape) + if split < 0: + split = len(shape) + split + rank = self.rank if rank is None else rank w_size = self.size if w_size is None else w_size if not isinstance(rank, int) or not isinstance(w_size, int): diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 1d63bc0efd..ec2d256ebb 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -7,8 +7,6 @@ import torch import warnings -# from memory_profiler import profile - from typing import Iterable, Type, List, Callable, Union, Tuple, Sequence, Optional from .communication import MPI @@ -3252,12 +3250,18 @@ def unique( if isinstance(torch_output, tuple): heat_output = tuple( factories.array( - i, dtype=types.canonical_heat_type(i.dtype), split=None, device=a.device + i, + dtype=types.canonical_heat_type(i.dtype), + split=None, + device=a.device, + copy=False, ) for i in torch_output ) else: - heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device) + heat_output = factories.array( + torch_output, dtype=a.dtype, split=None, device=a.device, copy=False + ) return heat_output rank = a.comm.rank @@ -3292,28 +3296,28 @@ def unique( lres = torch.empty(res_shape, dtype=a.dtype.torch_type()) else: lres = torch.unique(local_data, sorted=True, return_inverse=False, dim=unique_axis) - gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device) + gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device, copy=False) # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally - data_max_lbytes = torch.prod(a.lshape_map[0]) * a.larray.element_size() - if gres.nbytes <= data_max_lbytes: - # gather local uniques - gres.resplit_(None) - # final round of torch.unique - lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) - lres_split = None - gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) - else: - # balance gres if needed - gres.balance_() - # global sorted unique - lres = __pivot_sorting(gres, torch.unique, 0, sorted=True, return_inverse=True) - # second local unique - if 0 not in lres.shape: - lres = torch.unique(lres, sorted=True, dim=unique_axis) - lres_split = 0 - - gres = factories.array(lres, dtype=a.dtype, is_split=lres_split, device=a.device) + # data_max_lbytes = torch.prod(a.lshape_map[0]) * a.larray.element_size() + # if gres.nbytes <= data_max_lbytes: + # # gather local uniques + # gres.resplit_(None) + # # final round of torch.unique + # lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) + # lres_split = None + # gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) + # else: + # balance gres if needed + gres.balance_() + # global sorted unique + lres = __pivot_sorting(gres, torch.unique, 0, sorted=True, return_inverse=True) + # second local unique + if 0 not in lres.shape: + lres = torch.unique(lres, sorted=True, dim=unique_axis) + lres_split = 0 + + gres = factories.array(lres, dtype=a.dtype, is_split=lres_split, device=a.device, copy=False) gres.balance_() if return_inverse: @@ -3324,7 +3328,9 @@ def unique( inv_split = 0 if inverse.ndim == 1 else a.split else: inv_split = None - global_inverse = factories.array(inverse, is_split=inv_split, device=gres.device) + global_inverse = factories.array( + inverse, is_split=inv_split, device=gres.device, copy=False + ) unique_ranks = size if gres.is_distributed() else 1 gres_map = gres.lshape_map diff --git a/heat/core/tests/test_communication.py b/heat/core/tests/test_communication.py index 1410eaf9cc..504fd4161a 100644 --- a/heat/core/tests/test_communication.py +++ b/heat/core/tests/test_communication.py @@ -23,10 +23,6 @@ def setUpClass(cls): def test_self_communicator(self): comm = ht.core.communication.MPI_SELF - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=2) - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=-3) with self.assertRaises(TypeError): comm.chunk(self.data.shape, split=0, rank="dicndjh") @@ -47,11 +43,6 @@ def test_mpi_communicator(self): comm = ht.core.communication.MPI_WORLD self.assertLess(comm.rank, comm.size) - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=2) - with self.assertRaises(ValueError): - comm.chunk(self.data.shape, split=-3) - offset, lshape, chunks = comm.chunk(self.data.shape, split=0) self.assertIsInstance(offset, int) self.assertGreaterEqual(offset, 0) From dc2672fecf694675db1d383a89ea5122faaa4afb Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 11:11:44 +0100 Subject: [PATCH 59/87] Do not clone obj[slices] if copy is False. Remove unnecessary split sanitation. --- heat/core/factories.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/heat/core/factories.py b/heat/core/factories.py index 50a55a0613..0c51c3cf7f 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -17,7 +17,6 @@ from . import devices from . import types - __all__ = [ "arange", "array", @@ -320,6 +319,7 @@ def array( if device is not None else devices.get_device().torch_device, ) + except RuntimeError: raise TypeError("invalid data of type {}".format(type(obj))) else: @@ -360,11 +360,13 @@ def array( if ndmin_abs > 0 > ndmin: obj = obj.reshape(ndmin_abs * (1,) + obj.shape) - # sanitize the split axes, ensure mutual exclusiveness - split = sanitize_axis(obj.shape, split) - is_split = sanitize_axis(obj.shape, is_split) - if split is not None and is_split is not None: - raise ValueError("split and is_split are mutually exclusive parameters") + # sanitize split or is_split + if split is not None: + if is_split is not None: + raise ValueError("cannot specify both split and is_split") + split = sanitize_axis(obj.shape, split) + elif is_split is not None: + is_split = sanitize_axis(obj.shape, is_split) # sanitize comm object comm = sanitize_comm(comm) @@ -377,8 +379,12 @@ def array( # content shall be split, chunk the passed data object up if split is not None: _, _, slices = comm.chunk(gshape, split) - obj = obj[slices].clone() + if not copy: + obj = obj[slices] + else: + obj = obj[slices].clone() obj = sanitize_memory_layout(obj, order=order) + # check with the neighboring rank whether the local shape would fit into a global shape elif is_split is not None: gshape = np.array(gshape) From 49efd0a7c196ba844dd3a8d7ca43a4f0efd3b164 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 12:16:41 +0100 Subject: [PATCH 60/87] Improve memory usage for sanitize_memory_layout --- heat/core/memory.py | 46 ++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/heat/core/memory.py b/heat/core/memory.py index 99554c88ed..0a2ebca9ce 100644 --- a/heat/core/memory.py +++ b/heat/core/memory.py @@ -3,6 +3,7 @@ """ import torch +import tracemalloc from . import sanitation from .dndarray import DNDarray @@ -58,30 +59,33 @@ def sanitize_memory_layout(x: torch.Tensor, order: str = "C") -> torch.Tensor: if x.ndim < 2 or x.numel() == 0: # do nothing return x - dims = list(range(x.ndim)) stride = torch.tensor(x.stride()) # since strides can get a bit wonky with operations like transpose # we should assume that the tensors are row major or are distributed the default way - sdiff = stride[1:] - stride[:-1] - column_major = all(sdiff >= 0) - row_major = True if not column_major else False - if (order == "C" and row_major) or (order == "F" and column_major): + column_major = (stride[1:] - stride[:-1] >= 0).all() + if (order == "C" and not column_major) or (order == "F" and column_major): # do nothing return x - elif (order == "C" and column_major) or (order == "F" and row_major): - dims = tuple(reversed(dims)) - y = torch.empty_like(x) - permutation = x.permute(dims).contiguous() - y = y.set_( - permutation.storage(), - x.storage_offset(), - x.shape, - tuple(reversed(permutation.stride())), - ) - return y - else: - raise ValueError( - "combination of order and layout not permitted, order: {} column major: {} row major: {}".format( - order, column_major, row_major - ) + if (order == "C" and column_major) or (order == "F" and not column_major): + dims = tuple(range(x.ndim - 1, -1, -1)) + storage_offset = x.storage_offset() + shape = x.shape + x = x.permute(dims).contiguous() + reversed_stride = tuple(reversed(x.stride())) + x.set_(x.storage(), storage_offset, shape, reversed_stride) + return x + # y = torch.empty_like(x) + # permutation = x.permute(dims).contiguous() + # y = y.set_( + # permutation.storage(), + # x.storage_offset(), + # x.shape, + # tuple(reversed(permutation.stride())), + # ) + # return y + + raise ValueError( + "combination of order and layout not permitted, order: {} column major: {} row major: {}".format( + order, column_major, not column_major ) + ) From 1bf8d66f847496f1a91fba47d800cdabf54c5482 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 12:21:18 +0100 Subject: [PATCH 61/87] Reorganize sanitation logic --- heat/core/stride_tricks.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/heat/core/stride_tricks.py b/heat/core/stride_tricks.py index f07af2418a..7e3c5db571 100644 --- a/heat/core/stride_tricks.py +++ b/heat/core/stride_tricks.py @@ -97,14 +97,14 @@ def sanitize_axis( axis = None if axis is not None: - if not isinstance(axis, int) and not isinstance(axis, tuple): + if isinstance(axis, tuple): + axis = tuple(dim + len(shape) if dim < 0 else dim for dim in axis) + for dim in axis: + if dim < 0 or dim >= len(shape): + raise ValueError("axis {} is out of bounds for shape {}".format(axis, shape)) + return axis + if not isinstance(axis, int): raise TypeError("axis must be None or int or tuple, but was {}".format(type(axis))) - if isinstance(axis, tuple): - axis = tuple(dim + len(shape) if dim < 0 else dim for dim in axis) - for dim in axis: - if dim < 0 or dim >= len(shape): - raise ValueError("axis {} is out of bounds for shape {}".format(axis, shape)) - return axis if axis is None or 0 <= axis < len(shape): return axis @@ -113,7 +113,6 @@ def sanitize_axis( if axis < 0 or axis >= len(shape): raise ValueError("axis {} is out of bounds for shape {}".format(axis, shape)) - return axis From 3a91bca1e84d18829a6d50a5f38e4a934684a92a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 13:05:38 +0100 Subject: [PATCH 62/87] Remove "sparse unique" implementation --- heat/core/manipulations.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ec2d256ebb..c842693963 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3297,19 +3297,8 @@ def unique( else: lres = torch.unique(local_data, sorted=True, return_inverse=False, dim=unique_axis) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device, copy=False) - - # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally - # data_max_lbytes = torch.prod(a.lshape_map[0]) * a.larray.element_size() - # if gres.nbytes <= data_max_lbytes: - # # gather local uniques - # gres.resplit_(None) - # # final round of torch.unique - # lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) - # lres_split = None - # gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device) - # else: - # balance gres if needed gres.balance_() + # global sorted unique lres = __pivot_sorting(gres, torch.unique, 0, sorted=True, return_inverse=True) # second local unique From 10354b7befb24209c562aef43ed2f4ec04f41a32 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 13:11:35 +0100 Subject: [PATCH 63/87] Specify factories.array(copy=False) --- heat/core/manipulations.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index c842693963..be0b57ccde 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2261,7 +2261,6 @@ def shape(a: DNDarray) -> Tuple[int, ...]: return a.gshape -# @profile def __pivot_sorting( a: DNDarray, sort_op: Callable, axis: Optional[int] = None, descending: bool = False, **kwargs ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -2540,7 +2539,6 @@ def __pivot_sorting( return final_result, final_indices -# @profile def sort( a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DNDarray] = None ) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: @@ -2601,14 +2599,14 @@ def sort( final_result, final_indices = __pivot_sorting(a, torch.sort, axis, descending=descending) return_indices = factories.array( - final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm + final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm, copy=False ) if out is not None: out.larray = final_result return return_indices else: tensor = factories.array( - final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm + final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm, copy=False ) return tensor, return_indices @@ -3178,7 +3176,6 @@ def swapaxes(x: DNDarray, axis1: int, axis2: int) -> DNDarray: DNDarray.swapaxes.__doc__ = swapaxes.__doc__ -# @profile def unique( a: DNDarray, return_inverse: bool = False, axis: Optional[int] = None ) -> Union[DNDarray, Tuple[DNDarray, DNDarray]]: From 0ea2c3a55f76f8d9ff8a6dfb9cdd954e1befb4c2 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 13:24:29 +0100 Subject: [PATCH 64/87] Always copy obj if specified dtype is different from original dtype --- heat/core/factories.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/factories.py b/heat/core/factories.py index 0c51c3cf7f..b8b3026ab7 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -171,7 +171,7 @@ def array( the :func:`~heat.core.dndarray.astype` method. copy : bool, optional If ``True`` (default), then the object is copied. Otherwise, a copy will only be made if obj is a nested - sequence or if a copy is needed to satisfy any of the other requirements, e.g. ``dtype``. + sequence or if a copy is needed to satisfy any of the other requirements, e.g. ``dtype`` or ``order``. ndmin : int, optional Specifies the minimum number of dimensions that the resulting array should have. Ones will, if needed, be attached to the shape if ``ndim > 0`` and prefaced in case of ``ndim < 0`` to meet the requirement. @@ -337,7 +337,7 @@ def array( else: torch_dtype = dtype.torch_type() if obj.dtype != torch_dtype: - obj = obj.type(torch_dtype) + obj = obj.clone().type(torch_dtype) # infer device from obj if not explicitly given if device is None: From f809c951372e31e25b092b54d18eae2c1bff33d7 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 15:31:18 +0100 Subject: [PATCH 65/87] Copy obj when specified dtype different from original --- heat/core/factories.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/heat/core/factories.py b/heat/core/factories.py index b8b3026ab7..ec04f2d54c 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -337,7 +337,12 @@ def array( else: torch_dtype = dtype.torch_type() if obj.dtype != torch_dtype: - obj = obj.clone().type(torch_dtype) + if not copy: + # different dtype, copy anyway + obj = obj.clone().type(torch_dtype) + else: + # obj is already a copy + obj = obj.type(torch_dtype) # infer device from obj if not explicitly given if device is None: From d1248190db8998ee3df654f9df072b9c777f4ef7 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 15:32:13 +0100 Subject: [PATCH 66/87] Copy output of torch.diagonal (partial view) --- heat/core/manipulations.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index be0b57ccde..815102dc27 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -301,7 +301,10 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: # no splits, local concat if s0 is None and s1 is None: return factories.array( - torch.cat((arr0.larray, arr1.larray), dim=axis), device=arr0.device, comm=arr0.comm + torch.cat((arr0.larray, arr1.larray), dim=axis), + device=arr0.device, + comm=arr0.comm, + copy=False, ) # non-matching splits when both arrays are split @@ -320,6 +323,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: is_split=s1 if s1 is not None else s0, device=arr1.device, comm=arr0.comm, + copy=False, ) return out @@ -335,6 +339,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: is_split=s0, device=arr0.device, comm=arr0.comm, + copy=False, ) return out @@ -505,6 +510,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: dtype=out_dtype, device=arr0.device, comm=arr0.comm, + copy=False, ) return out @@ -582,7 +588,9 @@ def diag(a: DNDarray, offset: int = 0) -> DNDarray: local = torch.zeros(lshape, dtype=a.dtype.torch_type(), device=a.device.torch_device) local[indices_x, indices_y] = a.larray[indices_x] - return factories.array(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) + return factories.array( + local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm, copy=False + ) def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDarray: @@ -808,14 +816,24 @@ def flatten(a: DNDarray) -> DNDarray: if a.split is None: return factories.array( - torch.flatten(a.larray), dtype=a.dtype, is_split=None, device=a.device, comm=a.comm + torch.flatten(a.larray), + dtype=a.dtype, + is_split=None, + device=a.device, + comm=a.comm, + copy=False, ) if a.split > 0: a = resplit(a, 0) a = factories.array( - torch.flatten(a.larray), dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm + torch.flatten(a.larray), + dtype=a.dtype, + is_split=a.split, + device=a.device, + comm=a.comm, + copy=False, ) a.balance_() @@ -866,7 +884,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: if a.split not in axis: return factories.array( - flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm + flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm, copy=False ) # Need to redistribute tensors on split axis @@ -1446,6 +1464,7 @@ def pad( is_split=array.split, device=array.device, comm=array.comm, + copy=False, ) padded_tensor.balance_() @@ -1605,9 +1624,9 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr # sanitation `a` if not isinstance(a, DNDarray): if isinstance(a, (int, float)): - a = factories.array([a]) + a = factories.array([a], copy=False) elif isinstance(a, (tuple, list, np.ndarray)): - a = factories.array(a) + a = factories.array(a, copy=False) else: raise TypeError( "`a` must be a ht.DNDarray, np.ndarray, list, tuple, integer, or float, currently: {}".format( From a4dd4d4b6003481c1b38106228703ad2facacc25 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 18 Nov 2021 15:34:40 +0100 Subject: [PATCH 67/87] remove dead code --- heat/core/tests/test_manipulations.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 8b70136d19..72bb6ad36f 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2640,11 +2640,6 @@ def test_sort(self): exp_indices = torch.tensor([[rank] * size], device=self.device.torch_device) result, result_indices = ht.sort(data, descending=True, axis=0) self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - print( - "DEBUGGING: result_indices.larray, exp_indices = ", - result_indices.larray, - exp_indices.int(), - ) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) exp_axis_one, exp_indices = ( From 16318ef7d443287abc982851dfc68ed3318d9215 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 19 Nov 2021 05:40:57 +0100 Subject: [PATCH 68/87] Debugging GPU error --- heat/core/manipulations.py | 10 ++++++++-- heat/core/tests/test_manipulations.py | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 815102dc27..ccacdc4460 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -664,7 +664,9 @@ def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDa vz = 1 if a.split == dim1 else -1 off, _, _ = a.comm.chunk(a.shape, a.split) result = torch.diagonal(a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2) - return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) + return factories.array( + result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm, copy=True + ) def dsplit(x: Sequence[DNDarray, ...], indices_or_sections: Iterable) -> List[DNDarray, ...]: @@ -2618,7 +2620,11 @@ def sort( final_result, final_indices = __pivot_sorting(a, torch.sort, axis, descending=descending) return_indices = factories.array( - final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm, copy=False + final_indices, + dtype=types.int32, + is_split=a.split, + device=a.device, + comm=a.comm, # , copy=False ) if out is not None: out.larray = final_result diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 72bb6ad36f..8b70136d19 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2640,6 +2640,11 @@ def test_sort(self): exp_indices = torch.tensor([[rank] * size], device=self.device.torch_device) result, result_indices = ht.sort(data, descending=True, axis=0) self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + print( + "DEBUGGING: result_indices.larray, exp_indices = ", + result_indices.larray, + exp_indices.int(), + ) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) exp_axis_one, exp_indices = ( From 07b6fb1658e067844e14fd0768d35fba6562e584 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 19 Nov 2021 05:57:30 +0100 Subject: [PATCH 69/87] Debugging test_sort on GPU --- heat/core/tests/test_manipulations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 8b70136d19..706974fd71 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2641,7 +2641,8 @@ def test_sort(self): result, result_indices = ht.sort(data, descending=True, axis=0) self.assertTrue(torch.equal(result.larray, exp_axis_zero)) print( - "DEBUGGING: result_indices.larray, exp_indices = ", + "DEBUGGING: rank, result_indices.larray, exp_indices = ", + rank, result_indices.larray, exp_indices.int(), ) From 5cd0b8c2c475f87a3a3f177a38b4f7df574f62a9 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Sat, 20 Nov 2021 07:59:21 +0100 Subject: [PATCH 70/87] Make test_sort more stable incl. for GPUs --- heat/core/tests/test_manipulations.py | 78 +++++++++++---------------- 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 706974fd71..481a38447b 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2619,41 +2619,33 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank - tensor = ( - torch.arange(size, device=self.device.torch_device).repeat(size).reshape(size, size) - ) - + tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) + # sort along axis 0, split None data = ht.array(tensor, split=None) result, result_indices = ht.sort(data, axis=0, descending=True) - expected, exp_indices = torch.sort(tensor, dim=0, descending=True) - self.assertTrue(torch.equal(result.larray, expected)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - + expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) + self.assertTrue(torch.equal(result.larray, expected_dim0)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) + # sort along axis 1, split None result, result_indices = ht.sort(data, axis=1, descending=True) - expected, exp_indices = torch.sort(tensor, dim=1, descending=True) - self.assertTrue(torch.equal(result.larray, expected)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - + expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) + self.assertTrue(torch.equal(result.larray, expected_dim1)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) + # sort along axis 0, split 0 data = ht.array(tensor, split=0) - - exp_axis_zero = torch.arange(size, device=self.device.torch_device).reshape(1, size) - exp_indices = torch.tensor([[rank] * size], device=self.device.torch_device) result, result_indices = ht.sort(data, descending=True, axis=0) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - print( - "DEBUGGING: rank, result_indices.larray, exp_indices = ", - rank, - result_indices.larray, - exp_indices.int(), - ) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - - exp_axis_one, exp_indices = ( - torch.arange(size, device=self.device.torch_device) - .reshape(1, size) - .sort(dim=1, descending=True) - ) + # sort along axis 1, split 0 result, result_indices = ht.sort(data, descending=True, axis=1) + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] self.assertTrue(torch.equal(result.larray, exp_axis_one)) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) @@ -2661,30 +2653,24 @@ def test_sort(self): result2 = ht.sort(data, descending=True) self.assertTrue(ht.equal(result1[0], result2[0])) self.assertTrue(ht.equal(result1[1], result2[1])) - + # sort along axis 0, split 1 data = ht.array(tensor, split=1) - - exp_axis_zero = ( - torch.tensor(rank, device=self.device.torch_device).repeat(size).reshape(size, 1) - ) - indices_axis_zero = torch.arange( - size, dtype=torch.int64, device=self.device.torch_device - ).reshape(size, 1) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] result, result_indices = ht.sort(data, axis=0, descending=True) self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - # comparison value is only true on CPU - if result_indices.larray.is_cuda is False: - self.assertTrue(torch.equal(result_indices.larray, indices_axis_zero.int())) - - exp_axis_one = ( - torch.tensor(size - rank - 1, device=self.device.torch_device) - .repeat(size) - .reshape(size, 1) - ) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 1, split 1 + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] result, result_indices = ht.sort(data, descending=True, axis=1) self.assertTrue(torch.equal(result.larray, exp_axis_one)) - self.assertTrue(torch.equal(result_indices.larray, exp_axis_one.int())) - + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # 3D array tensor = torch.tensor( [ [[2, 8, 5], [7, 2, 3]], From 90b0d716030b292a5cf41905a82ada5fbca1a3e9 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Nov 2021 06:00:05 +0100 Subject: [PATCH 71/87] Debug GPU test_sort --- heat/core/tests/test_manipulations.py | 33 ++++++++++++++------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 481a38447b..f662621880 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2654,22 +2654,23 @@ def test_sort(self): self.assertTrue(ht.equal(result1[0], result2[0])) self.assertTrue(ht.equal(result1[1], result2[1])) # sort along axis 0, split 1 - data = ht.array(tensor, split=1) - _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) - _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) - exp_axis_zero = expected_dim0[local_slice] - exp_indices = exp_indices_dim0[local_slice_ind] - result, result_indices = ht.sort(data, axis=0, descending=True) - self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # sort along axis 1, split 1 - _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) - _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) - exp_axis_one = expected_dim1[local_slice] - exp_indices = exp_indices_dim1[local_slice_ind] - result, result_indices = ht.sort(data, descending=True, axis=1) - self.assertTrue(torch.equal(result.larray, exp_axis_one)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # data = ht.array(tensor, split=1) + # _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) + # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) + # exp_axis_zero = expected_dim0[local_slice] + # exp_indices = exp_indices_dim0[local_slice_ind] + # result, result_indices = ht.sort(data, axis=0, descending=True) + # self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # # sort along axis 1, split 1 + # _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) + # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) + # exp_axis_one = expected_dim1[local_slice] + # exp_indices = exp_indices_dim1[local_slice_ind] + # result, result_indices = ht.sort(data, descending=True, axis=1) + # self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # 3D array tensor = torch.tensor( [ From 7aa51aeac01425c9527c297a69246bcaedc9321e Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Nov 2021 06:06:05 +0100 Subject: [PATCH 72/87] Degug test_sort on GPU --- heat/core/tests/test_manipulations.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index f662621880..0e3dc5223c 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2640,20 +2640,20 @@ def test_sort(self): exp_indices = exp_indices_dim0[local_slice_ind] self.assertTrue(torch.equal(result.larray, exp_axis_zero)) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # sort along axis 1, split 0 - result, result_indices = ht.sort(data, descending=True, axis=1) - _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) - _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) - exp_axis_one = expected_dim1[local_slice] - exp_indices = exp_indices_dim1[local_slice_ind] - self.assertTrue(torch.equal(result.larray, exp_axis_one)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # # sort along axis 1, split 0 + # result, result_indices = ht.sort(data, descending=True, axis=1) + # _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) + # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) + # exp_axis_one = expected_dim1[local_slice] + # exp_indices = exp_indices_dim1[local_slice_ind] + # self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - result1 = ht.sort(data, axis=1, descending=True) - result2 = ht.sort(data, descending=True) - self.assertTrue(ht.equal(result1[0], result2[0])) - self.assertTrue(ht.equal(result1[1], result2[1])) - # sort along axis 0, split 1 + # result1 = ht.sort(data, axis=1, descending=True) + # result2 = ht.sort(data, descending=True) + # self.assertTrue(ht.equal(result1[0], result2[0])) + # self.assertTrue(ht.equal(result1[1], result2[1])) + # # sort along axis 0, split 1 # data = ht.array(tensor, split=1) # _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) From dbce096a621a751a0b335c48840f21ca91d355e5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Nov 2021 06:11:16 +0100 Subject: [PATCH 73/87] Debug test_sort on GPU --- heat/core/tests/test_manipulations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 0e3dc5223c..a45cc2e074 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2619,6 +2619,7 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank + torch.manual_seed(42) tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) # sort along axis 0, split None data = ht.array(tensor, split=None) From f615ade4b2999e61302cd59bdb31c38164b8eee5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Nov 2021 06:15:05 +0100 Subject: [PATCH 74/87] Debug test_sort on GPU --- heat/core/tests/test_manipulations.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index a45cc2e074..1cab15af85 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2632,16 +2632,16 @@ def test_sort(self): expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) self.assertTrue(torch.equal(result.larray, expected_dim1)) self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) - # sort along axis 0, split 0 - data = ht.array(tensor, split=0) - result, result_indices = ht.sort(data, descending=True, axis=0) - _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) - _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) - exp_axis_zero = expected_dim0[local_slice] - exp_indices = exp_indices_dim0[local_slice_ind] - self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # # sort along axis 1, split 0 + # # sort along axis 0, split 0 + # data = ht.array(tensor, split=0) + # result, result_indices = ht.sort(data, descending=True, axis=0) + # _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) + # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) + # exp_axis_zero = expected_dim0[local_slice] + # exp_indices = exp_indices_dim0[local_slice_ind] + # self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # # # sort along axis 1, split 0 # result, result_indices = ht.sort(data, descending=True, axis=1) # _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) From ea0aab479db99d173fea37854e85077665ea84fc Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Nov 2021 06:19:21 +0100 Subject: [PATCH 75/87] Debug --- heat/core/tests/test_manipulations.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 1cab15af85..84f4dfb9e3 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2619,19 +2619,19 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size rank = ht.MPI_WORLD.rank - torch.manual_seed(42) - tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) - # sort along axis 0, split None - data = ht.array(tensor, split=None) - result, result_indices = ht.sort(data, axis=0, descending=True) - expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) - self.assertTrue(torch.equal(result.larray, expected_dim0)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) - # sort along axis 1, split None - result, result_indices = ht.sort(data, axis=1, descending=True) - expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) - self.assertTrue(torch.equal(result.larray, expected_dim1)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) + # torch.manual_seed(42) + # tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) + # # sort along axis 0, split None + # data = ht.array(tensor, split=None) + # result, result_indices = ht.sort(data, axis=0, descending=True) + # expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) + # self.assertTrue(torch.equal(result.larray, expected_dim0)) + # self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) + # # sort along axis 1, split None + # result, result_indices = ht.sort(data, axis=1, descending=True) + # expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) + # self.assertTrue(torch.equal(result.larray, expected_dim1)) + # self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) # # sort along axis 0, split 0 # data = ht.array(tensor, split=0) # result, result_indices = ht.sort(data, descending=True, axis=0) From 152fc553cfb5220c3252ba3da64bd29a88f8714e Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Mon, 22 Nov 2021 09:43:32 +0100 Subject: [PATCH 76/87] Debug --- heat/core/tests/test_manipulations.py | 246 +++++++++++++------------- 1 file changed, 123 insertions(+), 123 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 84f4dfb9e3..7adc77d085 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2618,134 +2618,134 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size - rank = ht.MPI_WORLD.rank - # torch.manual_seed(42) - # tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) - # # sort along axis 0, split None - # data = ht.array(tensor, split=None) - # result, result_indices = ht.sort(data, axis=0, descending=True) - # expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) - # self.assertTrue(torch.equal(result.larray, expected_dim0)) - # self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) - # # sort along axis 1, split None - # result, result_indices = ht.sort(data, axis=1, descending=True) - # expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) - # self.assertTrue(torch.equal(result.larray, expected_dim1)) - # self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) - # # sort along axis 0, split 0 - # data = ht.array(tensor, split=0) - # result, result_indices = ht.sort(data, descending=True, axis=0) - # _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) - # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) - # exp_axis_zero = expected_dim0[local_slice] - # exp_indices = exp_indices_dim0[local_slice_ind] - # self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # # # sort along axis 1, split 0 - # result, result_indices = ht.sort(data, descending=True, axis=1) - # _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) - # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) - # exp_axis_one = expected_dim1[local_slice] - # exp_indices = exp_indices_dim1[local_slice_ind] - # self.assertTrue(torch.equal(result.larray, exp_axis_one)) - # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - - # result1 = ht.sort(data, axis=1, descending=True) - # result2 = ht.sort(data, descending=True) - # self.assertTrue(ht.equal(result1[0], result2[0])) - # self.assertTrue(ht.equal(result1[1], result2[1])) - # # sort along axis 0, split 1 - # data = ht.array(tensor, split=1) - # _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) - # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) - # exp_axis_zero = expected_dim0[local_slice] - # exp_indices = exp_indices_dim0[local_slice_ind] - # result, result_indices = ht.sort(data, axis=0, descending=True) - # self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # # sort along axis 1, split 1 - # _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) - # _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) - # exp_axis_one = expected_dim1[local_slice] - # exp_indices = exp_indices_dim1[local_slice_ind] - # result, result_indices = ht.sort(data, descending=True, axis=1) - # self.assertTrue(torch.equal(result.larray, exp_axis_one)) - # self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - - # 3D array - tensor = torch.tensor( - [ - [[2, 8, 5], [7, 2, 3]], - [[6, 5, 2], [1, 8, 7]], - [[9, 3, 0], [1, 2, 4]], - [[8, 4, 7], [0, 8, 9]], - ], - dtype=torch.int32, - device=self.device.torch_device, - ) - + # rank = ht.MPI_WORLD.rank + torch.manual_seed(42) + tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) + # sort along axis 0, split None + data = ht.array(tensor, split=None) + result, result_indices = ht.sort(data, axis=0, descending=True) + expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) + self.assertTrue(torch.equal(result.larray, expected_dim0)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) + # sort along axis 1, split None + result, result_indices = ht.sort(data, axis=1, descending=True) + expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) + self.assertTrue(torch.equal(result.larray, expected_dim1)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) + # sort along axis 0, split 0 data = ht.array(tensor, split=0) - exp_axis_zero = torch.tensor( - [[2, 3, 0], [0, 2, 3]], dtype=torch.int32, device=self.device.torch_device - ) - if torch.cuda.is_available() and data.device == ht.gpu and size < 4: - indices_axis_zero = torch.tensor( - [[0, 2, 2], [3, 2, 0]], dtype=torch.int32, device=self.device.torch_device - ) - else: - indices_axis_zero = torch.tensor( - [[0, 2, 2], [3, 0, 0]], dtype=torch.int32, device=self.device.torch_device - ) - result, result_indices = ht.sort(data, axis=0) - first = result[0].larray - first_indices = result_indices[0].larray - if rank == 0: - self.assertTrue(torch.equal(first, exp_axis_zero)) - self.assertTrue(torch.equal(first_indices, indices_axis_zero)) - + result, result_indices = ht.sort(data, descending=True, axis=0) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # # sort along axis 1, split 0 + result, result_indices = ht.sort(data, descending=True, axis=1) + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + result1 = ht.sort(data, axis=1, descending=True) + result2 = ht.sort(data, descending=True) + self.assertTrue(ht.equal(result1[0], result2[0])) + self.assertTrue(ht.equal(result1[1], result2[1])) + # sort along axis 0, split 1 data = ht.array(tensor, split=1) - exp_axis_one = torch.tensor([[2, 2, 3]], dtype=torch.int32, device=self.device.torch_device) - indices_axis_one = torch.tensor( - [[0, 1, 1]], dtype=torch.int32, device=self.device.torch_device - ) - result, result_indices = ht.sort(data, axis=1) - first = result[0].larray[:1] - first_indices = result_indices[0].larray[:1] - if rank == 0: - self.assertTrue(torch.equal(first, exp_axis_one)) - self.assertTrue(torch.equal(first_indices, indices_axis_one)) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + result, result_indices = ht.sort(data, axis=0, descending=True) + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 1, split 1 + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] + result, result_indices = ht.sort(data, descending=True, axis=1) + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + # # 3D array + # tensor = torch.tensor( + # [ + # [[2, 8, 5], [7, 2, 3]], + # [[6, 5, 2], [1, 8, 7]], + # [[9, 3, 0], [1, 2, 4]], + # [[8, 4, 7], [0, 8, 9]], + # ], + # dtype=torch.int32, + # device=self.device.torch_device, + # ) - data = ht.array(tensor, split=2) - exp_axis_two = torch.tensor([[2], [2]], dtype=torch.int32, device=self.device.torch_device) - indices_axis_two = torch.tensor( - [[0], [1]], dtype=torch.int32, device=self.device.torch_device - ) - result, result_indices = ht.sort(data, axis=2) - first = result[0].larray[:, :1] - first_indices = result_indices[0].larray[:, :1] - if rank == 0: - self.assertTrue(torch.equal(first, exp_axis_two)) - self.assertTrue(torch.equal(first_indices, indices_axis_two)) - # - out = ht.empty_like(data) - indices = ht.sort(data, axis=2, out=out) - self.assertTrue(ht.equal(out, result)) - self.assertTrue(ht.equal(indices, result_indices)) + # data = ht.array(tensor, split=0) + # exp_axis_zero = torch.tensor( + # [[2, 3, 0], [0, 2, 3]], dtype=torch.int32, device=self.device.torch_device + # ) + # if torch.cuda.is_available() and data.device == ht.gpu and size < 4: + # indices_axis_zero = torch.tensor( + # [[0, 2, 2], [3, 2, 0]], dtype=torch.int32, device=self.device.torch_device + # ) + # else: + # indices_axis_zero = torch.tensor( + # [[0, 2, 2], [3, 0, 0]], dtype=torch.int32, device=self.device.torch_device + # ) + # result, result_indices = ht.sort(data, axis=0) + # first = result[0].larray + # first_indices = result_indices[0].larray + # if rank == 0: + # self.assertTrue(torch.equal(first, exp_axis_zero)) + # self.assertTrue(torch.equal(first_indices, indices_axis_zero)) - with self.assertRaises(ValueError): - ht.sort(data, axis=3) - with self.assertRaises(TypeError): - ht.sort(data, axis="1") + # data = ht.array(tensor, split=1) + # exp_axis_one = torch.tensor([[2, 2, 3]], dtype=torch.int32, device=self.device.torch_device) + # indices_axis_one = torch.tensor( + # [[0, 1, 1]], dtype=torch.int32, device=self.device.torch_device + # ) + # result, result_indices = ht.sort(data, axis=1) + # first = result[0].larray[:1] + # first_indices = result_indices[0].larray[:1] + # if rank == 0: + # self.assertTrue(torch.equal(first, exp_axis_one)) + # self.assertTrue(torch.equal(first_indices, indices_axis_one)) + + # data = ht.array(tensor, split=2) + # exp_axis_two = torch.tensor([[2], [2]], dtype=torch.int32, device=self.device.torch_device) + # indices_axis_two = torch.tensor( + # [[0], [1]], dtype=torch.int32, device=self.device.torch_device + # ) + # result, result_indices = ht.sort(data, axis=2) + # first = result[0].larray[:, :1] + # first_indices = result_indices[0].larray[:, :1] + # if rank == 0: + # self.assertTrue(torch.equal(first, exp_axis_two)) + # self.assertTrue(torch.equal(first_indices, indices_axis_two)) + # # + # out = ht.empty_like(data) + # indices = ht.sort(data, axis=2, out=out) + # self.assertTrue(ht.equal(out, result)) + # self.assertTrue(ht.equal(indices, result_indices)) + + # with self.assertRaises(ValueError): + # ht.sort(data, axis=3) + # with self.assertRaises(TypeError): + # ht.sort(data, axis="1") - rank = ht.MPI_WORLD.rank - ht.random.seed(1) - data = ht.random.randn(100, 1, split=0) - result, _ = ht.sort(data, axis=0) - counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0) - for i, c in enumerate(counts): - for idx in range(c - 1): - if rank == i: - self.assertTrue(torch.lt(result.larray[idx], result.larray[idx + 1]).all()) + # rank = ht.MPI_WORLD.rank + # ht.random.seed(1) + # data = ht.random.randn(100, 1, split=0) + # result, _ = ht.sort(data, axis=0) + # counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0) + # for i, c in enumerate(counts): + # for idx in range(c - 1): + # if rank == i: + # self.assertTrue(torch.lt(result.larray[idx], result.larray[idx + 1]).all()) def test_split(self): # ==================================== From 3939564181ea72bee988b395a830fa9a00e4ca9f Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 06:21:10 +0100 Subject: [PATCH 77/87] Debugging --- heat/core/tests/test_manipulations.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 7adc77d085..2d17da3e70 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2618,7 +2618,7 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size - # rank = ht.MPI_WORLD.rank + rank = ht.MPI_WORLD.rank torch.manual_seed(42) tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) # sort along axis 0, split None @@ -2647,13 +2647,15 @@ def test_sort(self): _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) exp_axis_one = expected_dim1[local_slice] exp_indices = exp_indices_dim1[local_slice_ind] + print("DEBUGGING: heat/torch result: ", rank, result.larray, exp_axis_one) + print("DEBUGGING: heat/torch indices: ", rank, result_indices.larray, exp_indices.int()) self.assertTrue(torch.equal(result.larray, exp_axis_one)) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - result1 = ht.sort(data, axis=1, descending=True) - result2 = ht.sort(data, descending=True) - self.assertTrue(ht.equal(result1[0], result2[0])) - self.assertTrue(ht.equal(result1[1], result2[1])) + # result1 = ht.sort(data, axis=1, descending=True) + # result2 = ht.sort(data, descending=True) + # self.assertTrue(ht.equal(result1[0], result2[0])) + # self.assertTrue(ht.equal(result1[1], result2[1])) # sort along axis 0, split 1 data = ht.array(tensor, split=1) _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) From a27f7f87aba6c971b7c8faa64964be201ca57621 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 09:19:33 +0100 Subject: [PATCH 78/87] Debugging --- heat/core/tests/test_manipulations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 2d17da3e70..88a15637f6 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2671,6 +2671,8 @@ def test_sort(self): exp_axis_one = expected_dim1[local_slice] exp_indices = exp_indices_dim1[local_slice_ind] result, result_indices = ht.sort(data, descending=True, axis=1) + print("DEBUGGING: heat/torch result: ", rank, result.larray, exp_axis_one) + print("DEBUGGING: heat/torch indices: ", rank, result_indices.larray, exp_indices.int()) self.assertTrue(torch.equal(result.larray, exp_axis_one)) self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) From f0f79264a047cb2d32f18a5abb23a2ec17d1c941 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 10:32:04 +0100 Subject: [PATCH 79/87] Do not test sorting indices on GPU if sorting non-unique values --- heat/core/tests/test_manipulations.py | 55 ++++++++++++++++++--------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 88a15637f6..61d272a0d2 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2618,20 +2618,30 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size - rank = ht.MPI_WORLD.rank + # rank = ht.MPI_WORLD.rank torch.manual_seed(42) - tensor = torch.randint(0, 20, (size, size), device=self.device.torch_device) + tensor = torch.randint(0, 10 * size, (size, size), device=self.device.torch_device) # sort along axis 0, split None data = ht.array(tensor, split=None) result, result_indices = ht.sort(data, axis=0, descending=True) expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) self.assertTrue(torch.equal(result.larray, expected_dim0)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices_dim0).numel() == exp_indices_dim0.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim0.int())) # sort along axis 1, split None result, result_indices = ht.sort(data, axis=1, descending=True) expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) self.assertTrue(torch.equal(result.larray, expected_dim1)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices_dim1).numel() == exp_indices_dim1.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices_dim1.int())) # sort along axis 0, split 0 data = ht.array(tensor, split=0) result, result_indices = ht.sort(data, descending=True, axis=0) @@ -2640,22 +2650,26 @@ def test_sort(self): exp_axis_zero = expected_dim0[local_slice] exp_indices = exp_indices_dim0[local_slice_ind] self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) # # sort along axis 1, split 0 result, result_indices = ht.sort(data, descending=True, axis=1) _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) exp_axis_one = expected_dim1[local_slice] exp_indices = exp_indices_dim1[local_slice_ind] - print("DEBUGGING: heat/torch result: ", rank, result.larray, exp_axis_one) - print("DEBUGGING: heat/torch indices: ", rank, result_indices.larray, exp_indices.int()) self.assertTrue(torch.equal(result.larray, exp_axis_one)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # result1 = ht.sort(data, axis=1, descending=True) - # result2 = ht.sort(data, descending=True) - # self.assertTrue(ht.equal(result1[0], result2[0])) - # self.assertTrue(ht.equal(result1[1], result2[1])) # sort along axis 0, split 1 data = ht.array(tensor, split=1) _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) @@ -2664,18 +2678,25 @@ def test_sort(self): exp_indices = exp_indices_dim0[local_slice_ind] result, result_indices = ht.sort(data, axis=0, descending=True) self.assertTrue(torch.equal(result.larray, exp_axis_zero)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) # sort along axis 1, split 1 _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) exp_axis_one = expected_dim1[local_slice] exp_indices = exp_indices_dim1[local_slice_ind] result, result_indices = ht.sort(data, descending=True, axis=1) - print("DEBUGGING: heat/torch result: ", rank, result.larray, exp_axis_one) - print("DEBUGGING: heat/torch indices: ", rank, result_indices.larray, exp_indices.int()) self.assertTrue(torch.equal(result.larray, exp_axis_one)) - self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) # # 3D array # tensor = torch.tensor( # [ From 73c914db712bd5e6566037c5eedb3c898017ab41 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 11:34:42 +0100 Subject: [PATCH 80/87] Expand test_sort to 3d --- heat/core/tests/test_manipulations.py | 118 +++++++++++++++++++++----- 1 file changed, 97 insertions(+), 21 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 61d272a0d2..d3cbe59fba 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2620,7 +2620,8 @@ def test_sort(self): size = ht.MPI_WORLD.size # rank = ht.MPI_WORLD.rank torch.manual_seed(42) - tensor = torch.randint(0, 10 * size, (size, size), device=self.device.torch_device) + tensor_3d = torch.randint(0, 10 * size, (size, size, size), device=self.device.torch_device) + tensor = tensor_3d[0] # sort along axis 0, split None data = ht.array(tensor, split=None) result, result_indices = ht.sort(data, axis=0, descending=True) @@ -2698,18 +2699,93 @@ def test_sort(self): ): self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) # # 3D array - # tensor = torch.tensor( - # [ - # [[2, 8, 5], [7, 2, 3]], - # [[6, 5, 2], [1, 8, 7]], - # [[9, 3, 0], [1, 2, 4]], - # [[8, 4, 7], [0, 8, 9]], - # ], - # dtype=torch.int32, - # device=self.device.torch_device, - # ) + tensor = tensor_3d + expected_dim0, exp_indices_dim0 = torch.sort(tensor, dim=0, descending=True) + expected_dim1, exp_indices_dim1 = torch.sort(tensor, dim=1, descending=True) + expected_dim2, exp_indices_dim2 = torch.sort(tensor, dim=2, descending=True) + # sort along axis 0, split 0 + data = ht.array(tensor, split=0) + result, result_indices = ht.sort(data, descending=True, axis=0) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=0) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # # sort along axis 1, split 0 + result, result_indices = ht.sort(data, descending=True, axis=1) + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 0, split 1 + data = ht.array(tensor, split=1) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=1) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + result, result_indices = ht.sort(data, axis=0, descending=True) + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 1, split 1 + _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=1) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=1) + exp_axis_one = expected_dim1[local_slice] + exp_indices = exp_indices_dim1[local_slice_ind] + result, result_indices = ht.sort(data, descending=True, axis=1) + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + + # sort along axis 0, split 2 + data = ht.array(tensor, split=2) + _, _, local_slice = data.comm.chunk(expected_dim0.shape, split=2) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim0.shape, split=2) + exp_axis_zero = expected_dim0[local_slice] + exp_indices = exp_indices_dim0[local_slice_ind] + result, result_indices = ht.sort(data, axis=0, descending=True) + self.assertTrue(torch.equal(result.larray, exp_axis_zero)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) + # sort along axis 2, split 2 + _, _, local_slice = data.comm.chunk(expected_dim2.shape, split=2) + _, _, local_slice_ind = data.comm.chunk(exp_indices_dim2.shape, split=2) + exp_axis_one = expected_dim2[local_slice] + exp_indices = exp_indices_dim2[local_slice_ind] + result, result_indices = ht.sort(data, descending=True, axis=2) + self.assertTrue(torch.equal(result.larray, exp_axis_one)) + # indices unstable on GPU if sorting non-unique values + if ( + torch.unique(exp_indices).numel() == exp_indices.numel() + or result_indices.larray.is_cuda is False + ): + self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # data = ht.array(tensor, split=0) # exp_axis_zero = torch.tensor( # [[2, 3, 0], [0, 2, 3]], dtype=torch.int32, device=self.device.torch_device # ) @@ -2752,15 +2828,15 @@ def test_sort(self): # self.assertTrue(torch.equal(first, exp_axis_two)) # self.assertTrue(torch.equal(first_indices, indices_axis_two)) # # - # out = ht.empty_like(data) - # indices = ht.sort(data, axis=2, out=out) - # self.assertTrue(ht.equal(out, result)) - # self.assertTrue(ht.equal(indices, result_indices)) - - # with self.assertRaises(ValueError): - # ht.sort(data, axis=3) - # with self.assertRaises(TypeError): - # ht.sort(data, axis="1") + out = ht.empty_like(data) + indices = ht.sort(data, axis=2, out=out) + self.assertTrue(ht.equal(out, result)) + self.assertTrue(ht.equal(indices, result_indices)) + + with self.assertRaises(ValueError): + ht.sort(data, axis=3) + with self.assertRaises(TypeError): + ht.sort(data, axis="1") # rank = ht.MPI_WORLD.rank # ht.random.seed(1) From b972a3a943c0e8969a61aedb5af0afbc29e4b2c6 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 11:46:47 +0100 Subject: [PATCH 81/87] Expand test_sort for empty-node case --- heat/core/tests/test_manipulations.py | 70 ++++++--------------------- 1 file changed, 16 insertions(+), 54 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index d3cbe59fba..eba6fac6f2 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2618,9 +2618,10 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size - # rank = ht.MPI_WORLD.rank torch.manual_seed(42) - tensor_3d = torch.randint(0, 10 * size, (size, size, size), device=self.device.torch_device) + tensor_3d = torch.randint( + 0, 10 * size, (size, size - 1, size), device=self.device.torch_device + ) tensor = tensor_3d[0] # sort along axis 0, split None data = ht.array(tensor, split=None) @@ -2717,7 +2718,7 @@ def test_sort(self): or result_indices.larray.is_cuda is False ): self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # # sort along axis 1, split 0 + # sort along axis 1, split 0 result, result_indices = ht.sort(data, descending=True, axis=1) _, _, local_slice = data.comm.chunk(expected_dim1.shape, split=0) _, _, local_slice_ind = data.comm.chunk(exp_indices_dim1.shape, split=0) @@ -2786,67 +2787,28 @@ def test_sort(self): ): self.assertTrue(torch.equal(result_indices.larray, exp_indices.int())) - # exp_axis_zero = torch.tensor( - # [[2, 3, 0], [0, 2, 3]], dtype=torch.int32, device=self.device.torch_device - # ) - # if torch.cuda.is_available() and data.device == ht.gpu and size < 4: - # indices_axis_zero = torch.tensor( - # [[0, 2, 2], [3, 2, 0]], dtype=torch.int32, device=self.device.torch_device - # ) - # else: - # indices_axis_zero = torch.tensor( - # [[0, 2, 2], [3, 0, 0]], dtype=torch.int32, device=self.device.torch_device - # ) - # result, result_indices = ht.sort(data, axis=0) - # first = result[0].larray - # first_indices = result_indices[0].larray - # if rank == 0: - # self.assertTrue(torch.equal(first, exp_axis_zero)) - # self.assertTrue(torch.equal(first_indices, indices_axis_zero)) - - # data = ht.array(tensor, split=1) - # exp_axis_one = torch.tensor([[2, 2, 3]], dtype=torch.int32, device=self.device.torch_device) - # indices_axis_one = torch.tensor( - # [[0, 1, 1]], dtype=torch.int32, device=self.device.torch_device - # ) - # result, result_indices = ht.sort(data, axis=1) - # first = result[0].larray[:1] - # first_indices = result_indices[0].larray[:1] - # if rank == 0: - # self.assertTrue(torch.equal(first, exp_axis_one)) - # self.assertTrue(torch.equal(first_indices, indices_axis_one)) - - # data = ht.array(tensor, split=2) - # exp_axis_two = torch.tensor([[2], [2]], dtype=torch.int32, device=self.device.torch_device) - # indices_axis_two = torch.tensor( - # [[0], [1]], dtype=torch.int32, device=self.device.torch_device - # ) - # result, result_indices = ht.sort(data, axis=2) - # first = result[0].larray[:, :1] - # first_indices = result_indices[0].larray[:, :1] - # if rank == 0: - # self.assertTrue(torch.equal(first, exp_axis_two)) - # self.assertTrue(torch.equal(first_indices, indices_axis_two)) - # # + # test out, descending=False + result, result_indices = ht.sort(data, axis=2) out = ht.empty_like(data) indices = ht.sort(data, axis=2, out=out) self.assertTrue(ht.equal(out, result)) self.assertTrue(ht.equal(indices, result_indices)) + # test exceptions with self.assertRaises(ValueError): ht.sort(data, axis=3) with self.assertRaises(TypeError): ht.sort(data, axis="1") - # rank = ht.MPI_WORLD.rank - # ht.random.seed(1) - # data = ht.random.randn(100, 1, split=0) - # result, _ = ht.sort(data, axis=0) - # counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0) - # for i, c in enumerate(counts): - # for idx in range(c - 1): - # if rank == i: - # self.assertTrue(torch.lt(result.larray[idx], result.larray[idx + 1]).all()) + rank = ht.MPI_WORLD.rank + ht.random.seed(1) + data = ht.random.randn(100, 1, split=0) + result, _ = ht.sort(data, axis=0) + counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0) + for i, c in enumerate(counts): + for idx in range(c - 1): + if rank == i: + self.assertTrue(torch.lt(result.larray[idx], result.larray[idx + 1]).all()) def test_split(self): # ==================================== From 6daa7c251c589f69e2b7b20597fae07cbe7eea0d Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 11:53:06 +0100 Subject: [PATCH 82/87] Remove size-1 test in test_sort --- heat/core/tests/test_manipulations.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index eba6fac6f2..5b6dc21a99 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2619,9 +2619,7 @@ def test_shape(self): def test_sort(self): size = ht.MPI_WORLD.size torch.manual_seed(42) - tensor_3d = torch.randint( - 0, 10 * size, (size, size - 1, size), device=self.device.torch_device - ) + tensor_3d = torch.randint(0, 10 * size, (size, size, size), device=self.device.torch_device) tensor = tensor_3d[0] # sort along axis 0, split None data = ht.array(tensor, split=None) From 96f5d9d15eba0603aba3a753c560064691a21c12 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 12:50:45 +0100 Subject: [PATCH 83/87] Update changelog --- CHANGELOG.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df24024a8e..ff48ee45c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,18 +1,15 @@ # Pending additions -- [#867](https://github.com/helmholtz-analytics/heat/pull/867) Upgraded to support torch 1.9.0 -- [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN) -- [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use. - ## Bug Fixes - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension - [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set. - [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed. - [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element. +- [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN) ## Feature Additions -### Linear Algebra -- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot` +- [#867](https://github.com/helmholtz-analytics/heat/pull/867) Support torch 1.9.0 +- [#884](https://github.com/helmholtz-analytics/heat/pull/884) Support PyTorch 1.10.0, this is now the recommended version to use. ### Communication - [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split` @@ -21,12 +18,16 @@ - [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__` - [#885](https://github.com/helmholtz-analytics/heat/pull/885) New `DNDarray` method `conj` +### Factories +- [#749](https://github.com/helmholtz-analytics/heat/pull/749) `ht.array(copy=False)` behaviour now more in line with `np.array(copy=False)`, reduced memory footprint ### Linear Algebra - [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()` +- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot` - [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm` ### Logical - [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit` ### Manipulations +- [#749](https://github.com/helmholtz-analytics/heat/pull/749) Distributed sorted `ht.unique` - [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll` - [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes` - [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis` @@ -35,6 +36,7 @@ ### Rounding - [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn` + # v1.1.1 - [#864](https://github.com/helmholtz-analytics/heat/pull/864) Dependencies: constrain `torchvision` version range to match supported `pytorch` version range. @@ -98,7 +100,6 @@ Example on 2 processes: - [#768](https://github.com/helmholtz-analytics/heat/pull/768) New feature: unary positive and negative operations ### Manipulations -- [#749](https://github.com/helmholtz-analytics/heat/pull/749) Distributed sorted `ht.unique` - [#820](https://github.com/helmholtz-analytics/heat/pull/820) `dot` can handle matrix vector operation now - [#820](https://github.com/helmholtz-analytics/heat/pull/820) `dot` can handle matrix-vector operation now From 6a41a9c9bc41987721a2c4f5c2447985b6a74d3a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 24 Nov 2021 13:22:28 +0100 Subject: [PATCH 84/87] Remove dead code --- heat/core/communication.py | 1 - heat/core/memory.py | 10 ---------- 2 files changed, 11 deletions(-) diff --git a/heat/core/communication.py b/heat/core/communication.py index c446ae958f..8297ccc851 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -7,7 +7,6 @@ import os import subprocess import torch -import tracemalloc from mpi4py import MPI from typing import Any, Callable, Optional, List, Tuple, Union diff --git a/heat/core/memory.py b/heat/core/memory.py index 0a2ebca9ce..07b864e9de 100644 --- a/heat/core/memory.py +++ b/heat/core/memory.py @@ -3,7 +3,6 @@ """ import torch -import tracemalloc from . import sanitation from .dndarray import DNDarray @@ -74,15 +73,6 @@ def sanitize_memory_layout(x: torch.Tensor, order: str = "C") -> torch.Tensor: reversed_stride = tuple(reversed(x.stride())) x.set_(x.storage(), storage_offset, shape, reversed_stride) return x - # y = torch.empty_like(x) - # permutation = x.permute(dims).contiguous() - # y = y.set_( - # permutation.storage(), - # x.storage_offset(), - # x.shape, - # tuple(reversed(permutation.stride())), - # ) - # return y raise ValueError( "combination of order and layout not permitted, order: {} column major: {} row major: {}".format( From 6007d3104d0686b7444ce9cb721903d7daa1c7b0 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 25 Nov 2021 05:06:02 +0100 Subject: [PATCH 85/87] Reinstate "sparse" unique --- heat/core/manipulations.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ccacdc4460..b7bb0daf1e 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3319,14 +3319,25 @@ def unique( else: lres = torch.unique(local_data, sorted=True, return_inverse=False, dim=unique_axis) gres = factories.array(lres, dtype=a.dtype, is_split=0, device=a.device, copy=False) - gres.balance_() - # global sorted unique - lres = __pivot_sorting(gres, torch.unique, 0, sorted=True, return_inverse=True) - # second local unique - if 0 not in lres.shape: - lres = torch.unique(lres, sorted=True, dim=unique_axis) - lres_split = 0 + # calculate size (bytes) of local unique. If less than local_data, gather and run everything locally + data_max_lbytes = torch.prod(a.lshape_map[0]) * a.larray.element_size() + if gres.nbytes <= data_max_lbytes: + # gather local uniques + gres.resplit_(None) + # final round of torch.unique + lres = torch.unique(gres.larray, sorted=True, dim=unique_axis) + lres_split = None + gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device, copy=False) + else: + # TODO: balancing should be unnecessary before distributed sorting + gres.balance_() + # global sorted unique + lres = __pivot_sorting(gres, torch.unique, 0, sorted=True, return_inverse=True) + # second local unique + if 0 not in lres.shape: + lres = torch.unique(lres, sorted=True, dim=unique_axis) + lres_split = 0 gres = factories.array(lres, dtype=a.dtype, is_split=lres_split, device=a.device, copy=False) gres.balance_() From ec0684cde463fcfe9b1326e154432cb8362e0957 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 25 Nov 2021 05:54:47 +0100 Subject: [PATCH 86/87] Remove unnecessary `balance_` before distributed `unique` --- heat/core/manipulations.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index b7bb0daf1e..a5520257fc 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2620,11 +2620,7 @@ def sort( final_result, final_indices = __pivot_sorting(a, torch.sort, axis, descending=descending) return_indices = factories.array( - final_indices, - dtype=types.int32, - is_split=a.split, - device=a.device, - comm=a.comm, # , copy=False + final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm, copy=False ) if out is not None: out.larray = final_result @@ -3330,8 +3326,6 @@ def unique( lres_split = None gres = factories.array(lres, dtype=a.dtype, is_split=None, device=a.device, copy=False) else: - # TODO: balancing should be unnecessary before distributed sorting - gres.balance_() # global sorted unique lres = __pivot_sorting(gres, torch.unique, 0, sorted=True, return_inverse=True) # second local unique From 6d385e8b3bd24ad94c2538df15a93ca0eff8f8a2 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 20 Jan 2022 12:15:50 +0100 Subject: [PATCH 87/87] Bring back `factories.array` to original state, changes forked to dedicated branch --- heat/core/factories.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/heat/core/factories.py b/heat/core/factories.py index ec04f2d54c..50a55a0613 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -17,6 +17,7 @@ from . import devices from . import types + __all__ = [ "arange", "array", @@ -171,7 +172,7 @@ def array( the :func:`~heat.core.dndarray.astype` method. copy : bool, optional If ``True`` (default), then the object is copied. Otherwise, a copy will only be made if obj is a nested - sequence or if a copy is needed to satisfy any of the other requirements, e.g. ``dtype`` or ``order``. + sequence or if a copy is needed to satisfy any of the other requirements, e.g. ``dtype``. ndmin : int, optional Specifies the minimum number of dimensions that the resulting array should have. Ones will, if needed, be attached to the shape if ``ndim > 0`` and prefaced in case of ``ndim < 0`` to meet the requirement. @@ -319,7 +320,6 @@ def array( if device is not None else devices.get_device().torch_device, ) - except RuntimeError: raise TypeError("invalid data of type {}".format(type(obj))) else: @@ -337,12 +337,7 @@ def array( else: torch_dtype = dtype.torch_type() if obj.dtype != torch_dtype: - if not copy: - # different dtype, copy anyway - obj = obj.clone().type(torch_dtype) - else: - # obj is already a copy - obj = obj.type(torch_dtype) + obj = obj.type(torch_dtype) # infer device from obj if not explicitly given if device is None: @@ -365,13 +360,11 @@ def array( if ndmin_abs > 0 > ndmin: obj = obj.reshape(ndmin_abs * (1,) + obj.shape) - # sanitize split or is_split - if split is not None: - if is_split is not None: - raise ValueError("cannot specify both split and is_split") - split = sanitize_axis(obj.shape, split) - elif is_split is not None: - is_split = sanitize_axis(obj.shape, is_split) + # sanitize the split axes, ensure mutual exclusiveness + split = sanitize_axis(obj.shape, split) + is_split = sanitize_axis(obj.shape, is_split) + if split is not None and is_split is not None: + raise ValueError("split and is_split are mutually exclusive parameters") # sanitize comm object comm = sanitize_comm(comm) @@ -384,12 +377,8 @@ def array( # content shall be split, chunk the passed data object up if split is not None: _, _, slices = comm.chunk(gshape, split) - if not copy: - obj = obj[slices] - else: - obj = obj[slices].clone() + obj = obj[slices].clone() obj = sanitize_memory_layout(obj, order=order) - # check with the neighboring rank whether the local shape would fit into a global shape elif is_split is not None: gshape = np.array(gshape)