Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 95 additions & 31 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import soundfile as sf
import logging
import sys
import time
from typing import Optional, List

# Configure logging
logging.basicConfig(
Expand All @@ -24,17 +26,19 @@
model_path = "microsoft/Phi-4-multimodal-instruct"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, image_processor_kwargs={'use_fast': True})
processor = AutoProcessor.from_pretrained(
model_path, trust_remote_code=True, image_processor_kwargs={'use_fast': True})

model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
model_path,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
attn_implementation='flash_attention_2',
)
generation_config = GenerationConfig.from_pretrained(model_path)


def process_message_content(content_list):
prompt_text = ""
images = []
Expand All @@ -61,7 +65,8 @@ def process_message_content(content_list):
image = Image.open(response.raw)
image = image.convert('RGB') # Ensure RGB mode
images.append(image)
logger.info(f"Loaded image from {image_url}, size: {image.size}, mode: {image.mode}")
logger.info(
f"Loaded image from {image_url}, size: {image.size}, mode: {image.mode}")
prompt_text += f"<|image_{image_count+1}|>"
image_count += 1
elif item["type"] == "input_audio":
Expand All @@ -73,7 +78,8 @@ def process_message_content(content_list):
temp_file.flush()
audio, sample_rate = sf.read(temp_file.name)
audios.append((audio, sample_rate))
logger.info(f"Loaded audio from base64, shape: {audio.shape}, sample rate: {sample_rate}")
logger.info(
f"Loaded audio from base64, shape: {audio.shape}, sample rate: {sample_rate}")
prompt_text += f"<|audio_{audio_count+1}|>"
audio_count += 1
elif item["type"] == "audio_URL":
Expand All @@ -86,51 +92,91 @@ def process_message_content(content_list):
temp_file.flush()
audio, sample_rate = sf.read(temp_file.name)
audios.append((audio, sample_rate))
logger.info(f"Loaded audio from {url}, shape: {audio.shape}, sample rate: {sample_rate}")
logger.info(
f"Loaded audio from {url}, shape: {audio.shape}, sample rate: {sample_rate}")
prompt_text += f"<|audio_{audio_count+1}|>"
audio_count += 1
return prompt_text, images, audios


def generate_response(prompt_text, images, audios):
def generate_response(prompt_text, images, audios, max_tokens, temperature):
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'
full_prompt = f"{user_prompt}{prompt_text}{prompt_suffix}{assistant_prompt}"

logger.info(f"Prompt text: {prompt_text}")
logger.info(f"Number of images: {len(images)}")
logger.info(f"Max tokens: {max_tokens}")
if temperature is not None:
logger.info(f"Temperature: {temperature}")

if len(images) == 0:
logger.info("No images provided")
images = None
logger.info(f"Number of audios: {len(audios)}")
if len(audios) == 0:
logger.info("No audios provided")
audios = None

try:
inputs = processor(text=full_prompt, images=images, audios=audios, return_tensors='pt').to(model.device)
generate_ids = model.generate(
inputs = processor(text=full_prompt, images=images,
audios=audios, return_tensors='pt').to(model.device)
input_ids = inputs['input_ids']
input_token_count = input_ids.shape[1]

# Configure generation parameters
generation_kwargs = {
**inputs,
max_new_tokens=1000,
generation_config=generation_config,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
'max_new_tokens': max_tokens,
'generation_config': generation_config,
'return_dict_in_generate': True,
'output_scores': True
}

# Add temperatures if specified
if temperature is not None:
if temperature <= 0.0:
generation_kwargs['do_sample'] = False
else:
generation_kwargs['temperature'] = temperature
# Make sure that do_sample=True is included in the generate arguments otherwise the temperature value won't have any effect
generation_kwargs['do_sample'] = True

generation_output = model.generate(**generation_kwargs)

generated_ids = generation_output.sequences[:, input_ids.shape[1]:]
completion_token_count = generated_ids.shape[1]

response_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return response

return {
'text': response_text,
'prompt_tokens': input_token_count,
'completion_tokens': completion_token_count
}
except Exception as e:
logger.error(f"Failed to generate response: {e}")
raise


class ChatCompletionRequest(BaseModel):
model: str
messages: list[dict]
max_completion_tokens: int
messages: List[dict]
max_completion_tokens: Optional[int] = 1000
temperature: Optional[float] = None


class ChatCompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: list[dict]
usage: dict


@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
Expand All @@ -141,27 +187,45 @@ async def chat_completions(request: ChatCompletionRequest):
user_message = message
if not user_message:
return {"error": "No user message found"}

content_list = user_message["content"]
prompt_text, images, audios = process_message_content(content_list)

response_text = generate_response(prompt_text, images, audios)

# Format the response

result = generate_response(
prompt_text,
images,
audios,
max_tokens=request.max_completion_tokens,
temperature=request.temperature
)

# Create timestamp and unique ID
timestamp = int(time.time())
completion_id = f"chatcmpl-{timestamp}"

response = {
"id": completion_id,
"object": "chat.completion",
"created": timestamp,
"model": request.model,
"choices": [
{
"message": {
"role": "assistant",
"content": response_text
"content": result['text']
},
"finish_reason": "stop"
}
]
],
"usage": {
"prompt_tokens": result['prompt_tokens'],
"completion_tokens": result['completion_tokens'],
"total_tokens": result['prompt_tokens'] + result['completion_tokens']
}
}

return response


@app.get("/health")
async def health():
return {"status": "OK"}
return {"status": "OK"}