Skip to content

Commit eff7f44

Browse files
authored
Support tool call in openai api server (#254)
1 parent 23442b0 commit eff7f44

3 files changed

Lines changed: 113 additions & 17 deletions

File tree

README.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ For more options, please refer to [examples/langchain_client.py](examples/langch
498498
499499
Start an API server compatible with [OpenAI chat completions protocol](https://platform.openai.com/docs/api-reference/chat):
500500
```sh
501-
MODEL=./chatglm2-ggml.bin uvicorn chatglm_cpp.openai_api:app --host 127.0.0.1 --port 8000
501+
MODEL=./chatglm3-ggml.bin uvicorn chatglm_cpp.openai_api:app --host 127.0.0.1 --port 8000
502502
```
503503
504504
Test your endpoint with `curl`:
@@ -509,17 +509,22 @@ curl http://127.0.0.1:8000/v1/chat/completions -H 'Content-Type: application/jso
509509
510510
Use the OpenAI client to chat with your model:
511511
```python
512-
>>> import openai
512+
>>> from openai import OpenAI
513513
>>>
514-
>>> openai.api_base = "http://127.0.0.1:8000/v1"
515-
>>> response = openai.ChatCompletion.create(model="default-model", messages=[{"role": "user", "content": "你好"}])
516-
>>> response["choices"][0]["message"]["content"]
517-
'你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。'
514+
>>> client = OpenAI(base_url="http://127.0.0.1:8000/v1")
515+
>>> response = client.chat.completions.create(model="default-model", messages=[{"role": "user", "content": "你好"}])
516+
>>> response.choices[0].message.content
517+
'你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。'
518518
```
519519
520520
For stream response, check out the example client script:
521521
```sh
522-
OPENAI_API_BASE=http://127.0.0.1:8000/v1 python3 examples/openai_client.py --stream --prompt 你好
522+
OPENAI_BASE_URL=http://127.0.0.1:8000/v1 python3 examples/openai_client.py --stream --prompt 你好
523+
```
524+
525+
Tool calling is also supported:
526+
```sh
527+
OPENAI_BASE_URL=http://127.0.0.1:8000/v1 python3 examples/openai_client.py --tool_call --prompt 上海天气怎么样
523528
```
524529
525530
With this API server as backend, ChatGLM.cpp models can be seamlessly integrated into any frontend that uses OpenAI-style API, including [mckaywrigley/chatbot-ui](https://github.com/mckaywrigley/chatbot-ui), [fuergaosi233/wechat-chatgpt](https://github.com/fuergaosi233/wechat-chatgpt), [Yidadaa/ChatGPT-Next-Web](https://github.com/Yidadaa/ChatGPT-Next-Web), and more.

chatglm_cpp/openai_api.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
2+
import json
23
import logging
34
import time
4-
from typing import List, Literal, Optional, Union
5+
from typing import Dict, List, Literal, Optional, Union
56

67
import chatglm_cpp
8+
import uvicorn
79
from fastapi import FastAPI, HTTPException, status
810
from fastapi.middleware.cors import CORSMiddleware
911
from pydantic import BaseModel, Field, computed_field
@@ -14,18 +16,41 @@
1416

1517

1618
class Settings(BaseSettings):
17-
model: str = "chatglm-ggml.bin"
19+
model: str = "chatglm3-ggml.bin"
1820
num_threads: int = 0
1921

2022

23+
class ToolCallFunction(BaseModel):
24+
arguments: str
25+
name: str
26+
27+
28+
class ToolCall(BaseModel):
29+
function: Optional[ToolCallFunction] = None
30+
type: Literal["function"]
31+
32+
2133
class ChatMessage(BaseModel):
2234
role: Literal["system", "user", "assistant"]
2335
content: str
36+
tool_calls: Optional[List[ToolCall]] = None
2437

2538

2639
class DeltaMessage(BaseModel):
2740
role: Optional[Literal["system", "user", "assistant"]] = None
2841
content: Optional[str] = None
42+
tool_calls: Optional[List[ToolCall]] = None
43+
44+
45+
class ChatCompletionToolFunction(BaseModel):
46+
description: Optional[str] = None
47+
name: str
48+
parameters: Dict
49+
50+
51+
class ChatCompletionTool(BaseModel):
52+
type: Literal["function"] = "function"
53+
function: ChatCompletionToolFunction
2954

3055

3156
class ChatCompletionRequest(BaseModel):
@@ -35,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
3560
top_p: float = Field(default=0.7, ge=0.0, le=1.0)
3661
stream: bool = False
3762
max_tokens: int = Field(default=2048, ge=0)
63+
tools: Optional[List[ChatCompletionTool]] = None
3864

3965
model_config = {
4066
"json_schema_extra": {"examples": [{"model": "default-model", "messages": [{"role": "user", "content": "你好"}]}]}
@@ -44,7 +70,7 @@ class ChatCompletionRequest(BaseModel):
4470
class ChatCompletionResponseChoice(BaseModel):
4571
index: int = 0
4672
message: ChatMessage
47-
finish_reason: Literal["stop", "length"] = "stop"
73+
finish_reason: Literal["stop", "length", "function_call"]
4874

4975

5076
class ChatCompletionResponseStreamChoice(BaseModel):
@@ -144,10 +170,25 @@ async def stream_chat_event_publisher(history, body):
144170

145171
@app.post("/v1/chat/completions")
146172
async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionResponse:
173+
def to_json_arguments(arguments):
174+
def tool_call(**kwargs):
175+
return kwargs
176+
177+
try:
178+
return json.dumps(eval(arguments, dict(tool_call=tool_call)))
179+
except Exception:
180+
return arguments
181+
147182
if not body.messages:
148183
raise HTTPException(status.HTTP_400_BAD_REQUEST, "empty messages")
149184

150185
messages = [chatglm_cpp.ChatMessage(role=msg.role, content=msg.content) for msg in body.messages]
186+
if body.tools:
187+
system_content = (
188+
"Answer the following questions as best as you can. You have access to the following tools:\n"
189+
+ json.dumps([tool.model_dump() for tool in body.tools], indent=4)
190+
)
191+
messages.insert(0, chatglm_cpp.ChatMessage(role="system", content=system_content))
151192

152193
if body.stream:
153194
generator = stream_chat_event_publisher(messages, body)
@@ -166,9 +207,28 @@ async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionR
166207
prompt_tokens = len(pipeline.tokenizer.encode_messages(messages, max_context_length))
167208
completion_tokens = len(pipeline.tokenizer.encode(output.content, body.max_tokens))
168209

210+
finish_reason = "stop"
211+
tool_calls = None
212+
if output.tool_calls:
213+
tool_calls = [
214+
ToolCall(
215+
type=tool_call.type,
216+
function=ToolCallFunction(
217+
name=tool_call.function.name, arguments=to_json_arguments(tool_call.function.arguments)
218+
),
219+
)
220+
for tool_call in output.tool_calls
221+
]
222+
finish_reason = "function_call"
223+
169224
return ChatCompletionResponse(
170225
object="chat.completion",
171-
choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output.content))],
226+
choices=[
227+
ChatCompletionResponseChoice(
228+
message=ChatMessage(role="assistant", content=output.content, tool_calls=tool_calls),
229+
finish_reason=finish_reason,
230+
)
231+
],
172232
usage=ChatCompletionUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
173233
)
174234

@@ -199,3 +259,7 @@ class ModelList(BaseModel):
199259
@app.get("/v1/models")
200260
async def list_models() -> ModelList:
201261
return ModelList(data=[ModelCard(id="gpt-3.5-turbo")])
262+
263+
264+
if __name__ == "__main__":
265+
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

examples/openai_client.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,46 @@
11
import argparse
22

3-
import openai
3+
from openai import OpenAI
44

55
parser = argparse.ArgumentParser()
66
parser.add_argument("--stream", action="store_true")
77
parser.add_argument("--prompt", default="你好", type=str)
8+
parser.add_argument("--tool_call", action="store_true")
89
args = parser.parse_args()
910

11+
client = OpenAI()
12+
13+
tools = None
14+
if args.tool_call:
15+
tools = [
16+
{
17+
"type": "function",
18+
"function": {
19+
"name": "get_current_weather",
20+
"description": "Get the current weather in a given location",
21+
"parameters": {
22+
"type": "object",
23+
"properties": {
24+
"location": {
25+
"type": "string",
26+
"description": "The city and state, e.g. San Francisco, CA",
27+
},
28+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
29+
},
30+
"required": ["location"],
31+
},
32+
},
33+
}
34+
]
35+
1036
messages = [{"role": "user", "content": args.prompt}]
1137
if args.stream:
12-
response = openai.ChatCompletion.create(model="default-model", messages=messages, stream=True)
38+
response = client.chat.completions.create(model="default-model", messages=messages, stream=True, tools=tools)
1339
for chunk in response:
14-
content = chunk["choices"][0]["delta"].get("content", "")
15-
print(content, end="", flush=True)
40+
content = chunk.choices[0].delta.content
41+
if content is not None:
42+
print(content, end="", flush=True)
1643
print()
1744
else:
18-
response = openai.ChatCompletion.create(model="default-model", messages=messages)
19-
print(response["choices"][0]["message"]["content"])
45+
response = client.chat.completions.create(model="default-model", messages=messages, tools=tools)
46+
print(response.choices[0].message.content)

0 commit comments

Comments
 (0)