diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py new file mode 100644 index 00000000..6c218b7c --- /dev/null +++ b/tests/e2e/test_kv_interface_e2e.py @@ -0,0 +1,241 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end tests for the high-level KV interface. + +These tests verify the full lifecycle of the KV API: + kv_put → kv_batch_put → kv_batch_get → kv_list → kv_clear + +Prerequisites: + - Ray must be available + - TransferQueue must be initializable (tq.init()) + +Run: + pytest tests/e2e/test_kv_interface_e2e.py -v +""" + +import sys +from pathlib import Path + +import pytest +import ray +import torch + +# Setup path +parent_dir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(parent_dir)) + +import transfer_queue as tq # noqa: E402 + + +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + """Initialize and tear down TransferQueue for the entire test module.""" + if not ray.is_initialized(): + ray.init(namespace="test_kv_interface_e2e") + tq.init() + yield + tq.close() + ray.shutdown() + + +@pytest.fixture(autouse=True) +def clean_partition(): + """Clean up the test partition before each test.""" + yield + # Best-effort cleanup after each test + try: + tq.kv_clear(partition_id="test_kv") + except Exception: + pass + + +class TestKVPut: + """Tests for kv_put and kv_batch_put.""" + + def test_kv_put_single(self): + """kv_put should insert a single sample and make it retrievable.""" + tq.kv_put( + key="k1", + fields={"x": torch.tensor([1.0, 2.0])}, + partition_id="test_kv", + ) + + result = tq.kv_batch_get(keys=["k1"], partition_id="test_kv") + assert "k1" in result + assert torch.allclose(result["k1"]["x"].squeeze(), torch.tensor([1.0, 2.0])) + + def test_kv_put_with_tag(self): + """kv_put with tag should store metadata alongside the sample.""" + tq.kv_put( + key="tagged", + fields={"v": torch.tensor([42])}, + partition_id="test_kv", + tag={"status": "ready"}, + ) + + entries = tq.kv_list(partition_id="test_kv") + tagged_entry = [e for e in entries if e["key"] == "tagged"] + assert len(tagged_entry) == 1 + assert tagged_entry[0]["tag"]["status"] == "ready" + + def test_kv_batch_put(self): + """kv_batch_put should insert multiple samples in one call.""" + tq.kv_batch_put( + kv_pairs={ + "a": {"val": torch.tensor([10])}, + "b": {"val": torch.tensor([20])}, + "c": {"val": torch.tensor([30])}, + }, + partition_id="test_kv", + ) + + result = tq.kv_batch_get(keys=["a", "b", "c"], partition_id="test_kv") + assert result["a"]["val"].item() == 10 + assert result["b"]["val"].item() == 20 + assert result["c"]["val"].item() == 30 + + def test_kv_batch_put_with_tags(self): + """kv_batch_put should support per-key tags.""" + tq.kv_batch_put( + kv_pairs={ + "x": {"d": torch.tensor([1])}, + "y": {"d": torch.tensor([2])}, + }, + partition_id="test_kv", + tags={ + "x": {"label": "pos"}, + "y": {"label": "neg"}, + }, + ) + + entries = tq.kv_list(partition_id="test_kv") + tag_map = {e["key"]: e.get("tag", {}) for e in entries} + assert tag_map["x"]["label"] == "pos" + assert tag_map["y"]["label"] == "neg" + + +class TestKVGet: + """Tests for kv_batch_get.""" + + def test_get_all_fields(self): + """kv_batch_get without field filter should return all fields.""" + tq.kv_batch_put( + kv_pairs={ + "m1": {"f1": torch.tensor([1.0]), "f2": torch.tensor([2.0])}, + }, + partition_id="test_kv", + ) + + result = tq.kv_batch_get(keys=["m1"], partition_id="test_kv") + assert "f1" in result["m1"].keys() + assert "f2" in result["m1"].keys() + + def test_get_selected_fields(self): + """kv_batch_get with field selection should return only requested fields.""" + tq.kv_batch_put( + kv_pairs={ + "m2": {"alpha": torch.tensor([3.0]), "beta": torch.tensor([4.0])}, + }, + partition_id="test_kv", + ) + + result = tq.kv_batch_get(keys=["m2"], fields=["alpha"], partition_id="test_kv") + assert "alpha" in result["m2"].keys() + + def test_get_missing_key_raises(self): + """kv_batch_get should raise KeyError for non-existent keys.""" + with pytest.raises(KeyError, match="Keys not found in partition"): + tq.kv_batch_get(keys=["nonexistent"], partition_id="test_kv") + + +class TestKVList: + """Tests for kv_list.""" + + def test_list_empty_partition(self): + """kv_list on empty/unknown partition should return empty list.""" + entries = tq.kv_list(partition_id="empty_partition_xyz") + assert entries == [] + + def test_list_returns_all_keys(self): + """kv_list should return all keys in the partition.""" + tq.kv_batch_put( + kv_pairs={ + "p": {"v": torch.tensor([1])}, + "q": {"v": torch.tensor([2])}, + "r": {"v": torch.tensor([3])}, + }, + partition_id="test_kv", + ) + + entries = tq.kv_list(partition_id="test_kv") + keys = {e["key"] for e in entries} + assert keys == {"p", "q", "r"} + + +class TestKVClear: + """Tests for kv_clear.""" + + def test_clear_specific_keys(self): + """kv_clear with keys should remove only specified entries.""" + tq.kv_batch_put( + kv_pairs={ + "d1": {"v": torch.tensor([1])}, + "d2": {"v": torch.tensor([2])}, + "d3": {"v": torch.tensor([3])}, + }, + partition_id="test_kv", + ) + + tq.kv_clear(keys=["d1", "d3"], partition_id="test_kv") + + entries = tq.kv_list(partition_id="test_kv") + keys = {e["key"] for e in entries} + assert "d1" not in keys + assert "d3" not in keys + assert "d2" in keys + + def test_clear_entire_partition(self): + """kv_clear without keys should wipe the entire partition.""" + tq.kv_batch_put( + kv_pairs={ + "e1": {"v": torch.tensor([1])}, + "e2": {"v": torch.tensor([2])}, + }, + partition_id="test_kv", + ) + + tq.kv_clear(partition_id="test_kv") + entries = tq.kv_list(partition_id="test_kv") + assert entries == [] + + +class TestPartitionIsolation: + """Tests for partition namespace isolation.""" + + def test_same_key_different_partitions(self): + """Same key in different partitions should hold independent values.""" + tq.kv_put(key="shared", fields={"v": torch.tensor([100])}, partition_id="ns_1") + tq.kv_put(key="shared", fields={"v": torch.tensor([200])}, partition_id="ns_2") + + r1 = tq.kv_batch_get(keys=["shared"], partition_id="ns_1") + r2 = tq.kv_batch_get(keys=["shared"], partition_id="ns_2") + + assert r1["shared"]["v"].item() == 100 + assert r2["shared"]["v"].item() == 200 + + tq.kv_clear(partition_id="ns_1") + tq.kv_clear(partition_id="ns_2") diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 592ef0d2..17f85293 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -23,6 +23,11 @@ async_clear_samples, async_get_data, async_get_meta, + async_kv_batch_get, + async_kv_batch_put, + async_kv_clear, + async_kv_list, + async_kv_put, async_put, async_set_custom_meta, clear_partition, @@ -31,6 +36,11 @@ get_data, get_meta, init, + kv_batch_get, + kv_batch_put, + kv_clear, + kv_list, + kv_put, put, set_custom_meta, ) @@ -57,6 +67,16 @@ "async_set_custom_meta", "async_clear_samples", "async_clear_partition", + "kv_put", + "kv_batch_put", + "kv_batch_get", + "kv_list", + "kv_clear", + "async_kv_put", + "async_kv_batch_put", + "async_kv_batch_get", + "async_kv_list", + "async_kv_clear", "close", ] + [ "TransferQueueClient", diff --git a/transfer_queue/client.py b/transfer_queue/client.py index cd9c6e84..7037bbae 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -34,6 +34,7 @@ TransferQueueStorageManagerFactory, ) from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads +from transfer_queue.utils.enum_utils import ProductionStatus from transfer_queue.utils.zmq_utils import ( ZMQMessage, ZMQRequestType, @@ -898,6 +899,265 @@ async def async_get_partition_list( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e + # ── High-level KV Interface ────────────────────────────────────────── + # Redis-style Put/Get/List/Clear that wraps the low-level metadata APIs. + # String keys are mapped to global_indexes internally. + + def _ensure_kv_index(self): + """Lazily create the per-partition kv index.""" + if not hasattr(self, "_kv_key_to_index"): + self._kv_key_to_index: dict[str, dict[str, int]] = {} # partition_id -> {key -> global_index} + self._kv_index_to_key: dict[str, dict[int, str]] = {} # partition_id -> {global_index -> key} + + async def async_kv_put( + self, + key: str, + fields: dict, + partition_id: str = "default", + tag: Optional[dict[str, Any]] = None, + ) -> None: + """Insert or update a multi-column sample by key. + + Args: + key: Unique string key identifying the sample within the partition. + fields: Dictionary mapping field names to tensor values (without batch dim). + partition_id: Logical partition namespace. + tag: Optional metadata dictionary for status tracking. + """ + await self.async_kv_batch_put( + kv_pairs={key: fields}, + partition_id=partition_id, + tags={key: tag} if tag else None, + ) + + async def async_kv_batch_put( + self, + kv_pairs: dict[str, dict], + partition_id: str = "default", + tags: Optional[dict[str, dict[str, Any]]] = None, + ) -> None: + """Put multiple key-value pairs efficiently in batch. + + Args: + kv_pairs: Dictionary mapping keys to field dictionaries. + partition_id: Logical partition namespace. + tags: Optional dictionary mapping keys to metadata tag dictionaries. + """ + self._ensure_kv_index() + if partition_id not in self._kv_key_to_index: + self._kv_key_to_index[partition_id] = {} + self._kv_index_to_key[partition_id] = {} + + keys = list(kv_pairs.keys()) + fields_list = list(kv_pairs.values()) + + # Stack tensors into a batched TensorDict + # All samples must have the same set of field names + field_names = sorted(fields_list[0].keys()) + stacked = {} + for fname in field_names: + tensors = [f[fname] if isinstance(f[fname], Tensor) else torch.tensor(f[fname]) for f in fields_list] + stacked[fname] = torch.stack(tensors, dim=0) + data = TensorDict(stacked, batch_size=[len(keys)]) + + # Insert via low-level API (allocates global_indexes) + metadata = await self.async_put(data, partition_id=partition_id) + + # Record key <-> global_index mapping + for i, k in enumerate(keys): + gidx = metadata.global_indexes[i] + self._kv_key_to_index[partition_id][k] = gidx + self._kv_index_to_key[partition_id][gidx] = k + + # Store tags via custom_meta + if tags: + custom = {} + for k, t in tags.items(): + if t and k in self._kv_key_to_index[partition_id]: + gidx = self._kv_key_to_index[partition_id][k] + custom[gidx] = {**t, "_kv_key": k} + if custom: + metadata.update_custom_meta(custom) + await self.async_set_custom_meta(metadata) + else: + # Still store key mapping in custom_meta for kv_list + custom = {} + for k in keys: + gidx = self._kv_key_to_index[partition_id][k] + custom[gidx] = {"_kv_key": k} + metadata.update_custom_meta(custom) + await self.async_set_custom_meta(metadata) + + async def async_kv_batch_get( + self, + keys: list[str], + fields: Optional[list[str]] = None, + partition_id: str = "default", + ) -> dict[str, TensorDict]: + """Retrieve samples by keys, supporting column selection by fields. + + Args: + keys: List of string keys to retrieve. + fields: Optional list of field names. If None, retrieves all fields. + partition_id: Logical partition namespace. + + Returns: + Dictionary mapping keys to TensorDict of field values. + """ + self._ensure_kv_index() + if partition_id not in self._kv_key_to_index: + raise ValueError(f"Partition '{partition_id}' has no KV data.") + + key_map = self._kv_key_to_index[partition_id] + missing = [k for k in keys if k not in key_map] + if missing: + raise KeyError(f"Keys not found in partition '{partition_id}': {missing}") + + global_indexes = [key_map[k] for k in keys] + + # Construct a force_fetch to get specific global_indexes + # First get meta to know the field schema + all_fields = fields if fields else [] + metadata = await self.async_get_meta( + data_fields=all_fields, + batch_size=len(keys), + partition_id=partition_id, + mode="force_fetch", + ) + + # Build targeted metadata with only the requested global_indexes + from transfer_queue.metadata import FieldMeta, SampleMeta + + if fields is None: + # Use field names from the fetched metadata + fields = metadata.field_names + + samples = [] + for gidx in global_indexes: + sample_fields = {} + for fname in fields: + # Find the field meta from the full metadata if available + field_meta = None + for s in metadata.samples: + if s.global_index == gidx: + field_meta = s.fields.get(fname) + break + if field_meta is None: + # Construct a minimal field meta + field_meta = FieldMeta( + name=fname, dtype=None, shape=None, production_status=ProductionStatus.READY_FOR_CONSUME + ) + sample_fields[fname] = field_meta + samples.append(SampleMeta(partition_id=partition_id, global_index=gidx, fields=sample_fields)) + + target_meta = BatchMeta(samples=samples) + + # Fetch data + data = await self.async_get_data(target_meta) + + # Split into per-key results + result = {} + for i, k in enumerate(keys): + per_sample = {} + for fname in fields: + per_sample[fname] = data[fname][i : i + 1] + result[k] = TensorDict(per_sample, batch_size=[1]) + + return result + + async def async_kv_list( + self, + partition_id: str = "default", + ) -> list[dict[str, Any]]: + """List keys and tags (metadata) in a partition. + + Args: + partition_id: Logical partition namespace. + + Returns: + List of dictionaries with "key" and optional "tag" entries. + """ + self._ensure_kv_index() + if partition_id not in self._kv_key_to_index: + return [] + + # Get partition metadata including custom_meta + metadata = await self._get_partition_meta(partition_id) + + result = [] + index_to_key = self._kv_index_to_key.get(partition_id, {}) + for gidx in metadata.global_indexes: + entry: dict[str, Any] = {} + key = index_to_key.get(gidx) + if key is None: + # Try to recover from custom_meta + cm = metadata.custom_meta.get(gidx, {}) + key = cm.get("_kv_key", str(gidx)) + entry["key"] = key + cm = metadata.custom_meta.get(gidx, {}) + tag = {k: v for k, v in cm.items() if k != "_kv_key"} + if tag: + entry["tag"] = tag + result.append(entry) + return result + + async def async_kv_clear( + self, + keys: Optional[list[str]] = None, + partition_id: str = "default", + ) -> None: + """Remove key-value pairs from storage. + + Args: + keys: Optional list of string keys to remove. If None, clears entire partition. + partition_id: Logical partition namespace. + """ + self._ensure_kv_index() + + if keys is None: + await self.async_clear_partition(partition_id) + self._kv_key_to_index.pop(partition_id, None) + self._kv_index_to_key.pop(partition_id, None) + else: + key_map = self._kv_key_to_index.get(partition_id, {}) + from transfer_queue.metadata import FieldMeta, SampleMeta + + global_indexes = [] + for k in keys: + if k in key_map: + global_indexes.append(key_map[k]) + + if global_indexes: + # Get partition meta to know the field schema + partition_meta = await self._get_partition_meta(partition_id) + samples = [] + for gidx in global_indexes: + for s in partition_meta.samples: + if s.global_index == gidx: + samples.append(s) + break + else: + # Construct minimal sample + samples.append( + SampleMeta( + partition_id=partition_id, + global_index=gidx, + fields={ + fn: FieldMeta(fn, None, None, ProductionStatus.READY_FOR_CONSUME) + for fn in partition_meta.field_names + }, + ) + ) + + target_meta = BatchMeta(samples=samples) + await self.async_clear_samples(target_meta) + + # Clean up local index + for k in keys: + gidx = key_map.pop(k, None) + if gidx is not None and partition_id in self._kv_index_to_key: + self._kv_index_to_key[partition_id].pop(gidx, None) + def close(self) -> None: """Close the client and cleanup resources including storage manager.""" try: @@ -972,6 +1232,11 @@ def wrapper(*args, **kwargs): self._get_partition_list = _make_sync(self.async_get_partition_list) self._set_custom_meta = _make_sync(self.async_set_custom_meta) self._reset_consumption = _make_sync(self.async_reset_consumption) + self._kv_put = _make_sync(self.async_kv_put) + self._kv_batch_put = _make_sync(self.async_kv_batch_put) + self._kv_batch_get = _make_sync(self.async_kv_batch_get) + self._kv_list = _make_sync(self.async_kv_list) + self._kv_clear = _make_sync(self.async_kv_clear) def put( self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None @@ -1286,6 +1551,85 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: return self._set_custom_meta(metadata=metadata) + # ── High-level KV Interface (Sync) ──────────────────────────────────── + + def kv_put( + self, + key: str, + fields: dict, + partition_id: str = "default", + tag: Optional[dict[str, Any]] = None, + ) -> None: + """Insert or update a multi-column sample by key. + + Args: + key: Unique string key identifying the sample within the partition. + fields: Dictionary mapping field names to tensor values (without batch dim). + partition_id: Logical partition namespace. + tag: Optional metadata dictionary for status tracking. + """ + return self._kv_put(key=key, fields=fields, partition_id=partition_id, tag=tag) + + def kv_batch_put( + self, + kv_pairs: dict[str, dict], + partition_id: str = "default", + tags: Optional[dict[str, dict[str, Any]]] = None, + ) -> None: + """Put multiple key-value pairs efficiently in batch. + + Args: + kv_pairs: Dictionary mapping keys to field dictionaries. + partition_id: Logical partition namespace. + tags: Optional dictionary mapping keys to metadata tag dictionaries. + """ + return self._kv_batch_put(kv_pairs=kv_pairs, partition_id=partition_id, tags=tags) + + def kv_batch_get( + self, + keys: list[str], + fields: Optional[list[str]] = None, + partition_id: str = "default", + ) -> dict[str, TensorDict]: + """Retrieve samples by keys, supporting column selection by fields. + + Args: + keys: List of string keys to retrieve. + fields: Optional list of field names. If None, retrieves all fields. + partition_id: Logical partition namespace. + + Returns: + Dictionary mapping keys to TensorDict of field values. + """ + return self._kv_batch_get(keys=keys, fields=fields, partition_id=partition_id) + + def kv_list( + self, + partition_id: str = "default", + ) -> list[dict[str, Any]]: + """List keys and tags (metadata) in a partition. + + Args: + partition_id: Logical partition namespace. + + Returns: + List of dictionaries with "key" and optional "tag" entries. + """ + return self._kv_list(partition_id=partition_id) + + def kv_clear( + self, + keys: Optional[list[str]] = None, + partition_id: str = "default", + ) -> None: + """Remove key-value pairs from storage. + + Args: + keys: Optional list of string keys to remove. If None, clears entire partition. + partition_id: Logical partition namespace. + """ + return self._kv_clear(keys=keys, partition_id=partition_id) + def close(self) -> None: """Close the client and cleanup resources including event loop and thread.""" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 37ebdb23..19a157ef 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -625,6 +625,222 @@ async def async_clear_partition(partition_id: str): return await tq_client.async_clear_partition(partition_id) +def kv_put( + key: str, + fields: dict, + partition_id: str = "default", + tag: Optional[dict[str, Any]] = None, +) -> None: + """Insert or update a multi-column sample by key, with optional metadata tag. + + This is a high-level KV-style API that provides Redis-like Put/Get/List semantics + on top of TransferQueue's storage backends. + + Args: + key: Unique string key identifying the sample within the partition. + fields: Dictionary mapping field names to tensor values. + Each value should be a torch.Tensor (without batch dimension). + partition_id: Logical partition namespace. Defaults to "default". + tag: Optional lightweight metadata dictionary for status tracking. + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> tq.kv_put( + ... key="sample_0", + ... fields={"input_ids": torch.tensor([1, 2, 3]), "score": torch.tensor([0.95])}, + ... partition_id="rollout_v1", + ... tag={"status": "completed"}, + ... ) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.kv_put(key, fields, partition_id, tag) + + +async def async_kv_put( + key: str, + fields: dict, + partition_id: str = "default", + tag: Optional[dict[str, Any]] = None, +) -> None: + """Asynchronously insert or update a multi-column sample by key, with optional metadata tag. + + Args: + key: Unique string key identifying the sample within the partition. + fields: Dictionary mapping field names to tensor values. + partition_id: Logical partition namespace. Defaults to "default". + tag: Optional lightweight metadata dictionary for status tracking. + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_kv_put(key, fields, partition_id, tag) + + +def kv_batch_put( + kv_pairs: dict[str, dict], + partition_id: str = "default", + tags: Optional[dict[str, dict[str, Any]]] = None, +) -> None: + """Put multiple key-value pairs efficiently in batch. + + Args: + kv_pairs: Dictionary mapping keys to field dictionaries. + Each field dictionary maps field names to tensor values. + partition_id: Logical partition namespace. Defaults to "default". + tags: Optional dictionary mapping keys to metadata tag dictionaries. + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> tq.kv_batch_put( + ... kv_pairs={ + ... "s0": {"input_ids": torch.tensor([1, 2]), "reward": torch.tensor([0.5])}, + ... "s1": {"input_ids": torch.tensor([3, 4]), "reward": torch.tensor([0.8])}, + ... }, + ... partition_id="rollout_v1", + ... tags={"s0": {"status": "done"}, "s1": {"status": "done"}}, + ... ) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.kv_batch_put(kv_pairs, partition_id, tags) + + +async def async_kv_batch_put( + kv_pairs: dict[str, dict], + partition_id: str = "default", + tags: Optional[dict[str, dict[str, Any]]] = None, +) -> None: + """Asynchronously put multiple key-value pairs efficiently in batch. + + Args: + kv_pairs: Dictionary mapping keys to field dictionaries. + partition_id: Logical partition namespace. Defaults to "default". + tags: Optional dictionary mapping keys to metadata tag dictionaries. + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_kv_batch_put(kv_pairs, partition_id, tags) + + +def kv_batch_get( + keys: list[str], + fields: Optional[list[str]] = None, + partition_id: str = "default", +) -> dict[str, TensorDict]: + """Retrieve samples by keys, supporting column selection by fields. + + Args: + keys: List of string keys to retrieve. + fields: Optional list of field names to retrieve. If None, retrieves all fields. + partition_id: Logical partition namespace. Defaults to "default". + + Returns: + Dictionary mapping keys to TensorDict of field values. + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> results = tq.kv_batch_get( + ... keys=["s0", "s1"], + ... fields=["input_ids"], + ... partition_id="rollout_v1", + ... ) + >>> print(results["s0"]["input_ids"]) + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.kv_batch_get(keys, fields, partition_id) + + +async def async_kv_batch_get( + keys: list[str], + fields: Optional[list[str]] = None, + partition_id: str = "default", +) -> dict[str, TensorDict]: + """Asynchronously retrieve samples by keys, supporting column selection by fields. + + Args: + keys: List of string keys to retrieve. + fields: Optional list of field names to retrieve. If None, retrieves all fields. + partition_id: Logical partition namespace. Defaults to "default". + + Returns: + Dictionary mapping keys to TensorDict of field values. + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_kv_batch_get(keys, fields, partition_id) + + +def kv_list( + partition_id: str = "default", +) -> list[dict[str, Any]]: + """List keys and tags (metadata) in a partition. + + Args: + partition_id: Logical partition namespace. Defaults to "default". + + Returns: + List of dictionaries, each containing "key" and optional "tag" entries. + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> entries = tq.kv_list(partition_id="rollout_v1") + >>> for entry in entries: + ... print(f"Key: {entry['key']}, Tag: {entry.get('tag', {})}") + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.kv_list(partition_id) + + +async def async_kv_list( + partition_id: str = "default", +) -> list[dict[str, Any]]: + """Asynchronously list keys and tags (metadata) in a partition. + + Args: + partition_id: Logical partition namespace. Defaults to "default". + + Returns: + List of dictionaries, each containing "key" and optional "tag" entries. + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_kv_list(partition_id) + + +def kv_clear( + keys: Optional[list[str]] = None, + partition_id: str = "default", +) -> None: + """Remove key-value pairs from storage. + + If keys is None, clears the entire partition. + + Args: + keys: Optional list of string keys to remove. If None, clears all keys in the partition. + partition_id: Logical partition namespace. Defaults to "default". + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> tq.kv_clear(keys=["s0", "s1"], partition_id="rollout_v1") + >>> tq.kv_clear(partition_id="rollout_v1") # Clear entire partition + """ + tq_client = _maybe_create_transferqueue_client() + return tq_client.kv_clear(keys, partition_id) + + +async def async_kv_clear( + keys: Optional[list[str]] = None, + partition_id: str = "default", +) -> None: + """Asynchronously remove key-value pairs from storage. + + Args: + keys: Optional list of string keys to remove. If None, clears all keys in the partition. + partition_id: Logical partition namespace. Defaults to "default". + """ + tq_client = _maybe_create_transferqueue_client() + return await tq_client.async_kv_clear(keys, partition_id) + + def close(): """Close the TransferQueue system.""" global _TRANSFER_QUEUE_CLIENT diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py new file mode 100644 index 00000000..1e2b3262 --- /dev/null +++ b/tutorial/02_kv_interface.py @@ -0,0 +1,288 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 2: High-Level KV Interface + +This tutorial demonstrates the Redis-style Key-Value interface for TransferQueue. +The KV API provides a simple Put/Get/List/Clear interface that requires zero +knowledge of TransferQueue's internal metadata concepts (BatchMeta, SampleMeta, +FieldMeta). + +TransferQueue offers three API tiers: + 1. KV Interface (this tutorial) — Familiar Redis-style semantics for zero + learning curve. Best when you need fine-grained data access and manage + sample dispatching externally (e.g., via a ReplayBuffer or single-controller). + 2. StreamingDataLoader (tutorial/05_streaming_dataloader.py) — PyTorch-native + DataLoader with built-in streaming, sampling, and DP-rank coordination. + Best for fully-streamed training pipelines. + 3. Low-level TransferQueueClient — Full control over BatchMeta, Samplers, and + production/consumption tracking. Best when you need maximum flexibility. + +Key Methods: + - kv_put / async_kv_put — Insert/update a single sample by key + - kv_batch_put / async_kv_batch_put — Batch insert multiple key-value pairs + - kv_batch_get / async_kv_batch_get — Retrieve samples by keys (with optional + column/field selection) + - kv_list / async_kv_list — List all keys and tags in a partition + - kv_clear / async_kv_clear — Remove key-value pairs from storage + +Key Features: + ✓ Redis-style Semantics — Familiar KV interface with zero learning curve + ✓ Fine-grained Access — Read/write specific fields (columns) per key (row) + ✓ Partition Isolation — Logical separation of storage namespaces + ✓ Metadata Tags — Lightweight per-sample metadata for status tracking + ✓ Pluggable Backends — Works with SimpleStorage, Yuanrong, MooncakeStore, etc. + +Use Cases: + - Fine-grained data access where extreme streaming performance is non-essential + - Integration with external ReplayBuffer / single-controller that manages + sample dispatching + +Limitations (vs low-level native APIs): + - No built-in production/consumption tracking (track status via tags) + - No built-in Sampler support (dispatch data externally) + - Not fully streaming (consumers wait for dispatched keys) +""" + +import os +import sys +import textwrap +import warnings +from pathlib import Path + +warnings.filterwarnings( + action="ignore", + message=r"The PyTorch API of nested tensors is in prototype stage*", + category=UserWarning, + module=r"torch\.nested", +) + +warnings.filterwarnings( + action="ignore", + message=r"Tip: In future versions of Ray, Ray will no longer override accelerator visible " + r"devices env var if num_gpus=0 or num_gpus=None.*", + category=FutureWarning, + module=r"ray\._private\.worker", +) + + +import ray # noqa: E402 +import torch # noqa: E402 + +# Add the parent directory to the path +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +import transfer_queue as tq # noqa: E402 + +# Configure Ray +os.environ["RAY_DEDUP_LOGS"] = "0" + + +def setup(): + """Initialize Ray and TransferQueue.""" + if not ray.is_initialized(): + ray.init(namespace="TransferQueueTutorial") + + tq.init() + print("[Setup]: TransferQueue initialized.\n") + + +def teardown(): + """Shutdown TransferQueue and Ray.""" + tq.close() + ray.shutdown() + print("\n[Teardown]: TransferQueue and Ray shut down.") + + +# ────────────────────────────────────────────────────────────────────── +# Example 1: Basic Put / Get +# ────────────────────────────────────────────────────────────────────── +def example_basic_put_get(): + """Demonstrate single-key put and batch get.""" + print( + textwrap.dedent(""" + ┌──────────────────────────────────────────────────────┐ + │ Example 1: Basic Put / Get │ + └──────────────────────────────────────────────────────┘ + """) + ) + + partition = "demo_partition" + + # --- Put a single sample --- + tq.kv_put( + key="sample_0", + fields={ + "input_ids": torch.tensor([101, 2003, 1037]), + "attention_mask": torch.tensor([1, 1, 1]), + }, + partition_id=partition, + ) + print("[Put]: Inserted 'sample_0' with fields: input_ids, attention_mask") + + # --- Put another sample with a tag --- + tq.kv_put( + key="sample_1", + fields={ + "input_ids": torch.tensor([101, 4567, 2345]), + "attention_mask": torch.tensor([1, 1, 0]), + }, + partition_id=partition, + tag={"status": "completed", "score": 0.95}, + ) + print("[Put]: Inserted 'sample_1' with tag: {status: completed, score: 0.95}") + + # --- Batch get --- + results = tq.kv_batch_get( + keys=["sample_0", "sample_1"], + partition_id=partition, + ) + for key, td in results.items(): + print(f"[Get]: {key} → input_ids={td['input_ids'].squeeze()}") + + # --- Get with field selection --- + results = tq.kv_batch_get( + keys=["sample_1"], + fields=["attention_mask"], + partition_id=partition, + ) + print(f"[Get]: sample_1 (attention_mask only) → {results['sample_1']['attention_mask'].squeeze()}") + + tq.kv_clear(partition_id=partition) + print("[Clear]: Partition cleared.\n") + + +# ────────────────────────────────────────────────────────────────────── +# Example 2: Batch Put +# ────────────────────────────────────────────────────────────────────── +def example_batch_put(): + """Demonstrate efficient batch insertion.""" + print( + textwrap.dedent(""" + ┌──────────────────────────────────────────────────────┐ + │ Example 2: Batch Put │ + └──────────────────────────────────────────────────────┘ + """) + ) + + partition = "batch_partition" + + tq.kv_batch_put( + kv_pairs={ + "s0": {"reward": torch.tensor([0.5]), "input_ids": torch.tensor([1, 2])}, + "s1": {"reward": torch.tensor([0.8]), "input_ids": torch.tensor([3, 4])}, + "s2": {"reward": torch.tensor([0.3]), "input_ids": torch.tensor([5, 6])}, + }, + partition_id=partition, + tags={ + "s0": {"status": "done"}, + "s1": {"status": "done"}, + "s2": {"status": "pending"}, + }, + ) + print("[Batch Put]: Inserted s0, s1, s2 with tags.") + + # List all keys + entries = tq.kv_list(partition_id=partition) + for entry in entries: + print(f" Key: {entry['key']}, Tag: {entry.get('tag', {})}") + + tq.kv_clear(partition_id=partition) + print("[Clear]: Partition cleared.\n") + + +# ────────────────────────────────────────────────────────────────────── +# Example 3: Partition Isolation +# ────────────────────────────────────────────────────────────────────── +def example_partition_isolation(): + """Demonstrate partition-level namespace isolation.""" + print( + textwrap.dedent(""" + ┌──────────────────────────────────────────────────────┐ + │ Example 3: Partition Isolation │ + └──────────────────────────────────────────────────────┘ + """) + ) + + tq.kv_put(key="x", fields={"v": torch.tensor([1.0])}, partition_id="ns_a") + tq.kv_put(key="x", fields={"v": torch.tensor([2.0])}, partition_id="ns_b") + + res_a = tq.kv_batch_get(keys=["x"], partition_id="ns_a") + res_b = tq.kv_batch_get(keys=["x"], partition_id="ns_b") + + print(f"[Partition A]: x → {res_a['x']['v'].item()}") + print(f"[Partition B]: x → {res_b['x']['v'].item()}") + + tq.kv_clear(partition_id="ns_a") + tq.kv_clear(partition_id="ns_b") + print("[Clear]: Both partitions cleared.\n") + + +# ────────────────────────────────────────────────────────────────────── +# Example 4: Selective Clear +# ────────────────────────────────────────────────────────────────────── +def example_selective_clear(): + """Demonstrate clearing specific keys from a partition.""" + print( + textwrap.dedent(""" + ┌──────────────────────────────────────────────────────┐ + │ Example 4: Selective Clear │ + └──────────────────────────────────────────────────────┘ + """) + ) + + partition = "clear_demo" + + tq.kv_batch_put( + kv_pairs={ + "a": {"val": torch.tensor([10])}, + "b": {"val": torch.tensor([20])}, + "c": {"val": torch.tensor([30])}, + }, + partition_id=partition, + ) + print("[Put]: Inserted keys a, b, c.") + + tq.kv_clear(keys=["a", "c"], partition_id=partition) + print("[Clear]: Removed keys a, c.") + + remaining = tq.kv_list(partition_id=partition) + print(f"[List]: Remaining keys = {[e['key'] for e in remaining]}") + + tq.kv_clear(partition_id=partition) + print("[Clear]: Partition cleared.\n") + + +# ────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────── +if __name__ == "__main__": + print("=" * 60) + print(" TransferQueue — KV Interface Tutorial") + print("=" * 60) + + setup() + + try: + example_basic_put_get() + example_batch_put() + example_partition_isolation() + example_selective_clear() + finally: + teardown() + + print("\nDone! For more details, see tests/e2e/test_kv_interface_e2e.py")