From 3d055739a440ba220c9e9df31b386397f9d468a6 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 3 Jun 2026 12:16:04 +0200 Subject: [PATCH 1/2] First small cleanup --- heat/core/dndarray.py | 49 +++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index e29acf7d27..c94979ccf8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -990,7 +990,6 @@ def __process_key( split_key = key[arr.split] elif not isinstance(key, tuple): split_key = key - if ( isinstance(split_key, DNDarray) and split_key.dtype in (ht_bool, ht_uint8) @@ -1010,19 +1009,19 @@ def __process_key( # 1D mask on split=0 distr_mask_fast_path = True - # early out if mask and not tuple key - if distr_mask_fast_path and not isinstance(key, tuple): - return arr, ProcessedKey( - key=key.larray, - op_type="distr_mask", - output_shape=(), # Dummy shape, bypassed safely in __setitem__ - output_split=0 if op == "get" else arr.split, - split_key_is_ordered=0, - key_is_mask_like=True, - out_is_balanced=False, - root=None, - backwards_transpose_axes=tuple(range(arr.ndim)), - ) + # early out if mask and not tuple key + if distr_mask_fast_path and not isinstance(key, tuple): + return arr, ProcessedKey( + key=key.larray, + op_type="distr_mask", + output_shape=(), # Dummy shape, bypassed safely in __setitem__ + output_split=0 if op == "get" else arr.split, + split_key_is_ordered=0, + key_is_mask_like=True, + out_is_balanced=False, + root=None, + backwards_transpose_axes=tuple(range(arr.ndim)), + ) # normalize index components if isinstance(key, DNDarray): @@ -1098,19 +1097,15 @@ def __process_key( tuple(key.shape), arr.shape ) ) - if not distr_mask_fast_path: - if key_ndim == 0: - # 0-D boolean mask: keep as 0-D tensor, do not extract non-zero - key = key.larray if isinstance(key, DNDarray) else key - else: - # extract non-zero elements - try: - key = key.nonzero(as_tuple=True) - except TypeError: - key = key.nonzero() - else: - # keep the raw boolean mask + if key_ndim == 0: + # 0-D boolean mask: keep as 0-D tensor, do not extract non-zero key = key.larray if isinstance(key, DNDarray) else key + else: + # extract non-zero elements + try: + key = key.nonzero(as_tuple=True) + except TypeError: + key = key.nonzero() key_is_mask_like = True else: @@ -1204,7 +1199,7 @@ def __process_key( elif split_key_is_ordered == 0: op_type = "distributed" elif key_is_mask_like: - op_type = "distr_mask" if distr_mask_fast_path else "local_mask" + op_type = "local_mask" else: op_type = "advanced" From 3f855a9058ad8d8a1f3516ca02693ba57ce114b1 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 3 Jun 2026 16:27:39 +0200 Subject: [PATCH 2/2] Another small simplification --- heat/core/dndarray.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index c94979ccf8..3764cff4e8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -985,11 +985,13 @@ def __process_key( distr_mask_fast_path = False # mask along split axis within tuple? if arr.is_distributed(): - split_key = None - if isinstance(key, tuple) and len(key) > (arr.split or 0): + if isinstance(key, tuple) and len(key) > arr.split: split_key = key[arr.split] - elif not isinstance(key, tuple): + elif isinstance(key, DNDarray): split_key = key + else: + split_key = None + if ( isinstance(split_key, DNDarray) and split_key.dtype in (ht_bool, ht_uint8)