From 91e0afefdd1eb2fe23700f2dc3458934f49a0460 Mon Sep 17 00:00:00 2001 From: Caleb Winston Date: Wed, 12 Feb 2025 23:45:43 +0000 Subject: [PATCH 1/2] Reasoning LLM support --- lotus/models/lm.py | 20 ++- lotus/sem_ops/deepseek_utils.py | 47 +++++++ lotus/sem_ops/postprocessors.py | 164 +++++++++++++++++----- lotus/sem_ops/sem_agg.py | 17 ++- lotus/sem_ops/sem_extract.py | 17 ++- lotus/sem_ops/sem_filter.py | 5 +- lotus/sem_ops/sem_map.py | 6 +- lotus/templates/task_instructions.py | 196 +++++++++++++++++++-------- pytest.ini | 12 +- tests/test_deepseek.py | 177 ++++++++++++++++++++++++ 10 files changed, 556 insertions(+), 105 deletions(-) create mode 100644 lotus/sem_ops/deepseek_utils.py create mode 100644 tests/test_deepseek.py diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 616dc143..6be24284 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -36,6 +36,14 @@ def __init__( self.max_tokens = max_tokens self.max_batch_size = max_batch_size self.tokenizer = tokenizer + + # Configure deepseek models + self.is_deepseek = "deepseek-r1" in model.lower() + if self.is_deepseek: + # Set recommended temperature and strategy + temperature = 0.6 # Override temperature for deepseek + kwargs["strategy"] = kwargs.get("strategy", "deepseek") + self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.stats: LMStats = LMStats() @@ -50,6 +58,10 @@ def __call__( **kwargs: dict[str, Any], ) -> LMOutput: all_kwargs = {**self.kwargs, **kwargs} + + # Remove response_format for Ollama models since they don't support function calling + if "ollama" in self.model and "response_format" in all_kwargs: + del all_kwargs["response_format"] # Set top_logprobs if logprobs requested if all_kwargs.get("logprobs", False): @@ -144,14 +156,18 @@ def _update_stats(self, response: ModelResponse): # Sometimes the model's pricing information is not available lotus.logger.debug(f"Error updating completion cost: {e}") - def _get_top_choice(self, response: ModelResponse) -> str: + def _get_top_choice(self, response: ModelResponse | OpenAIError) -> str: + if isinstance(response, OpenAIError): + raise response choice = response.choices[0] assert isinstance(choice, Choices) if choice.message.content is None: raise ValueError(f"No content in response: {response}") return choice.message.content - def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]: + def _get_top_choice_logprobs(self, response: ModelResponse | OpenAIError) -> list[ChatCompletionTokenLogprob]: + if isinstance(response, OpenAIError): + raise response choice = response.choices[0] assert isinstance(choice, Choices) logprobs = choice.logprobs["content"] diff --git a/lotus/sem_ops/deepseek_utils.py b/lotus/sem_ops/deepseek_utils.py new file mode 100644 index 00000000..205b0ed9 --- /dev/null +++ b/lotus/sem_ops/deepseek_utils.py @@ -0,0 +1,47 @@ +"""Utilities for handling deepseek model outputs with reasoning traces.""" + +from typing import Tuple + +def extract_deepseek_reasoning(llm_answer: str) -> Tuple[str | None, str]: + """ + Extract reasoning and answer from deepseek model output. + + Args: + llm_answer: Raw LLM output that may contain tags + + Returns: + Tuple of (reasoning, answer) where reasoning may be None if no think tags found + """ + think_start = llm_answer.find("") + think_end = llm_answer.find("") + + if think_start != -1 and think_end != -1: + # Extract the reasoning from between the think tags + reasoning = llm_answer[think_start + 7:think_end].strip() + # Extract the answer from after the closing think tag + answer = llm_answer[think_end + 8:].strip() + # Return reasoning first, then answer + return reasoning, answer + + # If no think tags found, treat the whole thing as the answer with no reasoning + return None, llm_answer.strip() + +def format_deepseek_prompt(instruction: str) -> str: + """ + Format instruction for deepseek models following official guidelines: + - No system prompts + - Instructions in user prompt + - Enforce \n start + - Temperature 0.6 (handled in LM class) + + Args: + instruction: Base instruction for the task + + Returns: + Modified instruction that enforces \n start + """ + return ( + f"{instruction}\n\n" + "Start your response with '\\n' to show your reasoning, " + "then end with '' and provide your final answer." + ) \ No newline at end of file diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index d531099c..e47540fe 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -1,4 +1,5 @@ import json +import re import lotus from lotus.types import ( @@ -6,7 +7,79 @@ SemanticFilterPostprocessOutput, SemanticMapPostprocessOutput, ) +from lotus.sem_ops.deepseek_utils import extract_deepseek_reasoning + +def _process_deepseek_output(llm_answer: str) -> tuple[str | None, str]: + """Helper function to process deepseek model output.""" + reasoning, answer = extract_deepseek_reasoning(llm_answer) + return reasoning, answer + +def extract_json_from_text(text: str) -> dict: + """Helper function to extract JSON from text that may contain code blocks or raw JSON.""" + # Try to find JSON between curly braces + try: + start = text.find("{") + if start != -1: + end = text.rfind("}") + 1 + if end > start: + json_str = text[start:end] + # Clean up any potential Python string formatting + json_str = json_str.replace("'", '"') + return json.loads(json_str) + except json.JSONDecodeError: + pass + + # Try to find any JSON-like structure in the text + try: + matches = re.finditer(r'({[^{}]*})', text) + for match in matches: + try: + json_str = match.group(1).replace("'", '"') + return json.loads(json_str) + except json.JSONDecodeError: + continue + except Exception: + pass + + # If no valid JSON found, return empty dict + return {} + +def extract_postprocess( + llm_answers: list[str], + strategy: str | None = None +) -> SemanticExtractPostprocessOutput: + """ + Postprocess the output of the extract operator to extract the schema. + Args: + llm_answers (list[str]): The list of llm answers containing the extract. + strategy (str | None): The reasoning strategy ("deepseek" or None). + + Returns: + SemanticExtractPostprocessOutput + """ + extract_data = [] + for llm_answer in llm_answers: + lotus.logger.debug(f"Extract raw answer: {llm_answer}") + try: + if strategy == "deepseek": + # For deepseek models, extract the JSON from after + _, answer = extract_deepseek_reasoning(llm_answer) + output = extract_json_from_text(answer) + else: + output = extract_json_from_text(llm_answer) + + lotus.logger.debug(f"Parsed JSON: {output}") + except json.JSONDecodeError as e: + lotus.logger.info(f"\t Failed to parse: {llm_answer}") + lotus.logger.debug(f"JSON parse error: {e}") + output = {} + + # Convert all values to strings + output = {key: str(value) for key, value in output.items()} + extract_data.append(output) + + return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data) def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: """ @@ -29,57 +102,59 @@ def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: reasoning_idx += len("Reasoning:\n") answer_idx = llm_answer.find("Answer:") - reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") - answer = llm_answer[answer_idx + len("Answer:") :] + if answer_idx == -1: + # No explicit Answer: marker, treat whole thing as answer + answer = llm_answer[reasoning_idx:].strip() + reasoning = None + else: + reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") + answer = llm_answer[answer_idx + len("Answer:"):].strip() + outputs.append(answer) explanations.append(reasoning) return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) - -def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> SemanticMapPostprocessOutput: +def map_postprocess( + llm_answers: list[str], + strategy: str | None = None, + cot_reasoning: bool = False +) -> SemanticMapPostprocessOutput: """ Postprocess the output of the map operator. Args: llm_answers (list[str]): The list of llm answers. - cot_reasoning (bool): Whether there is CoT reasoning. + strategy (str | None): The reasoning strategy ("deepseek", "cot", or None). + cot_reasoning (bool): Whether there is CoT reasoning (deprecated, use strategy="cot" instead). Returns: SemanticMapPostprocessOutput """ - if cot_reasoning: + if strategy == "deepseek": + outputs: list[str] = [] + explanations: list[str | None] = [] + + for llm_answer in llm_answers: + lotus.logger.debug(f"Raw LLM answer: {llm_answer}") + reasoning, answer = _process_deepseek_output(llm_answer) + lotus.logger.debug(f"Extracted reasoning: {reasoning}") + lotus.logger.debug(f"Extracted answer: {answer}") + outputs.append(answer) + explanations.append(reasoning) + + return SemanticMapPostprocessOutput( + raw_outputs=llm_answers, + outputs=outputs, + explanations=explanations + ) + elif cot_reasoning or strategy == "cot": return map_postprocess_cot(llm_answers) outputs: list[str] = llm_answers explanations: list[str | None] = [None] * len(llm_answers) return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) - -def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOutput: - """ - Postprocess the output of the extract operator to extract the schema. - - Args: - llm_answers (list[str]): The list of llm answers containging the extract. - - Returns: - SemanticExtractPostprocessOutput - """ - extract_data = [] - for llm_answer in llm_answers: - try: - output = json.loads(llm_answer) - except json.JSONDecodeError: - lotus.logger.info(f"\t Failed to parse: {llm_answer}") - output = {} - - output = {key: str(value) for key, value in output.items()} - extract_data.append(output) - - return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data) - - def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput: """ Postprocess the output of the filter operator with CoT reasoning. @@ -117,10 +192,10 @@ def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFil return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) - def filter_postprocess( llm_answers: list[str], default: bool = True, + strategy: str | None = None, cot_reasoning: bool = False, ) -> SemanticFilterPostprocessOutput: """ @@ -129,12 +204,33 @@ def filter_postprocess( Args: llm_answers (list[str]): The list of llm answers. default (bool): The default value to use if we fail to parse the answer. - cot_reasoning (bool): Whether there is CoT reasoning. + strategy (str | None): The reasoning strategy ("deepseek", "cot", or None). + cot_reasoning (bool): Whether there is CoT reasoning (deprecated, use strategy="cot" instead). Returns: SemanticFilterPostprocessOutput """ - if cot_reasoning: + if strategy == "deepseek": + outputs: list[bool] = [] + explanations: list[str | None] = [] + + for llm_answer in llm_answers: + reasoning, answer = _process_deepseek_output(llm_answer) + if "True" in answer: + outputs.append(True) + elif "False" in answer: + outputs.append(False) + else: + lotus.logger.info(f"\t Failed to parse: defaulting to {default}") + outputs.append(default) + explanations.append(reasoning) + + return SemanticFilterPostprocessOutput( + raw_outputs=llm_answers, + outputs=outputs, + explanations=explanations + ) + elif cot_reasoning or strategy == "cot": return filter_postprocess_cot(llm_answers, default) outputs: list[bool] = [] diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 706f12f9..0bf923c0 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -13,6 +13,7 @@ def sem_agg( model: lotus.models.LM, user_instruction: str, partition_ids: list[int], + strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Aggregating", ) -> SemanticAggOutput: @@ -93,7 +94,6 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: partition_id != cur_partition_id and not do_fold ): # close the current prompt - prompt = template.replace("{{docs_str}}", context_str) lotus.logger.debug(f"Prompt added to batch: {prompt}") batch.append([{"role": "user", "content": prompt}]) @@ -146,8 +146,15 @@ def _validate(obj: Any) -> None: @staticmethod def process_group(args): - group, user_instruction, all_cols, suffix, progress_bar_desc = args - return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) + group, user_instruction, all_cols, suffix, strategy, progress_bar_desc = args + return group.sem_agg( + user_instruction, + all_cols=all_cols, + suffix=suffix, + group_by=None, + strategy=strategy, + progress_bar_desc=progress_bar_desc + ) @operator_cache def __call__( @@ -156,6 +163,7 @@ def __call__( all_cols: bool = False, suffix: str = "_output", group_by: list[str] | None = None, + strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Aggregating", ) -> pd.DataFrame: @@ -190,7 +198,7 @@ def __call__( if group_by: grouped = self._obj.groupby(group_by) - group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped] + group_args = [(group, user_instruction, all_cols, suffix, strategy, progress_bar_desc) for _, group in grouped] from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: @@ -213,6 +221,7 @@ def __call__( lotus.settings.lm, formatted_usr_instr, partition_ids, + strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, ) diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 6faf6c1c..8cb8213c 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -18,6 +18,7 @@ def sem_extract( output_cols: dict[str, str | None], extract_quotes: bool = False, postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, + strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Extracting", ) -> SemanticExtractOutput: @@ -37,7 +38,7 @@ def sem_extract( # prepare model inputs inputs = [] for doc in docs: - prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes) + prompt = task_instructions.extract_formatter(doc, output_cols, extract_quotes, strategy=strategy) lotus.logger.debug(f"input to model: {prompt}") lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") inputs.append(prompt) @@ -48,11 +49,16 @@ def sem_extract( estimated_LM_calls = len(docs) show_safe_mode(estimated_cost, estimated_LM_calls) + # call model with response_format=json_object for all models + kwargs = {"response_format": {"type": "json_object"}} + if "ollama" not in model.model: # Only add temperature for non-Ollama models + kwargs["temperature"] = 0.0 # Use zero temperature for extractions + # call model - lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc) + lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc, **kwargs) # post process results - postprocess_output = postprocessor(lm_output.outputs) + postprocess_output = postprocessor(lm_output.outputs, strategy=strategy) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") if safe_mode: @@ -80,6 +86,7 @@ def __call__( extract_quotes: bool = False, postprocessor: Callable[[list[str]], SemanticExtractPostprocessOutput] = extract_postprocess, return_raw_outputs: bool = False, + strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Extracting", ) -> pd.DataFrame: @@ -114,11 +121,13 @@ def __call__( output_cols=output_cols, extract_quotes=extract_quotes, postprocessor=postprocessor, + strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, ) - new_df = self._obj.copy() + # Create a new DataFrame with just the extracted fields + new_df = pd.DataFrame(index=range(len(self._obj))) for i, output_dict in enumerate(out.outputs): for key, value in output_dict.items(): if key not in new_df.columns: diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index f62172fe..7f921b25 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -63,7 +63,10 @@ def sem_filter( ) postprocess_output = filter_postprocess( - lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"] + lm_output.outputs, + default=default, + strategy=strategy, + cot_reasoning=strategy in ["cot", "zs-cot"] ) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 4cc22d88..d8d5b0dd 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -58,7 +58,11 @@ def sem_map( lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc) # post process results - postprocess_output = postprocessor(lm_output.outputs, strategy in ["cot", "zs-cot"]) + postprocess_output = postprocessor( + lm_output.outputs, + strategy=strategy, + cot_reasoning=strategy in ["cot", "zs-cot"] + ) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fc30efd9..d8558515 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -6,7 +6,7 @@ import lotus from lotus.dtype_extensions import ImageDtype from lotus.types import SerializationFormat - +from lotus.sem_ops.deepseek_utils import format_deepseek_prompt def context_formatter( multimodal_data: dict[str, Any] | str, @@ -35,18 +35,30 @@ def context_formatter( raise ValueError("multimodal_data must be a dictionary or a string") return text, image_inputs - def user_message_formatter( multimodal_data: dict[str, Any] | str, user_instruction_with_tag: str | None = None, + is_deepseek: bool = False, + base_instruction: str | None = None, ) -> dict[str, Any]: text, image_inputs = context_formatter(multimodal_data) + + # For deepseek, include instructions in user prompt and enforce start + if is_deepseek and base_instruction: + instruction = ( + f"{base_instruction}\n\n" + "Start your response with '' to show your reasoning, " + "then end with '' and provide your final answer.\n\n" + ) + else: + instruction = "" + if not image_inputs or len(image_inputs) == 0: return { "role": "user", - "content": f"Context:\n{text}\n\n{user_instruction_with_tag}", + "content": f"{instruction}Context:\n{text}\n\n{user_instruction_with_tag}", } - content = [{"type": "text", "text": f"Context:\n{text}"}] + image_inputs + content = [{"type": "text", "text": f"{instruction}Context:\n{text}"}] + image_inputs if user_instruction_with_tag: content.append({"type": "text", "text": f"\n\n{user_instruction_with_tag}"}) return { @@ -54,6 +66,73 @@ def user_message_formatter( "content": content, } +def extract_formatter( + multimodal_data: dict[str, Any], + output_cols: dict[str, str | None], + extract_quotes: bool = True, + strategy: str | None = None +) -> list[dict[str, str]]: + output_col_names = list(output_cols.keys()) + # Set the description to be the key if no value is provided + output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()} + + # Create example JSON with just the required fields + example_json = { + field: "example_value" for field in output_col_names + } + + if strategy == "deepseek": + # Create the instruction without any backslashes in f-strings + fields_list = ", ".join(output_col_names) + example_fields = ", ".join([f' "{field}": "value"' for field in output_col_names]) + + # Build instruction using concatenation instead of f-strings + base_instruction = ( + "Your task is to extract specific fields from the given context.\n" + "Fields to extract: " + str(output_cols_with_desc) + "\n\n" + "Instructions:\n" + "1. First, show your reasoning in tags\n" + "2. Then provide ONLY a valid JSON object with these exact fields:\n" + + fields_list + "\n\n" + "Example format:\n" + "\n" + "Your reasoning about how you extracted each field...\n" + "\n" + "{\n" + + example_fields + "\n" + "}\n\n" + "Important: The JSON must be properly formatted and contain exactly these fields. " + "Do not include any other text, code blocks, or explanations outside the tags." + ) + else: + # Build instruction using concatenation instead of f-strings + base_instruction = ( + "Your task is to extract specific fields from the given context.\n" + "Fields to extract: " + str(output_cols_with_desc) + "\n\n" + "Instructions:\n" + "1. Extract the requested fields from the context\n" + "2. Return ONLY a valid JSON object with these exact fields:\n" + + ", ".join(output_col_names) + "\n\n" + "Example format:\n" + "{\n" + + ", ".join([f' "{field}": "value"' for field in output_col_names]) + "\n" + "}\n\n" + "Important: The JSON must be properly formatted and contain exactly these fields. " + "Do not include any other text, code blocks, or explanations." + ) + + is_deepseek = strategy == "deepseek" + messages = [] + + if not is_deepseek: + messages.append({"role": "system", "content": base_instruction}) + + messages.append(user_message_formatter( + multimodal_data, + is_deepseek=is_deepseek, + base_instruction=base_instruction if is_deepseek else None + )) + return messages def filter_formatter_cot( multimodal_data: dict[str, Any], @@ -88,7 +167,6 @@ def filter_formatter_cot( messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages - def filter_formatter_zs_cot( multimodal_data: dict[str, Any], user_instruction: str, @@ -105,7 +183,6 @@ def filter_formatter_zs_cot( messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages - def filter_formatter( multimodal_data: dict[str, Any], user_instruction: str, @@ -122,14 +199,17 @@ def filter_formatter( elif strategy == "zs-cot": return filter_formatter_zs_cot(multimodal_data, user_instruction) - sys_instruction = ( + base_instruction = ( "The user will provide a claim and some relevant context.\n" "Your job is to determine whether the claim is true for the given context.\n" 'You must answer with a single word, "True" or "False".' ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] + + is_deepseek = strategy == "deepseek" + messages = [] + + if not is_deepseek: + messages.append({"role": "system", "content": base_instruction}) if examples_multimodal_data: assert examples_answer is not None @@ -139,15 +219,24 @@ def filter_formatter( ex_ans = examples_answer[i] messages.extend( [ - user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), + user_message_formatter( + ex_multimodal_data, + f"Claim: {user_instruction}", + is_deepseek=is_deepseek, + base_instruction=base_instruction if is_deepseek else None + ), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) + messages.append(user_message_formatter( + multimodal_data, + f"Claim: {user_instruction}", + is_deepseek=is_deepseek, + base_instruction=base_instruction if is_deepseek else None + )) return messages - def map_formatter_cot( multimodal_data: dict[str, Any], user_instruction: str, @@ -181,7 +270,6 @@ def map_formatter_cot( messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages - def map_formatter_zs_cot( multimodal_data: dict[str, Any], user_instruction: str, @@ -198,7 +286,6 @@ def map_formatter_zs_cot( messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages - def map_formatter( multimodal_data: dict[str, Any], user_instruction: str, @@ -215,58 +302,54 @@ def map_formatter( elif strategy == "zs-cot": return map_formatter_zs_cot(multimodal_data, user_instruction) - sys_instruction = ( - "The user will provide an instruction and some relevant context.\n" - "Your job is to answer the user's instruction given the context." - ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] + is_deepseek = strategy == "deepseek" + messages = [] + + if is_deepseek: + base_instruction = ( + "Your task is to answer the given instruction based on the context.\n\n" + "Instructions:\n" + "1. First, show your reasoning in tags\n" + "2. Then provide your final answer\n\n" + "Example format:\n" + "\n" + "Your reasoning about how you arrived at the answer...\n" + "\n" + "Your final answer here\n\n" + "Important: Make sure to include your reasoning in tags " + "followed by a clear, concise answer." + ) + else: + base_instruction = ( + "The user will provide an instruction and some relevant context.\n" + "Your job is to answer the user's instruction given the context.\n" + "Provide a single concise answer that directly responds to the instruction." + ) + messages.append({"role": "system", "content": base_instruction}) if examples_multimodal_data: assert examples_answer is not None for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer): messages.extend( [ - user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), + user_message_formatter( + ex_df_txt, + f"Instruction: {user_instruction}", + is_deepseek=is_deepseek, + base_instruction=base_instruction if is_deepseek else None + ), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) + messages.append(user_message_formatter( + multimodal_data, + f"Instruction: {user_instruction}", + is_deepseek=is_deepseek, + base_instruction=base_instruction if is_deepseek else None + )) return messages - -def extract_formatter( - multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True -) -> list[dict[str, str]]: - output_col_names = list(output_cols.keys()) - # Set the description to be the key if no value is provided - output_cols_with_desc: dict[str, str] = {col: col if desc is None else desc for col, desc in output_cols.items()} - - all_fields = output_col_names - if extract_quotes: - quote_fields = [f"{col}_quote" for col in output_col_names] - all_fields += quote_fields - - fields_str = ", ".join(all_fields) - - sys_instruction = ( - "The user will provide the columns that need to be extracted and some relevant context.\n" - f"Your job is to extract these columns and provide only a concise value for each field " - f"and the corresponding full quote for each field in the '{', '.join([f'{col}_quote' for col in output_col_names])}' fields.\n" - f"Here is a description of each field: {output_cols_with_desc}\n" - f"The response should be valid JSON format with the following fields: {fields_str}.\n" - ) - - messages = [ - {"role": "system", "content": sys_instruction}, - user_message_formatter(multimodal_data), - ] - return messages - - -# returns a list of strings corresponding to df rows def df2text(df: pd.DataFrame, cols: list[str]) -> list[str]: """Formats the given DataFrame into a string containing info from cols.""" @@ -305,7 +388,6 @@ def clean_and_escape_column_name(column_name: str) -> str: return formatted_rows - def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]]: """ Formats the given DataFrame into a string containing info from cols. @@ -323,7 +405,6 @@ def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any] ] return multimodal_data - def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Merges two multimodal info lists into one. Each row of first is merged with each row of second. @@ -346,6 +427,5 @@ def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, An for j in range(len(second)) ] - def li2text(li: list[str], name: str) -> str: return "".join([f"[{name}] {li[i]}\n" for i in range(len(li))]) diff --git a/pytest.ini b/pytest.ini index 38be7969..a2097370 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,14 @@ testpaths = tests python_files = test_*.py python_classes = Test* -python_functions = test_* \ No newline at end of file +python_functions = test_* + +# Configure test discovery +addopts = -v --tb=short + +# Mark custom test categories +markers = + deepseek: tests for deepseek model support + +# Test file patterns +norecursedirs = .git .tox .eggs *.egg \ No newline at end of file diff --git a/tests/test_deepseek.py b/tests/test_deepseek.py new file mode 100644 index 00000000..dbe99502 --- /dev/null +++ b/tests/test_deepseek.py @@ -0,0 +1,177 @@ +"""Tests for deepseek model support across semantic operators.""" + +import pandas as pd +import pytest + +import lotus +from lotus.models import LM + +def test_model_config(): + """Test that deepseek models are configured correctly""" + # Deepseek model should set temperature=0.6 and strategy=deepseek + lm = LM(model="ollama/deepseek-r1:7b") + assert lm.kwargs.get("strategy") == "deepseek" + assert lm.kwargs.get("temperature") == 0.6 # Recommended temperature + assert lm.is_deepseek # Should detect deepseek model + + # Non-deepseek model should keep original settings + lm = LM(model="ollama/llama3.1:8b", temperature=0.7) + assert "strategy" not in lm.kwargs + assert lm.kwargs.get("temperature") == 0.7 + assert not lm.is_deepseek # Should not detect as deepseek + +def test_map_behavior(): + """Test that sem_map handles both deepseek and non-deepseek models correctly""" + df = pd.DataFrame({ + "Course": ["Machine Learning", "Data Structures"] + }) + + # Test deepseek model + lm = LM(model="ollama/deepseek-r1:7b") + lotus.settings.configure(lm=lm) + + result = df.sem_map( + "What is a similar course to {Course}?", + return_explanations=True, + strategy="deepseek" + ) + + # Should have reasoning in explanation_map + assert "explanation_map" in result.columns + assert result["explanation_map"].iloc[0] is not None + assert isinstance(result["_map"].iloc[0], str) + + # Test non-deepseek model + lm = LM(model="ollama/llama3.1:8b") + lotus.settings.configure(lm=lm) + + result = df.sem_map( + "What is a similar course to {Course}?", + return_explanations=True # Non-deepseek model should not use deepseek strategy + ) + + # Should have no reasoning + assert result["explanation_map"].iloc[0] is None + assert isinstance(result["_map"].iloc[0], str) + +def test_filter_behavior(): + """Test that sem_filter handles both deepseek and non-deepseek models correctly""" + df = pd.DataFrame({ + "Course": ["Machine Learning", "Art History"] + }) + + # Test deepseek model + lm = LM(model="ollama/deepseek-r1:7b") + lotus.settings.configure(lm=lm) + + result = df.sem_filter( + "Is {Course} a technical course?", + return_explanations=True, + strategy="deepseek" + ) + + # Should have reasoning in explanation_filter + assert "explanation_filter" in result.columns + filtered_rows = result[result["Course"] == "Machine Learning"] + assert len(filtered_rows) > 0 + assert filtered_rows["explanation_filter"].iloc[0] is not None + + # Test non-deepseek model + lm = LM(model="ollama/llama3.1:8b") + lotus.settings.configure(lm=lm) + + result = df.sem_filter( + "Is {Course} a technical course?", + return_explanations=True # Non-deepseek model should not use deepseek strategy + ) + + # Should have no reasoning + filtered_rows = result[result["Course"] == "Machine Learning"] + assert len(filtered_rows) > 0 + assert filtered_rows["explanation_filter"].iloc[0] is None + +def test_join_behavior(): + """Test that sem_join handles both deepseek and non-deepseek models correctly""" + df1 = pd.DataFrame({ + "Course1": ["Machine Learning"] + }) + df2 = pd.DataFrame({ + "Course2": ["Statistics", "Art History"] + }) + + # Test deepseek model + lm = LM(model="ollama/deepseek-r1:7b") + lotus.settings.configure(lm=lm) + + result = df1.sem_join( + df2, + "{Course1} and {Course2} are related fields", + return_explanations=True, + strategy="deepseek" + ) + + # Should have reasoning in explanation_join + assert "explanation_join" in result.columns + assert len(result) > 0 + assert result["explanation_join"].iloc[0] is not None + + # Test non-deepseek model + lm = LM(model="ollama/llama3.1:8b") + lotus.settings.configure(lm=lm) + + result = df1.sem_join( + df2, + "{Course1} and {Course2} are related fields", + return_explanations=True # Non-deepseek model should not use deepseek strategy + ) + + # Should have no reasoning + assert len(result) > 0 + assert result["explanation_join"].iloc[0] is None + +def test_extract_behavior(): + """Test that sem_extract handles both deepseek and non-deepseek models correctly""" + df = pd.DataFrame({ + "Text": ["The course Machine Learning (CS229) is taught by Prof. Smith"] + }) + + # Test deepseek model + lm = LM(model="ollama/deepseek-r1:7b") + lotus.settings.configure(lm=lm) + + result = df.sem_extract( + input_cols=["Text"], + output_cols={ + "course_code": "Course code", + "professor": "Professor name" + }, + return_raw_outputs=True, + strategy="deepseek" + ) + + # Should extract fields correctly with reasoning + assert "course_code" in result.columns + assert "professor" in result.columns + assert isinstance(result["course_code"].iloc[0], str) + assert isinstance(result["professor"].iloc[0], str) + assert "" in result["raw_output"].iloc[0] + + # Test non-deepseek model + lm = LM(model="ollama/llama3.1:8b") + lotus.settings.configure(lm=lm) + + result = df.sem_extract( + input_cols=["Text"], + output_cols={ + "course_code": "Course code", + "professor": "Professor name" + }, + return_raw_outputs=True + ) + + # Should extract fields correctly without reasoning + assert "course_code" in result.columns + assert "professor" in result.columns + assert isinstance(result["course_code"].iloc[0], str) + assert isinstance(result["professor"].iloc[0], str) + assert "" not in result["raw_output"].iloc[0] \ No newline at end of file From 5c5828039b021404196d88074ab5ec02a1ec6e3d Mon Sep 17 00:00:00 2001 From: Caleb Winston Date: Thu, 13 Feb 2025 15:49:55 +0000 Subject: [PATCH 2/2] Add return_explanations to sem_agg --- lotus/sem_ops/postprocessors.py | 33 ++++++++++ lotus/sem_ops/sem_agg.py | 95 ++++++++++++---------------- lotus/templates/task_instructions.py | 54 ++++++++++++++-- lotus/types.py | 8 +++ tests/test_deepseek.py | 65 ++++++++++++++++++- 5 files changed, 194 insertions(+), 61 deletions(-) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index e47540fe..1e0c97d6 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -6,6 +6,7 @@ SemanticExtractPostprocessOutput, SemanticFilterPostprocessOutput, SemanticMapPostprocessOutput, + SemanticAggPostprocessOutput, ) from lotus.sem_ops.deepseek_utils import extract_deepseek_reasoning @@ -245,3 +246,35 @@ def filter_postprocess( outputs.append(default) return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) + +def agg_postprocess( + llm_answers: list[str], + strategy: str | None = None, +) -> SemanticAggPostprocessOutput: + """ + Postprocess the output of the aggregate operator. + + Args: + llm_answers (list[str]): The list of llm answers. + strategy (str | None): The reasoning strategy ("deepseek" or None). + + Returns: + SemanticAggPostprocessOutput + """ + outputs: list[str] = [] + explanations: list[str | None] = [] + + for llm_answer in llm_answers: + if strategy == "deepseek": + reasoning, answer = _process_deepseek_output(llm_answer) + outputs.append(answer) + explanations.append(reasoning) + else: + outputs.append(llm_answer) + explanations.append(None) + + return SemanticAggPostprocessOutput( + raw_outputs=llm_answers, + outputs=outputs, + explanations=explanations + ) diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 0bf923c0..fc5bdb24 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -1,11 +1,13 @@ -from typing import Any +from typing import Any, Callable import pandas as pd import lotus.models from lotus.cache import operator_cache from lotus.templates import task_instructions -from lotus.types import LMOutput, SemanticAggOutput +from lotus.types import LMOutput, SemanticAggOutput, SemanticAggPostprocessOutput +from lotus.utils import show_safe_mode +from .postprocessors import agg_postprocess def sem_agg( @@ -13,6 +15,7 @@ def sem_agg( model: lotus.models.LM, user_instruction: str, partition_ids: list[int], + postprocessor: Callable[[list[str], str | None], SemanticAggPostprocessOutput] = agg_postprocess, strategy: str | None = None, safe_mode: bool = False, progress_bar_desc: str = "Aggregating", @@ -25,84 +28,53 @@ def sem_agg( model (lotus.models.LM): The model to use. user_instruction (str): The user instruction for aggregation. partition_ids (list[int]): The partition ids for the documents. Documents with the same partition id will be aggregated together. + postprocessor (Callable): The postprocessor for the model outputs. Defaults to agg_postprocess. + strategy (str | None): The reasoning strategy ("deepseek" or None). Returns: - str: The aggregated answer. + SemanticAggOutput: The aggregated answer and explanations. """ - leaf_instr_template = ( - "Your job is to provide an answer to the user's instruction given the context below from multiple documents.\n" - "Remember that your job is to answer the user's instruction by combining all relevant information from all provided documents, into a single coherent answer.\n" - "Do NOT copy the format of the sources! Instead output your answer in a coherent, well-structured manner that best answers the user instruction.\n" - "You have limited space to provide your answer, so be concise and to the point.\n\n---\n\n" - "Follow the following format.\n\nContext: relevant facts from multiple documents\n\n" - "Instruction: the instruction provided by the user\n\nAnswer: Write your answer\n\n---\n\n" - "Context: {{docs_str}}\n\n" - f"Instruction: {user_instruction}\n\nAnswer:\n" - ) - - node_instr_template = ( - "Your job is to provide an answer to the user's instruction given the context below from multiple sources.\n" - "Note that each source may be formatted differently and contain information about several different documents.\n" - "Remember that your job is to answer the user's instruction by combining all relevant information from all provided sources, into a single coherent answer.\n" - "The sources may provide opposing viewpoints or complementary information.\n" - "Be sure to include information from ALL relevant sources in your answer.\n" - "Do NOT copy the format of the sources, instead output your answer in a coherent, well-structured manner that best answers the user instruction.\n" - "You have limited space to provide your answer, so be concise and to the point.\n" - "You may need to draw connections between sources to provide a complete answer.\n\n---\n\n" - "Follow the following format.\n\nContext: relevant facts from multiple sources\n\n" - "Instruction: the instruction provided by the user\n\nAnswer: Write your answer\n\n---\n\n" - "Context: {{docs_str}}\n\n" - f"Instruction: {user_instruction}\n\nAnswer:\n" - ) - - def leaf_doc_formatter(doc: str, ctr: int) -> str: - return f"\n\tDocument {ctr}: {doc}" - - def node_doc_formatter(doc: str, ctr: int) -> str: - return f"\n\tSource {ctr}: {doc}" - - def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: - return leaf_doc_formatter(doc, ctr) if tree_level == 0 else node_doc_formatter(doc, ctr) - if safe_mode: # TODO: implement safe mode lotus.logger.warning("Safe mode is not implemented yet") tree_level = 0 summaries: list[str] = [] + explanations: list[str | None] = [] new_partition_ids: list[int] = [] + while len(docs) != 1 or summaries == []: cur_partition_id = partition_ids[0] do_fold = len(partition_ids) == len(set(partition_ids)) context_str = "" - # prompt = "" batch = [] - if tree_level == 0: - template = leaf_instr_template - else: - template = node_instr_template - template_tokens = model.count_tokens(template) context_tokens = 0 doc_ctr = 1 # num docs in current prompt for idx in range(len(docs)): partition_id = partition_ids[idx] - formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr) + formatted_doc = f"\n\tDocument {doc_ctr}: {docs[idx]}" if tree_level == 0 else f"\n\tSource {doc_ctr}: {docs[idx]}" new_tokens = model.count_tokens(formatted_doc) - if (new_tokens + context_tokens + template_tokens > model.max_ctx_len - model.max_tokens) or ( + # Create multimodal data for the current batch + multimodal_data = {"text": context_str + formatted_doc} + prompt = task_instructions.agg_formatter(multimodal_data, user_instruction, strategy=strategy) + prompt_tokens = model.count_tokens(prompt) + + if (prompt_tokens > model.max_ctx_len - model.max_tokens) or ( partition_id != cur_partition_id and not do_fold ): # close the current prompt - prompt = template.replace("{{docs_str}}", context_str) + multimodal_data = {"text": context_str} + prompt = task_instructions.agg_formatter(multimodal_data, user_instruction, strategy=strategy) lotus.logger.debug(f"Prompt added to batch: {prompt}") - batch.append([{"role": "user", "content": prompt}]) + batch.append(prompt) new_partition_ids.append(cur_partition_id) cur_partition_id = partition_id doc_ctr = 1 # add new context to next prompt - formatted_doc = doc_formatter(tree_level, docs[idx], doc_ctr) + formatted_doc = f"\n\tDocument {doc_ctr}: {docs[idx]}" if tree_level == 0 else f"\n\tSource {doc_ctr}: {docs[idx]}" context_str = formatted_doc context_tokens = new_tokens doc_ctr += 1 @@ -112,14 +84,19 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: doc_ctr += 1 if doc_ctr > 1 or len(docs) == 1: - prompt = template.replace("{{docs_str}}", context_str) + multimodal_data = {"text": context_str} + prompt = task_instructions.agg_formatter(multimodal_data, user_instruction, strategy=strategy) lotus.logger.debug(f"Prompt added to batch: {prompt}") - batch.append([{"role": "user", "content": prompt}]) + batch.append(prompt) new_partition_ids.append(cur_partition_id) lm_output: LMOutput = model(batch, progress_bar_desc=progress_bar_desc) + + # Post process results + postprocess_output = postprocessor(lm_output.outputs, strategy=strategy) + summaries.extend(postprocess_output.outputs) + explanations.extend(postprocess_output.explanations) - summaries = lm_output.outputs partition_ids = new_partition_ids new_partition_ids = [] @@ -129,7 +106,7 @@ def doc_formatter(tree_level: int, doc: str, ctr: int) -> str: if safe_mode: model.print_total_usage() - return SemanticAggOutput(outputs=summaries) + return SemanticAggOutput(outputs=summaries, explanations=explanations) @pd.api.extensions.register_dataframe_accessor("sem_agg") @@ -146,13 +123,14 @@ def _validate(obj: Any) -> None: @staticmethod def process_group(args): - group, user_instruction, all_cols, suffix, strategy, progress_bar_desc = args + group, user_instruction, all_cols, suffix, strategy, return_explanations, progress_bar_desc = args return group.sem_agg( user_instruction, all_cols=all_cols, suffix=suffix, group_by=None, strategy=strategy, + return_explanations=return_explanations, progress_bar_desc=progress_bar_desc ) @@ -164,6 +142,7 @@ def __call__( suffix: str = "_output", group_by: list[str] | None = None, strategy: str | None = None, + return_explanations: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Aggregating", ) -> pd.DataFrame: @@ -175,6 +154,8 @@ def __call__( all_cols (bool): Whether to use all columns in the dataframe. Defaults to False. suffix (str): The suffix for the new column. Defaults to "_output". group_by (list[str] | None): The columns to group by before aggregation. Each group will be aggregated separately. + strategy (str | None): The reasoning strategy ("deepseek" or None). + return_explanations (bool): Whether to return explanations. Defaults to False. Returns: pd.DataFrame: The dataframe with the aggregated answer. """ @@ -198,7 +179,7 @@ def __call__( if group_by: grouped = self._obj.groupby(group_by) - group_args = [(group, user_instruction, all_cols, suffix, strategy, progress_bar_desc) for _, group in grouped] + group_args = [(group, user_instruction, all_cols, suffix, strategy, return_explanations, progress_bar_desc) for _, group in grouped] from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: @@ -227,5 +208,7 @@ def __call__( ) # package answer in a dataframe - answer_df = pd.DataFrame(answer.outputs, columns=[suffix]) + answer_df = pd.DataFrame({suffix: answer.outputs}) + if return_explanations: + answer_df[f"explanation{suffix}"] = answer.explanations return answer_df diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index d8558515..21fb012d 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -310,14 +310,12 @@ def map_formatter( "Your task is to answer the given instruction based on the context.\n\n" "Instructions:\n" "1. First, show your reasoning in tags\n" - "2. Then provide your final answer\n\n" + "2. Then provide ONLY a single word, phrase, or number as your final answer\n\n" "Example format:\n" "\n" "Your reasoning about how you arrived at the answer...\n" "\n" - "Your final answer here\n\n" - "Important: Make sure to include your reasoning in tags " - "followed by a clear, concise answer." + "answer\n\n" ) else: base_instruction = ( @@ -350,6 +348,54 @@ def map_formatter( )) return messages +def agg_formatter( + multimodal_data: dict[str, Any], + user_instruction: str, + strategy: str | None = None, +) -> list[dict[str, str]]: + """ + Format instructions for aggregation operator. + + Args: + multimodal_data (dict[str, Any]): The multimodal data to format. + user_instruction (str): The user instruction. + strategy (str | None): The reasoning strategy ("deepseek" or None). + + Returns: + list[dict[str, str]]: The formatted messages. + """ + is_deepseek = strategy == "deepseek" + messages = [] + + if is_deepseek: + base_instruction = ( + "Your task is to analyze multiple documents and provide a comprehensive answer.\n\n" + "Instructions:\n" + "1. First, show your reasoning in tags about how you analyzed the documents\n" + "2. Then provide your final answer that combines information from all documents\n\n" + "Example format:\n" + "\n" + "Your reasoning about how you analyzed and combined the documents...\n" + "\n" + "Your final comprehensive answer\n\n" + ) + else: + base_instruction = ( + "Your job is to provide an answer to the user's instruction given the context below from multiple documents.\n" + "Remember that your job is to answer the user's instruction by combining all relevant information from all provided documents, into a single coherent answer.\n" + "Do NOT copy the format of the sources! Instead output your answer in a coherent, well-structured manner that best answers the user instruction.\n" + "You have limited space to provide your answer, so be concise and to the point.\n" + ) + messages.append({"role": "system", "content": base_instruction}) + + messages.append(user_message_formatter( + multimodal_data, + f"Instruction: {user_instruction}", + is_deepseek=is_deepseek, + base_instruction=base_instruction if is_deepseek else None + )) + return messages + def df2text(df: pd.DataFrame, cols: list[str]) -> list[str]: """Formats the given DataFrame into a string containing info from cols.""" diff --git a/lotus/types.py b/lotus/types.py index d3d19db2..6469777e 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -88,9 +88,17 @@ class SemanticFilterOutput: logprobs: list[list[ChatCompletionTokenLogprob]] | None = None +@dataclass +class SemanticAggPostprocessOutput: + raw_outputs: list[str] + outputs: list[str] + explanations: list[str | None] + + @dataclass class SemanticAggOutput: outputs: list[str] + explanations: list[str | None] @dataclass diff --git a/tests/test_deepseek.py b/tests/test_deepseek.py index dbe99502..53d416b2 100644 --- a/tests/test_deepseek.py +++ b/tests/test_deepseek.py @@ -110,6 +110,12 @@ def test_join_behavior(): strategy="deepseek" ) + # Print the full result DataFrame + print("\n=== Deepseek Join Result ===") + with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000): + print(result) + print("===========================\n") + # Should have reasoning in explanation_join assert "explanation_join" in result.columns assert len(result) > 0 @@ -125,6 +131,12 @@ def test_join_behavior(): return_explanations=True # Non-deepseek model should not use deepseek strategy ) + # Print the full result DataFrame + print("\n=== Non-Deepseek Join Result ===") + with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000): + print(result) + print("===========================\n") + # Should have no reasoning assert len(result) > 0 assert result["explanation_join"].iloc[0] is None @@ -174,4 +186,55 @@ def test_extract_behavior(): assert "professor" in result.columns assert isinstance(result["course_code"].iloc[0], str) assert isinstance(result["professor"].iloc[0], str) - assert "" not in result["raw_output"].iloc[0] \ No newline at end of file + assert "" not in result["raw_output"].iloc[0] + +def test_agg_behavior(): + """Test that sem_agg handles both deepseek and non-deepseek models correctly""" + df = pd.DataFrame({ + "Note": [ + "Patient reports chest pain", + "Follow-up shows improved symptoms", + "Final checkup confirms recovery" + ] + }) + + # Test deepseek model + lm = LM(model="ollama/deepseek-r1:7b") + lotus.settings.configure(lm=lm) + + result = df.sem_agg( + "Summarize the progression of {Note}", + return_explanations=True, + strategy="deepseek" + ) + + # Print the full result DataFrame + print("\n=== Deepseek Agg Result ===") + with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000): + print(result) + print("===========================\n") + + # Should have reasoning in explanation column + assert "explanation_output" in result.columns + assert len(result) > 0 + assert result["explanation_output"].iloc[0] is not None + assert isinstance(result["_output"].iloc[0], str) + + # Test non-deepseek model + lm = LM(model="ollama/llama3.1:8b") + lotus.settings.configure(lm=lm) + + result = df.sem_agg( + "Summarize the progression of {Note}", + return_explanations=True # Non-deepseek model should not use deepseek strategy + ) + + # Print the full result DataFrame + print("\n=== Non-Deepseek Agg Result ===") + with pd.option_context('display.max_columns', None, 'display.max_rows', None, 'display.width', 1000): + print(result) + print("===========================\n") + + # Should have no reasoning + assert len(result) > 0 + assert result["explanation_output"].iloc[0] is None \ No newline at end of file