From 054bbfb808c9c49d48b49bbc1a64ab21c669ef55 Mon Sep 17 00:00:00 2001 From: dpj135 <958208521@qq.com> Date: Tue, 17 Mar 2026 10:39:20 +0800 Subject: [PATCH 1/4] Fixed bugs Signed-off-by: dpj135 <958208521@qq.com> --- tests/e2e/test_e2e_lifecycle_consistency.py | 25 ++++++++++++++------- transfer_queue/storage/managers/base.py | 20 +++++++++++------ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 39b0b91e..a0929618 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -78,6 +78,18 @@ }, }, }, + "Yuanrong": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "Yuanrong", + "Yuanrong": { + "host": "127.0.0.1", + "port": 31501, + }, + }, + }, } @@ -507,14 +519,11 @@ def test_cross_shard_complex_update(e2e_client): update_positions_in_full = [ i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis ] - update_meta_with_backend = full_meta.select_samples(update_positions_in_full) - # Populate empty schema for fields not yet in field_schema so select_fields can include them - for f in ["new_extra_tensor", "new_extra_non_tensor"]: - if f not in update_meta_with_backend.field_schema: - update_meta_with_backend.field_schema[f] = {} - update_meta_with_backend._field_names = sorted(update_meta_with_backend.field_schema.keys()) - extended_meta = update_meta_with_backend.select_fields( - base_fields + ["new_extra_tensor", "new_extra_non_tensor"] + extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"] + extended_meta = ( + poll_for_meta(client, partition_id, extended_fields, 40, task_name, mode="force_fetch") + .select_samples(update_positions_in_full) + .select_fields(extended_fields) ) update_region_data = client.get_data(extended_meta) assert "new_extra_tensor" in update_region_data.keys(), "new_extra_tensor should exist" diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index d33c591f..3e4230c4 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -540,13 +540,20 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: Store tensor data in the backend storage and notify the controller. """ num_samples = len(metadata.global_indexes) - if num_samples == 0: + if data.batch_size[0] != num_samples: + raise ValueError(f"Batch size of data ({data.batch_size[0]}) does not match expected ({num_samples})") + + if data.batch_size[0] == 0: + logger.warning("Attempted to put data with batch size 0. Operation will be skipped.") return - keys = self._generate_keys(data.keys(), metadata.global_indexes) + # Generate keys and values. + # The field_name of metadata is old fashion, we will generate keys/values according to data. + data_field_names = list(sorted(data.keys())) + keys = self._generate_keys(data_field_names, metadata.global_indexes) values = self._generate_values(data) - loop = asyncio.get_event_loop() + loop = asyncio.get_event_loop() custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) field_schema = extract_field_schema(data) @@ -562,15 +569,14 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: for global_idx in metadata.global_indexes: per_field_custom_backend_meta[global_idx] = {} - # FIXME(tianyi): the order of custom backend meta is coupled with keys/values - # FIXME: if put_data is called to partially update/add new fields, the current - # implementation will cause custom_backend_meta losses or mismatch! for (field_name, global_idx), meta_value in zip( - itertools.product(sorted(metadata.field_names), metadata.global_indexes), + itertools.product(data_field_names, metadata.global_indexes), custom_backend_meta, strict=True, ): per_field_custom_backend_meta[global_idx][field_name] = meta_value + # TODO: There should not visit private property of metadata, + # we should consider to add a public method in BatchMeta to set custom_backend_meta in the future. metadata._custom_backend_meta[global_index_to_position[global_idx]][field_name] = meta_value # Get current data partition id From d5d9b08f775b6c2ea4a6c84d02f137275abe2ef7 Mon Sep 17 00:00:00 2001 From: dpj135 <60139850+dpj135@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:56:26 +0800 Subject: [PATCH 2/4] Apply suggestions from code review Signed-off-by: dpj135 <958208521@qq.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tests/e2e/test_e2e_lifecycle_consistency.py | 15 +++++++++++++-- transfer_queue/storage/managers/base.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index a0929618..782ae993 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -520,9 +520,20 @@ def test_cross_shard_complex_update(e2e_client): i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis ] extended_fields = base_fields + ["new_extra_tensor", "new_extra_non_tensor"] + extended_meta = poll_for_meta( + client, + partition_id, + extended_fields, + 40, + task_name, + mode="force_fetch", + ) + assert extended_meta is not None and extended_meta.size > 0, ( + "Failed to fetch extended metadata for update region; " + "poll_for_meta returned no or empty metadata." + ) extended_meta = ( - poll_for_meta(client, partition_id, extended_fields, 40, task_name, mode="force_fetch") - .select_samples(update_positions_in_full) + extended_meta.select_samples(update_positions_in_full) .select_fields(extended_fields) ) update_region_data = client.get_data(extended_meta) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 3e4230c4..492bb70e 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -548,7 +548,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: return # Generate keys and values. - # The field_name of metadata is old fashion, we will generate keys/values according to data. + # metadata.field_names is legacy; generate keys/values from the actual data field names instead. data_field_names = list(sorted(data.keys())) keys = self._generate_keys(data_field_names, metadata.global_indexes) values = self._generate_values(data) From ad52a5c0a79d6d73009ca8ccc7096690fe3318ae Mon Sep 17 00:00:00 2001 From: dpj135 <958208521@qq.com> Date: Tue, 7 Apr 2026 16:36:03 +0800 Subject: [PATCH 3/4] Imporved robustness of 'yuanrong_client when calling clear' Signed-off-by: dpj135 <958208521@qq.com> --- transfer_queue/storage/clients/yuanrong_client.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 3b0de3ab..96a81363 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -585,7 +585,7 @@ def clear(self, keys: list[str], custom_backend_meta=None): strategy_tags = custom_backend_meta routed_indexes = self._route_to_strategies( - strategy_tags, lambda strategy_, item_: strategy_.supports_clear(item_) + strategy_tags, lambda strategy_, item_: strategy_.supports_clear(item_), failback=True ) def clear_task(strategy, indexes): @@ -598,6 +598,7 @@ def _route_to_strategies( self, items: list[Any], selector: Callable[[StorageStrategy, Any], bool], + failback: bool = False, ) -> dict[StorageStrategy, list[int]]: """Groups item indices by the first strategy that supports them. @@ -610,6 +611,8 @@ def _route_to_strategies( The order must correspond to the original keys. selector: A function that determines whether a strategy supports an item. Signature: `(strategy: StorageStrategy, item: Any) -> bool`. + failback: If True, items that don't match any strategy will be ignored (not included in output). + If False, a ValueError will be raised for any unmatched item. Returns: A dictionary mapping each active strategy to a list of indexes in `items` @@ -622,10 +625,11 @@ def _route_to_strategies( routed_indexes[strategy].append(i) break else: - raise ValueError( - f"No strategy supports item of type {type(item).__name__}: {item}. " - f"Available strategies: {[type(s).__name__ for s in self._strategies]}" - ) + if not failback: + raise ValueError( + f"No strategy supports item of type {type(item).__name__}: {item}. " + f"Available strategies: {[type(s).__name__ for s in self._strategies]}" + ) return routed_indexes @staticmethod From 67509f010d936390cfd83e0726e536485202ea44 Mon Sep 17 00:00:00 2001 From: dpj135 <958208521@qq.com> Date: Wed, 8 Apr 2026 14:44:19 +0800 Subject: [PATCH 4/4] Renamed failback to ignore_unmatched Signed-off-by: dpj135 <958208521@qq.com> --- transfer_queue/storage/clients/yuanrong_client.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 96a81363..bdf48d62 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -585,7 +585,7 @@ def clear(self, keys: list[str], custom_backend_meta=None): strategy_tags = custom_backend_meta routed_indexes = self._route_to_strategies( - strategy_tags, lambda strategy_, item_: strategy_.supports_clear(item_), failback=True + strategy_tags, lambda strategy_, item_: strategy_.supports_clear(item_), ignore_unmatched=True ) def clear_task(strategy, indexes): @@ -598,7 +598,7 @@ def _route_to_strategies( self, items: list[Any], selector: Callable[[StorageStrategy, Any], bool], - failback: bool = False, + ignore_unmatched: bool = False, ) -> dict[StorageStrategy, list[int]]: """Groups item indices by the first strategy that supports them. @@ -618,6 +618,7 @@ def _route_to_strategies( A dictionary mapping each active strategy to a list of indexes in `items` that it should handle. Every index appears exactly once. """ + unmatched_count = 0 routed_indexes: dict[StorageStrategy, list[int]] = {s: [] for s in self._strategies} for i, item in enumerate(items): for strategy in self._strategies: @@ -625,11 +626,16 @@ def _route_to_strategies( routed_indexes[strategy].append(i) break else: - if not failback: + if ignore_unmatched: + unmatched_count += 1 + else: raise ValueError( f"No strategy supports item of type {type(item).__name__}: {item}. " f"Available strategies: {[type(s).__name__ for s in self._strategies]}" ) + if unmatched_count > 0: + logger.warning(f"{unmatched_count} items were not matched to any strategy and will be ignored.") + return routed_indexes @staticmethod