diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index e29acf7d27..3764cff4e8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -985,11 +985,12 @@ def __process_key( distr_mask_fast_path = False # mask along split axis within tuple? if arr.is_distributed(): - split_key = None - if isinstance(key, tuple) and len(key) > (arr.split or 0): + if isinstance(key, tuple) and len(key) > arr.split: split_key = key[arr.split] - elif not isinstance(key, tuple): + elif isinstance(key, DNDarray): split_key = key + else: + split_key = None if ( isinstance(split_key, DNDarray) @@ -1010,19 +1011,19 @@ def __process_key( # 1D mask on split=0 distr_mask_fast_path = True - # early out if mask and not tuple key - if distr_mask_fast_path and not isinstance(key, tuple): - return arr, ProcessedKey( - key=key.larray, - op_type="distr_mask", - output_shape=(), # Dummy shape, bypassed safely in __setitem__ - output_split=0 if op == "get" else arr.split, - split_key_is_ordered=0, - key_is_mask_like=True, - out_is_balanced=False, - root=None, - backwards_transpose_axes=tuple(range(arr.ndim)), - ) + # early out if mask and not tuple key + if distr_mask_fast_path and not isinstance(key, tuple): + return arr, ProcessedKey( + key=key.larray, + op_type="distr_mask", + output_shape=(), # Dummy shape, bypassed safely in __setitem__ + output_split=0 if op == "get" else arr.split, + split_key_is_ordered=0, + key_is_mask_like=True, + out_is_balanced=False, + root=None, + backwards_transpose_axes=tuple(range(arr.ndim)), + ) # normalize index components if isinstance(key, DNDarray): @@ -1098,19 +1099,15 @@ def __process_key( tuple(key.shape), arr.shape ) ) - if not distr_mask_fast_path: - if key_ndim == 0: - # 0-D boolean mask: keep as 0-D tensor, do not extract non-zero - key = key.larray if isinstance(key, DNDarray) else key - else: - # extract non-zero elements - try: - key = key.nonzero(as_tuple=True) - except TypeError: - key = key.nonzero() - else: - # keep the raw boolean mask + if key_ndim == 0: + # 0-D boolean mask: keep as 0-D tensor, do not extract non-zero key = key.larray if isinstance(key, DNDarray) else key + else: + # extract non-zero elements + try: + key = key.nonzero(as_tuple=True) + except TypeError: + key = key.nonzero() key_is_mask_like = True else: @@ -1204,7 +1201,7 @@ def __process_key( elif split_key_is_ordered == 0: op_type = "distributed" elif key_is_mask_like: - op_type = "distr_mask" if distr_mask_fast_path else "local_mask" + op_type = "local_mask" else: op_type = "advanced"