diff --git a/pyproject.toml b/pyproject.toml index 3a067a18..0bbb2f0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ yuanrong = [ "openyuanrong-datasystem" ] mooncake = [ - "mooncake-transfer-engine==0.3.10.post1" + "mooncake-transfer-engine==0.3.10.post2" ] # If you need to mimic `package_dir={'': '.'}`: diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 8b171717..83f8b7eb 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -212,7 +212,6 @@ def test_kv_put_with_dict_fields(self, controller, tq_api): expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["data"], expected) - # delete the key (MooncakeStore does not support updating existing key, so we need to clear it before next test) tq_api.kv_clear(keys=key, partition_id=partition_id) def test_kv_put_with_tensordict_fields(self, controller, tq_api): diff --git a/tests/test_tensor_utils.py b/tests/test_tensor_utils.py new file mode 100644 index 00000000..5d534938 --- /dev/null +++ b/tests/test_tensor_utils.py @@ -0,0 +1,196 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for transfer_queue.utils.tensor_utils.""" + +import pytest +import torch + +from transfer_queue.utils.tensor_utils import ( + allocate_empty_tensors, + compute_stride, + get_nbytes, + merge_contiguous_memory, +) + + +class TestComputeStride: + """Tests for compute_stride.""" + + def test_3d(self): + assert compute_stride((2, 3, 4)) == (12, 4, 1) + + def test_1d(self): + assert compute_stride((5,)) == (1,) + + def test_scalar(self): + assert compute_stride(()) == () + + def test_2d(self): + assert compute_stride((3, 5)) == (5, 1) + + +class TestGetNbytes: + """Tests for get_nbytes.""" + + def test_basic(self): + dtypes = [torch.float32, torch.int32] + shapes = [(2, 3), (4,)] + result = get_nbytes(dtypes, shapes) + assert result == [2 * 3 * 4, 4 * 4] # float32=4, int32=4 + + def test_scalar(self): + dtypes = [torch.float64] + shapes = [()] + result = get_nbytes(dtypes, shapes) + assert result == [8] # scalar = 1 element + + def test_list_shape(self): + dtypes = [torch.float32] + shapes = [[]] # list instead of tuple + result = get_nbytes(dtypes, shapes) + assert result == [4] + + def test_mixed_dtypes(self): + dtypes = [torch.float16, torch.float32, torch.int64] + shapes = [(10,), (10,), (10,)] + result = get_nbytes(dtypes, shapes) + assert result == [10 * 2, 10 * 4, 10 * 8] + + +class TestAllocateEmptyTensors: + """Tests for allocate_empty_tensors.""" + + def test_basic(self): + dtypes = [torch.float32, torch.float32, torch.int32] + shapes = [(2, 3), (4,), (5,)] + tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + + assert len(tensors) == 3 + assert len(ptrs) == 3 + assert len(region_ptrs) == 2 # float32 group + int32 group + assert len(region_sizes) == 2 + + # Same dtype tensors share the same underlying storage + assert tensors[0].untyped_storage().data_ptr() == region_ptrs[0] + assert tensors[1].untyped_storage().data_ptr() == region_ptrs[0] + assert tensors[2].untyped_storage().data_ptr() == region_ptrs[1] + + # Shapes are correct + assert list(tensors[0].shape) == [2, 3] + assert list(tensors[1].shape) == [4] + assert list(tensors[2].shape) == [5] + + def test_scalar(self): + dtypes = [torch.float32, torch.int32] + shapes = [(), ()] + tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + + assert len(tensors) == 2 + assert tensors[0].numel() == 1 + assert tensors[1].numel() == 1 + assert len(region_ptrs) == 2 + + def test_empty(self): + result = allocate_empty_tensors([], []) + assert result == ([], [], [], []) + + def test_regions_complex(self): + """Mixed dtypes and shapes: verify region counts, sizes, and per-tensor offsets.""" + dtypes = [ + torch.float32, # group 0: (2, 3) -> 6 elements + torch.int32, # group 1: (4,) -> 4 elements + torch.float32, # group 0: scalar -> 1 element + torch.float64, # group 2: (2, 2) -> 4 elements + torch.int32, # group 1: (3, 2) -> 6 elements + ] + shapes = [(2, 3), (4,), (), (2, 2), (3, 2)] + tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + + # 3 dtype groups in insertion order: float32, int32, float64 + assert len(region_ptrs) == 3 + assert len(region_sizes) == 3 + assert len(set(region_ptrs)) == 3 # distinct allocations + + # float32 region: 6 + 1 = 7 elements * 4 bytes = 28 bytes + assert region_sizes[0] == 7 * 4 + # int32 region: 4 + 6 = 10 elements * 4 bytes = 40 bytes + assert region_sizes[1] == 10 * 4 + # float64 region: 4 elements * 8 bytes = 32 bytes + assert region_sizes[2] == 4 * 8 + + # Per-tensor ptrs must lie inside their respective regions + # tensor 0 (float32, shape (2,3), offset 0) + assert ptrs[0] == region_ptrs[0] + # tensor 1 (int32, shape (4,), offset 0) + assert ptrs[1] == region_ptrs[1] + # tensor 2 (float32, scalar, offset 6) + assert ptrs[2] == region_ptrs[0] + 6 * 4 + # tensor 3 (float64, shape (2,2), offset 0) + assert ptrs[3] == region_ptrs[2] + # tensor 4 (int32, shape (3,2), offset 4) + assert ptrs[4] == region_ptrs[1] + 4 * 4 + + +class TestMergeContiguousMemory: + """Tests for merge_contiguous_memory.""" + + def test_basic_merge(self): + ptrs = [0, 10, 30] + sizes = [10, 20, 10] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + # 0+10=10 (contiguous with 10), 10+20=30 (contiguous with 30) -> all merge into [0] + assert merged_ptrs == [0] + assert merged_sizes == [40] + + def test_no_contiguous(self): + ptrs = [0, 100, 200] + sizes = [50, 50, 50] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + assert merged_ptrs == [0, 100, 200] + assert merged_sizes == [50, 50, 50] + + def test_unsorted_input(self): + ptrs = [100, 0, 50] + sizes = [50, 50, 50] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + # After sorting: 0, 50, 100; all contiguous -> merge into [0] + assert merged_ptrs == [0] + assert merged_sizes == [150] + + def test_single_region(self): + ptrs = [10] + sizes = [100] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + assert merged_ptrs == [10] + assert merged_sizes == [100] + + def test_empty(self): + assert merge_contiguous_memory([], []) == ([], []) + + def test_mismatched_lengths_both_empty_not_triggered(self): + # If one is empty and other is not, should raise ValueError + with pytest.raises(ValueError, match="ptrs and sizes must have the same length"): + merge_contiguous_memory([], [10]) + + with pytest.raises(ValueError, match="ptrs and sizes must have the same length"): + merge_contiguous_memory([0], []) + + def test_three_continuous(self): + ptrs = [0, 10, 20] + sizes = [10, 10, 10] + merged_ptrs, merged_sizes = merge_contiguous_memory(ptrs, sizes) + assert merged_ptrs == [0] + assert merged_sizes == [30] diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 47a0c318..8f311b4d 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -14,6 +14,7 @@ # limitations under the License. import pickle +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Optional import torch @@ -22,16 +23,19 @@ from transfer_queue.storage.clients.base import TransferQueueStorageKVClient from transfer_queue.storage.clients.factory import StorageClientFactory from transfer_queue.utils.logging_utils import get_logger +from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory logger = get_logger(__name__) MOONCAKE_STORE_IMPORTED: bool = True try: - from mooncake.store import MooncakeDistributedStore + from mooncake.store import MooncakeDistributedStore, ReplicateConfig + except ImportError: MOONCAKE_STORE_IMPORTED = False -BATCH_SIZE_LIMIT: int = 500 +BATCH_SIZE_LIMIT: int = 200 +MAX_WORKER_THREADS = 4 @StorageClientFactory.register("MooncakeStoreClient") @@ -76,10 +80,8 @@ def __init__(self, config: dict[str, Any]): if not self.metadata_server.startswith("etcd://") and not self.metadata_server.endswith("/metadata"): self.metadata_server = self.metadata_server + "/metadata" - if self.metadata_server is None: - raise ValueError("Missing 'metadata_server' in config") - if self.master_server_address is None: - raise ValueError("Missing 'master_server_address' in config") + self.replica_config = ReplicateConfig() + self.replica_config.with_hard_pin = True self._store = MooncakeDistributedStore() ret = self._store.setup( @@ -94,7 +96,7 @@ def __init__(self, config: dict[str, Any]): if ret != 0: raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + def put(self, keys: list[str], values: list[Any]) -> None: """Stores multiple key-value pairs to MooncakeStore. Args: @@ -114,61 +116,77 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: for key, value in zip(keys, values, strict=True): if isinstance(value, torch.Tensor): - tensor = value.contiguous() - # TODO: use gpu direct rdma instead - if tensor.device.type == "cuda": - tensor = tensor.cpu() tensor_keys.append(key) - tensor_values.append(tensor) + tensor_values.append(value) else: non_tensor_keys.append(key) - non_tensor_values.append(pickle.dumps(value)) + non_tensor_values.append(value) + + futures = [] + with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: + for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT): + batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT] + futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) - if tensor_keys: - self._batch_put_tensors(tensor_keys, tensor_values) + for i in range(0, len(non_tensor_keys), BATCH_SIZE_LIMIT): + batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_values = non_tensor_values[i : i + BATCH_SIZE_LIMIT] + futures.append(executor.submit(self._put_bytes_thread_worker, batch_keys, batch_values)) - if non_tensor_keys: - self._batch_put_bytes(non_tensor_keys, non_tensor_values) + for future in as_completed(futures): + future.result() return None - def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_tensors = tensors[i : i + BATCH_SIZE_LIMIT] + def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None: + """Worker thread for putting batch of tensors to MooncakeStore.""" - results = self._store.batch_put_tensor(batch_keys, batch_tensors) + batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) + batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes) + self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) + + try: + results = self._store.batch_upsert_from(batch_keys, batch_ptrs, batch_sizes, config=self.replica_config) if not all(r == 0 for r in results): failed_indices = [j for j, r in enumerate(results) if r != 0] error_codes = [results[j] for j in failed_indices] raise RuntimeError( - f"batch_put_tensor failed for indices {failed_indices} with error codes: {error_codes}" + f"batch_upsert_from failed for indices {failed_indices} with error codes: {error_codes}" ) + finally: + self._unregister_all_buffers(batch_ptr_reduced) - def _batch_put_bytes(self, keys: list[str], values: list[bytes]): - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_values = values[i : i + BATCH_SIZE_LIMIT] + def _put_bytes_thread_worker(self, batch_keys: list[str], batch_values: list[Any]): + """Worker thread for putting batch of non-tensors to MooncakeStore.""" - ret = self._store.put_batch(batch_keys, batch_values) - if ret != 0: - raise RuntimeError(f"put_batch failed with error code: {ret}") + batch_values = [pickle.dumps(v, protocol=pickle.HIGHEST_PROTOCOL) for v in batch_values] - def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: + ret = self._store.upsert_batch(batch_keys, batch_values, self.replica_config) + if ret != 0: + raise RuntimeError(f"upsert_batch failed with error code: {ret}") + + def get( + self, + keys: list[str], + shapes: Optional[list[Any]] = None, + dtypes: Optional[list[Any]] = None, + custom_backend_meta: Optional[list[str]] = None, + ) -> list[Any]: """Get multiple key-value pairs from MooncakeStore. Args: - keys (List[str]): Keys to fetch. - shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). - dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. - custom_backend_meta (List[str], optional): ... + keys: Keys to fetch. + shapes: Expected tensor shapes (use [] for scalars). + dtypes: Expected dtypes; use None for non-tensor data. + custom_backend_meta: Optional custom backend metadata. Returns: - List[Any]: Retrieved values in the same order as input keys. + Retrieved values in the same order as input keys. """ if shapes is None or dtypes is None: - raise ValueError("MooncakeStoreClient needs shapes and dtypes") + raise ValueError("MooncakeStoreClient needs shapes and dtypes for zero-copy transfer.") if not (len(keys) == len(shapes) == len(dtypes)): raise ValueError("Lengths of keys, shapes, dtypes must match") @@ -183,84 +201,103 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non results = [None] * len(keys) - if tensor_indices: - tensor_keys = [keys[i] for i in tensor_indices] - tensor_shapes = [shapes[i] for i in tensor_indices] - tensor_dtypes = [dtypes[i] for i in tensor_indices] - tensor_results = self._batch_get_tensors(tensor_keys, tensor_shapes, tensor_dtypes) - # TODO: optimize these for loops - for idx, tensor in zip(tensor_indices, tensor_results, strict=True): - results[idx] = tensor - - if non_tensor_indices: - non_tensor_keys = [keys[i] for i in non_tensor_indices] - non_tensor_results = self._batch_get_bytes(non_tensor_keys) - for idx, data in zip(non_tensor_indices, non_tensor_results, strict=True): - results[idx] = pickle.loads(data) + futures = [] + with ThreadPoolExecutor(max_workers=MAX_WORKER_THREADS) as executor: + for i in range(0, len(tensor_indices), BATCH_SIZE_LIMIT): + batch_indexes = tensor_indices[i : i + BATCH_SIZE_LIMIT] + batch_keys = [keys[i] for i in batch_indexes] + batch_shapes = [shapes[i] for i in batch_indexes] + batch_dtypes = [dtypes[i] for i in batch_indexes] + futures.append( + executor.submit( + self._get_tensors_thread_worker, batch_keys, batch_shapes, batch_dtypes, batch_indexes + ) + ) + + for i in range(0, len(non_tensor_indices), BATCH_SIZE_LIMIT): + batch_indexes = non_tensor_indices[i : i + BATCH_SIZE_LIMIT] + batch_keys = [keys[i] for i in batch_indexes] + futures.append(executor.submit(self._get_bytes_thread_worker, batch_keys, batch_indexes)) + + for future in as_completed(futures): + retrieved_values, batch_indexes = future.result() + for idx, val in zip(batch_indexes, retrieved_values, strict=True): + results[idx] = val return results - def _batch_get_tensors(self, keys: list[str], shapes: list, dtypes: list) -> list[Tensor]: - tensors = [None] * len(keys) + def _get_tensors_thread_worker( + self, batch_keys: list[str], batch_shapes: list[tuple], batch_dtypes: list[torch.dtype], indexes: list[int] + ) -> tuple[list[Tensor], list[int]]: + batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) + batch_buffer_tensors, batch_buffer_ptrs, region_ptrs, region_sizes = allocate_empty_tensors( + batch_dtypes, batch_shapes + ) - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_shapes = shapes[i : i + BATCH_SIZE_LIMIT] - batch_dtypes = dtypes[i : i + BATCH_SIZE_LIMIT] + self._register_all_buffers(region_ptrs, region_sizes) + try: + ret_codes = self._store.batch_get_into(batch_keys, batch_buffer_ptrs, batch_nbytes) + if len(ret_codes) != len(batch_keys): + raise RuntimeError(f"batch_get_into returned {len(ret_codes)} results, expected {len(batch_keys)}") + for i, ret in enumerate(ret_codes): + if ret < 0: + raise RuntimeError(f"batch_get_into failed for key `{batch_keys[i]}` with error code: {ret}") + finally: + self._unregister_all_buffers(region_ptrs) - batch_results = self._store.batch_get_tensor(batch_keys) + return batch_buffer_tensors, indexes - if len(batch_results) != len(batch_keys): - raise RuntimeError(f"batch_get_tensor returned {len(batch_results)} items, expected {len(batch_keys)}") + def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) -> tuple[list[Any], list[int]]: + results = [] - for j, (tensor, shape, dtype) in enumerate(zip(batch_results, batch_shapes, batch_dtypes, strict=True)): - if tensor is None: - raise RuntimeError(f"batch_get_tensor returned None for key '{batch_keys[j]}'") - if tensor.shape != torch.Size(shape): - raise RuntimeError( - f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}" - ) - if tensor.dtype != dtype: - raise RuntimeError( - f"Dtype mismatch for key '{batch_keys[j]}': expected {dtype}, got {tensor.dtype}" - ) - tensors[i + j] = tensor + batch_results = self._store.get_batch(batch_keys) + if len(batch_results) != len(batch_keys): + raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}") - return tensors + batch_results = [pickle.loads(result) if result != b"" else None for result in batch_results] + results.extend(batch_results) - def _batch_get_bytes(self, keys: list[str]) -> list[bytes]: - results = [] - for i in range(0, len(keys), BATCH_SIZE_LIMIT): - batch_keys = keys[i : i + BATCH_SIZE_LIMIT] - batch_results = self._store.get_batch(batch_keys) - if len(batch_results) != len(batch_keys): - raise RuntimeError(f"get_batch returned {len(batch_results)} items, expected {len(batch_keys)}") - results.extend(batch_results) - return results + return results, indexes - def clear(self, keys: list[str], custom_backend_meta=None): + def clear(self, keys: list[str], custom_backend_meta: Optional[list[Any]] = None) -> None: """Deletes multiple keys from MooncakeStore. - Args: keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - global_indexes_patterns = {key.split("@")[0] + "@.*" for key in keys} - for p in global_indexes_patterns: - ret = self._store.remove_by_regex(p, force=True) - if ret < 0: - logger.warning(f"remove failed for key '{p}' with error code: {ret}") - - # FIXME: controller returned BatchMeta may have mismatched fields in some case, preventing - # key-value based backends to accurately clear all existing keys.. - # for key in keys: - # ret = self._store.remove(key) - # if not (ret == 0 or ret == -704): - # logger.warning(f"remove failed for key '{key}' with error code: {ret}") + ret_codes = self._store.batch_remove(keys, force=True) + for i, ret in enumerate(ret_codes): + if not (ret == 0 or ret == -704): + logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}") def close(self): """Closes MooncakeStore.""" if self._store: self._store.close() self._store = None + + @staticmethod + def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[int], list[int], list[Tensor]]: + ptr_list: list[int] = [] + size_list: list[int] = [] + tensor_list: list[Tensor] = [] # hold reference for the contiguous tensor + for t in values: + # TODO: support gpu direct rdma and use different data paths. + # For GPU, it's more reasonable to perform data copy since + # The register overhead is much higher than CPU + if t.device.type == "cuda": + t = t.cpu() + t = t.contiguous() + tensor_list.append(t) + ptr_list.append(t.data_ptr()) + size_list.append(t.nbytes) + return ptr_list, size_list, tensor_list + + def _register_all_buffers(self, ptrs, sizes): + for ptr, size in zip(ptrs, sizes, strict=True): + self._store.register_buffer(ptr, size) + + def _unregister_all_buffers(self, ptrs): + for ptr in ptrs: + self._store.unregister_buffer(ptr) diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index bfd6cf49..9aeffdc1 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -55,7 +55,7 @@ def supports_put(self, value: Any) -> bool: """Check if this strategy can store the given value.""" @abstractmethod - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> None: """Store key-value pairs using this strategy.""" @abstractmethod @@ -71,7 +71,7 @@ def supports_clear(self, strategy_tag: Any) -> bool: """Check if this strategy owns data identified by metadata.""" @abstractmethod - def clear(self, keys: list[str]): + def clear(self, keys: list[str]) -> None: """Delete keys from storage.""" @@ -129,7 +129,7 @@ def supports_put(self, value: Any) -> bool: # Only contiguous NPU tensors are supported by this adapter. return value.is_contiguous() - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> None: """Store NPU tensors in batches; deletes before overwrite.""" for i in range(0, len(keys), self.KEYS_LIMIT): batch_keys = keys[i : i + self.KEYS_LIMIT] @@ -167,14 +167,14 @@ def supports_clear(self, strategy_tag: str) -> bool: """Matches 'DsTensorClient' strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def clear(self, keys: list[str]): + def clear(self, keys: list[str]) -> None: """Delete NPU tensor keys in batches.""" for i in range(0, len(keys), self.KEYS_LIMIT): batch = keys[i : i + self.KEYS_LIMIT] # Todo(dpj): Test call clear when no (key,value) put in ds self._ds_client.delete(batch) - def _create_empty_npu_tensorlist(self, shapes: list, dtypes: list): + def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) -> list[Tensor]: """ Create a list of empty NPU tensors with given shapes and dtypes. @@ -182,7 +182,7 @@ def _create_empty_npu_tensorlist(self, shapes: list, dtypes: list): shapes (list): List of tensor shapes (e.g., [(3,), (2, 4)]) dtypes (list): List of torch dtypes (e.g., [torch.float32, torch.int64]) Returns: - list: List of uninitialized NPU tensors + list[Tensor]: List of uninitialized NPU tensors """ tensors: list[Tensor] = [] for shape, dtype in zip(shapes, dtypes, strict=True): @@ -241,7 +241,7 @@ def supports_put(self, value: Any) -> bool: """Accepts any Python object.""" return True - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> None: """Store objects via zero-copy serialization in batches.""" for i in range(0, len(keys), self.PUT_KEYS_LIMIT): batch_keys = keys[i : i + self.PUT_KEYS_LIMIT] @@ -265,7 +265,7 @@ def supports_clear(self, strategy_tag: str) -> bool: """Matches 'KVClient' strategy tag.""" return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag() - def clear(self, keys: list[str]): + def clear(self, keys: list[str]) -> None: """Delete keys in batches.""" for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT] @@ -431,7 +431,13 @@ def put_task(strategy, indexes): strategy_tags[original_index] = tag return strategy_tags - def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]: + def get( + self, + keys: list[str], + shapes: Optional[list[Any]] = None, + dtypes: Optional[list[Any]] = None, + custom_backend_meta: Optional[list[str]] = None, + ) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. Requires shape and dtype hints to reconstruct NPU tensors correctly. @@ -470,7 +476,7 @@ def get_task(strategy, indexes): results[original_index] = value return results - def clear(self, keys: list[str], custom_backend_meta=None): + def clear(self, keys: list[str], custom_backend_meta: Optional[list[str]] = None) -> None: """Deletes multiple keys from remote storage. Args: @@ -511,8 +517,8 @@ def _route_to_strategies( The order must correspond to the original keys. selector: A function that determines whether a strategy supports an item. Signature: `(strategy: StorageStrategy, item: Any) -> bool`. - failback: If True, items that don't match any strategy will be ignored (not included in output). - If False, a ValueError will be raised for any unmatched item. + ignore_unmatched: If True, items that don't match any strategy will be ignored (not included in output). + If False, a ValueError will be raised for any unmatched item. Returns: A dictionary mapping each active strategy to a list of indexes in `items` diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 9f245a77..55901e31 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -405,19 +405,20 @@ def _generate_keys(field_names: list[str], global_indexes: list[int]) -> list[st return [pfx + sfx for sfx, pfx in itertools.product(keys_suffixes, keys_prefixes)] @staticmethod - def _generate_values(data: TensorDict) -> list[Tensor]: + def _generate_values(data: TensorDict) -> list[Any]: """ - Extract and flatten tensor values from a TensorDict in field-major order. + Extract and flatten values from a TensorDict in field-major order. Values are ordered by sorted field names, then by row (sample) order within each field. This matches the key order generated by `_generate_keys`. Args: - data (TensorDict): Input data where keys are field names and values are tensors. + data (TensorDict): Input data where keys are field names and values are tensors or any type + wrapped by NonTensorStack. Returns: - list[Tensor]: Flattened list of tensors, e.g., - [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] + list[Any]: Flattened list of values, e.g., + [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] """ - results: list[Tensor] = [] + results: list[Any] = [] for field in sorted(data.keys()): field_data = data[field] if isinstance(field_data, Tensor) and field_data.is_nested: @@ -461,17 +462,17 @@ def _get_executor(self) -> ThreadPoolExecutor: assert self._multi_threads_executor is not None return self._multi_threads_executor - def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Tensor]) -> TensorDict: + def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Any]) -> TensorDict: """ Reconstruct a TensorDict from a list of values using metadata. The values list is assumed to be in the same order as keys generated by `_generate_keys`. According to field names and global indexes in metadata, this method can determine - which dict key and which row this tensor belongs to. Then it reshapes the flat tensors list + which dict key and which row this value belongs to. Then it reshapes the flat values list back into a structured TensorDict . Args: metadata (BatchMeta): Metadata containing global indexes and field names. - values (list[Tensor]): List of tensors in field-major order. + values (list[Any]): List of values in field-major order. Returns: TensorDict: Reconstructed tensor dictionary with batch size equal to number of samples. """ @@ -538,7 +539,9 @@ def process_field(field_idx: int): return TensorDict(merged_data, batch_size=num_samples) @staticmethod - def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): + def _get_shape_type_custom_backend_meta_list( + metadata: BatchMeta, + ) -> tuple[list[torch.Size], list[torch.dtype], list[Any]]: """ Extract the expected shape, dtype, and custom_backend_meta for each field-sample pair in metadata. The order matches the key/value order: sorted by field name, then by global index. diff --git a/transfer_queue/utils/tensor_utils.py b/transfer_queue/utils/tensor_utils.py new file mode 100644 index 00000000..b3b8fa06 --- /dev/null +++ b/transfer_queue/utils/tensor_utils.py @@ -0,0 +1,182 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import operator +import os +from functools import reduce + +import torch +from torch import Tensor + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + + +def allocate_empty_tensors( + dtypes: list[torch.dtype], shapes: list[tuple] +) -> tuple[list[Tensor], list[int], list[int], list[int]]: + """Allocate empty tensors, grouping same dtypes into shared memory blocks. + + Instead of allocating each tensor separately, this function groups tensors + by their dtype and allocates one large contiguous memory block per dtype. + Each tensor is then created as a view into this shared memory. + + Args: + dtypes: List of torch dtypes for each tensor. + shapes: List of shapes (tuples) for each tensor. + + Returns: + A tuple containing: + - List of tensors sharing memory within their dtype groups. + - List of memory pointers (data_ptr) for each tensor. + - List of base pointers for each allocated memory region (one per dtype). + - List of total bytes for each allocated memory region (one per dtype). + + Example: + >>> dtypes = [torch.float32, torch.float32, torch.int32, torch.float32] + >>> shapes = [(10,), (20,), (5,), (15,)] + >>> tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes) + >>> # tensors[0], [1], [3] share the same dtype and memory block + """ + assert len(dtypes) == len(shapes), "dtypes and shapes must have the same length" + + if len(dtypes) == 0: + return [], [], [], [] + + # Group indices by dtype + dtype_groups: dict[torch.dtype, list[int]] = {} + for i, dtype in enumerate(dtypes): + if dtype not in dtype_groups: + dtype_groups[dtype] = [] + dtype_groups[dtype].append(i) + + tensor_list = [torch.empty(()) for _ in range(len(dtypes))] + ptr_list = [0] * len(dtypes) + region_ptrs: list[int] = [] + region_sizes: list[int] = [] + + # For each dtype group, allocate one big tensor and create views + for dtype, indices in dtype_groups.items(): + # Calculate total number of elements needed for this dtype + total_elements = 0 + shape_info = [] # Store (index, shape, num_elements, offset) + + for idx in indices: + shape = tuple(shapes[idx]) + num_elements = reduce(operator.mul, shape, 1) + shape_info.append((idx, shape, num_elements, total_elements)) + total_elements += num_elements + + # Allocate one big contiguous memory block for this dtype + big_tensor = torch.empty(total_elements, dtype=dtype) + region_ptrs.append(big_tensor.data_ptr()) + region_sizes.append(big_tensor.nbytes) + + # Create views into the big tensor for each small tensor + for idx, shape, num_elements, offset in shape_info: + # Use as_strided to create a view with the correct shape + small_tensor = big_tensor.as_strided(size=shape, stride=compute_stride(shape), storage_offset=offset) + tensor_list[idx] = small_tensor + ptr_list[idx] = small_tensor.data_ptr() + + return tensor_list, ptr_list, region_ptrs, region_sizes + + +def compute_stride(shape: tuple[int, ...]) -> tuple[int, ...]: + """Compute stride for a contiguous row-major (C-style) tensor. + + Args: + shape: The shape of the tensor. + + Returns: + Stride tuple for contiguous storage. + + Example: + >>> compute_stride((2, 3, 4)) + (12, 4, 1) + """ + stride = [] + cumulative = 1 + # Iterate from last dimension to first + for dim in reversed(shape): + stride.append(cumulative) + cumulative *= dim + return tuple(reversed(stride)) + + +def get_nbytes(dtypes, shapes) -> list[int]: + """Calculate number of bytes according to tensor dtypes and shapes.""" + assert len(dtypes) == len(shapes) + nbytes = [] + for i in range(len(dtypes)): + elem_size = torch.tensor([], dtype=dtypes[i]).element_size() + shape = tuple(shapes[i]) + numel = reduce(operator.mul, shape, 1) + nbytes.append(elem_size * numel) + + return nbytes + + +def merge_contiguous_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]: + """Merge contiguous memory regions to reduce register_buffer overhead + + Args: + ptrs: List of memory pointers (starting addresses). + sizes: List of memory region sizes corresponding to each pointer. + + Returns: + A tuple of (merged_ptrs, merged_sizes) where contiguous regions + have been merged into single regions. + + Example: + >>> merge_contiguous_memory([0, 10, 30], [10, 20, 10]) + ([0, 30], [30, 10]) + + >>> merge_contiguous_memory([0, 5, 20], [5, 5, 10]) + ([0, 20], [10, 10]) + """ + if len(ptrs) != len(sizes): + raise ValueError("ptrs and sizes must have the same length") + + if not ptrs: + return [], [] + + # Create list of (ptr, size) pairs and sort by pointer address + regions = sorted(zip(ptrs, sizes, strict=False), key=lambda x: x[0]) + + merged_ptrs = [] + merged_sizes = [] + + # Initialize with the first region + current_ptr, current_size = regions[0] + + for ptr, size in regions[1:]: + # Check if current region is contiguous with the next one + # A region is contiguous if: ptr == current_ptr + current_size + if ptr == current_ptr + current_size: + # Merge: extend the current region + current_size += size + else: + # Not contiguous: save the current region and start a new one + merged_ptrs.append(current_ptr) + merged_sizes.append(current_size) + current_ptr, current_size = ptr, size + + # Add the last region + merged_ptrs.append(current_ptr) + merged_sizes.append(current_size) + + return merged_ptrs, merged_sizes