11import asyncio
2+ import json
23import logging
34import time
4- from typing import List , Literal , Optional , Union
5+ from typing import Dict , List , Literal , Optional , Union
56
67import chatglm_cpp
8+ import uvicorn
79from fastapi import FastAPI , HTTPException , status
810from fastapi .middleware .cors import CORSMiddleware
911from pydantic import BaseModel , Field , computed_field
1416
1517
1618class Settings (BaseSettings ):
17- model : str = "chatglm -ggml.bin"
19+ model : str = "chatglm3 -ggml.bin"
1820 num_threads : int = 0
1921
2022
23+ class ToolCallFunction (BaseModel ):
24+ arguments : str
25+ name : str
26+
27+
28+ class ToolCall (BaseModel ):
29+ function : Optional [ToolCallFunction ] = None
30+ type : Literal ["function" ]
31+
32+
2133class ChatMessage (BaseModel ):
2234 role : Literal ["system" , "user" , "assistant" ]
2335 content : str
36+ tool_calls : Optional [List [ToolCall ]] = None
2437
2538
2639class DeltaMessage (BaseModel ):
2740 role : Optional [Literal ["system" , "user" , "assistant" ]] = None
2841 content : Optional [str ] = None
42+ tool_calls : Optional [List [ToolCall ]] = None
43+
44+
45+ class ChatCompletionToolFunction (BaseModel ):
46+ description : Optional [str ] = None
47+ name : str
48+ parameters : Dict
49+
50+
51+ class ChatCompletionTool (BaseModel ):
52+ type : Literal ["function" ] = "function"
53+ function : ChatCompletionToolFunction
2954
3055
3156class ChatCompletionRequest (BaseModel ):
@@ -35,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
3560 top_p : float = Field (default = 0.7 , ge = 0.0 , le = 1.0 )
3661 stream : bool = False
3762 max_tokens : int = Field (default = 2048 , ge = 0 )
63+ tools : Optional [List [ChatCompletionTool ]] = None
3864
3965 model_config = {
4066 "json_schema_extra" : {"examples" : [{"model" : "default-model" , "messages" : [{"role" : "user" , "content" : "你好" }]}]}
@@ -44,7 +70,7 @@ class ChatCompletionRequest(BaseModel):
4470class ChatCompletionResponseChoice (BaseModel ):
4571 index : int = 0
4672 message : ChatMessage
47- finish_reason : Literal ["stop" , "length" ] = "stop"
73+ finish_reason : Literal ["stop" , "length" , "function_call" ]
4874
4975
5076class ChatCompletionResponseStreamChoice (BaseModel ):
@@ -144,10 +170,25 @@ async def stream_chat_event_publisher(history, body):
144170
145171@app .post ("/v1/chat/completions" )
146172async def create_chat_completion (body : ChatCompletionRequest ) -> ChatCompletionResponse :
173+ def to_json_arguments (arguments ):
174+ def tool_call (** kwargs ):
175+ return kwargs
176+
177+ try :
178+ return json .dumps (eval (arguments , dict (tool_call = tool_call )))
179+ except Exception :
180+ return arguments
181+
147182 if not body .messages :
148183 raise HTTPException (status .HTTP_400_BAD_REQUEST , "empty messages" )
149184
150185 messages = [chatglm_cpp .ChatMessage (role = msg .role , content = msg .content ) for msg in body .messages ]
186+ if body .tools :
187+ system_content = (
188+ "Answer the following questions as best as you can. You have access to the following tools:\n "
189+ + json .dumps ([tool .model_dump () for tool in body .tools ], indent = 4 )
190+ )
191+ messages .insert (0 , chatglm_cpp .ChatMessage (role = "system" , content = system_content ))
151192
152193 if body .stream :
153194 generator = stream_chat_event_publisher (messages , body )
@@ -166,9 +207,28 @@ async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionR
166207 prompt_tokens = len (pipeline .tokenizer .encode_messages (messages , max_context_length ))
167208 completion_tokens = len (pipeline .tokenizer .encode (output .content , body .max_tokens ))
168209
210+ finish_reason = "stop"
211+ tool_calls = None
212+ if output .tool_calls :
213+ tool_calls = [
214+ ToolCall (
215+ type = tool_call .type ,
216+ function = ToolCallFunction (
217+ name = tool_call .function .name , arguments = to_json_arguments (tool_call .function .arguments )
218+ ),
219+ )
220+ for tool_call in output .tool_calls
221+ ]
222+ finish_reason = "function_call"
223+
169224 return ChatCompletionResponse (
170225 object = "chat.completion" ,
171- choices = [ChatCompletionResponseChoice (message = ChatMessage (role = "assistant" , content = output .content ))],
226+ choices = [
227+ ChatCompletionResponseChoice (
228+ message = ChatMessage (role = "assistant" , content = output .content , tool_calls = tool_calls ),
229+ finish_reason = finish_reason ,
230+ )
231+ ],
172232 usage = ChatCompletionUsage (prompt_tokens = prompt_tokens , completion_tokens = completion_tokens ),
173233 )
174234
@@ -199,3 +259,7 @@ class ModelList(BaseModel):
199259@app .get ("/v1/models" )
200260async def list_models () -> ModelList :
201261 return ModelList (data = [ModelCard (id = "gpt-3.5-turbo" )])
262+
263+
264+ if __name__ == "__main__" :
265+ uvicorn .run (app , host = "0.0.0.0" , port = 8000 , workers = 1 )
0 commit comments