Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 23 additions & 26 deletions modelexpress_client/python/modelexpress/p2p_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions modelexpress_client/python/modelexpress/p2p_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
Expand Down
3 changes: 2 additions & 1 deletion modelexpress_client/python/modelexpress/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class TensorDescriptor:
class WorkerMetadata:
"""Metadata for a single GPU worker."""
worker_rank: int
nixl_metadata: bytes
tensors: list[TensorDescriptor]
nixl_metadata: bytes = b""
transfer_engine_session_id: str = ""


@dataclass
Expand Down
30 changes: 25 additions & 5 deletions modelexpress_common/proto/p2p.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ message TensorDescriptor {

// Data type (e.g., "float16", "bfloat16")
string dtype = 5;

// Shard metadata for mixed tensor-parallelism support.
// When present, allows target instances with different TP to compute
// which byte ranges to read from which source ranks.

// Unsharded logical shape (e.g., [5120, 5120] for a full weight matrix)
repeated uint64 full_shape = 6;

// Which dimension is sharded: -1 = replicated, 0 or 1 = sharded dim
int32 shard_dim = 7;

// Number of shards this tensor is split across (effective TP for this tensor)
uint32 effective_tp_size = 8;

// Which shard this rank holds (0..effective_tp_size-1)
uint32 shard_index = 9;
}

// ============================================================================
Expand All @@ -50,11 +66,15 @@ message TensorDescriptor {
message WorkerMetadata {
// Worker rank (GPU index within the instance)
uint32 worker_rank = 1;

// Serialized NIXL agent metadata for this worker
// Used by remote agents to establish RDMA connections
bytes nixl_metadata = 2;


// Backend-specific metadata for establishing transfers.
// NIXL: serialized agent metadata for RDMA connections.
// TransferEngine: Mooncake session ID ("ip:port").
oneof backend_metadata {
bytes nixl_metadata = 2;
string transfer_engine_session_id = 10;
}

// Tensor descriptors for this worker's GPU
repeated TensorDescriptor tensors = 3;
}
Expand Down
30 changes: 29 additions & 1 deletion modelexpress_server/src/k8s_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,14 @@ pub struct WorkerStatus {
#[serde(rename = "workerRank")]
pub worker_rank: i32,

/// Base64-encoded NIXL agent metadata blob
/// Base64-encoded NIXL agent metadata blob (mutually exclusive with transferEngineSessionId)
#[serde(rename = "nixlMetadata", default)]
pub nixl_metadata: String,

/// Mooncake TransferEngine session ID (mutually exclusive with nixlMetadata)
#[serde(rename = "transferEngineSessionId", default)]
pub transfer_engine_session_id: Option<String>,

/// Number of tensors registered by this worker
#[serde(rename = "tensorCount", default)]
pub tensor_count: i32,
Expand Down Expand Up @@ -149,6 +153,22 @@ pub struct TensorDescriptorJson {
pub size: String,
pub device_id: u32,
pub dtype: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub full_shape: Vec<u64>,
#[serde(default, skip_serializing_if = "is_zero_i32")]
pub shard_dim: i32,
#[serde(default, skip_serializing_if = "is_zero_u32")]
pub effective_tp_size: u32,
#[serde(default, skip_serializing_if = "is_zero_u32")]
pub shard_index: u32,
}

fn is_zero_i32(v: &i32) -> bool {
*v == 0
}

fn is_zero_u32(v: &u32) -> bool {
*v == 0
}

/// Sanitize model name to be a valid Kubernetes resource name
Expand Down Expand Up @@ -208,6 +228,10 @@ mod tests {
size: "134217728".to_string(),
device_id: 0,
dtype: "bfloat16".to_string(),
full_shape: vec![5120, 5120],
shard_dim: 0,
effective_tp_size: 2,
shard_index: 0,
};

let json = serde_json::to_string(&original).expect("serialize");
Expand All @@ -233,6 +257,10 @@ mod tests {
size: u64::MAX.to_string(),
device_id: 7,
dtype: "float16".to_string(),
full_shape: vec![],
shard_dim: 0,
effective_tp_size: 0,
shard_index: 0,
};

let json = serde_json::to_string(&desc).expect("serialize");
Expand Down
46 changes: 43 additions & 3 deletions modelexpress_server/src/metadata_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,22 @@ pub struct ModelMetadataRecord {
pub published_at: i64,
}

/// Backend-specific metadata for a worker
#[derive(Debug, Clone)]
pub enum BackendMetadataRecord {
/// Serialized NIXL agent metadata for RDMA connections
Nixl(Vec<u8>),
/// Mooncake TransferEngine session ID ("ip:port")
TransferEngine(String),
/// No backend metadata provided
None,
}

/// Worker metadata record
#[derive(Debug, Clone)]
pub struct WorkerRecord {
pub worker_rank: u32,
pub nixl_metadata: Vec<u8>,
pub backend_metadata: BackendMetadataRecord,
pub tensors: Vec<TensorRecord>,
}

Expand All @@ -48,14 +59,27 @@ pub struct TensorRecord {
pub size: u64,
pub device_id: u32,
pub dtype: String,
// Shard metadata for mixed tensor-parallelism
pub full_shape: Vec<u64>,
pub shard_dim: i32,
pub effective_tp_size: u32,
pub shard_index: u32,
}

// Conversions from gRPC types
impl From<WorkerMetadata> for WorkerRecord {
fn from(meta: WorkerMetadata) -> Self {
use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
let backend_metadata = match meta.backend_metadata {
Some(BackendMetadata::NixlMetadata(data)) => BackendMetadataRecord::Nixl(data),
Some(BackendMetadata::TransferEngineSessionId(sid)) => {
BackendMetadataRecord::TransferEngine(sid)
}
None => BackendMetadataRecord::None,
};
Self {
worker_rank: meta.worker_rank,
nixl_metadata: meta.nixl_metadata,
backend_metadata,
tensors: meta.tensors.into_iter().map(TensorRecord::from).collect(),
}
}
Expand All @@ -69,16 +93,28 @@ impl From<modelexpress_common::grpc::p2p::TensorDescriptor> for TensorRecord {
size: desc.size,
device_id: desc.device_id,
dtype: desc.dtype,
full_shape: desc.full_shape,
shard_dim: desc.shard_dim,
effective_tp_size: desc.effective_tp_size,
shard_index: desc.shard_index,
}
}
}

// Conversions back to gRPC types
impl From<WorkerRecord> for WorkerMetadata {
fn from(record: WorkerRecord) -> Self {
use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata;
let backend_metadata = match record.backend_metadata {
BackendMetadataRecord::Nixl(data) => Some(BackendMetadata::NixlMetadata(data)),
BackendMetadataRecord::TransferEngine(sid) => {
Some(BackendMetadata::TransferEngineSessionId(sid))
}
BackendMetadataRecord::None => None,
};
Self {
worker_rank: record.worker_rank,
nixl_metadata: record.nixl_metadata,
backend_metadata,
tensors: record
.tensors
.into_iter()
Expand All @@ -96,6 +132,10 @@ impl From<TensorRecord> for modelexpress_common::grpc::p2p::TensorDescriptor {
size: record.size,
device_id: record.device_id,
dtype: record.dtype,
full_shape: record.full_shape,
shard_dim: record.shard_dim,
effective_tp_size: record.effective_tp_size,
shard_index: record.shard_index,
}
}
}
Expand Down
Loading
Loading