From 58e7029ca22f3eaa3584063cc57a11163bf91a89 Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Thu, 5 Mar 2026 02:20:40 +0000 Subject: [PATCH 1/2] Add TransferEngine backend support to P2P metadata Extend WorkerMetadata with oneof backend_metadata supporting both NIXL (bytes) and TransferEngine (session_id string). Update all metadata backends (memory, redis, kubernetes, layered) and regenerate Python protobuf stubs for protobuf 5.x compatibility. --- Cargo.lock | 8 +-- .../python/modelexpress/p2p_pb2.py | 45 ++++++++--------- .../python/modelexpress/p2p_pb2_grpc.py | 3 -- .../python/modelexpress/types.py | 3 +- modelexpress_common/proto/p2p.proto | 14 ++++-- modelexpress_server/src/k8s_types.rs | 6 ++- modelexpress_server/src/metadata_backend.rs | 33 +++++++++++-- .../src/metadata_backend/kubernetes.rs | 49 +++++++++++++++---- .../src/metadata_backend/layered.rs | 3 +- .../src/metadata_backend/memory.rs | 12 +++-- .../src/metadata_backend/redis.rs | 26 +++++++++- modelexpress_server/src/state.rs | 19 ++++--- 12 files changed, 157 insertions(+), 64 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1000aa23..d4384d8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1923,7 +1923,7 @@ dependencies = [ [[package]] name = "model-express-workspace-tests" -version = "0.2.2" +version = "0.3.0" dependencies = [ "criterion", "modelexpress-client", @@ -1939,7 +1939,7 @@ dependencies = [ [[package]] name = "modelexpress-client" -version = "0.2.2" +version = "0.3.0" dependencies = [ "anyhow", "clap", @@ -1962,7 +1962,7 @@ dependencies = [ [[package]] name = "modelexpress-common" -version = "0.2.2" +version = "0.3.0" dependencies = [ "anyhow", "async-trait", @@ -1988,7 +1988,7 @@ dependencies = [ [[package]] name = "modelexpress-server" -version = "0.2.2" +version = "0.3.0" dependencies = [ "anyhow", "async-trait", diff --git a/modelexpress_client/python/modelexpress/p2p_pb2.py b/modelexpress_client/python/modelexpress/p2p_pb2.py index 0fc2b3e0..9286b5bf 100644 --- a/modelexpress_client/python/modelexpress/p2p_pb2.py +++ b/modelexpress_client/python/modelexpress/p2p_pb2.py @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE @@ -27,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tp2p.proto\x12\x11model_express.p2p\"^\n\x10TensorDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x02 \x01(\x04\x12\x0c\n\x04size\x18\x03 \x01(\x04\x12\x11\n\tdevice_id\x18\x04 \x01(\r\x12\r\n\x05\x64type\x18\x05 \x01(\t\"r\n\x0eWorkerMetadata\x12\x13\n\x0bworker_rank\x18\x01 \x01(\r\x12\x15\n\rnixl_metadata\x18\x02 \x01(\x0c\x12\x34\n\x07tensors\x18\x03 \x03(\x0b\x32#.model_express.p2p.TensorDescriptor\"`\n\x16PublishMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\";\n\x17PublishMetadataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x12GetMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\"X\n\x13GetMetadataResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\"\x97\x01\n\x13PublishReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x12\n\nnixl_ready\x18\x05 \x01(\x08\x12\x1a\n\x12stability_verified\x18\x06 \x01(\x08\"8\n\x14PublishReadyResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"8\n\x0fGetReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\"n\n\x10GetReadyResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\r\n\x05ready\x18\x02 \x01(\x08\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x01\x32\x8a\x03\n\nP2pService\x12h\n\x0fPublishMetadata\x12).model_express.p2p.PublishMetadataRequest\x1a*.model_express.p2p.PublishMetadataResponse\x12\\\n\x0bGetMetadata\x12%.model_express.p2p.GetMetadataRequest\x1a&.model_express.p2p.GetMetadataResponse\x12_\n\x0cPublishReady\x12&.model_express.p2p.PublishReadyRequest\x1a\'.model_express.p2p.PublishReadyResponse\x12S\n\x08GetReady\x12\".model_express.p2p.GetReadyRequest\x1a#.model_express.p2p.GetReadyResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tp2p.proto\x12\x11model_express.p2p\"^\n\x10TensorDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x02 \x01(\x04\x12\x0c\n\x04size\x18\x03 \x01(\x04\x12\x11\n\tdevice_id\x18\x04 \x01(\r\x12\r\n\x05\x64type\x18\x05 \x01(\t\"\xae\x01\n\x0eWorkerMetadata\x12\x13\n\x0bworker_rank\x18\x01 \x01(\r\x12\x17\n\rnixl_metadata\x18\x02 \x01(\x0cH\x00\x12$\n\x1atransfer_engine_session_id\x18\n \x01(\tH\x00\x12\x34\n\x07tensors\x18\x03 \x03(\x0b\x32#.model_express.p2p.TensorDescriptorB\x12\n\x10\x62\x61\x63kend_metadata\"`\n\x16PublishMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\";\n\x17PublishMetadataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x12GetMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\"X\n\x13GetMetadataResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\"\x97\x01\n\x13PublishReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x12\n\nnixl_ready\x18\x05 \x01(\x08\x12\x1a\n\x12stability_verified\x18\x06 \x01(\x08\"8\n\x14PublishReadyResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"8\n\x0fGetReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\"n\n\x10GetReadyResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\r\n\x05ready\x18\x02 \x01(\x08\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x01\x32\x8a\x03\n\nP2pService\x12h\n\x0fPublishMetadata\x12).model_express.p2p.PublishMetadataRequest\x1a*.model_express.p2p.PublishMetadataResponse\x12\\\n\x0bGetMetadata\x12%.model_express.p2p.GetMetadataRequest\x1a&.model_express.p2p.GetMetadataResponse\x12_\n\x0cPublishReady\x12&.model_express.p2p.PublishReadyRequest\x1a\'.model_express.p2p.PublishReadyResponse\x12S\n\x08GetReady\x12\".model_express.p2p.GetReadyRequest\x1a#.model_express.p2p.GetReadyResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -36,24 +33,24 @@ DESCRIPTOR._loaded_options = None _globals['_TENSORDESCRIPTOR']._serialized_start=32 _globals['_TENSORDESCRIPTOR']._serialized_end=126 - _globals['_WORKERMETADATA']._serialized_start=128 - _globals['_WORKERMETADATA']._serialized_end=242 - _globals['_PUBLISHMETADATAREQUEST']._serialized_start=244 - _globals['_PUBLISHMETADATAREQUEST']._serialized_end=340 - _globals['_PUBLISHMETADATARESPONSE']._serialized_start=342 - _globals['_PUBLISHMETADATARESPONSE']._serialized_end=401 - _globals['_GETMETADATAREQUEST']._serialized_start=403 - _globals['_GETMETADATAREQUEST']._serialized_end=443 - _globals['_GETMETADATARESPONSE']._serialized_start=445 - _globals['_GETMETADATARESPONSE']._serialized_end=533 - _globals['_PUBLISHREADYREQUEST']._serialized_start=536 - _globals['_PUBLISHREADYREQUEST']._serialized_end=687 - _globals['_PUBLISHREADYRESPONSE']._serialized_start=689 - _globals['_PUBLISHREADYRESPONSE']._serialized_end=745 - _globals['_GETREADYREQUEST']._serialized_start=747 - _globals['_GETREADYREQUEST']._serialized_end=803 - _globals['_GETREADYRESPONSE']._serialized_start=805 - _globals['_GETREADYRESPONSE']._serialized_end=915 - _globals['_P2PSERVICE']._serialized_start=918 - _globals['_P2PSERVICE']._serialized_end=1312 + _globals['_WORKERMETADATA']._serialized_start=129 + _globals['_WORKERMETADATA']._serialized_end=303 + _globals['_PUBLISHMETADATAREQUEST']._serialized_start=305 + _globals['_PUBLISHMETADATAREQUEST']._serialized_end=401 + _globals['_PUBLISHMETADATARESPONSE']._serialized_start=403 + _globals['_PUBLISHMETADATARESPONSE']._serialized_end=462 + _globals['_GETMETADATAREQUEST']._serialized_start=464 + _globals['_GETMETADATAREQUEST']._serialized_end=504 + _globals['_GETMETADATARESPONSE']._serialized_start=506 + _globals['_GETMETADATARESPONSE']._serialized_end=594 + _globals['_PUBLISHREADYREQUEST']._serialized_start=597 + _globals['_PUBLISHREADYREQUEST']._serialized_end=748 + _globals['_PUBLISHREADYRESPONSE']._serialized_start=750 + _globals['_PUBLISHREADYRESPONSE']._serialized_end=806 + _globals['_GETREADYREQUEST']._serialized_start=808 + _globals['_GETREADYREQUEST']._serialized_end=864 + _globals['_GETREADYRESPONSE']._serialized_start=866 + _globals['_GETREADYRESPONSE']._serialized_end=976 + _globals['_P2PSERVICE']._serialized_start=979 + _globals['_P2PSERVICE']._serialized_end=1373 # @@protoc_insertion_point(module_scope) diff --git a/modelexpress_client/python/modelexpress/p2p_pb2_grpc.py b/modelexpress_client/python/modelexpress/p2p_pb2_grpc.py index 3e28afc2..9597298c 100644 --- a/modelexpress_client/python/modelexpress/p2p_pb2_grpc.py +++ b/modelexpress_client/python/modelexpress/p2p_pb2_grpc.py @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/modelexpress_client/python/modelexpress/types.py b/modelexpress_client/python/modelexpress/types.py index 86906433..749db358 100644 --- a/modelexpress_client/python/modelexpress/types.py +++ b/modelexpress_client/python/modelexpress/types.py @@ -20,8 +20,9 @@ class TensorDescriptor: class WorkerMetadata: """Metadata for a single GPU worker.""" worker_rank: int - nixl_metadata: bytes tensors: list[TensorDescriptor] + nixl_metadata: bytes = b"" + transfer_engine_session_id: str = "" @dataclass diff --git a/modelexpress_common/proto/p2p.proto b/modelexpress_common/proto/p2p.proto index d76a2f69..83078373 100644 --- a/modelexpress_common/proto/p2p.proto +++ b/modelexpress_common/proto/p2p.proto @@ -50,11 +50,15 @@ message TensorDescriptor { message WorkerMetadata { // Worker rank (GPU index within the instance) uint32 worker_rank = 1; - - // Serialized NIXL agent metadata for this worker - // Used by remote agents to establish RDMA connections - bytes nixl_metadata = 2; - + + // Backend-specific metadata for establishing transfers. + // NIXL: serialized agent metadata for RDMA connections. + // TransferEngine: Mooncake session ID ("ip:port"). + oneof backend_metadata { + bytes nixl_metadata = 2; + string transfer_engine_session_id = 10; + } + // Tensor descriptors for this worker's GPU repeated TensorDescriptor tensors = 3; } diff --git a/modelexpress_server/src/k8s_types.rs b/modelexpress_server/src/k8s_types.rs index bf862cb2..5b4b5f33 100644 --- a/modelexpress_server/src/k8s_types.rs +++ b/modelexpress_server/src/k8s_types.rs @@ -87,10 +87,14 @@ pub struct WorkerStatus { #[serde(rename = "workerRank")] pub worker_rank: i32, - /// Base64-encoded NIXL agent metadata blob + /// Base64-encoded NIXL agent metadata blob (mutually exclusive with transferEngineSessionId) #[serde(rename = "nixlMetadata", default)] pub nixl_metadata: String, + /// Mooncake TransferEngine session ID (mutually exclusive with nixlMetadata) + #[serde(rename = "transferEngineSessionId", default)] + pub transfer_engine_session_id: Option, + /// Number of tensors registered by this worker #[serde(rename = "tensorCount", default)] pub tensor_count: i32, diff --git a/modelexpress_server/src/metadata_backend.rs b/modelexpress_server/src/metadata_backend.rs index a5bbd7e4..b9d13e76 100644 --- a/modelexpress_server/src/metadata_backend.rs +++ b/modelexpress_server/src/metadata_backend.rs @@ -32,11 +32,22 @@ pub struct ModelMetadataRecord { pub published_at: i64, } +/// Backend-specific metadata for a worker +#[derive(Debug, Clone)] +pub enum BackendMetadataRecord { + /// Serialized NIXL agent metadata for RDMA connections + Nixl(Vec), + /// Mooncake TransferEngine session ID ("ip:port") + TransferEngine(String), + /// No backend metadata provided + None, +} + /// Worker metadata record #[derive(Debug, Clone)] pub struct WorkerRecord { pub worker_rank: u32, - pub nixl_metadata: Vec, + pub backend_metadata: BackendMetadataRecord, pub tensors: Vec, } @@ -53,9 +64,17 @@ pub struct TensorRecord { // Conversions from gRPC types impl From for WorkerRecord { fn from(meta: WorkerMetadata) -> Self { + use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata; + let backend_metadata = match meta.backend_metadata { + Some(BackendMetadata::NixlMetadata(data)) => BackendMetadataRecord::Nixl(data), + Some(BackendMetadata::TransferEngineSessionId(sid)) => { + BackendMetadataRecord::TransferEngine(sid) + } + None => BackendMetadataRecord::None, + }; Self { worker_rank: meta.worker_rank, - nixl_metadata: meta.nixl_metadata, + backend_metadata, tensors: meta.tensors.into_iter().map(TensorRecord::from).collect(), } } @@ -76,9 +95,17 @@ impl From for TensorRecord { // Conversions back to gRPC types impl From for WorkerMetadata { fn from(record: WorkerRecord) -> Self { + use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata; + let backend_metadata = match record.backend_metadata { + BackendMetadataRecord::Nixl(data) => Some(BackendMetadata::NixlMetadata(data)), + BackendMetadataRecord::TransferEngine(sid) => { + Some(BackendMetadata::TransferEngineSessionId(sid)) + } + BackendMetadataRecord::None => None, + }; Self { worker_rank: record.worker_rank, - nixl_metadata: record.nixl_metadata, + backend_metadata, tensors: record .tensors .into_iter() diff --git a/modelexpress_server/src/metadata_backend/kubernetes.rs b/modelexpress_server/src/metadata_backend/kubernetes.rs index 66c13b04..a21eed93 100644 --- a/modelexpress_server/src/metadata_backend/kubernetes.rs +++ b/modelexpress_server/src/metadata_backend/kubernetes.rs @@ -260,9 +260,18 @@ impl MetadataBackend for KubernetesBackend { ) .await?; + let (nixl_metadata, transfer_engine_session_id) = match &worker.backend_metadata { + super::BackendMetadataRecord::Nixl(data) => (BASE64.encode(data), None), + super::BackendMetadataRecord::TransferEngine(sid) => { + (String::new(), Some(sid.clone())) + } + super::BackendMetadataRecord::None => (String::new(), None), + }; + worker_statuses.push(WorkerStatus { worker_rank: worker.worker_rank as i32, - nixl_metadata: BASE64.encode(&worker.nixl_metadata), + nixl_metadata, + transfer_engine_session_id, tensor_count: worker.tensors.len() as i32, tensor_config_map: Some(cm_name), ready: true, @@ -383,13 +392,35 @@ impl MetadataBackend for KubernetesBackend { // Reconstruct workers from status + ConfigMaps let mut workers = Vec::new(); for worker_status in status.workers { - // Decode NIXL metadata — propagate error instead of silently returning empty - let nixl_metadata = BASE64.decode(&worker_status.nixl_metadata).map_err(|e| { - format!( - "Failed to decode NIXL metadata for worker {}: {}", - worker_status.worker_rank, e - ) - })?; + // Determine backend metadata variant + let backend_metadata = + if let Some(sid) = &worker_status.transfer_engine_session_id { + if !sid.is_empty() { + super::BackendMetadataRecord::TransferEngine(sid.clone()) + } else if !worker_status.nixl_metadata.is_empty() { + let data = + BASE64.decode(&worker_status.nixl_metadata).map_err(|e| { + format!( + "Failed to decode NIXL metadata for worker {}: {}", + worker_status.worker_rank, e + ) + })?; + super::BackendMetadataRecord::Nixl(data) + } else { + super::BackendMetadataRecord::None + } + } else if !worker_status.nixl_metadata.is_empty() { + let data = + BASE64.decode(&worker_status.nixl_metadata).map_err(|e| { + format!( + "Failed to decode NIXL metadata for worker {}: {}", + worker_status.worker_rank, e + ) + })?; + super::BackendMetadataRecord::Nixl(data) + } else { + super::BackendMetadataRecord::None + }; // Read tensors from ConfigMap let tensors = if let Some(cm_name) = &worker_status.tensor_config_map { @@ -406,7 +437,7 @@ impl MetadataBackend for KubernetesBackend { workers.push(WorkerRecord { worker_rank: worker_status.worker_rank as u32, - nixl_metadata, + backend_metadata, tensors, }); } diff --git a/modelexpress_server/src/metadata_backend/layered.rs b/modelexpress_server/src/metadata_backend/layered.rs index 40b9db9f..acdf3b1c 100644 --- a/modelexpress_server/src/metadata_backend/layered.rs +++ b/modelexpress_server/src/metadata_backend/layered.rs @@ -162,9 +162,10 @@ mod tests { use super::*; fn test_workers(rank: u32) -> Vec { + use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata; vec![WorkerMetadata { worker_rank: rank, - nixl_metadata: vec![1, 2, 3], + backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1, 2, 3])), tensors: vec![], }] } diff --git a/modelexpress_server/src/metadata_backend/memory.rs b/modelexpress_server/src/metadata_backend/memory.rs index d15a5b8c..a0e64865 100644 --- a/modelexpress_server/src/metadata_backend/memory.rs +++ b/modelexpress_server/src/metadata_backend/memory.rs @@ -125,6 +125,8 @@ mod tests { use super::*; use modelexpress_common::grpc::p2p::TensorDescriptor; + use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata; + #[tokio::test] async fn test_publish_and_get() { let backend = InMemoryBackend::new(); @@ -132,7 +134,7 @@ mod tests { let workers = vec![WorkerMetadata { worker_rank: 0, - nixl_metadata: vec![1, 2, 3], + backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1, 2, 3])), tensors: vec![TensorDescriptor { name: "layer.0.weight".to_string(), addr: 0x1000, @@ -167,7 +169,7 @@ mod tests { "test-model", vec![WorkerMetadata { worker_rank: 0, - nixl_metadata: vec![1], + backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1])), tensors: vec![], }], ) @@ -180,7 +182,7 @@ mod tests { "test-model", vec![WorkerMetadata { worker_rank: 1, - nixl_metadata: vec![2], + backend_metadata: Some(BackendMetadata::NixlMetadata(vec![2])), tensors: vec![], }], ) @@ -203,7 +205,7 @@ mod tests { "model-a", vec![WorkerMetadata { worker_rank: 0, - nixl_metadata: vec![], + backend_metadata: None, tensors: vec![], }], ) @@ -215,7 +217,7 @@ mod tests { "model-b", vec![WorkerMetadata { worker_rank: 0, - nixl_metadata: vec![], + backend_metadata: None, tensors: vec![], }], ) diff --git a/modelexpress_server/src/metadata_backend/redis.rs b/modelexpress_server/src/metadata_backend/redis.rs index 27fb0ac9..ebf4a345 100644 --- a/modelexpress_server/src/metadata_backend/redis.rs +++ b/modelexpress_server/src/metadata_backend/redis.rs @@ -118,15 +118,24 @@ impl From for TensorRecord { #[derive(Debug, Clone, Serialize, Deserialize)] struct WorkerRecordJson { pub worker_rank: u32, + #[serde(default)] pub nixl_metadata: Vec, + #[serde(default)] + pub transfer_engine_session_id: Option, pub tensors: Vec, } impl From for WorkerRecordJson { fn from(record: WorkerRecord) -> Self { + let (nixl_metadata, transfer_engine_session_id) = match record.backend_metadata { + super::BackendMetadataRecord::Nixl(data) => (data, None), + super::BackendMetadataRecord::TransferEngine(sid) => (Vec::new(), Some(sid)), + super::BackendMetadataRecord::None => (Vec::new(), None), + }; Self { worker_rank: record.worker_rank, - nixl_metadata: record.nixl_metadata, + nixl_metadata, + transfer_engine_session_id, tensors: record .tensors .into_iter() @@ -138,9 +147,22 @@ impl From for WorkerRecordJson { impl From for WorkerRecord { fn from(json: WorkerRecordJson) -> Self { + let backend_metadata = if let Some(sid) = json.transfer_engine_session_id { + if !sid.is_empty() { + super::BackendMetadataRecord::TransferEngine(sid) + } else if !json.nixl_metadata.is_empty() { + super::BackendMetadataRecord::Nixl(json.nixl_metadata) + } else { + super::BackendMetadataRecord::None + } + } else if !json.nixl_metadata.is_empty() { + super::BackendMetadataRecord::Nixl(json.nixl_metadata) + } else { + super::BackendMetadataRecord::None + }; Self { worker_rank: json.worker_rank, - nixl_metadata: json.nixl_metadata, + backend_metadata, tensors: json.tensors.into_iter().map(TensorRecord::from).collect(), } } diff --git a/modelexpress_server/src/state.rs b/modelexpress_server/src/state.rs index f27a6271..35967ed0 100644 --- a/modelexpress_server/src/state.rs +++ b/modelexpress_server/src/state.rs @@ -18,7 +18,9 @@ use tokio::sync::RwLock; use tracing::{debug, info}; // Re-export types for backwards compatibility -pub use crate::metadata_backend::{ModelMetadataRecord, TensorRecord, WorkerRecord}; +pub use crate::metadata_backend::{ + BackendMetadataRecord, ModelMetadataRecord, TensorRecord, WorkerRecord, +}; /// Ready state for a source worker (stored in-memory, always ephemeral). #[derive(Debug, Clone, Serialize, Deserialize)] @@ -236,9 +238,11 @@ mod tests { #[test] fn test_worker_record_conversion() { + use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata; + let meta = WorkerMetadata { worker_rank: 3, - nixl_metadata: vec![1, 2, 3, 4, 5], + backend_metadata: Some(BackendMetadata::NixlMetadata(vec![1, 2, 3, 4, 5])), tensors: vec![TensorDescriptor { name: "test.weight".to_string(), addr: 0x1000, @@ -250,12 +254,15 @@ mod tests { let record = WorkerRecord::from(meta.clone()); assert_eq!(record.worker_rank, 3); - assert_eq!(record.nixl_metadata, vec![1, 2, 3, 4, 5]); + assert!(matches!( + &record.backend_metadata, + BackendMetadataRecord::Nixl(d) if d == &vec![1, 2, 3, 4, 5] + )); assert_eq!(record.tensors.len(), 1); let back: WorkerMetadata = record.into(); assert_eq!(back.worker_rank, meta.worker_rank); - assert_eq!(back.nixl_metadata, meta.nixl_metadata); + assert_eq!(back.backend_metadata, meta.backend_metadata); } #[test] @@ -265,7 +272,7 @@ mod tests { workers: vec![ WorkerRecord { worker_rank: 0, - nixl_metadata: vec![10, 20, 30], + backend_metadata: BackendMetadataRecord::Nixl(vec![10, 20, 30]), tensors: vec![TensorRecord { name: "layer.0.weight".to_string(), addr: 0x7f00_0000_0000, @@ -276,7 +283,7 @@ mod tests { }, WorkerRecord { worker_rank: 1, - nixl_metadata: vec![40, 50, 60], + backend_metadata: BackendMetadataRecord::Nixl(vec![40, 50, 60]), tensors: vec![TensorRecord { name: "layer.0.weight".to_string(), addr: 0x7f00_0000_0000, From 893a2c1d81c62639df7031bd64daa32a499ac67d Mon Sep 17 00:00:00 2001 From: Ishan Dhanani Date: Thu, 5 Mar 2026 21:40:41 +0000 Subject: [PATCH 2/2] Add shard metadata fields to TensorDescriptor for mixed TP Adds full_shape, shard_dim, effective_tp_size, and shard_index to TensorDescriptor in p2p.proto. These fields enable target workers to compute byte-range overlaps when source and target have different tensor parallelism degrees (e.g., seed TP=2, target TP=4). Updates all backends (memory, redis, k8s), types, and Python stubs. --- .../python/modelexpress/p2p_pb2.py | 46 +++++++++---------- modelexpress_common/proto/p2p.proto | 16 +++++++ modelexpress_server/src/k8s_types.rs | 24 ++++++++++ modelexpress_server/src/metadata_backend.rs | 13 ++++++ .../src/metadata_backend/kubernetes.rs | 8 ++++ .../src/metadata_backend/memory.rs | 4 ++ .../src/metadata_backend/redis.rs | 16 +++++++ modelexpress_server/src/state.rs | 16 +++++++ 8 files changed, 120 insertions(+), 23 deletions(-) diff --git a/modelexpress_client/python/modelexpress/p2p_pb2.py b/modelexpress_client/python/modelexpress/p2p_pb2.py index 9286b5bf..151c1409 100644 --- a/modelexpress_client/python/modelexpress/p2p_pb2.py +++ b/modelexpress_client/python/modelexpress/p2p_pb2.py @@ -24,33 +24,33 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tp2p.proto\x12\x11model_express.p2p\"^\n\x10TensorDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x02 \x01(\x04\x12\x0c\n\x04size\x18\x03 \x01(\x04\x12\x11\n\tdevice_id\x18\x04 \x01(\r\x12\r\n\x05\x64type\x18\x05 \x01(\t\"\xae\x01\n\x0eWorkerMetadata\x12\x13\n\x0bworker_rank\x18\x01 \x01(\r\x12\x17\n\rnixl_metadata\x18\x02 \x01(\x0cH\x00\x12$\n\x1atransfer_engine_session_id\x18\n \x01(\tH\x00\x12\x34\n\x07tensors\x18\x03 \x03(\x0b\x32#.model_express.p2p.TensorDescriptorB\x12\n\x10\x62\x61\x63kend_metadata\"`\n\x16PublishMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\";\n\x17PublishMetadataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x12GetMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\"X\n\x13GetMetadataResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\"\x97\x01\n\x13PublishReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x12\n\nnixl_ready\x18\x05 \x01(\x08\x12\x1a\n\x12stability_verified\x18\x06 \x01(\x08\"8\n\x14PublishReadyResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"8\n\x0fGetReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\"n\n\x10GetReadyResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\r\n\x05ready\x18\x02 \x01(\x08\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x01\x32\x8a\x03\n\nP2pService\x12h\n\x0fPublishMetadata\x12).model_express.p2p.PublishMetadataRequest\x1a*.model_express.p2p.PublishMetadataResponse\x12\\\n\x0bGetMetadata\x12%.model_express.p2p.GetMetadataRequest\x1a&.model_express.p2p.GetMetadataResponse\x12_\n\x0cPublishReady\x12&.model_express.p2p.PublishReadyRequest\x1a\'.model_express.p2p.PublishReadyResponse\x12S\n\x08GetReady\x12\".model_express.p2p.GetReadyRequest\x1a#.model_express.p2p.GetReadyResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tp2p.proto\x12\x11model_express.p2p\"\xb5\x01\n\x10TensorDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x02 \x01(\x04\x12\x0c\n\x04size\x18\x03 \x01(\x04\x12\x11\n\tdevice_id\x18\x04 \x01(\r\x12\r\n\x05\x64type\x18\x05 \x01(\t\x12\x12\n\nfull_shape\x18\x06 \x03(\x04\x12\x11\n\tshard_dim\x18\x07 \x01(\x05\x12\x19\n\x11\x65\x66\x66\x65\x63tive_tp_size\x18\x08 \x01(\r\x12\x13\n\x0bshard_index\x18\t \x01(\r\"\xae\x01\n\x0eWorkerMetadata\x12\x13\n\x0bworker_rank\x18\x01 \x01(\r\x12\x17\n\rnixl_metadata\x18\x02 \x01(\x0cH\x00\x12$\n\x1atransfer_engine_session_id\x18\n \x01(\tH\x00\x12\x34\n\x07tensors\x18\x03 \x03(\x0b\x32#.model_express.p2p.TensorDescriptorB\x12\n\x10\x62\x61\x63kend_metadata\"`\n\x16PublishMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\";\n\x17PublishMetadataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x12GetMetadataRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\"X\n\x13GetMetadataResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\x32\n\x07workers\x18\x02 \x03(\x0b\x32!.model_express.p2p.WorkerMetadata\"\x97\x01\n\x13PublishReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x12\n\nnixl_ready\x18\x05 \x01(\x08\x12\x1a\n\x12stability_verified\x18\x06 \x01(\x08\"8\n\x14PublishReadyResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"8\n\x0fGetReadyRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\r\"n\n\x10GetReadyResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\r\n\x05ready\x18\x02 \x01(\x08\x12\x12\n\nsession_id\x18\x03 \x01(\t\x12\x15\n\rmetadata_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x01\x32\x8a\x03\n\nP2pService\x12h\n\x0fPublishMetadata\x12).model_express.p2p.PublishMetadataRequest\x1a*.model_express.p2p.PublishMetadataResponse\x12\\\n\x0bGetMetadata\x12%.model_express.p2p.GetMetadataRequest\x1a&.model_express.p2p.GetMetadataResponse\x12_\n\x0cPublishReady\x12&.model_express.p2p.PublishReadyRequest\x1a\'.model_express.p2p.PublishReadyResponse\x12S\n\x08GetReady\x12\".model_express.p2p.GetReadyRequest\x1a#.model_express.p2p.GetReadyResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'p2p_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TENSORDESCRIPTOR']._serialized_start=32 - _globals['_TENSORDESCRIPTOR']._serialized_end=126 - _globals['_WORKERMETADATA']._serialized_start=129 - _globals['_WORKERMETADATA']._serialized_end=303 - _globals['_PUBLISHMETADATAREQUEST']._serialized_start=305 - _globals['_PUBLISHMETADATAREQUEST']._serialized_end=401 - _globals['_PUBLISHMETADATARESPONSE']._serialized_start=403 - _globals['_PUBLISHMETADATARESPONSE']._serialized_end=462 - _globals['_GETMETADATAREQUEST']._serialized_start=464 - _globals['_GETMETADATAREQUEST']._serialized_end=504 - _globals['_GETMETADATARESPONSE']._serialized_start=506 - _globals['_GETMETADATARESPONSE']._serialized_end=594 - _globals['_PUBLISHREADYREQUEST']._serialized_start=597 - _globals['_PUBLISHREADYREQUEST']._serialized_end=748 - _globals['_PUBLISHREADYRESPONSE']._serialized_start=750 - _globals['_PUBLISHREADYRESPONSE']._serialized_end=806 - _globals['_GETREADYREQUEST']._serialized_start=808 - _globals['_GETREADYREQUEST']._serialized_end=864 - _globals['_GETREADYRESPONSE']._serialized_start=866 - _globals['_GETREADYRESPONSE']._serialized_end=976 - _globals['_P2PSERVICE']._serialized_start=979 - _globals['_P2PSERVICE']._serialized_end=1373 + _globals['_TENSORDESCRIPTOR']._serialized_start=33 + _globals['_TENSORDESCRIPTOR']._serialized_end=214 + _globals['_WORKERMETADATA']._serialized_start=217 + _globals['_WORKERMETADATA']._serialized_end=391 + _globals['_PUBLISHMETADATAREQUEST']._serialized_start=393 + _globals['_PUBLISHMETADATAREQUEST']._serialized_end=489 + _globals['_PUBLISHMETADATARESPONSE']._serialized_start=491 + _globals['_PUBLISHMETADATARESPONSE']._serialized_end=550 + _globals['_GETMETADATAREQUEST']._serialized_start=552 + _globals['_GETMETADATAREQUEST']._serialized_end=592 + _globals['_GETMETADATARESPONSE']._serialized_start=594 + _globals['_GETMETADATARESPONSE']._serialized_end=682 + _globals['_PUBLISHREADYREQUEST']._serialized_start=685 + _globals['_PUBLISHREADYREQUEST']._serialized_end=836 + _globals['_PUBLISHREADYRESPONSE']._serialized_start=838 + _globals['_PUBLISHREADYRESPONSE']._serialized_end=894 + _globals['_GETREADYREQUEST']._serialized_start=896 + _globals['_GETREADYREQUEST']._serialized_end=952 + _globals['_GETREADYRESPONSE']._serialized_start=954 + _globals['_GETREADYRESPONSE']._serialized_end=1064 + _globals['_P2PSERVICE']._serialized_start=1067 + _globals['_P2PSERVICE']._serialized_end=1461 # @@protoc_insertion_point(module_scope) diff --git a/modelexpress_common/proto/p2p.proto b/modelexpress_common/proto/p2p.proto index 83078373..5c207ab2 100644 --- a/modelexpress_common/proto/p2p.proto +++ b/modelexpress_common/proto/p2p.proto @@ -41,6 +41,22 @@ message TensorDescriptor { // Data type (e.g., "float16", "bfloat16") string dtype = 5; + + // Shard metadata for mixed tensor-parallelism support. + // When present, allows target instances with different TP to compute + // which byte ranges to read from which source ranks. + + // Unsharded logical shape (e.g., [5120, 5120] for a full weight matrix) + repeated uint64 full_shape = 6; + + // Which dimension is sharded: -1 = replicated, 0 or 1 = sharded dim + int32 shard_dim = 7; + + // Number of shards this tensor is split across (effective TP for this tensor) + uint32 effective_tp_size = 8; + + // Which shard this rank holds (0..effective_tp_size-1) + uint32 shard_index = 9; } // ============================================================================ diff --git a/modelexpress_server/src/k8s_types.rs b/modelexpress_server/src/k8s_types.rs index 5b4b5f33..86dac832 100644 --- a/modelexpress_server/src/k8s_types.rs +++ b/modelexpress_server/src/k8s_types.rs @@ -153,6 +153,22 @@ pub struct TensorDescriptorJson { pub size: String, pub device_id: u32, pub dtype: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub full_shape: Vec, + #[serde(default, skip_serializing_if = "is_zero_i32")] + pub shard_dim: i32, + #[serde(default, skip_serializing_if = "is_zero_u32")] + pub effective_tp_size: u32, + #[serde(default, skip_serializing_if = "is_zero_u32")] + pub shard_index: u32, +} + +fn is_zero_i32(v: &i32) -> bool { + *v == 0 +} + +fn is_zero_u32(v: &u32) -> bool { + *v == 0 } /// Sanitize model name to be a valid Kubernetes resource name @@ -212,6 +228,10 @@ mod tests { size: "134217728".to_string(), device_id: 0, dtype: "bfloat16".to_string(), + full_shape: vec![5120, 5120], + shard_dim: 0, + effective_tp_size: 2, + shard_index: 0, }; let json = serde_json::to_string(&original).expect("serialize"); @@ -237,6 +257,10 @@ mod tests { size: u64::MAX.to_string(), device_id: 7, dtype: "float16".to_string(), + full_shape: vec![], + shard_dim: 0, + effective_tp_size: 0, + shard_index: 0, }; let json = serde_json::to_string(&desc).expect("serialize"); diff --git a/modelexpress_server/src/metadata_backend.rs b/modelexpress_server/src/metadata_backend.rs index b9d13e76..5b194c7e 100644 --- a/modelexpress_server/src/metadata_backend.rs +++ b/modelexpress_server/src/metadata_backend.rs @@ -59,6 +59,11 @@ pub struct TensorRecord { pub size: u64, pub device_id: u32, pub dtype: String, + // Shard metadata for mixed tensor-parallelism + pub full_shape: Vec, + pub shard_dim: i32, + pub effective_tp_size: u32, + pub shard_index: u32, } // Conversions from gRPC types @@ -88,6 +93,10 @@ impl From for TensorRecord { size: desc.size, device_id: desc.device_id, dtype: desc.dtype, + full_shape: desc.full_shape, + shard_dim: desc.shard_dim, + effective_tp_size: desc.effective_tp_size, + shard_index: desc.shard_index, } } } @@ -123,6 +132,10 @@ impl From for modelexpress_common::grpc::p2p::TensorDescriptor { size: record.size, device_id: record.device_id, dtype: record.dtype, + full_shape: record.full_shape, + shard_dim: record.shard_dim, + effective_tp_size: record.effective_tp_size, + shard_index: record.shard_index, } } } diff --git a/modelexpress_server/src/metadata_backend/kubernetes.rs b/modelexpress_server/src/metadata_backend/kubernetes.rs index a21eed93..1551a2e4 100644 --- a/modelexpress_server/src/metadata_backend/kubernetes.rs +++ b/modelexpress_server/src/metadata_backend/kubernetes.rs @@ -71,6 +71,10 @@ impl KubernetesBackend { size: t.size.to_string(), device_id: t.device_id, dtype: t.dtype.clone(), + full_shape: t.full_shape.clone(), + shard_dim: t.shard_dim, + effective_tp_size: t.effective_tp_size, + shard_index: t.shard_index, }) .collect(); @@ -165,6 +169,10 @@ impl KubernetesBackend { size, device_id: t.device_id, dtype: t.dtype, + full_shape: t.full_shape, + shard_dim: t.shard_dim, + effective_tp_size: t.effective_tp_size, + shard_index: t.shard_index, }) }) .collect::>>()?; diff --git a/modelexpress_server/src/metadata_backend/memory.rs b/modelexpress_server/src/metadata_backend/memory.rs index a0e64865..7834cab4 100644 --- a/modelexpress_server/src/metadata_backend/memory.rs +++ b/modelexpress_server/src/metadata_backend/memory.rs @@ -141,6 +141,10 @@ mod tests { size: 4096, device_id: 0, dtype: "bfloat16".to_string(), + full_shape: vec![], + shard_dim: 0, + effective_tp_size: 0, + shard_index: 0, }], }]; diff --git a/modelexpress_server/src/metadata_backend/redis.rs b/modelexpress_server/src/metadata_backend/redis.rs index ebf4a345..360df68f 100644 --- a/modelexpress_server/src/metadata_backend/redis.rs +++ b/modelexpress_server/src/metadata_backend/redis.rs @@ -36,6 +36,14 @@ struct TensorRecordJson { pub size: u64, pub device_id: u32, pub dtype: String, + #[serde(default)] + pub full_shape: Vec, + #[serde(default)] + pub shard_dim: i32, + #[serde(default)] + pub effective_tp_size: u32, + #[serde(default)] + pub shard_index: u32, } fn serialize_u64_as_string(value: &u64, serializer: S) -> Result @@ -98,6 +106,10 @@ impl From for TensorRecordJson { size: record.size, device_id: record.device_id, dtype: record.dtype, + full_shape: record.full_shape, + shard_dim: record.shard_dim, + effective_tp_size: record.effective_tp_size, + shard_index: record.shard_index, } } } @@ -110,6 +122,10 @@ impl From for TensorRecord { size: json.size, device_id: json.device_id, dtype: json.dtype, + full_shape: json.full_shape, + shard_dim: json.shard_dim, + effective_tp_size: json.effective_tp_size, + shard_index: json.shard_index, } } } diff --git a/modelexpress_server/src/state.rs b/modelexpress_server/src/state.rs index 35967ed0..85559533 100644 --- a/modelexpress_server/src/state.rs +++ b/modelexpress_server/src/state.rs @@ -225,6 +225,10 @@ mod tests { size: 1024 * 1024 * 1024, device_id: 0, dtype: "bfloat16".to_string(), + full_shape: vec![5120, 5120], + shard_dim: 0, + effective_tp_size: 2, + shard_index: 0, }; let record = TensorRecord::from(desc.clone()); @@ -249,6 +253,10 @@ mod tests { size: 4096, device_id: 3, dtype: "float16".to_string(), + full_shape: vec![], + shard_dim: 0, + effective_tp_size: 0, + shard_index: 0, }], }; @@ -279,6 +287,10 @@ mod tests { size: 1_000_000, device_id: 0, dtype: "bfloat16".to_string(), + full_shape: vec![], + shard_dim: 0, + effective_tp_size: 0, + shard_index: 0, }], }, WorkerRecord { @@ -290,6 +302,10 @@ mod tests { size: 1_000_000, device_id: 1, dtype: "bfloat16".to_string(), + full_shape: vec![], + shard_dim: 0, + effective_tp_size: 0, + shard_index: 0, }], }, ],