Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions llm_client/src/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,19 @@ impl LLMClient for OpenAIClient {
request_builder = request_builder.temperature(request.temperature());
}

// if its o1 or o3-mini we should set reasoning_effort to high
if llm_model == &LLMType::O1 || llm_model == &LLMType::O3MiniHigh {
if let Some(reasoning_effort) = request.reasoning_effort() {
match reasoning_effort {
crate::clients::types::ReasoningEffort::Low => {
request_builder = request_builder.reasoning_effort(ReasoningEffort::Low);
}
crate::clients::types::ReasoningEffort::Medium => {
request_builder = request_builder.reasoning_effort(ReasoningEffort::Medium);
}
crate::clients::types::ReasoningEffort::High => {
request_builder = request_builder.reasoning_effort(ReasoningEffort::High);
}
}
} else if llm_model == &LLMType::O1 || llm_model == &LLMType::O3MiniHigh {
request_builder = request_builder.reasoning_effort(ReasoningEffort::High);
}

Expand Down
14 changes: 14 additions & 0 deletions llm_client/src/clients/openai_compatible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ impl LLMClient for OpenAICompatibleClient {
.messages(messages)
.temperature(request.temperature())
.stream(true);

if let Some(reasoning_effort) = request.reasoning_effort() {
match reasoning_effort {
crate::clients::types::ReasoningEffort::Low => {
request_builder = request_builder.reasoning_effort(async_openai::types::ReasoningEffort::Low);
}
crate::clients::types::ReasoningEffort::Medium => {
request_builder = request_builder.reasoning_effort(async_openai::types::ReasoningEffort::Medium);
}
crate::clients::types::ReasoningEffort::High => {
request_builder = request_builder.reasoning_effort(async_openai::types::ReasoningEffort::High);
}
}
}
if let Some(frequency_penalty) = request.frequency_penalty() {
request_builder = request_builder.frequency_penalty(frequency_penalty);
}
Expand Down
18 changes: 18 additions & 0 deletions llm_client/src/clients/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,13 @@ impl LLMClientMessage {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ReasoningEffort {
Low,
Medium,
High,
}

#[derive(Clone, Debug)]
pub struct LLMClientCompletionRequest {
model: LLMType,
Expand All @@ -693,6 +700,7 @@ pub struct LLMClientCompletionRequest {
frequency_penalty: Option<f32>,
stop_words: Option<Vec<String>>,
max_tokens: Option<usize>,
reasoning_effort: Option<ReasoningEffort>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -771,6 +779,7 @@ impl LLMClientCompletionRequest {
frequency_penalty,
stop_words: None,
max_tokens: None,
reasoning_effort: None,
}
}

Expand Down Expand Up @@ -859,6 +868,15 @@ impl LLMClientCompletionRequest {
pub fn get_max_tokens(&self) -> Option<usize> {
self.max_tokens
}

pub fn set_reasoning_effort(mut self, reasoning_effort: ReasoningEffort) -> Self {
self.reasoning_effort = Some(reasoning_effort);
self
}

pub fn reasoning_effort(&self) -> Option<&ReasoningEffort> {
self.reasoning_effort.as_ref()
}
}

#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
Expand Down