Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions e2e_test/responses/test_tools_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,124 @@ def assert_streaming_mcp_call_ids_match(events):
assert set(final_mcp_ids) == expected_ids


def mcp_list_tools_labels(output):
return [item.server_label for item in output if item.type == "mcp_list_tools"]


def streaming_added_mcp_list_tools_labels(events):
return [
event.item.server_label
for event in events
if event.type == "response.output_item.added"
and event.item is not None
and event.item.type == "mcp_list_tools"
]


def assert_previous_response_id_mcp_binding_behavior_non_streaming(model, api_client):
time.sleep(2)
Comment thread
zhaowenzi marked this conversation as resolved.

resp1 = api_client.responses.create(
model=model,
input=MCP_TEST_PROMPT,
tools=[BRAVE_MCP_TOOL],
stream=False,
reasoning={"effort": "low"},
)
assert resp1.error is None
assert resp1.status == "completed"
assert mcp_list_tools_labels(resp1.output) == ["brave"]

resp2 = api_client.responses.create(
model=model,
input=(
"Search the web for 'Rust programming language'. Set count to 1 and return one "
"sentence response."
),
previous_response_id=resp1.id,
tools=[BRAVE_MCP_TOOL],
stream=False,
reasoning={"effort": "low"},
)
assert resp2.error is None
assert resp2.status == "completed"
assert mcp_list_tools_labels(resp2.output) == []
assert any(item.type == "mcp_call" for item in resp2.output)

resp3 = api_client.responses.create(
model=model,
input=(
"Use deepwiki to tell me which transport protocols the 2025-03-26 MCP spec "
"supports, and also use brave_web_search to search the web for 'Rust programming "
"language'. Return exactly two bullet points."
),
previous_response_id=resp2.id,
tools=[BRAVE_MCP_TOOL, DEEPWIKI_MCP_TOOL],
stream=False,
reasoning={"effort": "low"},
)
assert resp3.error is None
assert resp3.status == "completed"
assert mcp_list_tools_labels(resp3.output) == ["deepwiki"]
assert any(item.type == "mcp_call" for item in resp3.output)


def assert_previous_response_id_mcp_binding_behavior_streaming(model, api_client):
time.sleep(2)

events1 = list(
api_client.responses.create(
model=model,
input=MCP_TEST_PROMPT,
tools=[BRAVE_MCP_TOOL],
stream=True,
reasoning={"effort": "low"},
)
)
assert streaming_added_mcp_list_tools_labels(events1) == ["brave"]

events2 = list(
api_client.responses.create(
model=model,
input=(
"Search the web for 'Rust programming language'. Set count to 1 and return one "
"sentence response."
),
previous_response_id=[e for e in events1 if e.type == "response.completed"][
0
].response.id,
Comment thread
zhaowenzi marked this conversation as resolved.
Comment thread
zhaowenzi marked this conversation as resolved.
tools=[BRAVE_MCP_TOOL],
stream=True,
reasoning={"effort": "low"},
)
)
assert streaming_added_mcp_list_tools_labels(events2) == []
assert any(event.type == "response.mcp_call.completed" for event in events2)

events3 = list(
api_client.responses.create(
model=model,
input=(
"Use deepwiki to tell me which transport protocols the 2025-03-26 MCP spec "
"supports, and also use brave_web_search to search the web for 'Rust programming "
"language'. Return exactly two bullet points."
),
previous_response_id=[e for e in events2 if e.type == "response.completed"][
0
].response.id,
tools=[BRAVE_MCP_TOOL, DEEPWIKI_MCP_TOOL],
stream=True,
reasoning={"effort": "low"},
)
)
assert streaming_added_mcp_list_tools_labels(events3) == ["deepwiki"]
assert [e.type for e in events3].count("response.mcp_list_tools.in_progress") == 1
assert [e.type for e in events3].count("response.mcp_list_tools.completed") == 1
completed_events = [e for e in events3 if e.type == "response.completed"]
assert len(completed_events) == 1
assert mcp_list_tools_labels(completed_events[0].response.output) == ["deepwiki"]


