Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def __init__(
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.tokenizer = tokenizer

# Configure deepseek models
self.is_deepseek = "deepseek-r1" in model.lower()
if self.is_deepseek:
# Set recommended temperature and strategy
temperature = 0.6 # Override temperature for deepseek
kwargs["strategy"] = kwargs.get("strategy", "deepseek")

self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
Expand All @@ -50,6 +58,10 @@ def __call__(
**kwargs: dict[str, Any],
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

# Remove response_format for Ollama models since they don't support function calling
if "ollama" in self.model and "response_format" in all_kwargs:
del all_kwargs["response_format"]

# Set top_logprobs if logprobs requested
if all_kwargs.get("logprobs", False):
Expand Down Expand Up @@ -144,14 +156,18 @@ def _update_stats(self, response: ModelResponse):
# Sometimes the model's pricing information is not available
lotus.logger.debug(f"Error updating completion cost: {e}")

def _get_top_choice(self, response: ModelResponse) -> str:
def _get_top_choice(self, response: ModelResponse | OpenAIError) -> str:
if isinstance(response, OpenAIError):
raise response
choice = response.choices[0]
assert isinstance(choice, Choices)
if choice.message.content is None:
raise ValueError(f"No content in response: {response}")
return choice.message.content

def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]:
def _get_top_choice_logprobs(self, response: ModelResponse | OpenAIError) -> list[ChatCompletionTokenLogprob]:
if isinstance(response, OpenAIError):
raise response
choice = response.choices[0]
assert isinstance(choice, Choices)
logprobs = choice.logprobs["content"]
Expand Down
47 changes: 47 additions & 0 deletions lotus/sem_ops/deepseek_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Utilities for handling deepseek model outputs with reasoning traces."""

from typing import Tuple

def extract_deepseek_reasoning(llm_answer: str) -> Tuple[str | None, str]:
"""
Extract reasoning and answer from deepseek model output.

Args:
llm_answer: Raw LLM output that may contain <think></think> tags

Returns:
Tuple of (reasoning, answer) where reasoning may be None if no think tags found
"""
think_start = llm_answer.find("<think>")
think_end = llm_answer.find("</think>")

if think_start != -1 and think_end != -1:
# Extract the reasoning from between the think tags
reasoning = llm_answer[think_start + 7:think_end].strip()
# Extract the answer from after the closing think tag
answer = llm_answer[think_end + 8:].strip()
# Return reasoning first, then answer
return reasoning, answer

# If no think tags found, treat the whole thing as the answer with no reasoning
return None, llm_answer.strip()

def format_deepseek_prompt(instruction: str) -> str:
"""
Format instruction for deepseek models following official guidelines:
- No system prompts
- Instructions in user prompt
- Enforce <think>\n start
- Temperature 0.6 (handled in LM class)

Args:
instruction: Base instruction for the task

Returns:
Modified instruction that enforces <think>\n start
"""
return (
f"{instruction}\n\n"
"Start your response with '<think>\\n' to show your reasoning, "
"then end with '</think>' and provide your final answer."
)
197 changes: 163 additions & 34 deletions lotus/sem_ops/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,86 @@
import json
import re

import lotus
from lotus.types import (
SemanticExtractPostprocessOutput,
SemanticFilterPostprocessOutput,
SemanticMapPostprocessOutput,
SemanticAggPostprocessOutput,
)
from lotus.sem_ops.deepseek_utils import extract_deepseek_reasoning

def _process_deepseek_output(llm_answer: str) -> tuple[str | None, str]:
"""Helper function to process deepseek model output."""
reasoning, answer = extract_deepseek_reasoning(llm_answer)
return reasoning, answer

def extract_json_from_text(text: str) -> dict:
"""Helper function to extract JSON from text that may contain code blocks or raw JSON."""
# Try to find JSON between curly braces
try:
start = text.find("{")
if start != -1:
end = text.rfind("}") + 1
if end > start:
json_str = text[start:end]
# Clean up any potential Python string formatting
json_str = json_str.replace("'", '"')
return json.loads(json_str)
except json.JSONDecodeError:
pass

# Try to find any JSON-like structure in the text
try:
matches = re.finditer(r'({[^{}]*})', text)
for match in matches:
try:
json_str = match.group(1).replace("'", '"')
return json.loads(json_str)
except json.JSONDecodeError:
continue
except Exception:
pass

# If no valid JSON found, return empty dict
return {}

def extract_postprocess(
llm_answers: list[str],
strategy: str | None = None
) -> SemanticExtractPostprocessOutput:
"""
Postprocess the output of the extract operator to extract the schema.

Args:
llm_answers (list[str]): The list of llm answers containing the extract.
strategy (str | None): The reasoning strategy ("deepseek" or None).

Returns:
SemanticExtractPostprocessOutput
"""
extract_data = []
for llm_answer in llm_answers:
lotus.logger.debug(f"Extract raw answer: {llm_answer}")
try:
if strategy == "deepseek":
# For deepseek models, extract the JSON from after </think>
_, answer = extract_deepseek_reasoning(llm_answer)
output = extract_json_from_text(answer)
else:
output = extract_json_from_text(llm_answer)

lotus.logger.debug(f"Parsed JSON: {output}")
except json.JSONDecodeError as e:
lotus.logger.info(f"\t Failed to parse: {llm_answer}")
lotus.logger.debug(f"JSON parse error: {e}")
output = {}

# Convert all values to strings
output = {key: str(value) for key, value in output.items()}
extract_data.append(output)

return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data)

def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
"""
Expand All @@ -29,57 +103,59 @@ def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
reasoning_idx += len("Reasoning:\n")

answer_idx = llm_answer.find("Answer:")
reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n")
answer = llm_answer[answer_idx + len("Answer:") :]
if answer_idx == -1:
# No explicit Answer: marker, treat whole thing as answer
answer = llm_answer[reasoning_idx:].strip()
reasoning = None
else:
reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n")
answer = llm_answer[answer_idx + len("Answer:"):].strip()

outputs.append(answer)
explanations.append(reasoning)

return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> SemanticMapPostprocessOutput:
def map_postprocess(
llm_answers: list[str],
strategy: str | None = None,
cot_reasoning: bool = False
) -> SemanticMapPostprocessOutput:
"""
Postprocess the output of the map operator.

Args:
llm_answers (list[str]): The list of llm answers.
cot_reasoning (bool): Whether there is CoT reasoning.
strategy (str | None): The reasoning strategy ("deepseek", "cot", or None).
cot_reasoning (bool): Whether there is CoT reasoning (deprecated, use strategy="cot" instead).

Returns:
SemanticMapPostprocessOutput
"""
if cot_reasoning:
if strategy == "deepseek":
outputs: list[str] = []
explanations: list[str | None] = []

for llm_answer in llm_answers:
lotus.logger.debug(f"Raw LLM answer: {llm_answer}")
reasoning, answer = _process_deepseek_output(llm_answer)
lotus.logger.debug(f"Extracted reasoning: {reasoning}")
lotus.logger.debug(f"Extracted answer: {answer}")
outputs.append(answer)
explanations.append(reasoning)

return SemanticMapPostprocessOutput(
raw_outputs=llm_answers,
outputs=outputs,
explanations=explanations
)
elif cot_reasoning or strategy == "cot":
return map_postprocess_cot(llm_answers)

outputs: list[str] = llm_answers
explanations: list[str | None] = [None] * len(llm_answers)
return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOutput:
"""
Postprocess the output of the extract operator to extract the schema.

Args:
llm_answers (list[str]): The list of llm answers containging the extract.

Returns:
SemanticExtractPostprocessOutput
"""
extract_data = []
for llm_answer in llm_answers:
try:
output = json.loads(llm_answer)
except json.JSONDecodeError:
lotus.logger.info(f"\t Failed to parse: {llm_answer}")
output = {}

output = {key: str(value) for key, value in output.items()}
extract_data.append(output)

return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data)


