Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,11 +985,12 @@ def __process_key(
distr_mask_fast_path = False
# mask along split axis within tuple?
if arr.is_distributed():
split_key = None
if isinstance(key, tuple) and len(key) > (arr.split or 0):
if isinstance(key, tuple) and len(key) > arr.split:
split_key = key[arr.split]
elif not isinstance(key, tuple):
elif isinstance(key, DNDarray):
split_key = key
else:
split_key = None

if (
isinstance(split_key, DNDarray)
Expand All @@ -1010,19 +1011,19 @@ def __process_key(
# 1D mask on split=0
distr_mask_fast_path = True

# early out if mask and not tuple key
if distr_mask_fast_path and not isinstance(key, tuple):
return arr, ProcessedKey(
key=key.larray,
op_type="distr_mask",
output_shape=(), # Dummy shape, bypassed safely in __setitem__
output_split=0 if op == "get" else arr.split,
split_key_is_ordered=0,
key_is_mask_like=True,
out_is_balanced=False,
root=None,
backwards_transpose_axes=tuple(range(arr.ndim)),
)
# early out if mask and not tuple key
if distr_mask_fast_path and not isinstance(key, tuple):
return arr, ProcessedKey(
key=key.larray,
op_type="distr_mask",
output_shape=(), # Dummy shape, bypassed safely in __setitem__
output_split=0 if op == "get" else arr.split,
split_key_is_ordered=0,
key_is_mask_like=True,
out_is_balanced=False,
root=None,
backwards_transpose_axes=tuple(range(arr.ndim)),
)

# normalize index components
if isinstance(key, DNDarray):
Expand Down Expand Up @@ -1098,19 +1099,15 @@ def __process_key(
tuple(key.shape), arr.shape
)
)
if not distr_mask_fast_path:
if key_ndim == 0:
# 0-D boolean mask: keep as 0-D tensor, do not extract non-zero
key = key.larray if isinstance(key, DNDarray) else key
else:
# extract non-zero elements
try:
key = key.nonzero(as_tuple=True)
except TypeError:
key = key.nonzero()
else:
# keep the raw boolean mask
if key_ndim == 0:
# 0-D boolean mask: keep as 0-D tensor, do not extract non-zero
key = key.larray if isinstance(key, DNDarray) else key
else:
# extract non-zero elements
try:
key = key.nonzero(as_tuple=True)
except TypeError:
key = key.nonzero()

key_is_mask_like = True
else:
Expand Down Expand Up @@ -1204,7 +1201,7 @@ def __process_key(
elif split_key_is_ordered == 0:
op_type = "distributed"
elif key_is_mask_like:
op_type = "distr_mask" if distr_mask_fast_path else "local_mask"
op_type = "local_mask"
else:
op_type = "advanced"

Expand Down
Loading