diff --git a/.github/workflows/dco-check.yml b/.github/workflows/dco-check.yml new file mode 100644 index 00000000..4ff4479d --- /dev/null +++ b/.github/workflows/dco-check.yml @@ -0,0 +1,37 @@ +name: DCO Check + +on: + pull_request: + types: [opened, synchronize, reopened] + +permissions: + pull-requests: read + contents: read + +jobs: + check-dco: + runs-on: ubuntu-latest + steps: + - name: Check for Signed-off-by + uses: actions/github-script@v7 + with: + script: | + const commits = await github.rest.pulls.listCommits({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.issue.number, + }); + + const regex = /^Signed-off-by: .* <.*@.*>/m; + let failed = false; + + for (const commit of commits.data) { + if (!regex.test(commit.commit.message)) { + console.log(`Commit ${commit.sha} is missing Signed-off-by`); + failed = true; + } + } + + if (failed) { + core.setFailed('One or more commits are missing the Signed-off-by line.'); + } \ No newline at end of file diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 836a46f1..47006cdd 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -8,6 +8,7 @@ on: branches: - main - dev + - v0.* # Declare permissions just read content. permissions: diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 14339bb8..2917c789 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -5,9 +5,13 @@ name: Python package on: push: - branches: [ "main", "dev" ] + branches: + - main + - v0.* pull_request: - branches: [ "main", "dev" ] + branches: + - main + - v0.* jobs: build: diff --git a/.github/workflows/sanity.yml b/.github/workflows/sanity.yml new file mode 100644 index 00000000..3c689b2d --- /dev/null +++ b/.github/workflows/sanity.yml @@ -0,0 +1,54 @@ + +name: sanity + +on: + # Trigger the workflow on push or pull request + push: + branches: + - main + - v0.* + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + - .github/workflows/sanity.yml + - "tests/sanity/**" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + sanity: + runs-on: ubuntu-latest + timeout-minutes: 5 # Increase this timeout value as needed + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install build + python -m build --wheel + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip install dist/*.whl + - name: Run license test + run: | + python3 tests/sanity/check_license.py --directories . + - name: Check docstrings for specified files + run: | + python3 tests/sanity/check_docstrings.py + diff --git a/scripts/performance_test.py b/scripts/performance_test.py index 1fc68152..4b36a53b 100644 --- a/scripts/performance_test.py +++ b/scripts/performance_test.py @@ -1,3 +1,18 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio import logging import math diff --git a/scripts/put_benchmark.py b/scripts/put_benchmark.py index 96bad6ee..0e7e82cc 100644 --- a/scripts/put_benchmark.py +++ b/scripts/put_benchmark.py @@ -1,3 +1,18 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import asyncio import json diff --git a/tests/sanity/check_docstrings.py b/tests/sanity/check_docstrings.py new file mode 100644 index 00000000..de3530b8 --- /dev/null +++ b/tests/sanity/check_docstrings.py @@ -0,0 +1,170 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Python script to check docstrings for functions and classes in specified files. +Checks that every public function and class has proper docstring documentation. +""" + +import ast +import os +import sys + + +class DocstringChecker(ast.NodeVisitor): + """AST visitor to check for missing docstrings in functions and classes.""" + + def __init__(self, filename: str): + self.filename = filename + self.missing_docstrings: list[tuple[str, str, int]] = [] + self.current_class = None + self.function_nesting_level = 0 + + def visit_FunctionDef(self, node: ast.FunctionDef): + """Visit function definitions and check for docstrings.""" + if not node.name.startswith("_") and self.function_nesting_level == 0: + if not self._has_docstring(node): + func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name + self.missing_docstrings.append((func_name, self.filename, node.lineno)) + + self.function_nesting_level += 1 + self.generic_visit(node) + self.function_nesting_level -= 1 + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): + """Visit async function definitions and check for docstrings.""" + if not node.name.startswith("_") and self.function_nesting_level == 0: + if not self._has_docstring(node): + func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name + self.missing_docstrings.append((func_name, self.filename, node.lineno)) + + self.function_nesting_level += 1 + self.generic_visit(node) + self.function_nesting_level -= 1 + + def visit_ClassDef(self, node: ast.ClassDef): + """Visit class definitions and check for docstrings.""" + if not node.name.startswith("_"): + if not self._has_docstring(node): + self.missing_docstrings.append((node.name, self.filename, node.lineno)) + + old_class = self.current_class + self.current_class = node.name + self.generic_visit(node) + self.current_class = old_class + + def _has_docstring(self, node) -> bool: + """Check if a node has a docstring.""" + return ast.get_docstring(node) is not None + + +def check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]: + """Check docstrings in a single file.""" + try: + with open(filepath, encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=filepath) + checker = DocstringChecker(filepath) + checker.visit(tree) + return checker.missing_docstrings + + except Exception as e: + print(f"Error processing {filepath}: {e}") + return [] + + +def get_python_files_in_transfer_queue(repo_path: str) -> list[str]: + """Get all Python files in the transfer_queue directory.""" + transfer_queue_path = os.path.join(repo_path, "transfer_queue") + if not os.path.exists(transfer_queue_path): + print(f"Warning: transfer_queue directory {transfer_queue_path} does not exist!") + return [] + + python_files = [] + for root, _, files in os.walk(transfer_queue_path): + for file in files: + if file.endswith(".py"): + python_files.append(os.path.join(root, file)) + + return sorted(python_files) + + +def main(): + """Main function to check docstrings in transfer_queue Python files.""" + + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_path = os.path.dirname(os.path.dirname(script_dir)) + + if not os.path.exists(repo_path): + print(f"Repository path {repo_path} does not exist!") + sys.exit(1) + + os.chdir(repo_path) + + files_to_check = get_python_files_in_transfer_queue(repo_path) + + if not files_to_check: + print("No Python files found in transfer_queue directory!") + sys.exit(1) + + all_missing_docstrings = [] + + print("Checking docstrings in transfer_queue Python files...") + print(f"Found {len(files_to_check)} Python files to check") + print("=" * 60) + + for file_path in files_to_check: + if not os.path.exists(file_path): + print(f"Warning: File {file_path} does not exist!") + continue + + print(f"Checking {file_path}...") + missing = check_file_docstrings(file_path) + all_missing_docstrings.extend(missing) + + if missing: + print(f" Found {len(missing)} missing docstrings") + else: + print(" All functions and classes have docstrings [OK]") + + print("=" * 60) + + if all_missing_docstrings: + print(f"\nSUMMARY: Found {len(all_missing_docstrings)} functions/classes missing docstrings:") + print("-" * 60) + + by_file = {} + for name, filepath, lineno in all_missing_docstrings: + if filepath not in by_file: + by_file[filepath] = [] + by_file[filepath].append((name, lineno)) + + for filepath in sorted(by_file.keys()): + print(f"\n{filepath}:") + for name, lineno in sorted(by_file[filepath], key=lambda x: x[1]): + print(f" - {name} (line {lineno})") + + print(f"\nTotal missing docstrings: {len(all_missing_docstrings)}") + + raise Exception(f"Found {len(all_missing_docstrings)} functions/classes without proper docstrings!") + + else: + print("\n[OK] All functions and classes have proper docstrings!") + + +if __name__ == "__main__": + main() diff --git a/tests/sanity/check_license.py b/tests/sanity/check_license.py new file mode 100644 index 00000000..e1cac117 --- /dev/null +++ b/tests/sanity/check_license.py @@ -0,0 +1,72 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from pathlib import Path +from typing import Iterable + +# Add license headers below +license_head_huawei = "Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved." +license_head_tq = "Copyright 2025 The TransferQueue Team" + +license_headers = [ + license_head_huawei, + license_head_tq, +] + + +def get_py_files(path_arg: Path) -> Iterable[Path]: + """Get Python files under a directory. If already a Python file, return it. + + Args: + path_arg (Path): path to scan for .py files + + Returns: + Iterable[Path]: list of .py files + """ + if path_arg.is_dir(): + return path_arg.glob("**/*.py") + elif path_arg.is_file() and path_arg.suffix == ".py": + return [path_arg] + return [] + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--directories", + "-d", + required=True, + type=Path, + nargs="+", + help="List of directories to check for license headers", + ) + args = parser.parse_args() + + # Collect all Python files from specified directories + pathlist = set(path for path_arg in args.directories for path in get_py_files(path_arg)) + + for path in pathlist: + # because path is object not string + path_in_str = str(path.absolute()) + print(path_in_str) + with open(path_in_str, encoding="utf-8") as f: + file_content = f.read() + + has_license = False + for lh in license_headers: + if lh in file_content: + has_license = True + break + assert has_license, f"file {path_in_str} does not contain license" diff --git a/tests/test_ray_p2p.py b/tests/test_ray_p2p.py index 231bc26c..4bd54a7b 100644 --- a/tests/test_ray_p2p.py +++ b/tests/test_ray_p2p.py @@ -1,3 +1,18 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import sys import time from pathlib import Path diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index db6239dd..d84b381f 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -58,4 +58,5 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li @abstractmethod def clear(self, keys: list[str]) -> None: + """Clear key-value pairs in the storage backend.""" raise NotImplementedError("Subclasses must implement clear") diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 80efa09a..9f69a977 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -1,3 +1,18 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os import pickle @@ -23,6 +38,10 @@ @StorageClientFactory.register("MooncakeStorageClient") class MooncakeStorageClient(TransferQueueStorageKVClient): + """ + Storage client for MooncakeStore. + """ + def __init__(self, config: dict[str, Any]): if not MOONCAKE_STORE_IMPORTED: raise ImportError("Mooncake Store not installed. Please install via: pip install mooncake-transfer-engine") @@ -54,6 +73,13 @@ def __init__(self, config: dict[str, Any]): raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + """Stores multiple key-value pairs to MooncakeStore. + + Args: + keys (List[str]): List of unique string identifiers. + values (List[Any]): List of values to store (tensors, scalars, dicts, etc.). + """ + if not isinstance(keys, list) or not isinstance(values, list): raise ValueError("keys and values must be lists") if len(keys) != len(values): @@ -107,6 +133,18 @@ def _batch_put_bytes(self, keys: list[str], values: list[bytes]): raise RuntimeError(f"put_batch failed with error code: {ret}") def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: + """Get multiple key-value pairs from MooncakeStore. + + Args: + keys (List[str]): Keys to fetch. + shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). + dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. + custom_meta (List[str], optional): Device type (npu/cpu) for each key + + Returns: + List[Any]: Retrieved values in the same order as input keys. + """ + if shapes is None or dtypes is None: raise ValueError("MooncakeStorageClient needs shapes and dtypes") if not (len(keys) == len(shapes) == len(dtypes)): @@ -179,12 +217,18 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]: return results def clear(self, keys: list[str]): + """Deletes multiple keys from MooncakeStore. + + Args: + keys (List[str]): List of keys to remove. + """ for key in keys: ret = self._store.remove(key) if ret != 0: logger.warning(f"remove failed for key '{key}' with error code: {ret}") def close(self): + """Closes MooncakeStore.""" if self._store: self._store.close() self._store = None diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index 5ffd0233..78a6c07d 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -1,3 +1,18 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools from typing import Any, Optional @@ -10,17 +25,22 @@ @ray.remote(max_concurrency=8) class RayObjectRefStorage: + """Ray object ref storage.""" + def __init__(self): self.storage_dict = {} def put_obj_ref(self, keys: list[str], obj_refs: list[ray.ObjectRef]): + """Put object ref to remote storage.""" self.storage_dict.update(itertools.starmap(lambda k, v: (k, v), zip(keys, obj_refs, strict=True))) def get_obj_ref(self, keys: list[str]) -> list[ray.ObjectRef]: + """Get object ref from remote storage.""" obj_refs = [self.storage_dict.get(key, None) for key in keys] return obj_refs def clear_obj_ref(self, keys: list[str]): + """Clear object ref from remote storage.""" for key in keys: if key in self.storage_dict: del self.storage_dict[key] @@ -28,6 +48,10 @@ def clear_obj_ref(self, keys: list[str]): @StorageClientFactory.register("RayStorageClient") class RayStorageClient(TransferQueueStorageKVClient): + """ + Storage client for Ray RDT. + """ + def __init__(self, config=None): if not ray.is_initialized(): raise RuntimeError("Ray is not initialized. Please call ray.init() before creating RayStorageClient.") diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index c7da326b..69309cd6 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -83,9 +83,11 @@ def __init__(self, config: dict[str, Any]): self._cpu_ds_client.init() def npu_ds_client_is_available(self): + """Check if NPU client is available.""" return self._npu_ds_client is not None def cpu_ds_client_is_available(self): + """Check if CPU client is available.""" return self._cpu_ds_client is not None def _create_empty_npu_tensorlist(self, shapes, dtypes): diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index e2d5f1b0..6136f67d 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -272,14 +272,36 @@ async def notify_data_update( @abstractmethod async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: + """ + Put data into the storage backend. + + Args: + data: Data to be put into the storage. + metadata: BatchMeta of the corresponding data. + """ raise NotImplementedError("Subclasses must implement put_data") @abstractmethod async def get_data(self, metadata: BatchMeta) -> TensorDict: + """ + Get data from the storage backend. + + Args: + metadata: BatchMeta of the data to be retrieved from the storage. + + Returns: + TensorDict containing the data retrieved from the storage. + """ raise NotImplementedError("Subclasses must implement get_data") @abstractmethod async def clear_data(self, metadata: BatchMeta) -> None: + """ + Clear data from the storage backend. + + Args: + metadata: BatchMeta of the data to be cleared from the storage. + """ raise NotImplementedError("Subclasses must implement clear_data") def close(self) -> None: diff --git a/transfer_queue/storage/managers/factory.py b/transfer_queue/storage/managers/factory.py index 0ccf44e6..e595ccd8 100644 --- a/transfer_queue/storage/managers/factory.py +++ b/transfer_queue/storage/managers/factory.py @@ -25,6 +25,8 @@ class TransferQueueStorageManagerFactory: @classmethod def register(cls, manager_type: str): + """Register a TransferQueueStorageManager class.""" + def decorator(manager_cls: type[TransferQueueStorageManager]): if not issubclass(manager_cls, TransferQueueStorageManager): raise TypeError( @@ -38,6 +40,7 @@ def decorator(manager_cls: type[TransferQueueStorageManager]): @classmethod def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager: + """Create and return a TransferQueueStorageManager instance.""" if manager_type not in cls._registry: raise ValueError( f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}" diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 3961efc3..ca555668 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -1,3 +1,18 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os from typing import Any @@ -11,6 +26,8 @@ @TransferQueueStorageManagerFactory.register("MooncakeStorageManager") class MooncakeStorageManager(KVStorageManager): + """Storage manager for MooncakeStorage backend.""" + def __init__(self, config: dict[str, Any]): # Required: Address of the HTTP metadata server (e.g., "localhost:8080") metadata_server = config.get("metadata_server", None) diff --git a/transfer_queue/storage/managers/yuanrong_manager.py b/transfer_queue/storage/managers/yuanrong_manager.py index 33ccbdce..bfb79e6c 100644 --- a/transfer_queue/storage/managers/yuanrong_manager.py +++ b/transfer_queue/storage/managers/yuanrong_manager.py @@ -32,6 +32,8 @@ @TransferQueueStorageManagerFactory.register("YuanrongStorageManager") class YuanrongStorageManager(KVStorageManager): + """Storage manager for Yuanrong backend.""" + def __init__(self, config: dict[str, Any]): host = config.get("host", None) port = config.get("port", None) diff --git a/transfer_queue/utils/perf_utils.py b/transfer_queue/utils/perf_utils.py index 5e3ee797..d7b375c3 100644 --- a/transfer_queue/utils/perf_utils.py +++ b/transfer_queue/utils/perf_utils.py @@ -113,6 +113,7 @@ def _flush_logs(self): @contextmanager def measure(self, op_type: str): + """Measures performance statistics.""" start_time = time.perf_counter() try: yield diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index cd308a9d..c20ce3d4 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -60,6 +60,7 @@ def __init__(self): self.aux_buffers: list[bytestr] = [] def encode(self, obj: Any) -> Sequence[bytestr]: + """Encode a given object to a byte array.""" try: self.aux_buffers = bufs = [b""] bufs[0] = self.encoder.encode(obj) @@ -71,15 +72,6 @@ def encode(self, obj: Any) -> Sequence[bytestr]: finally: self.aux_buffers = [] - def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: - try: - self.aux_buffers = [buf] - bufs = self.aux_buffers - self.encoder.encode_into(obj, buf) - return bufs - finally: - self.aux_buffers = [] - def enc_hook(self, obj: Any) -> Any: """Custom encoding hook for types msgspec doesn't natively support. @@ -221,6 +213,7 @@ def __init__(self): self.aux_buffers: Sequence[bytestr] = () def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any: + """Decode a list of bytes.""" if isinstance(bufs, bytestr): result = self.decoder.decode(bufs) else: @@ -290,6 +283,13 @@ def _decode_nested_tensor(self, nested_meta: dict) -> torch.Tensor: return torch.nested.as_nested_tensor(sub_tensors, layout=torch.strided) def ext_hook(self, code: int, data: memoryview) -> Any: + """Custom decoding hook for types msgspec doesn't natively support. + + For zero-copy tensor serialization, we need to handle: + - torch.Tensor: Extract buffer, store metadata + - TensorDict: Convert to dict structure for recursive processing + - numpy.ndarray: Convert to tensor for unified handling + """ if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) if code == CUSTOM_TYPE_CLOUDPICKLE: diff --git a/transfer_queue/utils/utils.py b/transfer_queue/utils/utils.py index 0ab240ca..b84e824d 100644 --- a/transfer_queue/utils/utils.py +++ b/transfer_queue/utils/utils.py @@ -48,13 +48,18 @@ def _missing_(cls, value): class TransferQueueRole(ExplicitEnum): + """Available Roles of TransferQueue.""" + CONTROLLER = "TransferQueueController" STORAGE = "TransferQueueStorage" CLIENT = "TransferQueueClient" -# production_status enum: 0: not produced, 1: ready for consume class ProductionStatus(ExplicitEnum): + """ + Data Production Status. + """ + NOT_PRODUCED = 0 READY_FOR_CONSUME = 1 @@ -151,6 +156,7 @@ def limit_pytorch_auto_parallel_threads(target_num_threads: Optional[int] = None def get_env_bool(env_key: str, default: bool = False) -> bool: + """Robustly get a boolean from an environment variable.""" env_value = os.getenv(env_key) if env_value is None: diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 3ece8edc..4887d073 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -44,6 +44,10 @@ class ZMQRequestType(ExplicitEnum): + """ + Enumerate all available request types in TransferQueue. + """ + # HANDSHAKE HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit @@ -91,6 +95,10 @@ class ZMQRequestType(ExplicitEnum): class ZMQServerInfo: + """ + TransferQueue server info class. + """ + def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]): self.role = role self.id = id @@ -98,9 +106,11 @@ def __init__(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, s self.ports = ports def to_addr(self, port_name: str) -> str: + """Convert zmq port name to address string.""" return f"tcp://{self.ip}:{self.ports[port_name]}" def to_dict(self): + """Convert ZMQServerInfo to dict.""" return { "role": self.role, "id": self.id, @@ -114,6 +124,10 @@ def __str__(self) -> str: @dataclass class ZMQMessage: + """ + ZMQMessage class for TransferQueue communication. + """ + request_type: ZMQRequestType sender_id: str receiver_id: str | None @@ -129,6 +143,7 @@ def create( body: dict[str, Any], receiver_id: Optional[str] = None, ) -> "ZMQMessage": + """Create ZMQMessage.""" return cls( request_type=request_type, sender_id=sender_id, @@ -173,6 +188,7 @@ def deserialize(cls, frames: list) -> "ZMQMessage": def get_free_port() -> str: + """Get free port of the host.""" with socket.socket() as sock: sock.bind(("", 0)) return sock.getsockname()[1] @@ -183,6 +199,7 @@ def create_zmq_socket( socket_type: Any, identity: Optional[bytestr] = None, ) -> zmq.Socket: + """Create ZMQ socket.""" mem = psutil.virtual_memory() socket = ctx.socket(socket_type)