def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput:
"""
Postprocess the output of the filter operator with CoT reasoning.
Expand Down Expand Up @@ -117,10 +193,10 @@ def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFil

return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def filter_postprocess(
llm_answers: list[str],
default: bool = True,
strategy: str | None = None,
cot_reasoning: bool = False,
) -> SemanticFilterPostprocessOutput:
"""
Expand All @@ -129,12 +205,33 @@ def filter_postprocess(
Args:
llm_answers (list[str]): The list of llm answers.
default (bool): The default value to use if we fail to parse the answer.
cot_reasoning (bool): Whether there is CoT reasoning.
strategy (str | None): The reasoning strategy ("deepseek", "cot", or None).
cot_reasoning (bool): Whether there is CoT reasoning (deprecated, use strategy="cot" instead).

Returns:
SemanticFilterPostprocessOutput
"""
if cot_reasoning:
if strategy == "deepseek":
outputs: list[bool] = []
explanations: list[str | None] = []

for llm_answer in llm_answers:
reasoning, answer = _process_deepseek_output(llm_answer)
if "True" in answer:
outputs.append(True)
elif "False" in answer:
outputs.append(False)
else:
lotus.logger.info(f"\t Failed to parse: defaulting to {default}")
outputs.append(default)
explanations.append(reasoning)

return SemanticFilterPostprocessOutput(
raw_outputs=llm_answers,
outputs=outputs,
explanations=explanations
)
elif cot_reasoning or strategy == "cot":
return filter_postprocess_cot(llm_answers, default)

outputs: list[bool] = []
Expand All @@ -149,3 +246,35 @@ def filter_postprocess(
outputs.append(default)

return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)

def agg_postprocess(
llm_answers: list[str],
strategy: str | None = None,
) -> SemanticAggPostprocessOutput:
"""
Postprocess the output of the aggregate operator.

Args:
llm_answers (list[str]): The list of llm answers.
strategy (str | None): The reasoning strategy ("deepseek" or None).

Returns:
SemanticAggPostprocessOutput
"""
outputs: list[str] = []
explanations: list[str | None] = []

for llm_answer in llm_answers:
if strategy == "deepseek":
reasoning, answer = _process_deepseek_output(llm_answer)
outputs.append(answer)
explanations.append(reasoning)
else:
outputs.append(llm_answer)
explanations.append(None)

return SemanticAggPostprocessOutput(
raw_outputs=llm_answers,
outputs=outputs,
explanations=explanations
)
Loading