From 902c87fb920b87dbb081efd5fdecde887c8679ea Mon Sep 17 00:00:00 2001 From: hanjun Date: Fri, 13 Mar 2026 20:07:46 +0900 Subject: [PATCH 1/2] feat(gpt-oss): implement tool_choice=required via EBNF grammar Use xgrammar EBNF grammar to enforce tool calls for Harmony models instead of the previous LogitsProcessor approach. This avoids dependency on model runner internals (which are being refactored for v2) and preserves the model's reasoning ability. The grammar allows analysis/commentary channels while blocking the final channel entirely, requiring at least one tool call. Changes: - Add adjust_request() to OpenAIToolParser with EBNF grammar builder for tool_choice=required - Use raw_decode() for JSON extraction to handle trailing structural tokens from partial Harmony parsing - Add error handling in parse_output_into_messages for grammar-constrained token sequences - Add comprehensive unit tests (59 cases) for EBNF grammar acceptance/blocking via xgrammar - Add E2E test script (15 scenarios) for live server testing Signed-off-by: hanjun --- tests/tool_parsers/e2e_gptoss_tool_choice.py | 361 +++++++++++++++++ .../test_openai_tool_parser_ebnf.py | 379 ++++++++++++++++++ .../openai/parser/harmony_utils.py | 14 +- vllm/tool_parsers/openai_tool_parser.py | 63 ++- 4 files changed, 813 insertions(+), 4 deletions(-) create mode 100644 tests/tool_parsers/e2e_gptoss_tool_choice.py create mode 100644 tests/tool_parsers/test_openai_tool_parser_ebnf.py diff --git a/tests/tool_parsers/e2e_gptoss_tool_choice.py b/tests/tool_parsers/e2e_gptoss_tool_choice.py new file mode 100644 index 000000000000..c6f268edb097 --- /dev/null +++ b/tests/tool_parsers/e2e_gptoss_tool_choice.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""E2E test for GPT-OSS tool_choice=required. + +Usage: + vllm serve --tool-parser-plugin openai --enable-auto-tool-choice + python tests/tool_parsers/e2e_gptoss_tool_choice.py [-v] [--scenario NAME] +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path + +from openai import OpenAI + +logger = logging.getLogger("e2e_tool_choice") + +# --------------------------------------------------------------------------- +# Tool definitions +# --------------------------------------------------------------------------- + +TOOL_GET_WEATHER = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, +} + +TOOL_SEARCH = { + "type": "function", + "function": { + "name": "search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, +} + +TOOL_CALCULATE = { + "type": "function", + "function": { + "name": "calculate", + "description": "Evaluate a math expression", + "parameters": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + }, +} + +TOOL_DATABASE_QUERY = { + "type": "function", + "function": { + "name": "database_query", + "description": "Execute a database query", + "parameters": { + "type": "object", + "properties": { + "sql": {"type": "string"}, + "database": {"type": "string"}, + }, + "required": ["sql", "database"], + }, + }, +} + +TOOL_GET_TIME = { + "type": "function", + "function": { + "name": "get_current_time", + "description": "Get the current time", + "parameters": {"type": "object", "properties": {}}, + }, +} + +# --------------------------------------------------------------------------- +# Scenarios +# --------------------------------------------------------------------------- + + +@dataclass +class TestScenario: + name: str + messages: list[dict] + tools: list[dict] + expected_tool_names: list[str] | None = None + min_tool_calls: int = 1 + + +SCENARIOS: list[TestScenario] = [ + TestScenario( + name="simple_weather", + messages=[{"role": "user", "content": "What's the weather in Tokyo?"}], + tools=[TOOL_GET_WEATHER], + expected_tool_names=["get_weather"], + ), + TestScenario( + name="select_from_multiple", + messages=[{"role": "user", "content": "What is the weather in Seoul?"}], + tools=[TOOL_GET_WEATHER, TOOL_SEARCH, TOOL_CALCULATE], + expected_tool_names=["get_weather"], + ), + TestScenario( + name="nested_json_args", + messages=[ + { + "role": "user", + "content": "Query the users database: " + "SELECT * FROM users WHERE age > 18 AND status = 'active'", + } + ], + tools=[TOOL_DATABASE_QUERY], + expected_tool_names=["database_query"], + ), + TestScenario( + name="multi_turn_with_tool_result", + messages=[ + {"role": "user", "content": "What's the weather in Paris?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_001", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_001", + "content": '{"temperature": 18, "condition": "cloudy"}', + }, + {"role": "user", "content": "Now search for indoor activities."}, + ], + tools=[TOOL_GET_WEATHER, TOOL_SEARCH], + expected_tool_names=["search"], + ), + TestScenario( + name="special_chars", + messages=[ + { + "role": "user", + "content": "My code has if x < 10 && y >= 20. Search for help.", + } + ], + tools=[TOOL_SEARCH], + expected_tool_names=["search"], + ), + TestScenario( + name="korean_unicode", + messages=[ + { + "role": "user", + "content": "서울의 현재 날씨를 알려주세요.", + } + ], + tools=[TOOL_GET_WEATHER], + expected_tool_names=["get_weather"], + ), + TestScenario( + name="no_arg_tool", + messages=[{"role": "user", "content": "What time is it right now?"}], + tools=[TOOL_GET_TIME], + expected_tool_names=["get_current_time"], + ), +] + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + + +@dataclass +class TestResult: + scenario: str + passed: bool + tool_calls: list[dict] = field(default_factory=list) + error: str | None = None + latency_ms: float = 0.0 + + +def _validate_tool_args(tool_call: dict, tool_defs: list[dict]) -> str | None: + name = tool_call["name"] + try: + args = json.loads(tool_call["arguments"]) + except json.JSONDecodeError: + return f"{name}: invalid JSON" + tool_def = next((t for t in tool_defs if t["function"]["name"] == name), None) + if tool_def is None: + return f"{name}: not in tool definitions" + for req in tool_def["function"].get("parameters", {}).get("required", []): + if req not in args: + return f"{name}: missing required field '{req}'" + return None + + +def _detect_model(client: OpenAI) -> str: + models = [m.id for m in client.models.list().data] + if len(models) == 1: + return models[0] + logger.error("Expected 1 model, got %s", models) + sys.exit(1) + + +def run_scenario( + client: OpenAI, + model: str, + scenario: TestScenario, + verbose: bool, +) -> TestResult: + logger.info("--- [%s] ---", scenario.name) + if verbose: + logger.debug( + "Request:\n%s", json.dumps(scenario.messages, indent=2, ensure_ascii=False) + ) + + t0 = time.monotonic() + try: + response = client.chat.completions.create( + model=model, + messages=scenario.messages, + tools=scenario.tools, + tool_choice="required", + temperature=0, + max_tokens=4096, + ) + except Exception as e: + return TestResult(scenario=scenario.name, passed=False, error=str(e)) + latency = (time.monotonic() - t0) * 1000 + + choice = response.choices[0] + msg = choice.message + if verbose: + logger.debug( + "Response:\n%s", + json.dumps(response.model_dump(), indent=2, ensure_ascii=False), + ) + + tc_data = [] + if msg.tool_calls: + for tc in msg.tool_calls: + tc_data.append( + { + "name": tc.function.name, + "arguments": tc.function.arguments, + } + ) + logger.info(" Tool: %s(%s)", tc.function.name, tc.function.arguments) + + passed = True + errors: list[str] = [] + + if len(tc_data) < scenario.min_tool_calls: + passed = False + errors.append( + f"Expected >= {scenario.min_tool_calls} calls, got {len(tc_data)}" + ) + + if scenario.expected_tool_names and tc_data: + actual = {tc["name"] for tc in tc_data} + if not actual & set(scenario.expected_tool_names): + passed = False + errors.append(f"Expected {scenario.expected_tool_names}, got {actual}") + + if choice.finish_reason in ("error", "length"): + passed = False + errors.append(f"finish_reason={choice.finish_reason}") + + for tc in tc_data: + err = _validate_tool_args(tc, scenario.tools) + if err: + passed = False + errors.append(err) + + logger.info(" %s (%.0fms)", "PASS" if passed else "FAIL", latency) + if not passed: + logger.error(" Error: %s", "; ".join(errors)) + + return TestResult( + scenario=scenario.name, + passed=passed, + tool_calls=tc_data, + error="; ".join(errors) if errors else None, + latency_ms=latency, + ) + + +def main() -> None: + ap = argparse.ArgumentParser(description="E2E test: GPT-OSS tool_choice=required") + ap.add_argument("--base-url", default="http://localhost:8000/v1") + ap.add_argument("--model", default=None) + ap.add_argument("--api-key", default="EMPTY") + ap.add_argument("-v", "--verbose", action="store_true") + ap.add_argument("--log-dir", default=None) + ap.add_argument("--scenario", default=None) + args = ap.parse_args() + + handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)] + if args.log_dir: + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + handlers.append( + logging.FileHandler( + Path(args.log_dir) / f"e2e_{time.strftime('%Y%m%d_%H%M%S')}.log" + ) + ) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)-7s %(message)s", + handlers=handlers, + ) + + client = OpenAI(base_url=args.base_url, api_key=args.api_key) + model = args.model or _detect_model(client) + logger.info("Model: %s", model) + + scenarios = SCENARIOS + if args.scenario: + scenarios = [s for s in SCENARIOS if s.name == args.scenario] + if not scenarios: + logger.error("Scenario '%s' not found", args.scenario) + sys.exit(1) + + results = [run_scenario(client, model, s, args.verbose) for s in scenarios] + passed = sum(1 for r in results if r.passed) + logger.info( + "\n%s\nSUMMARY: %d/%d passed\n%s", "=" * 40, passed, len(results), "=" * 40 + ) + for r in results: + logger.info( + " %-30s %s", r.scenario, "PASS" if r.passed else f"FAIL: {r.error}" + ) + sys.exit(0 if all(r.passed for r in results) else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/tool_parsers/test_openai_tool_parser_ebnf.py b/tests/tool_parsers/test_openai_tool_parser_ebnf.py new file mode 100644 index 000000000000..84d0591ea59a --- /dev/null +++ b/tests/tool_parsers/test_openai_tool_parser_ebnf.py @@ -0,0 +1,379 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for OpenAIToolParser EBNF grammar and xgrammar validation.""" + +import pytest + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + FunctionDefinition, +) +from vllm.sampling_params import StructuredOutputsParams +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.openai_tool_parser import OpenAIToolParser + +MODEL = "gpt2" + + +@pytest.fixture(scope="module") +def tokenizer(): + return get_tokenizer(MODEL) + + +@pytest.fixture +def parser(tokenizer): + return OpenAIToolParser(tokenizer) + + +# --------------------------------------------------------------------------- +# Grammar generation +# --------------------------------------------------------------------------- + + +def test_build_grammar_single_tool(parser: OpenAIToolParser) -> None: + grammar = parser._build_tool_required_grammar(["get_weather"]) + assert '"functions.get_weather"' in grammar + assert "root ::=" in grammar + assert "tool_block" in grammar + + +def test_build_grammar_multiple_tools(parser: OpenAIToolParser) -> None: + grammar = parser._build_tool_required_grammar( + ["get_weather", "search", "calculate"] + ) + assert ( + '"functions.get_weather" | "functions.search" | "functions.calculate"' + in grammar + ) + + +def test_build_grammar_no_final_channel(parser: OpenAIToolParser) -> None: + grammar = parser._build_tool_required_grammar(["f"]) + assert '"final"' not in grammar + + +def test_build_grammar_rejects_tool_name_with_quotes( + parser: OpenAIToolParser, +) -> None: + with pytest.raises(ValueError, match="invalid for EBNF grammar"): + parser._build_tool_required_grammar(['get"weather']) + + +def test_build_grammar_rejects_tool_name_with_newlines( + parser: OpenAIToolParser, +) -> None: + with pytest.raises(ValueError, match="invalid for EBNF grammar"): + parser._build_tool_required_grammar(["get\nweather"]) + + +# --------------------------------------------------------------------------- +# adjust_request +# --------------------------------------------------------------------------- + + +def _make_tools(*names: str) -> list[ChatCompletionToolsParam]: + return [ + ChatCompletionToolsParam( + type="function", + function=FunctionDefinition( + name=name, + description=f"Tool {name}", + parameters={"type": "object", "properties": {"q": {"type": "string"}}}, + ), + ) + for name in names + ] + + +def test_adjust_request_required(parser: OpenAIToolParser) -> None: + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "hi"}], + tools=_make_tools("get_weather", "search"), + tool_choice="required", + ) + result = parser.adjust_request(request) + assert isinstance(result.structured_outputs, StructuredOutputsParams) + assert result.structured_outputs.grammar is not None + assert '"functions.get_weather"' in result.structured_outputs.grammar + assert result.response_format is None + + +def test_adjust_request_auto_unchanged(parser: OpenAIToolParser) -> None: + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "hi"}], + tools=_make_tools("f"), + tool_choice="auto", + ) + assert parser.adjust_request(request).structured_outputs is None + + +def test_adjust_request_no_tools_unchanged(parser: OpenAIToolParser) -> None: + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "hi"}], + ) + assert parser.adjust_request(request).structured_outputs is None + + +# --------------------------------------------------------------------------- +# xgrammar validation (require xgrammar installed) +# --------------------------------------------------------------------------- + +xgrammar = pytest.importorskip("xgrammar") + +VOCAB = [ + "<|end|>", # 0 + "<|start|>", # 1 + "<|channel|>", # 2 + "<|message|>", # 3 + "<|return|>", # 4 + "<|call|>", # 5 + "assistant", # 6 + "analysis", # 7 + "commentary", # 8 + "final", # 9 + " to=", # 10 + "functions.", # 11 + "get_weather", # 12 + "search", # 13 + "I need", # 14 + " to check", # 15 + "{", # 16 + "}", # 17 + '"', # 18 + "location", # 19 + ":", # 20 + "Tokyo", # 21 + "<|eos|>", # 22 + "Let me", # 23 + " call", # 24 + " < ", # 25 — comparison operator + "hello", # 26 +] +V = {s: i for i, s in enumerate(VOCAB)} + + +@pytest.fixture(scope="module") +def xgr_compiler(): + tokenizer_info = xgrammar.TokenizerInfo( + encoded_vocab=VOCAB, + vocab_type=xgrammar.VocabType.RAW, + vocab_size=len(VOCAB), + stop_token_ids=[V["<|eos|>"]], + ) + return xgrammar.GrammarCompiler(tokenizer_info) + + +def _compile_and_run(compiler, tool_names, token_ids) -> bool: + grammar = OpenAIToolParser._build_tool_required_grammar(tool_names) + ctx = compiler.compile_grammar(grammar) + matcher = xgrammar.GrammarMatcher(ctx) + return all(matcher.accept_token(tid) for tid in token_ids) + + +def _bitmask_allowed(bitmask, token_id: int) -> bool: + return bool(bitmask[0, token_id // 32].item() & (1 << (token_id % 32))) + + +# -- Acceptance -- + + +class TestXgrammarAcceptance: + def test_direct_tool_call(self, xgr_compiler) -> None: + seq = [ + V["commentary"], + V[" to="], + V["functions."], + V["get_weather"], + V["<|message|>"], + V["{"], + V['"'], + V["location"], + V['"'], + V[":"], + V['"'], + V["Tokyo"], + V['"'], + V["}"], + V["<|end|>"], + V["<|call|>"], + ] + assert _compile_and_run(xgr_compiler, ["get_weather"], seq) + + def test_analysis_then_tool_call(self, xgr_compiler) -> None: + seq = [ + V["analysis"], + V["<|message|>"], + V["I need"], + V[" to check"], + V["<|end|>"], + V["<|start|>"], + V["assistant"], + V["<|channel|>"], + V["commentary"], + V[" to="], + V["functions."], + V["get_weather"], + V["<|message|>"], + V["{"], + V["}"], + V["<|end|>"], + V["<|call|>"], + ] + assert _compile_and_run(xgr_compiler, ["get_weather"], seq) + + def test_preamble_then_tool_call(self, xgr_compiler) -> None: + seq = [ + V["commentary"], + V["<|message|>"], + V["Let me"], + V[" call"], + V["<|end|>"], + V["<|start|>"], + V["assistant"], + V["<|channel|>"], + V["commentary"], + V[" to="], + V["functions."], + V["get_weather"], + V["<|message|>"], + V["{"], + V["}"], + V["<|end|>"], + V["<|call|>"], + ] + assert _compile_and_run(xgr_compiler, ["get_weather", "search"], seq) + + def test_two_tool_calls(self, xgr_compiler) -> None: + seq = [ + V["commentary"], + V[" to="], + V["functions."], + V["get_weather"], + V["<|message|>"], + V["{"], + V["}"], + V["<|end|>"], + V["<|call|>"], + V["<|start|>"], + V["assistant"], + V["<|channel|>"], + V["commentary"], + V[" to="], + V["functions."], + V["search"], + V["<|message|>"], + V["{"], + V["}"], + V["<|end|>"], + V["<|call|>"], + ] + assert _compile_and_run(xgr_compiler, ["get_weather", "search"], seq) + + def test_content_with_lt_operator(self, xgr_compiler) -> None: + seq = [ + V["analysis"], + V["<|message|>"], + V["hello"], + V[" < "], + V["hello"], + V["<|end|>"], + V["<|start|>"], + V["assistant"], + V["<|channel|>"], + V["commentary"], + V[" to="], + V["functions."], + V["get_weather"], + V["<|message|>"], + V["{"], + V["}"], + V["<|end|>"], + V["<|call|>"], + ] + assert _compile_and_run(xgr_compiler, ["get_weather"], seq) + + +# -- Blocking -- + + +class TestXgrammarBlocking: + def test_final_channel_blocked(self, xgr_compiler) -> None: + grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) + ctx = xgr_compiler.compile_grammar(grammar) + matcher = xgrammar.GrammarMatcher(ctx) + assert not matcher.accept_token(V["final"]) + + def test_return_token_blocked(self, xgr_compiler) -> None: + grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) + ctx = xgr_compiler.compile_grammar(grammar) + matcher = xgrammar.GrammarMatcher(ctx) + assert not matcher.accept_token(V["<|return|>"]) + + def test_wrong_function_name_blocked(self, xgr_compiler) -> None: + grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) + ctx = xgr_compiler.compile_grammar(grammar) + matcher = xgrammar.GrammarMatcher(ctx) + for tid in [V["commentary"], V[" to="], V["functions."]]: + assert matcher.accept_token(tid) + assert not matcher.accept_token(V["search"]) + + +# -- Termination -- + + +class TestXgrammarTermination: + def test_eos_blocked_before_tool_call(self, xgr_compiler) -> None: + grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) + ctx = xgr_compiler.compile_grammar(grammar) + matcher = xgrammar.GrammarMatcher(ctx) + for tid in [V["analysis"], V["<|message|>"], V["hello"], V["<|end|>"]]: + assert matcher.accept_token(tid) + bitmask = xgrammar.allocate_token_bitmask(1, len(VOCAB)) + matcher.fill_next_token_bitmask(bitmask, 0) + assert not _bitmask_allowed(bitmask, V["<|eos|>"]) + + def test_eos_allowed_after_tool_call(self, xgr_compiler) -> None: + grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) + ctx = xgr_compiler.compile_grammar(grammar) + matcher = xgrammar.GrammarMatcher(ctx) + seq = [ + V["commentary"], + V[" to="], + V["functions."], + V["get_weather"], + V["<|message|>"], + V["{"], + V["}"], + V["<|end|>"], + V["<|call|>"], + ] + for tid in seq: + assert matcher.accept_token(tid) + bitmask = xgrammar.allocate_token_bitmask(1, len(VOCAB)) + matcher.fill_next_token_bitmask(bitmask, 0) + assert _bitmask_allowed(bitmask, V["<|eos|>"]) + + def test_channel_bitmask(self, xgr_compiler) -> None: + """After <|channel|>, only analysis/commentary are allowed.""" + grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) + ctx = xgr_compiler.compile_grammar(grammar) + matcher = xgrammar.GrammarMatcher(ctx) + for tid in [ + V["analysis"], + V["<|message|>"], + V["hello"], + V["<|end|>"], + V["<|start|>"], + V["assistant"], + V["<|channel|>"], + ]: + assert matcher.accept_token(tid) + bitmask = xgrammar.allocate_token_bitmask(1, len(VOCAB)) + matcher.fill_next_token_bitmask(bitmask, 0) + assert _bitmask_allowed(bitmask, V["analysis"]) + assert _bitmask_allowed(bitmask, V["commentary"]) + assert not _bitmask_allowed(bitmask, V["final"]) diff --git a/vllm/entrypoints/openai/parser/harmony_utils.py b/vllm/entrypoints/openai/parser/harmony_utils.py index 9b4264456c51..c61139b58089 100644 --- a/vllm/entrypoints/openai/parser/harmony_utils.py +++ b/vllm/entrypoints/openai/parser/harmony_utils.py @@ -335,7 +335,19 @@ def get_streamable_parser_for_assistant() -> StreamableParser: def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: parser = get_streamable_parser_for_assistant() for token_id in token_ids: - parser.process(token_id) + try: + parser.process(token_id) + except (ValueError, RuntimeError): + # Grammar-constrained output (e.g. tool_choice=required EBNF) + # may produce token sequences that the Harmony parser cannot + # fully handle (e.g. <|call|> after <|end|>). Return the + # partial parse so callers can still extract messages parsed + # before the error. + logger.warning( + "HarmonyError while parsing token %d, returning partial parse results.", + token_id, + ) + break return parser diff --git a/vllm/tool_parsers/openai_tool_parser.py b/vllm/tool_parsers/openai_tool_parser.py index 76f7a49dfaea..b567d0968caa 100644 --- a/vllm/tool_parsers/openai_tool_parser.py +++ b/vllm/tool_parsers/openai_tool_parser.py @@ -15,6 +15,7 @@ ) from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages from vllm.logger import init_logger +from vllm.sampling_params import StructuredOutputsParams from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) @@ -28,9 +29,62 @@ class OpenAIToolParser(ToolParser): + """ + Tool parser for GPT-OSS Harmony models. + + Supports tool_choice="required" via EBNF grammar that constrains + generation to analysis/commentary channels, blocking the final channel. + """ + def __init__(self, tokenizer: "TokenizerLike"): super().__init__(tokenizer) + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if not request.tools or request.tool_choice != "required": + return super().adjust_request(request) + + tool_names = [t.function.name for t in request.tools] + grammar = self._build_tool_required_grammar(tool_names) + request.structured_outputs = StructuredOutputsParams(grammar=grammar) + request.response_format = None + logger.debug( + "GPT-OSS tool_choice=required: using EBNF grammar with %d tools", + len(tool_names), + ) + return request + + @staticmethod + def _build_tool_required_grammar(tool_names: list[str]) -> str: + """Build EBNF grammar that enforces tool calls for Harmony format. + + The grammar: + - Allows analysis blocks (multi-round reasoning) + - Allows commentary preambles + - Requires at least one tool call (commentary to=functions.X) + - Blocks the final channel entirely (not defined in grammar) + + Content rule uses ([^<] | "<" [^|])* to allow '<' in text + while blocking Harmony special tokens (<|...|>). + """ + for n in tool_names: + if '"' in n or "\n" in n: + raise ValueError( + f"Tool name {n!r} contains characters invalid for EBNF grammar" + ) + func_alts = " | ".join(f'"functions.{n}"' for n in tool_names) + return ( + "root ::= non_tool_block* tool_block more_tool*\n" + 'non_tool_block ::= ("analysis" | "commentary")' + ' "<|message|>" content "<|end|>"' + ' "<|start|>" "assistant" "<|channel|>"\n' + 'tool_block ::= "commentary to=" func_name' + ' "<|message|>" content "<|end|>" "<|call|>"\n' + 'more_tool ::= "<|start|>" "assistant" "<|channel|>"' + " non_tool_block* tool_block\n" + f"func_name ::= {func_alts}\n" + 'content ::= ([^<] | "<" [^|])*' + ) + def extract_tool_calls( self, model_output: str, @@ -57,10 +111,13 @@ def extract_tool_calls( # most common case with gpt-oss models. if not msg.content_type or "json" in msg.content_type: # load and dump the JSON text to check validity and - # remove any extra newlines or other odd formatting + # remove any extra newlines or other odd formatting. + # Use raw_decode to handle trailing garbage from + # partial Harmony parsing (e.g. structural tokens). try: - tool_args = json.dumps(json.loads(msg_text)) - except json.JSONDecodeError: + obj, _ = json.JSONDecoder().raw_decode(msg_text) + tool_args = json.dumps(obj) + except (json.JSONDecodeError, ValueError): logger.exception( "Error decoding JSON tool call from response." ) From 87f3f8d1b8523ce714e494e7fe55a7bd54854226 Mon Sep 17 00:00:00 2001 From: hanjun Date: Sat, 14 Mar 2026 00:45:05 +0900 Subject: [PATCH 2/2] fix(gpt-oss): call adjust_request in Harmony render path The Harmony code path in render_chat() uses _make_request_with_harmony() instead of _preprocess_chat(), which bypassed tool_parser.adjust_request(). This meant the EBNF grammar for tool_choice="required" was never applied, leaving generation unconstrained (~85% tool call rate instead of 100%). Add adjust_request() call after _make_request_with_harmony() so grammar constraints are applied for GPT-OSS models. Also clean up debug logging and consolidate redundant test cases. Verified: 100/100 requests produce tool calls on gpt-oss-120b. Signed-off-by: hanjun --- tests/tool_parsers/e2e_gptoss_tool_choice.py | 361 ------------------ .../test_openai_tool_parser_ebnf.py | 92 +---- vllm/entrypoints/serve/render/serving.py | 9 + vllm/tool_parsers/openai_tool_parser.py | 6 +- 4 files changed, 21 insertions(+), 447 deletions(-) delete mode 100644 tests/tool_parsers/e2e_gptoss_tool_choice.py diff --git a/tests/tool_parsers/e2e_gptoss_tool_choice.py b/tests/tool_parsers/e2e_gptoss_tool_choice.py deleted file mode 100644 index c6f268edb097..000000000000 --- a/tests/tool_parsers/e2e_gptoss_tool_choice.py +++ /dev/null @@ -1,361 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""E2E test for GPT-OSS tool_choice=required. - -Usage: - vllm serve --tool-parser-plugin openai --enable-auto-tool-choice - python tests/tool_parsers/e2e_gptoss_tool_choice.py [-v] [--scenario NAME] -""" - -from __future__ import annotations - -import argparse -import json -import logging -import sys -import time -from dataclasses import dataclass, field -from pathlib import Path - -from openai import OpenAI - -logger = logging.getLogger("e2e_tool_choice") - -# --------------------------------------------------------------------------- -# Tool definitions -# --------------------------------------------------------------------------- - -TOOL_GET_WEATHER = { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get current weather for a location", - "parameters": { - "type": "object", - "properties": { - "location": {"type": "string"}, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, -} - -TOOL_SEARCH = { - "type": "function", - "function": { - "name": "search", - "description": "Search the web", - "parameters": { - "type": "object", - "properties": {"query": {"type": "string"}}, - "required": ["query"], - }, - }, -} - -TOOL_CALCULATE = { - "type": "function", - "function": { - "name": "calculate", - "description": "Evaluate a math expression", - "parameters": { - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, - }, -} - -TOOL_DATABASE_QUERY = { - "type": "function", - "function": { - "name": "database_query", - "description": "Execute a database query", - "parameters": { - "type": "object", - "properties": { - "sql": {"type": "string"}, - "database": {"type": "string"}, - }, - "required": ["sql", "database"], - }, - }, -} - -TOOL_GET_TIME = { - "type": "function", - "function": { - "name": "get_current_time", - "description": "Get the current time", - "parameters": {"type": "object", "properties": {}}, - }, -} - -# --------------------------------------------------------------------------- -# Scenarios -# --------------------------------------------------------------------------- - - -@dataclass -class TestScenario: - name: str - messages: list[dict] - tools: list[dict] - expected_tool_names: list[str] | None = None - min_tool_calls: int = 1 - - -SCENARIOS: list[TestScenario] = [ - TestScenario( - name="simple_weather", - messages=[{"role": "user", "content": "What's the weather in Tokyo?"}], - tools=[TOOL_GET_WEATHER], - expected_tool_names=["get_weather"], - ), - TestScenario( - name="select_from_multiple", - messages=[{"role": "user", "content": "What is the weather in Seoul?"}], - tools=[TOOL_GET_WEATHER, TOOL_SEARCH, TOOL_CALCULATE], - expected_tool_names=["get_weather"], - ), - TestScenario( - name="nested_json_args", - messages=[ - { - "role": "user", - "content": "Query the users database: " - "SELECT * FROM users WHERE age > 18 AND status = 'active'", - } - ], - tools=[TOOL_DATABASE_QUERY], - expected_tool_names=["database_query"], - ), - TestScenario( - name="multi_turn_with_tool_result", - messages=[ - {"role": "user", "content": "What's the weather in Paris?"}, - { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_001", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"location": "Paris"}', - }, - } - ], - }, - { - "role": "tool", - "tool_call_id": "call_001", - "content": '{"temperature": 18, "condition": "cloudy"}', - }, - {"role": "user", "content": "Now search for indoor activities."}, - ], - tools=[TOOL_GET_WEATHER, TOOL_SEARCH], - expected_tool_names=["search"], - ), - TestScenario( - name="special_chars", - messages=[ - { - "role": "user", - "content": "My code has if x < 10 && y >= 20. Search for help.", - } - ], - tools=[TOOL_SEARCH], - expected_tool_names=["search"], - ), - TestScenario( - name="korean_unicode", - messages=[ - { - "role": "user", - "content": "서울의 현재 날씨를 알려주세요.", - } - ], - tools=[TOOL_GET_WEATHER], - expected_tool_names=["get_weather"], - ), - TestScenario( - name="no_arg_tool", - messages=[{"role": "user", "content": "What time is it right now?"}], - tools=[TOOL_GET_TIME], - expected_tool_names=["get_current_time"], - ), -] - -# --------------------------------------------------------------------------- -# Runner -# --------------------------------------------------------------------------- - - -@dataclass -class TestResult: - scenario: str - passed: bool - tool_calls: list[dict] = field(default_factory=list) - error: str | None = None - latency_ms: float = 0.0 - - -def _validate_tool_args(tool_call: dict, tool_defs: list[dict]) -> str | None: - name = tool_call["name"] - try: - args = json.loads(tool_call["arguments"]) - except json.JSONDecodeError: - return f"{name}: invalid JSON" - tool_def = next((t for t in tool_defs if t["function"]["name"] == name), None) - if tool_def is None: - return f"{name}: not in tool definitions" - for req in tool_def["function"].get("parameters", {}).get("required", []): - if req not in args: - return f"{name}: missing required field '{req}'" - return None - - -def _detect_model(client: OpenAI) -> str: - models = [m.id for m in client.models.list().data] - if len(models) == 1: - return models[0] - logger.error("Expected 1 model, got %s", models) - sys.exit(1) - - -def run_scenario( - client: OpenAI, - model: str, - scenario: TestScenario, - verbose: bool, -) -> TestResult: - logger.info("--- [%s] ---", scenario.name) - if verbose: - logger.debug( - "Request:\n%s", json.dumps(scenario.messages, indent=2, ensure_ascii=False) - ) - - t0 = time.monotonic() - try: - response = client.chat.completions.create( - model=model, - messages=scenario.messages, - tools=scenario.tools, - tool_choice="required", - temperature=0, - max_tokens=4096, - ) - except Exception as e: - return TestResult(scenario=scenario.name, passed=False, error=str(e)) - latency = (time.monotonic() - t0) * 1000 - - choice = response.choices[0] - msg = choice.message - if verbose: - logger.debug( - "Response:\n%s", - json.dumps(response.model_dump(), indent=2, ensure_ascii=False), - ) - - tc_data = [] - if msg.tool_calls: - for tc in msg.tool_calls: - tc_data.append( - { - "name": tc.function.name, - "arguments": tc.function.arguments, - } - ) - logger.info(" Tool: %s(%s)", tc.function.name, tc.function.arguments) - - passed = True - errors: list[str] = [] - - if len(tc_data) < scenario.min_tool_calls: - passed = False - errors.append( - f"Expected >= {scenario.min_tool_calls} calls, got {len(tc_data)}" - ) - - if scenario.expected_tool_names and tc_data: - actual = {tc["name"] for tc in tc_data} - if not actual & set(scenario.expected_tool_names): - passed = False - errors.append(f"Expected {scenario.expected_tool_names}, got {actual}") - - if choice.finish_reason in ("error", "length"): - passed = False - errors.append(f"finish_reason={choice.finish_reason}") - - for tc in tc_data: - err = _validate_tool_args(tc, scenario.tools) - if err: - passed = False - errors.append(err) - - logger.info(" %s (%.0fms)", "PASS" if passed else "FAIL", latency) - if not passed: - logger.error(" Error: %s", "; ".join(errors)) - - return TestResult( - scenario=scenario.name, - passed=passed, - tool_calls=tc_data, - error="; ".join(errors) if errors else None, - latency_ms=latency, - ) - - -def main() -> None: - ap = argparse.ArgumentParser(description="E2E test: GPT-OSS tool_choice=required") - ap.add_argument("--base-url", default="http://localhost:8000/v1") - ap.add_argument("--model", default=None) - ap.add_argument("--api-key", default="EMPTY") - ap.add_argument("-v", "--verbose", action="store_true") - ap.add_argument("--log-dir", default=None) - ap.add_argument("--scenario", default=None) - args = ap.parse_args() - - handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)] - if args.log_dir: - Path(args.log_dir).mkdir(parents=True, exist_ok=True) - handlers.append( - logging.FileHandler( - Path(args.log_dir) / f"e2e_{time.strftime('%Y%m%d_%H%M%S')}.log" - ) - ) - logging.basicConfig( - level=logging.DEBUG if args.verbose else logging.INFO, - format="%(asctime)s %(levelname)-7s %(message)s", - handlers=handlers, - ) - - client = OpenAI(base_url=args.base_url, api_key=args.api_key) - model = args.model or _detect_model(client) - logger.info("Model: %s", model) - - scenarios = SCENARIOS - if args.scenario: - scenarios = [s for s in SCENARIOS if s.name == args.scenario] - if not scenarios: - logger.error("Scenario '%s' not found", args.scenario) - sys.exit(1) - - results = [run_scenario(client, model, s, args.verbose) for s in scenarios] - passed = sum(1 for r in results if r.passed) - logger.info( - "\n%s\nSUMMARY: %d/%d passed\n%s", "=" * 40, passed, len(results), "=" * 40 - ) - for r in results: - logger.info( - " %-30s %s", r.scenario, "PASS" if r.passed else f"FAIL: {r.error}" - ) - sys.exit(0 if all(r.passed for r in results) else 1) - - -if __name__ == "__main__": - main() diff --git a/tests/tool_parsers/test_openai_tool_parser_ebnf.py b/tests/tool_parsers/test_openai_tool_parser_ebnf.py index 84d0591ea59a..718246a2acb7 100644 --- a/tests/tool_parsers/test_openai_tool_parser_ebnf.py +++ b/tests/tool_parsers/test_openai_tool_parser_ebnf.py @@ -53,16 +53,11 @@ def test_build_grammar_no_final_channel(parser: OpenAIToolParser) -> None: assert '"final"' not in grammar -def test_build_grammar_rejects_tool_name_with_quotes( +def test_build_grammar_rejects_invalid_tool_names( parser: OpenAIToolParser, ) -> None: with pytest.raises(ValueError, match="invalid for EBNF grammar"): parser._build_tool_required_grammar(['get"weather']) - - -def test_build_grammar_rejects_tool_name_with_newlines( - parser: OpenAIToolParser, -) -> None: with pytest.raises(ValueError, match="invalid for EBNF grammar"): parser._build_tool_required_grammar(["get\nweather"]) @@ -100,22 +95,15 @@ def test_adjust_request_required(parser: OpenAIToolParser) -> None: assert result.response_format is None -def test_adjust_request_auto_unchanged(parser: OpenAIToolParser) -> None: - request = ChatCompletionRequest( - model="test", - messages=[{"role": "user", "content": "hi"}], - tools=_make_tools("f"), - tool_choice="auto", - ) - assert parser.adjust_request(request).structured_outputs is None - - -def test_adjust_request_no_tools_unchanged(parser: OpenAIToolParser) -> None: - request = ChatCompletionRequest( - model="test", - messages=[{"role": "user", "content": "hi"}], - ) - assert parser.adjust_request(request).structured_outputs is None +def test_adjust_request_non_required_unchanged(parser: OpenAIToolParser) -> None: + for tool_choice in ["auto", "none"]: + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "hi"}], + tools=_make_tools("f"), + tool_choice=tool_choice, + ) + assert parser.adjust_request(request).structured_outputs is None # --------------------------------------------------------------------------- @@ -150,7 +138,7 @@ def test_adjust_request_no_tools_unchanged(parser: OpenAIToolParser) -> None: "<|eos|>", # 22 "Let me", # 23 " call", # 24 - " < ", # 25 — comparison operator + " < ", # 25 "hello", # 26 ] V = {s: i for i, s in enumerate(VOCAB)} @@ -178,9 +166,6 @@ def _bitmask_allowed(bitmask, token_id: int) -> bool: return bool(bitmask[0, token_id // 32].item() & (1 << (token_id % 32))) -# -- Acceptance -- - - class TestXgrammarAcceptance: def test_direct_tool_call(self, xgr_compiler) -> None: seq = [ @@ -225,28 +210,6 @@ def test_analysis_then_tool_call(self, xgr_compiler) -> None: ] assert _compile_and_run(xgr_compiler, ["get_weather"], seq) - def test_preamble_then_tool_call(self, xgr_compiler) -> None: - seq = [ - V["commentary"], - V["<|message|>"], - V["Let me"], - V[" call"], - V["<|end|>"], - V["<|start|>"], - V["assistant"], - V["<|channel|>"], - V["commentary"], - V[" to="], - V["functions."], - V["get_weather"], - V["<|message|>"], - V["{"], - V["}"], - V["<|end|>"], - V["<|call|>"], - ] - assert _compile_and_run(xgr_compiler, ["get_weather", "search"], seq) - def test_two_tool_calls(self, xgr_compiler) -> None: seq = [ V["commentary"], @@ -297,9 +260,6 @@ def test_content_with_lt_operator(self, xgr_compiler) -> None: assert _compile_and_run(xgr_compiler, ["get_weather"], seq) -# -- Blocking -- - - class TestXgrammarBlocking: def test_final_channel_blocked(self, xgr_compiler) -> None: grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) @@ -307,12 +267,6 @@ def test_final_channel_blocked(self, xgr_compiler) -> None: matcher = xgrammar.GrammarMatcher(ctx) assert not matcher.accept_token(V["final"]) - def test_return_token_blocked(self, xgr_compiler) -> None: - grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) - ctx = xgr_compiler.compile_grammar(grammar) - matcher = xgrammar.GrammarMatcher(ctx) - assert not matcher.accept_token(V["<|return|>"]) - def test_wrong_function_name_blocked(self, xgr_compiler) -> None: grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) ctx = xgr_compiler.compile_grammar(grammar) @@ -322,9 +276,6 @@ def test_wrong_function_name_blocked(self, xgr_compiler) -> None: assert not matcher.accept_token(V["search"]) -# -- Termination -- - - class TestXgrammarTermination: def test_eos_blocked_before_tool_call(self, xgr_compiler) -> None: grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) @@ -356,24 +307,3 @@ def test_eos_allowed_after_tool_call(self, xgr_compiler) -> None: bitmask = xgrammar.allocate_token_bitmask(1, len(VOCAB)) matcher.fill_next_token_bitmask(bitmask, 0) assert _bitmask_allowed(bitmask, V["<|eos|>"]) - - def test_channel_bitmask(self, xgr_compiler) -> None: - """After <|channel|>, only analysis/commentary are allowed.""" - grammar = OpenAIToolParser._build_tool_required_grammar(["get_weather"]) - ctx = xgr_compiler.compile_grammar(grammar) - matcher = xgrammar.GrammarMatcher(ctx) - for tid in [ - V["analysis"], - V["<|message|>"], - V["hello"], - V["<|end|>"], - V["<|start|>"], - V["assistant"], - V["<|channel|>"], - ]: - assert matcher.accept_token(tid) - bitmask = xgrammar.allocate_token_bitmask(1, len(VOCAB)) - matcher.fill_next_token_bitmask(bitmask, 0) - assert _bitmask_allowed(bitmask, V["analysis"]) - assert _bitmask_allowed(bitmask, V["commentary"]) - assert not _bitmask_allowed(bitmask, V["final"]) diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 0ff737824596..9744fb54e8b6 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -178,6 +178,15 @@ async def render_chat( request, should_include_tools ) + # Apply tool parser adjust_request (e.g. EBNF grammar for + # tool_choice="required") — _make_request_with_harmony does + # not go through _preprocess_chat where this normally happens. + if tool_parser is not None: + tool_choice = getattr(request, "tool_choice", "none") + if tool_choice != "none": + assert tokenizer is not None + request = tool_parser(tokenizer).adjust_request(request=request) + return conversation, engine_prompts async def render_completion_request( diff --git a/vllm/tool_parsers/openai_tool_parser.py b/vllm/tool_parsers/openai_tool_parser.py index b567d0968caa..032a18c3c176 100644 --- a/vllm/tool_parsers/openai_tool_parser.py +++ b/vllm/tool_parsers/openai_tool_parser.py @@ -45,12 +45,8 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques tool_names = [t.function.name for t in request.tools] grammar = self._build_tool_required_grammar(tool_names) - request.structured_outputs = StructuredOutputsParams(grammar=grammar) + request.structured_outputs = StructuredOutputsParams(grammar=grammar) # type: ignore[call-arg] request.response_format = None - logger.debug( - "GPT-OSS tool_choice=required: using EBNF grammar with %d tools", - len(tool_names), - ) return request @staticmethod