@pytest.mark.vendor("openai")
@pytest.mark.gpu(0)
@pytest.mark.parametrize("setup_backend", ["openai"], indirect=True)
Expand Down Expand Up @@ -470,6 +588,16 @@ def test_mcp_multi_server_tool_call_streaming(self, model, api_client):
for mcp_call in mcp_calls:
assert mcp_call.server_label == "brave"

def test_previous_response_id_mcp_binding_behavior(self, model, api_client):
"""Resumed turns should not relist existing MCP bindings."""

assert_previous_response_id_mcp_binding_behavior_non_streaming(model, api_client)

def test_previous_response_id_mcp_binding_behavior_streaming(self, model, api_client):
"""Streaming resumed turns should only list newly added MCP bindings."""

assert_previous_response_id_mcp_binding_behavior_streaming(model, api_client)

def test_concurrent_mcp_different_servers(self, model, api_client):
"""Concurrent non-streaming requests with different MCP servers don't contaminate each other."""

Expand Down
1 change: 0 additions & 1 deletion model_gateway/src/routers/openai/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ pub(super) async fn route_chat(
ctx.state.payload = Some(PayloadState {
json: payload,
url: url.clone(),
previous_response_id: None,
});

// Wrap values in Arc to avoid cloning large objects on each retry attempt
Expand Down
15 changes: 14 additions & 1 deletion model_gateway/src/routers/openai/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ impl ComponentRefs {
pub struct ProcessingState {
pub worker: Option<WorkerSelection>,
pub payload: Option<PayloadState>,
pub responses_payload: Option<ResponsesPayloadState>,
}

pub struct WorkerSelection {
Expand All @@ -102,7 +103,12 @@ pub struct WorkerSelection {
pub struct PayloadState {
pub json: Value,
pub url: String,
}

#[derive(Default)]
pub struct ResponsesPayloadState {
pub previous_response_id: Option<String>,
pub existing_mcp_list_tools_labels: Vec<String>,
}

impl RequestContext {
Expand Down Expand Up @@ -188,6 +194,10 @@ impl RequestContext {
pub fn take_payload(&mut self) -> Option<PayloadState> {
self.state.payload.take()
}

pub fn take_responses_payload(&mut self) -> Option<ResponsesPayloadState> {
self.state.responses_payload.take()
}
}

pub struct StorageHandles {
Expand All @@ -202,12 +212,14 @@ pub struct OwnedStreamingContext {
pub payload: Value,
pub original_body: ResponsesRequest,
pub previous_response_id: Option<String>,
pub existing_mcp_list_tools_labels: Vec<String>,
pub storage: StorageHandles,
}

impl RequestContext {
pub fn into_streaming_context(mut self) -> Result<OwnedStreamingContext, &'static str> {
let payload_state = self.take_payload().ok_or("Payload not prepared")?;
let responses_payload_state = self.take_responses_payload().unwrap_or_default();
let original_body = self
.responses_request()
.ok_or("Expected responses request")?
Expand All @@ -232,7 +244,8 @@ impl RequestContext {
url: payload_state.url,
payload: payload_state.json,
original_body,
previous_response_id: payload_state.previous_response_id,
previous_response_id: responses_payload_state.previous_response_id,
existing_mcp_list_tools_labels: responses_payload_state.existing_mcp_list_tools_labels,
storage: StorageHandles {
response,
conversation,
Expand Down
4 changes: 2 additions & 2 deletions model_gateway/src/routers/openai/mcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ pub(crate) use tool_handler::{StreamAction, StreamingToolHandler};
// Re-export functions used by responses/streaming.rs and responses/non_streaming.rs
pub(crate) use tool_loop::{
build_resume_payload, execute_streaming_tool_calls, execute_tool_loop,
inject_mcp_metadata_streaming, prepare_mcp_tools_as_functions, send_mcp_list_tools_events,
ToolLoopState,
inject_mcp_metadata_streaming, mcp_list_tools_bindings_to_emit, prepare_mcp_tools_as_functions,
send_mcp_list_tools_events, ToolLoopExecutionContext, ToolLoopState,
};
Loading
Loading