diff --git a/lotus/models/lm.py b/lotus/models/lm.py
index 616dc143..6be24284 100644
--- a/lotus/models/lm.py
+++ b/lotus/models/lm.py
@@ -36,6 +36,14 @@ def __init__(
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.tokenizer = tokenizer
+
+ # Configure deepseek models
+ self.is_deepseek = "deepseek-r1" in model.lower()
+ if self.is_deepseek:
+ # Set recommended temperature and strategy
+ temperature = 0.6 # Override temperature for deepseek
+ kwargs["strategy"] = kwargs.get("strategy", "deepseek")
+
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.stats: LMStats = LMStats()
@@ -50,6 +58,10 @@ def __call__(
**kwargs: dict[str, Any],
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}
+
+ # Remove response_format for Ollama models since they don't support function calling
+ if "ollama" in self.model and "response_format" in all_kwargs:
+ del all_kwargs["response_format"]
# Set top_logprobs if logprobs requested
if all_kwargs.get("logprobs", False):
@@ -144,14 +156,18 @@ def _update_stats(self, response: ModelResponse):
# Sometimes the model's pricing information is not available
lotus.logger.debug(f"Error updating completion cost: {e}")
- def _get_top_choice(self, response: ModelResponse) -> str:
+ def _get_top_choice(self, response: ModelResponse | OpenAIError) -> str:
+ if isinstance(response, OpenAIError):
+ raise response
choice = response.choices[0]
assert isinstance(choice, Choices)
if choice.message.content is None:
raise ValueError(f"No content in response: {response}")
return choice.message.content
- def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]:
+ def _get_top_choice_logprobs(self, response: ModelResponse | OpenAIError) -> list[ChatCompletionTokenLogprob]:
+ if isinstance(response, OpenAIError):
+ raise response
choice = response.choices[0]
assert isinstance(choice, Choices)
logprobs = choice.logprobs["content"]
diff --git a/lotus/sem_ops/deepseek_utils.py b/lotus/sem_ops/deepseek_utils.py
new file mode 100644
index 00000000..205b0ed9
--- /dev/null
+++ b/lotus/sem_ops/deepseek_utils.py
@@ -0,0 +1,47 @@
+"""Utilities for handling deepseek model outputs with reasoning traces."""
+
+from typing import Tuple
+
+def extract_deepseek_reasoning(llm_answer: str) -> Tuple[str | None, str]:
+ """
+ Extract reasoning and answer from deepseek model output.
+
+ Args:
+ llm_answer: Raw LLM output that may contain tags
+
+ Returns:
+ Tuple of (reasoning, answer) where reasoning may be None if no think tags found
+ """
+ think_start = llm_answer.find("")
+ think_end = llm_answer.find("")
+
+ if think_start != -1 and think_end != -1:
+ # Extract the reasoning from between the think tags
+ reasoning = llm_answer[think_start + 7:think_end].strip()
+ # Extract the answer from after the closing think tag
+ answer = llm_answer[think_end + 8:].strip()
+ # Return reasoning first, then answer
+ return reasoning, answer
+
+ # If no think tags found, treat the whole thing as the answer with no reasoning
+ return None, llm_answer.strip()
+
+def format_deepseek_prompt(instruction: str) -> str:
+ """
+ Format instruction for deepseek models following official guidelines:
+ - No system prompts
+ - Instructions in user prompt
+ - Enforce \n start
+ - Temperature 0.6 (handled in LM class)
+
+ Args:
+ instruction: Base instruction for the task
+
+ Returns:
+ Modified instruction that enforces \n start
+ """
+ return (
+ f"{instruction}\n\n"
+ "Start your response with '\\n' to show your reasoning, "
+ "then end with '' and provide your final answer."
+ )
\ No newline at end of file
diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py
index d531099c..1e0c97d6 100644
--- a/lotus/sem_ops/postprocessors.py
+++ b/lotus/sem_ops/postprocessors.py
@@ -1,12 +1,86 @@
import json
+import re
import lotus
from lotus.types import (
SemanticExtractPostprocessOutput,
SemanticFilterPostprocessOutput,
SemanticMapPostprocessOutput,
+ SemanticAggPostprocessOutput,
)
+from lotus.sem_ops.deepseek_utils import extract_deepseek_reasoning
+
+def _process_deepseek_output(llm_answer: str) -> tuple[str | None, str]:
+ """Helper function to process deepseek model output."""
+ reasoning, answer = extract_deepseek_reasoning(llm_answer)
+ return reasoning, answer
+
+def extract_json_from_text(text: str) -> dict:
+ """Helper function to extract JSON from text that may contain code blocks or raw JSON."""
+ # Try to find JSON between curly braces
+ try:
+ start = text.find("{")
+ if start != -1:
+ end = text.rfind("}") + 1
+ if end > start:
+ json_str = text[start:end]
+ # Clean up any potential Python string formatting
+ json_str = json_str.replace("'", '"')
+ return json.loads(json_str)
+ except json.JSONDecodeError:
+ pass
+
+ # Try to find any JSON-like structure in the text
+ try:
+ matches = re.finditer(r'({[^{}]*})', text)
+ for match in matches:
+ try:
+ json_str = match.group(1).replace("'", '"')
+ return json.loads(json_str)
+ except json.JSONDecodeError:
+ continue
+ except Exception:
+ pass
+
+ # If no valid JSON found, return empty dict
+ return {}
+
+def extract_postprocess(
+ llm_answers: list[str],
+ strategy: str | None = None
+) -> SemanticExtractPostprocessOutput:
+ """
+ Postprocess the output of the extract operator to extract the schema.
+ Args:
+ llm_answers (list[str]): The list of llm answers containing the extract.
+ strategy (str | None): The reasoning strategy ("deepseek" or None).
+
+ Returns:
+ SemanticExtractPostprocessOutput
+ """
+ extract_data = []
+ for llm_answer in llm_answers:
+ lotus.logger.debug(f"Extract raw answer: {llm_answer}")
+ try:
+ if strategy == "deepseek":
+ # For deepseek models, extract the JSON from after
+ _, answer = extract_deepseek_reasoning(llm_answer)
+ output = extract_json_from_text(answer)
+ else:
+ output = extract_json_from_text(llm_answer)
+
+ lotus.logger.debug(f"Parsed JSON: {output}")
+ except json.JSONDecodeError as e:
+ lotus.logger.info(f"\t Failed to parse: {llm_answer}")
+ lotus.logger.debug(f"JSON parse error: {e}")
+ output = {}
+
+ # Convert all values to strings
+ output = {key: str(value) for key, value in output.items()}
+ extract_data.append(output)
+
+ return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data)
def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
"""
@@ -29,57 +103,59 @@ def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
reasoning_idx += len("Reasoning:\n")
answer_idx = llm_answer.find("Answer:")
- reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n")
- answer = llm_answer[answer_idx + len("Answer:") :]
+ if answer_idx == -1:
+ # No explicit Answer: marker, treat whole thing as answer
+ answer = llm_answer[reasoning_idx:].strip()
+ reasoning = None
+ else:
+ reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n")
+ answer = llm_answer[answer_idx + len("Answer:"):].strip()
+
outputs.append(answer)
explanations.append(reasoning)
return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)
-
-def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> SemanticMapPostprocessOutput:
+def map_postprocess(
+ llm_answers: list[str],
+ strategy: str | None = None,
+ cot_reasoning: bool = False
+) -> SemanticMapPostprocessOutput:
"""
Postprocess the output of the map operator.
Args:
llm_answers (list[str]): The list of llm answers.
- cot_reasoning (bool): Whether there is CoT reasoning.
+ strategy (str | None): The reasoning strategy ("deepseek", "cot", or None).
+ cot_reasoning (bool): Whether there is CoT reasoning (deprecated, use strategy="cot" instead).
Returns:
SemanticMapPostprocessOutput
"""
- if cot_reasoning:
+ if strategy == "deepseek":
+ outputs: list[str] = []
+ explanations: list[str | None] = []
+
+ for llm_answer in llm_answers:
+ lotus.logger.debug(f"Raw LLM answer: {llm_answer}")
+ reasoning, answer = _process_deepseek_output(llm_answer)
+ lotus.logger.debug(f"Extracted reasoning: {reasoning}")
+ lotus.logger.debug(f"Extracted answer: {answer}")
+ outputs.append(answer)
+ explanations.append(reasoning)
+
+ return SemanticMapPostprocessOutput(
+ raw_outputs=llm_answers,
+ outputs=outputs,
+ explanations=explanations
+ )
+ elif cot_reasoning or strategy == "cot":
return map_postprocess_cot(llm_answers)
outputs: list[str] = llm_answers
explanations: list[str | None] = [None] * len(llm_answers)
return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)
-
-def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOutput:
- """
- Postprocess the output of the extract operator to extract the schema.
-
- Args:
- llm_answers (list[str]): The list of llm answers containging the extract.
-
- Returns:
- SemanticExtractPostprocessOutput
- """
- extract_data = []
- for llm_answer in llm_answers:
- try:
- output = json.loads(llm_answer)
- except json.JSONDecodeError:
- lotus.logger.info(f"\t Failed to parse: {llm_answer}")
- output = {}
-
- output = {key: str(value) for key, value in output.items()}
- extract_data.append(output)
-
- return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data)
-
-
def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput:
"""
Postprocess the output of the filter operator with CoT reasoning.
@@ -117,10 +193,10 @@ def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFil
return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)
-
def filter_postprocess(
llm_answers: list[str],
default: bool = True,
+ strategy: str | None = None,
cot_reasoning: bool = False,
) -> SemanticFilterPostprocessOutput:
"""
@@ -129,12 +205,33 @@ def filter_postprocess(
Args:
llm_answers (list[str]): The list of llm answers.
default (bool): The default value to use if we fail to parse the answer.
- cot_reasoning (bool): Whether there is CoT reasoning.
+ strategy (str | None): The reasoning strategy ("deepseek", "cot", or None).
+ cot_reasoning (bool): Whether there is CoT reasoning (deprecated, use strategy="cot" instead).
Returns:
SemanticFilterPostprocessOutput
"""
- if cot_reasoning:
+ if strategy == "deepseek":
+ outputs: list[bool] = []
+ explanations: list[str | None] = []
+
+ for llm_answer in llm_answers:
+ reasoning, answer = _process_deepseek_output(llm_answer)
+ if "True" in answer:
+ outputs.append(True)
+ elif "False" in answer:
+ outputs.append(False)
+ else:
+ lotus.logger.info(f"\t Failed to parse: defaulting to {default}")
+ outputs.append(default)
+ explanations.append(reasoning)
+
+ return SemanticFilterPostprocessOutput(
+ raw_outputs=llm_answers,
+ outputs=outputs,
+ explanations=explanations
+ )
+ elif cot_reasoning or strategy == "cot":
return filter_postprocess_cot(llm_answers, default)
outputs: list[bool] = []
@@ -149,3 +246,35 @@ def filter_postprocess(
outputs.append(default)
return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)
+
+def agg_postprocess(
+ llm_answers: list[str],
+ strategy: str | None = None,
+) -> SemanticAggPostprocessOutput:
+ """
+ Postprocess the output of the aggregate operator.
+
+ Args:
+ llm_answers (list[str]): The list of llm answers.
+ strategy (str | None): The reasoning strategy ("deepseek" or None).
+
+ Returns:
+ SemanticAggPostprocessOutput
+ """
+ outputs: list[str] = []
+ explanations: list[str | None] = []
+
+ for llm_answer in llm_answers:
+ if strategy == "deepseek":
+ reasoning, answer = _process_deepseek_output(llm_answer)
+ outputs.append(answer)
+ explanations.append(reasoning)
+ else:
+ outputs.append(llm_answer)
+ explanations.append(None)
+
+ return SemanticAggPostprocessOutput(
+ raw_outputs=llm_answers,
+ outputs=outputs,
+ explanations=explanations
+ )
diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py
index 706f12f9..fc5bdb24 100644
--- a/lotus/sem_ops/sem_agg.py
+++ b/lotus/sem_ops/sem_agg.py
@@ -1,11 +1,13 @@
-from typing import Any
+from typing import Any, Callable
import pandas as pd
import lotus.models
from lotus.cache import operator_cache
from lotus.templates import task_instructions
-from lotus.types import LMOutput, SemanticAggOutput
+from lotus.types import LMOutput, SemanticAggOutput, SemanticAggPostprocessOutput
+from lotus.utils import show_safe_mode
+from .postprocessors import agg_postprocess
def sem_agg(
@@ -13,6 +15,8 @@ def sem_agg(
model: lotus.models.LM,
user_instruction: str,
partition_ids: list[int],
+ postprocessor: Callable[[list[str], str | None], SemanticAggPostprocessOutput] = agg_postprocess,
+ strategy: str | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
) -> SemanticAggOutput:
@@ -24,85 +28,53 @@ def sem_agg(
model (lotus.models.LM): The model to use.
user_instruction (str): The user instruction for aggregation.
partition_ids (list[int]): The partition ids for the documents. Documents with the same partition id will be aggregated together.
+ postprocessor (Callable): The postprocessor for the model outputs. Defaults to agg_postprocess.
+ strategy (str | None): The reasoning strategy ("deepseek" or None).
Returns:
- str: The aggregated answer.
+ SemanticAggOutput: The aggregated answer and explanations.
"""
- leaf_instr_template = (
- "Your job is to provide an answer to the user's instruction given the context below from multiple documents.\n"
- "Remember that your job is to answer the user's instruction by combining all relevant information from all provided documents, into a single coherent answer.\n"
- "Do NOT copy the format of the sources! Instead output your answer in a coherent, well-structured manner that best answers the user instruction.\n"
- "You have limited space to provide your answer, so be concise and to the point.\n\n---\n\n"
- "Follow the following format.\n\nContext: relevant facts from multiple documents\n\n"
- "Instruction: the instruction provided by the user\n\nAnswer: Write your answer\n\n---\n\n"
- "Context: {{docs_str}}\n\n"
- f"Instruction: {user_instruction}\n\nAnswer:\n"
- )
-
- node_instr_template = (
- "Your job is to provide an answer to the user's instruction given the context below from multiple sources.\n"
- "Note that each source may be formatted differently and contain information about several different documents.\n"
- "Remember that your job is to answer the user's instruction by combining all relevant information from all provided sources, into a single coherent answer.\n"
- "The sources may provide opposing viewpoints or complementary information.\n"
- "Be sure to include information from ALL relevant sources in your answer.\n"
- "Do NOT copy the format of the sources, instead output your answer in a coherent, well-structured manner that best answers the user instruction.\n"
- "You have limited space to provide your answer, so be concise and to the point.\n"
- "You may need to draw connections between sources to provide a complete answer.\n\n---\n\n"
- "Follow the following format.\n\nContext: relevant facts from multiple sources\n\n"
- "Instruction: the instruction provided by the user\n\nAnswer: Write your answer\n\n---\n\n"
- "Context: {{docs_str}}\n\n"
- f"Instruction: {user_instruction}\n\nAnswer:\n"
- )
-
- def leaf_doc_formatter(doc: str, ctr: int) -> str:
- return f"\n\tDocument {ctr}: {doc}"
-
- def node_doc_formatter(doc: str, ctr: int) -> str:
- return f"\n\tSource {ctr}: {doc}"
-
- def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
- return leaf_doc_formatter(doc, ctr) if tree_level == 0 else node_doc_formatter(doc, ctr)
-
if safe_mode:
# TODO: implement safe mode
lotus.logger.warning("Safe mode is not implemented yet")
tree_level = 0
summaries: list[str] = []
+ explanations: list[str | None] = []
new_partition_ids: list[int] = []
+
while len(docs) != 1 or summaries == []:
cur_partition_id = partition_ids[0]
do_fold = len(partition_ids) == len(set(partition_ids))
context_str = ""
- # prompt = ""
batch = []
- if tree_level == 0:
- template = leaf_instr_template
- else:
- template = node_instr_template
- template_tokens = model.count_tokens(template)
context_tokens = 0
doc_ctr = 1 # num docs in current prompt
for idx in range(len(docs)):
partition_id = partition_ids[idx]
- formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr)
+ formatted_doc = f"\n\tDocument {doc_ctr}: {docs[idx]}" if tree_level == 0 else f"\n\tSource {doc_ctr}: {docs[idx]}"
new_tokens = model.count_tokens(formatted_doc)
- if (new_tokens + context_tokens + template_tokens > model.max_ctx_len - model.max_tokens) or (
+ # Create multimodal data for the current batch
+ multimodal_data = {"text": context_str + formatted_doc}
+ prompt = task_instructions.agg_formatter(multimodal_data, user_instruction, strategy=strategy)
+ prompt_tokens = model.count_tokens(prompt)
+
+ if (prompt_tokens > model.max_ctx_len - model.max_tokens) or (
partition_id != cur_partition_id and not do_fold
):
# close the current prompt
-
- prompt = template.replace("{{docs_str}}", context_str)
+ multimodal_data = {"text": context_str}
+ prompt = task_instructions.agg_formatter(multimodal_data, user_instruction, strategy=strategy)
lotus.logger.debug(f"Prompt added to batch: {prompt}")
- batch.append([{"role": "user", "content": prompt}])
+ batch.append(prompt)
new_partition_ids.append(cur_partition_id)
cur_partition_id = partition_id
doc_ctr = 1
# add new context to next prompt
- formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr)
+ formatted_doc = f"\n\tDocument {doc_ctr}: {docs[idx]}" if tree_level == 0 else f"\n\tSource {doc_ctr}: {docs[idx]}"
context_str = formatted_doc
context_tokens = new_tokens
doc_ctr += 1
@@ -112,14 +84,19 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
doc_ctr += 1
if doc_ctr > 1 or len(docs) == 1:
- prompt = template.replace("{{docs_str}}", context_str)
+ multimodal_data = {"text": context_str}
+ prompt = task_instructions.agg_formatter(multimodal_data, user_instruction, strategy=strategy)
lotus.logger.debug(f"Prompt added to batch: {prompt}")
- batch.append([{"role": "user", "content": prompt}])
+ batch.append(prompt)
new_partition_ids.append(cur_partition_id)
lm_output: LMOutput = model(batch, progress_bar_desc=progress_bar_desc)
+
+ # Post process results
+ postprocess_output = postprocessor(lm_output.outputs, strategy=strategy)
+ summaries.extend(postprocess_output.outputs)
+ explanations.extend(postprocess_output.explanations)
- summaries = lm_output.outputs
partition_ids = new_partition_ids
new_partition_ids = []
@@ -129,7 +106,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str:
if safe_mode:
model.print_total_usage()
- return SemanticAggOutput(outputs=summaries)
+ return SemanticAggOutput(outputs=summaries, explanations=explanations)
@pd.api.extensions.register_dataframe_accessor("sem_agg")
@@ -146,8 +123,16 @@ def _validate(obj: Any) -> None:
@staticmethod
def process_group(args):
- group, user_instruction, all_cols, suffix, progress_bar_desc = args
- return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)
+ group, user_instruction, all_cols, suffix, strategy, return_explanations, progress_bar_desc = args
+ return group.sem_agg(
+ user_instruction,
+ all_cols=all_cols,
+ suffix=suffix,
+ group_by=None,
+ strategy=strategy,
+ return_explanations=return_explanations,
+ progress_bar_desc=progress_bar_desc
+ )
@operator_cache
def __call__(
@@ -156,6 +141,8 @@ def __call__(
all_cols: bool = False,
suffix: str = "_output",
group_by: list[str] | None = None,
+ strategy: str | None = None,
+ return_explanations: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Aggregating",
) -> pd.DataFrame:
@@ -167,6 +154,8 @@ def __call__(
all_cols (bool): Whether to use all columns in the dataframe. Defaults to False.
suffix (str): The suffix for the new column. Defaults to "_output".
group_by (list[str] | None): The columns to group by before aggregation. Each group will be aggregated separately.
+ strategy (str | None): The reasoning strategy ("deepseek" or None).
+ return_explanations (bool): Whether to return explanations. Defaults to False.
Returns:
pd.DataFrame: The dataframe with the aggregated answer.
"""
@@ -190,7 +179,7 @@ def __call__(
if group_by:
grouped = self._obj.groupby(group_by)
- group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped]
+ group_args = [(group, user_instruction, all_cols, suffix, strategy, return_explanations, progress_bar_desc) for _, group in grouped]
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor:
@@ -213,10 +202,13 @@ def __call__(
lotus.settings.lm,
formatted_usr_instr,
partition_ids,
+ strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
)
# package answer in a dataframe
- answer_df = pd.DataFrame(answer.outputs, columns=[suffix])
+ answer_df = pd.DataFrame({suffix: answer.outputs})
+ if return_explanations:
+ answer_df[f"explanation{suffix}"] = answer.explanations
return answer_df
diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py
index d05dfe46..ebac22e4 100644
--- a/lotus/sem_ops/sem_extract.py
+++ b/lotus/sem_ops/sem_extract.py
@@ -18,6 +18,7 @@ def sem_extract(
output_cols: dict[str, str | None],
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
+ strategy: str | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> SemanticExtractOutput:
@@ -37,7 +38,7 @@ def sem_extract(
# prepare model inputs
inputs = []
for doc in docs:
- prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes)
+ prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes, strategy=strategy)
lotus.logger.debug(f"input to model: {prompt}")
lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}")
inputs.append(prompt)
@@ -48,11 +49,16 @@ def sem_extract(
estimated_LM_calls = len(docs)
show_safe_mode(estimated_cost, estimated_LM_calls)
+ # call model with response_format=json_object for all models
+ kwargs = {"response_format": {"type": "json_object"}}
+ if "ollama" not in model.model: # Only add temperature for non-Ollama models
+ kwargs["temperature"] = 0.0 # Use zero temperature for extractions
+
# call model
- lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc)
+ lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc, **kwargs)
# post process results
- postprocess_output = postprocessor(lm_output.outputs)
+ postprocess_output = postprocessor(lm_output.outputs, strategy=strategy)
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
if safe_mode:
@@ -80,6 +86,7 @@ def __call__(
extract_quotes: bool = False,
postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess,
return_raw_outputs: bool = False,
+ strategy: str | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Extracting",
) -> pd.DataFrame:
@@ -114,21 +121,26 @@ def __call__(
output_cols=output_cols,
extract_quotes=extract_quotes,
postprocessor=postprocessor,
+ strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
)
- new_df = self._obj.copy()
- indices = new_df.index.to_list()
- for i, output_dict in enumerate(out.outputs):
- if i >= len(indices):
- break
- for key, value in output_dict.items():
- if key not in new_df.columns:
- new_df[key] = None
- new_df.loc[indices[i], key] = value
+ # Create a copy of the original DataFrame so we can preserve original columns
+ new_df = self._obj.copy()
+ indices = new_df.index.to_list()
+
+ # Insert the extracted columns into new_df
+ for i, output_dict in enumerate(out.outputs):
+ if i >= len(indices):
+ break
+ for key, value in output_dict.items():
+ if key not in new_df.columns:
+ new_df[key] = None
+ new_df.loc[indices[i], key] = value
- if return_raw_outputs:
- new_df["raw_output"] = out.raw_outputs
+ # Optionally add raw outputs as a column
+ if return_raw_outputs:
+ new_df["raw_output"] = out.raw_outputs
- return new_df
+ return new_df
diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py
index 0f30874f..13f4d7e6 100644
--- a/lotus/sem_ops/sem_filter.py
+++ b/lotus/sem_ops/sem_filter.py
@@ -63,7 +63,10 @@ def sem_filter(
)
postprocess_output = filter_postprocess(
- lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]
+ lm_output.outputs,
+ default=default,
+ strategy=strategy,
+ cot_reasoning=strategy in ["cot", "zs-cot"]
)
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}")
diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py
index 4cc22d88..d8d5b0dd 100644
--- a/lotus/sem_ops/sem_map.py
+++ b/lotus/sem_ops/sem_map.py
@@ -58,7 +58,11 @@ def sem_map(
lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc)
# post process results
- postprocess_output = postprocessor(lm_output.outputs, strategy in ["cot", "zs-cot"])
+ postprocess_output = postprocessor(
+ lm_output.outputs,
+ strategy=strategy,
+ cot_reasoning=strategy in ["cot", "zs-cot"]
+ )
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")
diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py
index fc30efd9..21fb012d 100644
--- a/lotus/templates/task_instructions.py
+++ b/lotus/templates/task_instructions.py
@@ -6,7 +6,7 @@
import lotus
from lotus.dtype_extensions import ImageDtype
from lotus.types import SerializationFormat
-
+from lotus.sem_ops.deepseek_utils import format_deepseek_prompt
def context_formatter(
multimodal_data: dict[str, Any] | str,
@@ -35,18 +35,30 @@ def context_formatter(
raise ValueError("multimodal_data must be a dictionary or a string")
return text, image_inputs
-
def user_message_formatter(
multimodal_data: dict[str, Any] | str,
user_instruction_with_tag: str | None = None,
+ is_deepseek: bool = False,
+ base_instruction: str | None = None,
) -> dict[str, Any]:
text, image_inputs = context_formatter(multimodal_data)
+
+ # For deepseek, include instructions in user prompt and enforce start
+ if is_deepseek and base_instruction:
+ instruction = (
+ f"{base_instruction}\n\n"
+ "Start your response with '' to show your reasoning, "
+ "then end with '' and provide your final answer.\n\n"
+ )
+ else:
+ instruction = ""
+
if not image_inputs or len(image_inputs) == 0:
return {
"role": "user",
- "content": f"Context:\n{text}\n\n{user_instruction_with_tag}",
+ "content": f"{instruction}Context:\n{text}\n\n{user_instruction_with_tag}",
}
- content = [{"type": "text", "text": f"Context:\n{text}"}] + image_inputs
+ content = [{"type": "text", "text": f"{instruction}Context:\n{text}"}] + image_inputs
if user_instruction_with_tag:
content.append({"type": "text", "text": f"\n\n{user_instruction_with_tag}"})
return {
@@ -54,6 +66,73 @@ def user_message_formatter(
"content": content,
}
+def extract_formatter(
+ multimodal_data: dict[str, Any],
+ output_cols: dict[str, str | None],
+ extract_quotes: bool = True,
+ strategy: str | None = None
+) -> list[dict[str, str]]:
+ output_col_names = list(output_cols.keys())
+ # Set the description to be the key if no value is provided
+ output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()}
+
+ # Create example JSON with just the required fields
+ example_json = {
+ field: "example_value" for field in output_col_names
+ }
+
+ if strategy == "deepseek":
+ # Create the instruction without any backslashes in f-strings
+ fields_list = ", ".join(output_col_names)
+ example_fields = ", ".join([f' "{field}": "value"' for field in output_col_names])
+
+ # Build instruction using concatenation instead of f-strings
+ base_instruction = (
+ "Your task is to extract specific fields from the given context.\n"
+ "Fields to extract: " + str(output_cols_with_desc) + "\n\n"
+ "Instructions:\n"
+ "1. First, show your reasoning in tags\n"
+ "2. Then provide ONLY a valid JSON object with these exact fields:\n" +
+ fields_list + "\n\n"
+ "Example format:\n"
+ "\n"
+ "Your reasoning about how you extracted each field...\n"
+ "\n"
+ "{\n" +
+ example_fields + "\n"
+ "}\n\n"
+ "Important: The JSON must be properly formatted and contain exactly these fields. "
+ "Do not include any other text, code blocks, or explanations outside the tags."
+ )
+ else:
+ # Build instruction using concatenation instead of f-strings
+ base_instruction = (
+ "Your task is to extract specific fields from the given context.\n"
+ "Fields to extract: " + str(output_cols_with_desc) + "\n\n"
+ "Instructions:\n"
+ "1. Extract the requested fields from the context\n"
+ "2. Return ONLY a valid JSON object with these exact fields:\n" +
+ ", ".join(output_col_names) + "\n\n"
+ "Example format:\n"
+ "{\n" +
+ ", ".join([f' "{field}": "value"' for field in output_col_names]) + "\n"
+ "}\n\n"
+ "Important: The JSON must be properly formatted and contain exactly these fields. "
+ "Do not include any other text, code blocks, or explanations."
+ )
+
+ is_deepseek = strategy == "deepseek"
+ messages = []
+
+ if not is_deepseek:
+ messages.append({"role": "system", "content": base_instruction})
+
+ messages.append(user_message_formatter(
+ multimodal_data,
+ is_deepseek=is_deepseek,
+ base_instruction=base_instruction if is_deepseek else None
+ ))
+ return messages
def filter_formatter_cot(
multimodal_data: dict[str, Any],
@@ -88,7 +167,6 @@ def filter_formatter_cot(
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
return messages
-
def filter_formatter_zs_cot(
multimodal_data: dict[str, Any],
user_instruction: str,
@@ -105,7 +183,6 @@ def filter_formatter_zs_cot(
messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
return messages
-
def filter_formatter(
multimodal_data: dict[str, Any],
user_instruction: str,
@@ -122,14 +199,17 @@ def filter_formatter(
elif strategy == "zs-cot":
return filter_formatter_zs_cot(multimodal_data, user_instruction)
- sys_instruction = (
+ base_instruction = (
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'You must answer with a single word, "True" or "False".'
)
- messages = [
- {"role": "system", "content": sys_instruction},
- ]
+
+ is_deepseek = strategy == "deepseek"
+ messages = []
+
+ if not is_deepseek:
+ messages.append({"role": "system", "content": base_instruction})
if examples_multimodal_data:
assert examples_answer is not None
@@ -139,15 +219,24 @@ def filter_formatter(
ex_ans = examples_answer[i]
messages.extend(
[
- user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
+ user_message_formatter(
+ ex_multimodal_data,
+ f"Claim: {user_instruction}",
+ is_deepseek=is_deepseek,
+ base_instruction=base_instruction if is_deepseek else None
+ ),
{"role": "assistant", "content": str(ex_ans)},
]
)
- messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
+ messages.append(user_message_formatter(
+ multimodal_data,
+ f"Claim: {user_instruction}",
+ is_deepseek=is_deepseek,
+ base_instruction=base_instruction if is_deepseek else None
+ ))
return messages
-
def map_formatter_cot(
multimodal_data: dict[str, Any],
user_instruction: str,
@@ -181,7 +270,6 @@ def map_formatter_cot(
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
return messages
-
def map_formatter_zs_cot(
multimodal_data: dict[str, Any],
user_instruction: str,
@@ -198,7 +286,6 @@ def map_formatter_zs_cot(
messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
return messages
-
def map_formatter(
multimodal_data: dict[str, Any],
user_instruction: str,
@@ -215,58 +302,100 @@ def map_formatter(
elif strategy == "zs-cot":
return map_formatter_zs_cot(multimodal_data, user_instruction)
- sys_instruction = (
- "The user will provide an instruction and some relevant context.\n"
- "Your job is to answer the user's instruction given the context."
- )
- messages = [
- {"role": "system", "content": sys_instruction},
- ]
+ is_deepseek = strategy == "deepseek"
+ messages = []
+
+ if is_deepseek:
+ base_instruction = (
+ "Your task is to answer the given instruction based on the context.\n\n"
+ "Instructions:\n"
+ "1. First, show your reasoning in tags\n"
+ "2. Then provide ONLY a single word, phrase, or number as your final answer\n\n"
+ "Example format:\n"
+ "\n"
+ "Your reasoning about how you arrived at the answer...\n"
+ "\n"
+ "answer\n\n"
+ )
+ else:
+ base_instruction = (
+ "The user will provide an instruction and some relevant context.\n"
+ "Your job is to answer the user's instruction given the context.\n"
+ "Provide a single concise answer that directly responds to the instruction."
+ )
+ messages.append({"role": "system", "content": base_instruction})
if examples_multimodal_data:
assert examples_answer is not None
for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer):
messages.extend(
[
- user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"),
+ user_message_formatter(
+ ex_df_txt,
+ f"Instruction: {user_instruction}",
+ is_deepseek=is_deepseek,
+ base_instruction=base_instruction if is_deepseek else None
+ ),
{"role": "assistant", "content": str(ex_ans)},
]
)
- messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}"))
+ messages.append(user_message_formatter(
+ multimodal_data,
+ f"Instruction: {user_instruction}",
+ is_deepseek=is_deepseek,
+ base_instruction=base_instruction if is_deepseek else None
+ ))
return messages
-
-def extract_formatter(
- multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True
+def agg_formatter(
+ multimodal_data: dict[str, Any],
+ user_instruction: str,
+ strategy: str | None = None,
) -> list[dict[str, str]]:
- output_col_names = list(output_cols.keys())
- # Set the description to be the key if no value is provided
- output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()}
-
- all_fields = output_col_names
- if extract_quotes:
- quote_fields = [f"{col}_quote" for col in output_col_names]
- all_fields += quote_fields
-
- fields_str = ", ".join(all_fields)
+ """
+ Format instructions for aggregation operator.
- sys_instruction = (
- "The user will provide the columns that need to be extracted and some relevant context.\n"
- f"Your job is to extract these columns and provide only a concise value for each field "
- f"and the corresponding full quote for each field in the '{', '.join([f'{col}_quote' for col in output_col_names])}' fields.\n"
- f"Here is a description of each field: {output_cols_with_desc}\n"
- f"The response should be valid JSON format with the following fields: {fields_str}.\n"
- )
+ Args:
+ multimodal_data (dict[str, Any]): The multimodal data to format.
+ user_instruction (str): The user instruction.
+ strategy (str | None): The reasoning strategy ("deepseek" or None).
- messages = [
- {"role": "system", "content": sys_instruction},
- user_message_formatter(multimodal_data),
- ]
+ Returns:
+ list[dict[str, str]]: The formatted messages.
+ """
+ is_deepseek = strategy == "deepseek"
+ messages = []
+
+ if is_deepseek:
+ base_instruction = (
+ "Your task is to analyze multiple documents and provide a comprehensive answer.\n\n"
+ "Instructions:\n"
+ "1. First, show your reasoning in tags about how you analyzed the documents\n"
+ "2. Then provide your final answer that combines information from all documents\n\n"
+ "Example format:\n"
+ "\n"
+ "Your reasoning about how you analyzed and combined the documents...\n"
+ "\n"
+ "Your final comprehensive answer\n\n"
+ )
+ else:
+ base_instruction = (
+ "Your job is to provide an answer to the user's instruction given the context below from multiple documents.\n"
+ "Remember that your job is to answer the user's instruction by combining all relevant information from all provided documents, into a single coherent answer.\n"
+ "Do NOT copy the format of the sources! Instead output your answer in a coherent, well-structured manner that best answers the user instruction.\n"
+ "You have limited space to provide your answer, so be concise and to the point.\n"
+ )
+ messages.append({"role": "system", "content": base_instruction})
+
+ messages.append(user_message_formatter(
+ multimodal_data,
+ f"Instruction: {user_instruction}",
+ is_deepseek=is_deepseek,
+ base_instruction=base_instruction if is_deepseek else None
+ ))
return messages
-
-# returns a list of strings corresponding to df rows
def df2text(df: pd.DataFrame, cols: list[str]) -> list[str]:
"""Formats the given DataFrame into a string containing info from cols."""
@@ -305,7 +434,6 @@ def clean_and_escape_column_name(column_name: str) -> str:
return formatted_rows
-
def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]]:
"""
Formats the given DataFrame into a string containing info from cols.
@@ -323,7 +451,6 @@ def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]
]
return multimodal_data
-
def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Merges two multimodal info lists into one. Each row of first is merged with each row of second.
@@ -346,6 +473,5 @@ def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, An
for j in range(len(second))
]
-
def li2text(li: list[str], name: str) -> str:
return "".join([f"[{name}] {li[i]}\n" for i in range(len(li))])
diff --git a/lotus/types.py b/lotus/types.py
index fce2c4d1..445dee49 100644
--- a/lotus/types.py
+++ b/lotus/types.py
@@ -88,9 +88,17 @@ class SemanticFilterOutput:
logprobs: list[list[ChatCompletionTokenLogprob]] | None = None
+@dataclass
+class SemanticAggPostprocessOutput:
+ raw_outputs: list[str]
+ outputs: list[str]
+ explanations: list[str | None]
+
+
@dataclass
class SemanticAggOutput:
outputs: list[str]
+ explanations: list[str | None]
@dataclass
diff --git a/pytest.ini b/pytest.ini
index 38be7969..a2097370 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -2,4 +2,14 @@
testpaths = tests
python_files = test_*.py
python_classes = Test*
-python_functions = test_*
\ No newline at end of file
+python_functions = test_*
+
+# Configure test discovery
+addopts = -v --tb=short
+
+# Mark custom test categories
+markers =
+ deepseek: tests for deepseek model support
+
+# Test file patterns
+norecursedirs = .git .tox .eggs *.egg
\ No newline at end of file
diff --git a/tests/test_deepseek.py b/tests/test_deepseek.py
new file mode 100644
index 00000000..53d416b2
--- /dev/null
+++ b/tests/test_deepseek.py
@@ -0,0 +1,240 @@
+"""Tests for deepseek model support across semantic operators."""
+
+import pandas as pd
+import pytest
+
+import lotus
+from lotus.models import LM
+
+def test_model_config():
+ """Test that deepseek models are configured correctly"""
+ # Deepseek model should set temperature=0.6 and strategy=deepseek
+ lm = LM(model="ollama/deepseek-r1:7b")
+ assert lm.kwargs.get("strategy") == "deepseek"
+ assert lm.kwargs.get("temperature") == 0.6 # Recommended temperature
+ assert lm.is_deepseek # Should detect deepseek model
+
+ # Non-deepseek model should keep original settings
+ lm = LM(model="ollama/llama3.1:8b", temperature=0.7)
+ assert "strategy" not in lm.kwargs
+ assert lm.kwargs.get("temperature") == 0.7
+ assert not lm.is_deepseek # Should not detect as deepseek
+
+def test_map_behavior():
+ """Test that sem_map handles both deepseek and non-deepseek models correctly"""
+ df = pd.DataFrame({
+ "Course": ["Machine Learning", "Data Structures"]
+ })
+
+ # Test deepseek model
+ lm = LM(model="ollama/deepseek-r1:7b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_map(
+ "What is a similar course to {Course}?",
+ return_explanations=True,
+ strategy="deepseek"
+ )
+
+ # Should have reasoning in explanation_map
+ assert "explanation_map" in result.columns
+ assert result["explanation_map"].iloc[0] is not None
+ assert isinstance(result["_map"].iloc[0], str)
+
+ # Test non-deepseek model
+ lm = LM(model="ollama/llama3.1:8b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_map(
+ "What is a similar course to {Course}?",
+ return_explanations=True # Non-deepseek model should not use deepseek strategy
+ )
+
+ # Should have no reasoning
+ assert result["explanation_map"].iloc[0] is None
+ assert isinstance(result["_map"].iloc[0], str)
+
+def test_filter_behavior():
+ """Test that sem_filter handles both deepseek and non-deepseek models correctly"""
+ df = pd.DataFrame({
+ "Course": ["Machine Learning", "Art History"]
+ })
+
+ # Test deepseek model
+ lm = LM(model="ollama/deepseek-r1:7b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_filter(
+ "Is {Course} a technical course?",
+ return_explanations=True,
+ strategy="deepseek"
+ )
+
+ # Should have reasoning in explanation_filter
+ assert "explanation_filter" in result.columns
+ filtered_rows = result[result["Course"] == "Machine Learning"]
+ assert len(filtered_rows) > 0
+ assert filtered_rows["explanation_filter"].iloc[0] is not None
+
+ # Test non-deepseek model
+ lm = LM(model="ollama/llama3.1:8b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_filter(
+ "Is {Course} a technical course?",
+ return_explanations=True # Non-deepseek model should not use deepseek strategy
+ )
+
+ # Should have no reasoning
+ filtered_rows = result[result["Course"] == "Machine Learning"]
+ assert len(filtered_rows) > 0
+ assert filtered_rows["explanation_filter"].iloc[0] is None
+
+def test_join_behavior():
+ """Test that sem_join handles both deepseek and non-deepseek models correctly"""
+ df1 = pd.DataFrame({
+ "Course1": ["Machine Learning"]
+ })
+ df2 = pd.DataFrame({
+ "Course2": ["Statistics", "Art History"]
+ })
+
+ # Test deepseek model
+ lm = LM(model="ollama/deepseek-r1:7b")
+ lotus.settings.configure(lm=lm)
+
+ result = df1.sem_join(
+ df2,
+ "{Course1} and {Course2} are related fields",
+ return_explanations=True,
+ strategy="deepseek"
+ )
+
+ # Print the full result DataFrame
+ print("\n=== Deepseek Join Result ===")
+ with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000):
+ print(result)
+ print("===========================\n")
+
+ # Should have reasoning in explanation_join
+ assert "explanation_join" in result.columns
+ assert len(result) > 0
+ assert result["explanation_join"].iloc[0] is not None
+
+ # Test non-deepseek model
+ lm = LM(model="ollama/llama3.1:8b")
+ lotus.settings.configure(lm=lm)
+
+ result = df1.sem_join(
+ df2,
+ "{Course1} and {Course2} are related fields",
+ return_explanations=True # Non-deepseek model should not use deepseek strategy
+ )
+
+ # Print the full result DataFrame
+ print("\n=== Non-Deepseek Join Result ===")
+ with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000):
+ print(result)
+ print("===========================\n")
+
+ # Should have no reasoning
+ assert len(result) > 0
+ assert result["explanation_join"].iloc[0] is None
+
+def test_extract_behavior():
+ """Test that sem_extract handles both deepseek and non-deepseek models correctly"""
+ df = pd.DataFrame({
+ "Text": ["The course Machine Learning (CS229) is taught by Prof. Smith"]
+ })
+
+ # Test deepseek model
+ lm = LM(model="ollama/deepseek-r1:7b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_extract(
+ input_cols=["Text"],
+ output_cols={
+ "course_code": "Course code",
+ "professor": "Professor name"
+ },
+ return_raw_outputs=True,
+ strategy="deepseek"
+ )
+
+ # Should extract fields correctly with reasoning
+ assert "course_code" in result.columns
+ assert "professor" in result.columns
+ assert isinstance(result["course_code"].iloc[0], str)
+ assert isinstance(result["professor"].iloc[0], str)
+ assert "" in result["raw_output"].iloc[0]
+
+ # Test non-deepseek model
+ lm = LM(model="ollama/llama3.1:8b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_extract(
+ input_cols=["Text"],
+ output_cols={
+ "course_code": "Course code",
+ "professor": "Professor name"
+ },
+ return_raw_outputs=True
+ )
+
+ # Should extract fields correctly without reasoning
+ assert "course_code" in result.columns
+ assert "professor" in result.columns
+ assert isinstance(result["course_code"].iloc[0], str)
+ assert isinstance(result["professor"].iloc[0], str)
+ assert "" not in result["raw_output"].iloc[0]
+
+def test_agg_behavior():
+ """Test that sem_agg handles both deepseek and non-deepseek models correctly"""
+ df = pd.DataFrame({
+ "Note": [
+ "Patient reports chest pain",
+ "Follow-up shows improved symptoms",
+ "Final checkup confirms recovery"
+ ]
+ })
+
+ # Test deepseek model
+ lm = LM(model="ollama/deepseek-r1:7b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_agg(
+ "Summarize the progression of {Note}",
+ return_explanations=True,
+ strategy="deepseek"
+ )
+
+ # Print the full result DataFrame
+ print("\n=== Deepseek Agg Result ===")
+ with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000):
+ print(result)
+ print("===========================\n")
+
+ # Should have reasoning in explanation column
+ assert "explanation_output" in result.columns
+ assert len(result) > 0
+ assert result["explanation_output"].iloc[0] is not None
+ assert isinstance(result["_output"].iloc[0], str)
+
+ # Test non-deepseek model
+ lm = LM(model="ollama/llama3.1:8b")
+ lotus.settings.configure(lm=lm)
+
+ result = df.sem_agg(
+ "Summarize the progression of {Note}",
+ return_explanations=True # Non-deepseek model should not use deepseek strategy
+ )
+
+ # Print the full result DataFrame
+ print("\n=== Non-Deepseek Agg Result ===")
+ with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000):
+ print(result)
+ print("===========================\n")
+
+ # Should have no reasoning
+ assert len(result) > 0
+ assert result["explanation_output"].iloc[0] is None
\ No newline at end of file