From 7c822f4765b62d529c453491a07d24d609a12295 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Wed, 20 Aug 2025 20:24:19 -0700 Subject: [PATCH 1/8] initial refactor --- examples/op_examples/simple_reasoning.py | 90 +++++++ lotus/__init__.py | 3 + lotus/sem_ops/demonstration_bootstrap.py | 319 +++++++++++++++++++++++ lotus/sem_ops/sem_filter.py | 56 +++- lotus/sem_ops/sem_join.py | 2 +- lotus/sem_ops/sem_map.py | 58 ++++- lotus/sem_ops/sem_topk.py | 6 +- lotus/templates/task_instructions.py | 22 +- lotus/types.py | 27 +- 9 files changed, 558 insertions(+), 25 deletions(-) create mode 100644 examples/op_examples/simple_reasoning.py create mode 100644 lotus/sem_ops/demonstration_bootstrap.py diff --git a/examples/op_examples/simple_reasoning.py b/examples/op_examples/simple_reasoning.py new file mode 100644 index 00000000..3921f5b3 --- /dev/null +++ b/examples/op_examples/simple_reasoning.py @@ -0,0 +1,90 @@ +""" +Simple Reasoning Strategies Demo + +This example shows the new, simplified reasoning system in Lotus: +1. ReasoningStrategy.CoT - Chain-of-thought reasoning +2. ReasoningStrategy.Demonstrations - Few-shot examples +3. ReasoningStrategy.CoT_Demonstrations - Both combined +4. Automatic demonstration bootstrapping +""" + +import pandas as pd + +import lotus +from lotus.models import LM +from lotus.types import DemonstrationConfig, ReasoningStrategy + +# Configure the language model +lm = LM(model="gpt-4o-mini") +lotus.settings.configure(lm=lm) + +# Sample data +data = { + "Course Name": ["Linear Algebra", "Poetry Writing", "Calculus II", "Art History", "Statistics", "Creative Writing"] +} +df = pd.DataFrame(data) +user_instruction = "{Course Name} requires a lot of math" + +# Example 1: Basic filtering (no reasoning) +print("=== 1. Basic Filtering ===") +basic_df = df.sem_filter(user_instruction, return_all=True) +print(basic_df[["Course Name", "filter_label"]]) +print() + +# Example 2: Chain-of-Thought reasoning +print("=== 2. Chain-of-Thought Reasoning ===") +cot_df = df.sem_filter(user_instruction, strategy=ReasoningStrategy.CoT, return_explanations=True, return_all=True) +print(cot_df[["Course Name", "filter_label", "explanation_filter"]]) +print() + +# Example 3: Few-shot examples (demonstrations) +print("=== 3. Few-shot Examples ===") +examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature", "Physics"], "Answer": [True, False, True]}) + +demo_df = df.sem_filter( + user_instruction, + strategy=ReasoningStrategy.Demonstrations, + examples=examples, # Still works for backward compatibility + return_all=True, +) +print(demo_df[["Course Name", "filter_label"]]) +print() + +# Example 4: CoT + Demonstrations (the powerful combination) +print("=== 4. CoT + Demonstrations ===") +examples_with_reasoning = pd.DataFrame( + { + "Course Name": ["Machine Learning", "Literature", "Physics"], + "Answer": [True, False, True], + "Reasoning": [ + "Machine Learning requires linear algebra, calculus, and statistics", + "Literature focuses on reading, writing, and analysis - no math required", + "Physics is fundamentally mathematical with equations and calculations", + ], + } +) + +combined_df = df.sem_filter( + user_instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + examples=examples_with_reasoning, + return_explanations=True, + return_all=True, +) +print(combined_df[["Course Name", "filter_label", "explanation_filter"]]) +print() + +# Example 5: Automatic demonstration bootstrapping +print("=== 5. Bootstrapped Demonstrations ===") +bootstrap_config = DemonstrationConfig(bootstrap=True, num_demonstrations=2) + +bootstrap_df = df.sem_filter( + user_instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + demonstration_config=bootstrap_config, + return_explanations=True, + return_all=True, +) +print("Automatically generated demonstrations:") +print(bootstrap_df[["Course Name", "filter_label", "explanation_filter"]]) +print() diff --git a/lotus/__init__.py b/lotus/__init__.py index 3193e400..3a4623f5 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -22,6 +22,7 @@ ) from lotus.web_search import web_search, WebSearchCorpus from lotus.settings import settings # type: ignore[attr-defined] +from lotus.types import ReasoningStrategy, DemonstrationConfig logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) @@ -51,4 +52,6 @@ "dtype_extensions", "web_search", "WebSearchCorpus", + "ReasoningStrategy", + "DemonstrationConfig", ] diff --git a/lotus/sem_ops/demonstration_bootstrap.py b/lotus/sem_ops/demonstration_bootstrap.py new file mode 100644 index 00000000..ba7fce9c --- /dev/null +++ b/lotus/sem_ops/demonstration_bootstrap.py @@ -0,0 +1,319 @@ +import random +from typing import Any + +import lotus +from lotus.models import LM +from lotus.templates import task_instructions +from lotus.types import DemonstrationConfig, ReasoningStrategy + + +def bootstrap_demonstrations_for_filter( + multimodal_data: list[dict[str, Any]], + user_instruction: str, + config: DemonstrationConfig, + oracle_model: LM | None = None, +) -> tuple[list[dict[str, Any]], list[bool], list[str] | None]: + """ + Bootstrap demonstrations for semantic filter operations. + + Args: + multimodal_data: The full dataset to sample from + user_instruction: The filter instruction + config: Configuration for demonstration generation + oracle_model: Oracle model for labeling (if None, uses main model) + + Returns: + Tuple of (examples_multimodal_data, examples_answers, cot_reasoning) + """ + if not config.bootstrap: + raise ValueError("Bootstrap must be enabled in DemonstrationConfig") + + # Sample data for demonstrations + sample_size = min(config.num_demonstrations, len(multimodal_data)) + sample_indices = random.sample(range(len(multimodal_data)), sample_size) + sample_data = [multimodal_data[i] for i in sample_indices] + + # Use oracle model or main model + model = oracle_model or lotus.settings.lm + if model is None: + raise ValueError("No oracle model or main model configured") + + # Generate labels using the oracle + examples_answers = [] + cot_reasoning = [] + + for doc in sample_data: + # Generate with CoT reasoning if needed + if config.oracle_model or hasattr(config, "include_reasoning"): + # Generate with CoT reasoning + prompt = task_instructions.filter_formatter(model, doc, user_instruction, strategy=ReasoningStrategy.CoT) + else: + # Generate without reasoning + prompt = task_instructions.filter_formatter(model, doc, user_instruction, strategy=None) + + # Get oracle response + response = model([prompt], progress_bar_desc="Bootstrapping demonstrations") + raw_output = response.outputs[0] + + # Parse the response + if config.oracle_model or hasattr(config, "include_reasoning"): + # Extract reasoning and answer from CoT response + reasoning, answer = _parse_cot_response(raw_output) + cot_reasoning.append(reasoning) + else: + answer = _parse_answer_response(raw_output) + + examples_answers.append(answer) + + return sample_data, examples_answers, cot_reasoning if cot_reasoning else None + + +def bootstrap_demonstrations_for_map( + multimodal_data: list[dict[str, Any]], + user_instruction: str, + config: DemonstrationConfig, + oracle_model: LM | None = None, +) -> tuple[list[dict[str, Any]], list[str], list[str] | None]: + """ + Bootstrap demonstrations for semantic map operations. + + Args: + multimodal_data: The full dataset to sample from + user_instruction: The map instruction + config: Configuration for demonstration generation + oracle_model: Oracle model for labeling (if None, uses main model) + + Returns: + Tuple of (examples_multimodal_data, examples_answers, cot_reasoning) + """ + if not config.bootstrap: + raise ValueError("Bootstrap must be enabled in DemonstrationConfig") + + # Sample data for demonstrations + sample_size = min(config.num_demonstrations, len(multimodal_data)) + sample_indices = random.sample(range(len(multimodal_data)), sample_size) + sample_data = [multimodal_data[i] for i in sample_indices] + + # Use oracle model or main model + model = oracle_model or lotus.settings.lm + if model is None: + raise ValueError("No oracle model or main model configured") + + # Generate labels using the oracle + examples_answers = [] + cot_reasoning = [] + + for doc in sample_data: + # Generate with CoT reasoning if needed + if config.oracle_model or hasattr(config, "include_reasoning"): + # Generate with CoT reasoning + prompt = task_instructions.map_formatter(model, doc, user_instruction, strategy=ReasoningStrategy.CoT) + else: + # Generate without reasoning + prompt = task_instructions.map_formatter(model, doc, user_instruction, strategy=None) + + # Get oracle response + response = model([prompt], progress_bar_desc="Bootstrapping demonstrations") + raw_output = response.outputs[0] + + # Parse the response + if config.oracle_model or hasattr(config, "include_reasoning"): + # Extract reasoning and answer from CoT response + reasoning, answer = _parse_cot_map_response(raw_output) + cot_reasoning.append(reasoning) + else: + answer = _parse_map_answer_response(raw_output) + + examples_answers.append(answer) + + return sample_data, examples_answers, cot_reasoning if cot_reasoning else None + + +def bootstrap_demonstrations_for_extract( + multimodal_data: list[dict[str, Any]], + output_cols: dict[str, str | None], + config: DemonstrationConfig, + oracle_model: LM | None = None, +) -> tuple[list[dict[str, Any]], list[dict[str, str]], list[str] | None]: + """ + Bootstrap demonstrations for semantic extract operations. + + Args: + multimodal_data: The full dataset to sample from + output_cols: The columns to extract + config: Configuration for demonstration generation + oracle_model: Oracle model for labeling (if None, uses main model) + + Returns: + Tuple of (examples_multimodal_data, examples_answers, cot_reasoning) + """ + if not config.bootstrap: + raise ValueError("Bootstrap must be enabled in DemonstrationConfig") + + # Sample data for demonstrations + sample_size = min(config.num_demonstrations, len(multimodal_data)) + sample_indices = random.sample(range(len(multimodal_data)), sample_size) + sample_data = [multimodal_data[i] for i in sample_indices] + + # Use oracle model or main model + model = oracle_model or lotus.settings.lm + if model is None: + raise ValueError("No oracle model or main model configured") + + # Generate labels using the oracle + examples_answers = [] + cot_reasoning = [] + + for doc in sample_data: + # Generate with CoT reasoning if needed + if config.oracle_model or hasattr(config, "include_reasoning"): + # Generate with CoT reasoning + prompt = task_instructions.extract_formatter(model, doc, output_cols, strategy=ReasoningStrategy.CoT) + else: + # Generate without reasoning + prompt = task_instructions.extract_formatter(model, doc, output_cols, strategy=None) + + # Get oracle response + response = model([prompt], progress_bar_desc="Bootstrapping demonstrations") + raw_output = response.outputs[0] + + # Parse the response + if config.oracle_model or hasattr(config, "include_reasoning"): + # Extract reasoning and answer from CoT response + reasoning, answer = _parse_cot_extract_response(raw_output) + cot_reasoning.append(reasoning) + else: + answer = _parse_extract_response(raw_output) + + examples_answers.append(answer) + + return sample_data, examples_answers, cot_reasoning if cot_reasoning else None + + +def _parse_cot_response(response: str) -> tuple[str, bool]: + """Parse a CoT response to extract reasoning and boolean answer""" + lines = response.strip().split("\n") + reasoning_lines = [] + answer = True # default + + in_reasoning = False + for line in lines: + line = line.strip() + if line.startswith("Reasoning:"): + in_reasoning = True + reasoning_lines.append(line[10:].strip()) + elif line.startswith("Answer:"): + in_reasoning = False + answer_text = line[7:].strip().lower() + answer = answer_text in ["true", "yes", "1"] + elif in_reasoning: + reasoning_lines.append(line) + + reasoning = "\n".join(reasoning_lines).strip() + return reasoning, answer + + +def _parse_answer_response(response: str) -> bool: + """Parse a simple answer response to extract boolean answer""" + lines = response.strip().split("\n") + for line in lines: + line = line.strip() + if line.startswith("Answer:"): + answer_text = line[7:].strip().lower() + return answer_text in ["true", "yes", "1"] + + # Fallback: check if response contains true/false + response_lower = response.lower() + if "true" in response_lower: + return True + elif "false" in response_lower: + return False + + return True # default + + +def _parse_cot_extract_response(response: str) -> tuple[str, dict[str, str]]: + """Parse a CoT response for extract operations""" + lines = response.strip().split("\n") + reasoning_lines = [] + answer = {} + + in_reasoning = False + for line in lines: + line = line.strip() + if line.startswith("Reasoning:"): + in_reasoning = True + reasoning_lines.append(line[10:].strip()) + elif line.startswith("Answer:"): + in_reasoning = False + # Try to parse JSON answer + try: + import json + + answer_text = line[7:].strip() + answer = json.loads(answer_text) + except (json.JSONDecodeError, ValueError): + answer = {"extracted": answer_text} + elif in_reasoning: + reasoning_lines.append(line) + + reasoning = "\n".join(reasoning_lines).strip() + return reasoning, answer + + +def _parse_extract_response(response: str) -> dict[str, str]: + """Parse a simple extract response""" + lines = response.strip().split("\n") + for line in lines: + line = line.strip() + if line.startswith("Answer:"): + # Try to parse JSON answer + try: + import json + + answer_text = line[7:].strip() + return json.loads(answer_text) + except (json.JSONDecodeError, ValueError): + return {"extracted": answer_text} + + # Fallback: try to parse entire response as JSON + try: + import json + + return json.loads(response) + except (json.JSONDecodeError, ValueError): + return {"extracted": response.strip()} + + +def _parse_cot_map_response(response: str) -> tuple[str, str]: + """Parse a CoT response to extract reasoning and string answer for map operations""" + lines = response.strip().split("\n") + reasoning_lines = [] + answer = "" # default + + in_reasoning = False + for line in lines: + line = line.strip() + if line.startswith("Reasoning:"): + in_reasoning = True + reasoning_lines.append(line[10:].strip()) + elif line.startswith("Answer:"): + in_reasoning = False + answer = line[7:].strip() + elif in_reasoning: + reasoning_lines.append(line) + + reasoning = "\n".join(reasoning_lines).strip() + return reasoning, answer + + +def _parse_map_answer_response(response: str) -> str: + """Parse a simple answer response to extract string answer for map operations""" + lines = response.strip().split("\n") + for line in lines: + line = line.strip() + if line.startswith("Answer:"): + return line[7:].strip() + + # Fallback: return the entire response + return response.strip() diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 37e07277..16e10bd3 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -6,9 +6,11 @@ import lotus from lotus.cache import operator_cache +from lotus.models import LM from lotus.templates import task_instructions from lotus.types import ( CascadeArgs, + DemonstrationConfig, LMOutput, LogprobsForFilterCascade, ProxyModel, @@ -18,6 +20,7 @@ from lotus.utils import show_safe_mode from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds +from .demonstration_bootstrap import bootstrap_demonstrations_for_filter from .postprocessors import filter_postprocess @@ -30,6 +33,7 @@ def sem_filter( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, strategy: ReasoningStrategy | None = None, + demonstration_config: DemonstrationConfig | None = None, logprobs: bool = False, safe_mode: bool = False, show_progress_bar: bool = True, @@ -47,12 +51,44 @@ def sem_filter( examples_multimodal_data (list[dict[str, Any]] | None): The text for examples. Defaults to None. examples_answers (list[bool] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. + strategy (ReasoningStrategyType): The reasoning strategy to use. Can be CoT, Demonstrations, or both combined. + demonstration_config (DemonstrationConfig | None): Configuration for demonstration bootstrapping. logprobs (bool): Whether to return log probabilities. Defaults to False. additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: SemanticFilterOutput: The True/False outputs, raw outputs, and explanations, and log probabilities. """ + + # Handle demonstration bootstrapping + if ( + strategy in [ReasoningStrategy.Demonstrations, ReasoningStrategy.CoT_Demonstrations] + and demonstration_config + and demonstration_config.bootstrap + ): + oracle_model = None + if demonstration_config.oracle_model: + oracle_model = LM(model=demonstration_config.oracle_model) + + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations_for_filter( + docs, user_instruction, demonstration_config, oracle_model + ) + elif ( + strategy in [ReasoningStrategy.Demonstrations, ReasoningStrategy.CoT_Demonstrations] + and demonstration_config + and demonstration_config.examples is not None + ): + examples_df = demonstration_config.examples + assert "Answer" in examples_df.columns, "Answer must be a column in examples dataframe" + + examples_multimodal_data = task_instructions.df2multimodal_info(examples_df, list(examples_df.columns)) + examples_answers = examples_df["Answer"].tolist() + + if strategy == ReasoningStrategy.CoT_Demonstrations and "Reasoning" in examples_df.columns: + cot_reasoning = examples_df["Reasoning"].tolist() + elif strategy == ReasoningStrategy.CoT_Demonstrations: + cot_reasoning = ["Reasoning omitted"] * len(examples_answers) if examples_answers else [] + inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( @@ -122,6 +158,7 @@ def learn_filter_cascade_thresholds( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + demonstration_config=None, # No demonstration config for threshold learning safe_mode=False, progress_bar_desc="Running oracle for threshold learning", additional_cot_instructions=additional_cot_instructions, @@ -168,6 +205,7 @@ def __call__( examples: pd.DataFrame | None = None, helper_examples: pd.DataFrame | None = None, strategy: ReasoningStrategy | None = None, + demonstration_config: DemonstrationConfig | None = None, cascade_args: CascadeArgs | None = None, return_stats: bool = False, safe_mode: bool = False, @@ -205,6 +243,7 @@ def __call__( lotus.logger.debug(user_instruction) col_li = lotus.nl_expression.parse_cols(user_instruction) lotus.logger.debug(col_li) + helper_strategy = strategy # check that column exists @@ -219,12 +258,18 @@ def __call__( examples_multimodal_data = None examples_answers = None cot_reasoning = None - if examples is not None: + + # Create demonstration config if examples are provided but no config exists + if examples is not None and demonstration_config is None: + demonstration_config = DemonstrationConfig(examples=examples) assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT and "Reasoning" in examples.columns: + if ( + strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations] + and "Reasoning" in examples.columns + ): cot_reasoning = examples["Reasoning"].tolist() pos_cascade_threshold, neg_cascade_threshold = None, None @@ -237,7 +282,7 @@ def __call__( assert "Answer" in helper_examples.columns, "Answer must be a column in examples dataframe" helper_examples_multimodal_data = task_instructions.df2multimodal_info(helper_examples, col_li) helper_examples_answers = helper_examples["Answer"].tolist() - if helper_strategy == ReasoningStrategy.COT and "Reasoning" in helper_examples.columns: + if helper_strategy == ReasoningStrategy.CoT and "Reasoning" in helper_examples.columns: helper_cot_reasoning = helper_examples["Reasoning"].tolist() if cascade_args: @@ -256,7 +301,7 @@ def __call__( if not lotus.settings.helper_lm: raise ValueError("Helper LM must be set in settings") - if helper_strategy == ReasoningStrategy.COT: + if helper_strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: raise ValueError("CoT not supported for helper models in cascades.") # Run small LM and get logits @@ -270,6 +315,7 @@ def __call__( cot_reasoning=helper_cot_reasoning, logprobs=True, strategy=helper_strategy, + demonstration_config=None, # Helper models don't use demonstration config safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", @@ -364,6 +410,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + demonstration_config=demonstration_config, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", additional_cot_instructions=additional_cot_instructions, @@ -387,6 +434,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + demonstration_config=demonstration_config, safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 81f540a3..fa1d8b12 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -670,7 +670,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, [real_left_on, real_right_on]) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT: + if strategy == ReasoningStrategy.CoT: return_explanations = True cot_reasoning = examples["Reasoning"].tolist() diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index ca8908be..6dac1a02 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -4,10 +4,18 @@ import lotus from lotus.cache import operator_cache +from lotus.models import LM from lotus.templates import task_instructions -from lotus.types import LMOutput, ReasoningStrategy, SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import ( + DemonstrationConfig, + LMOutput, + ReasoningStrategy, + SemanticMapOutput, + SemanticMapPostprocessOutput, +) from lotus.utils import show_safe_mode +from .demonstration_bootstrap import bootstrap_demonstrations_for_map from .postprocessors import map_postprocess @@ -20,6 +28,7 @@ def sem_map( examples_answers: list[str] | None = None, cot_reasoning: list[str] | None = None, strategy: ReasoningStrategy | None = None, + demonstration_config: DemonstrationConfig | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", ) -> SemanticMapOutput: @@ -34,11 +43,43 @@ def sem_map( examples_multimodal_data (list[dict[str, Any]] | None): The text for examples. Defaults to None. examples_answers (list[str] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. + strategy (ReasoningStrategyType): The reasoning strategy to use. Can be CoT, Demonstrations, or both combined. + demonstration_config (DemonstrationConfig | None): Configuration for demonstration bootstrapping. Returns: SemanticMapOutput: The outputs, raw outputs, and explanations. """ + # Handle demonstration bootstrapping + if ( + (strategy in [ReasoningStrategy.Demonstrations, ReasoningStrategy.CoT_Demonstrations]) + and demonstration_config + and demonstration_config.bootstrap + ): + oracle_model = None + if demonstration_config.oracle_model: + oracle_model = LM(model=demonstration_config.oracle_model) + + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations_for_map( + docs, user_instruction, demonstration_config, oracle_model + ) + + elif ( + (strategy in [ReasoningStrategy.Demonstrations, ReasoningStrategy.CoT_Demonstrations]) + and demonstration_config + and demonstration_config.examples is not None + ): + examples_df = demonstration_config.examples + assert "Answer" in examples_df.columns, "Answer must be a column in examples dataframe" + + examples_multimodal_data = task_instructions.df2multimodal_info(examples_df, list(examples_df.columns)) + examples_answers = examples_df["Answer"].tolist() + + if strategy == ReasoningStrategy.CoT_Demonstrations and "Reasoning" in examples_df.columns: + cot_reasoning = examples_df["Reasoning"].tolist() + elif strategy == ReasoningStrategy.CoT_Demonstrations: + cot_reasoning = ["Reasoning omitted"] * len(examples_answers) if examples_answers else [] + # prepare model inputs inputs = [] for doc in docs: @@ -60,7 +101,7 @@ def sem_map( # post process results postprocess_output = postprocessor( - lm_output.outputs, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] + lm_output.outputs, model, strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations] ) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") @@ -98,6 +139,7 @@ def __call__( suffix: str = "_map", examples: pd.DataFrame | None = None, strategy: ReasoningStrategy | None = None, + demonstration_config: DemonstrationConfig | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", ) -> pd.DataFrame: @@ -131,18 +173,23 @@ def __call__( multimodal_data = task_instructions.df2multimodal_info(self._obj, col_li) formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) + # Handle examples and demonstrations examples_multimodal_data = None examples_answers = None cot_reasoning = None - if examples is not None: + if examples is not None and demonstration_config is None: + demonstration_config = DemonstrationConfig(examples=examples) assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT or strategy == ReasoningStrategy.ZS_COT: + if strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: return_explanations = True - cot_reasoning = examples["Reasoning"].tolist() + if "Reasoning" in examples.columns: + cot_reasoning = examples["Reasoning"].tolist() + else: + cot_reasoning = ["Reasoning omitted"] * len(examples_answers) output = sem_map( multimodal_data, @@ -153,6 +200,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + demonstration_config=demonstration_config, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, ) diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index bf0cc6c6..d3d04c3d 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -20,7 +20,7 @@ def get_match_prompt_binary( model: lotus.models.LM, strategy: ReasoningStrategy | None = None, ) -> list[dict[str, Any]]: - if strategy == ReasoningStrategy.ZS_COT: + if strategy == ReasoningStrategy.CoT: sys_prompt = ( "Your job is to to select and return the most relevant document to the user's question.\n" "Carefully read the user's question and the two documents provided below.\n" @@ -41,7 +41,7 @@ def get_match_prompt_binary( content_text, content_image_inputs = task_instructions.context_formatter(doc) prompt += [{"type": "text", "text": f"\nDocument {idx+1}:\n{content_text}"}, *content_image_inputs] - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + if strategy == ReasoningStrategy.CoT and model.is_deepseek(): deepseek_instructions = """Please think through your reasoning step by step, then provide your final answer. You must put your reasoning insdie the tags, then provide your final answer after the tag with the format: Answer: your answer.""" @@ -558,7 +558,7 @@ def __call__( new_df = new_df.reindex(output.indexes).reset_index(drop=True) new_df = new_df.head(K) - if return_explanations and strategy == ReasoningStrategy.ZS_COT: + if return_explanations and strategy == ReasoningStrategy.CoT: explanations = [] for idx in output.indexes[:K]: explanation = "No Comparison Made" diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 81232c1f..a2dfa677 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -5,7 +5,10 @@ import lotus from lotus.dtype_extensions import ImageDtype -from lotus.types import ReasoningStrategy, SerializationFormat +from lotus.types import ( + ReasoningStrategy, + SerializationFormat, +) def cot_formatter(reasoning, answer): @@ -100,7 +103,8 @@ def filter_formatter( Your job is to determine whether the claim is true for the given context. """ - if strategy == ReasoningStrategy.COT: + # Simple strategy checking + if strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: sys_instruction += cot_prompt_formatter( reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions ) @@ -131,7 +135,7 @@ def filter_formatter( # reasoning as filler if the user wants cot reasoning if cot_reasoning: content = cot_formatter(cot_reasoning[idx], str(ex_ans)) - elif strategy == "cot": + elif strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: content = cot_formatter("Reasoning omitted", str(ex_ans)) else: content = answer_only_formatter(str(ex_ans)) @@ -145,7 +149,8 @@ def filter_formatter( }, ] ) - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + # Handle DeepSeek CoT formatting (backward compatibility) + if strategy == ReasoningStrategy.CoT and model.is_deepseek() and not examples_multimodal_data: user_instruction = f"Claim: {user_instruction}\n\n{deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_instruction)) else: @@ -211,7 +216,7 @@ def map_formatter( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[str] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | str | None = None, + strategy: ReasoningStrategy | None = None, ) -> list[dict[str, str]]: sys_instruction = ( "The user will provide an instruction and some relevant context.\n" @@ -222,7 +227,7 @@ def map_formatter( return map_formatter_cot( multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning ) - elif strategy == ReasoningStrategy.ZS_COT: + elif strategy == ReasoningStrategy.CoT and not examples_multimodal_data: return map_formatter_zs_cot(multimodal_data, user_instruction) messages = [ @@ -239,7 +244,8 @@ def map_formatter( ] ) - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + # Handle DeepSeek CoT formatting (backward compatibility) + if strategy == ReasoningStrategy.CoT and model.is_deepseek() and not examples_multimodal_data: user_intructions = f"Instruction: {user_instruction}\n\n{deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_intructions)) else: @@ -278,7 +284,7 @@ def extract_formatter( user_message_formatter(multimodal_data), ] - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + if strategy == ReasoningStrategy.CoT and model.is_deepseek(): user_intructions = f"Instruction: {deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_intructions)) diff --git a/lotus/types.py b/lotus/types.py index 08519729..656b404b 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -216,7 +216,26 @@ class LotusUsageLimitException(LotusException): # Reasoning Strategy ################################################################################ class ReasoningStrategy(Enum): - DEFAULT = auto() - COT = auto() - ZS_COT = auto() - FEW_SHOT = auto() + """ + Simple, intuitive reasoning strategies for semantic operations. + + - CoT: Chain-of-thought reasoning with step-by-step explanations + - CoT_Demonstrations: CoT with few-shot examples (user-provided or bootstrapped) + - Demonstrations: Few-shot examples without explicit reasoning + """ + + CoT = auto() + CoT_Demonstrations = auto() + Demonstrations = auto() + + +@dataclass +class DemonstrationConfig: + """Configuration for demonstration-based reasoning""" + + # User-provided examples (alternative to passing examples directly) + examples: pd.DataFrame | None = None + # Bootstrapping configuration - automatically generate examples + bootstrap: bool = False + num_demonstrations: int = 3 + oracle_model: str | None = None # If None, uses the main model From 27f3e042a7e2906ae1dcd215fba5e32892c0609f Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 21 Aug 2025 00:22:18 -0700 Subject: [PATCH 2/8] testing + extract fixes --- lotus/cache.py | 4 +- lotus/sem_ops/postprocessors.py | 58 +-- lotus/sem_ops/sem_extract.py | 7 +- lotus/templates/task_instructions.py | 4 + tests/deepseek_cot_tests.py | 296 ++++++++++------ tests/test_reasoning_strategies.py | 503 +++++++++++++++++++++++++++ 6 files changed, 735 insertions(+), 137 deletions(-) create mode 100644 tests/test_reasoning_strategies.py diff --git a/lotus/cache.py b/lotus/cache.py index 30ae99ad..6925f754 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -42,7 +42,7 @@ def wrapper(self, *args, **kwargs): def serialize(value: Any) -> Any: """ Serialize a value into a JSON-serializable format. - Supports basic types, pandas DataFrames, and objects with a `dict` or `__dict__` method. + Supports basic types, pandas DataFrames, Enums, and objects with a `dict` or `__dict__` method. """ if value is None or isinstance(value, (str, int, float, bool)): return value @@ -52,6 +52,8 @@ def serialize(value: Any) -> Any: return [serialize(item) for item in value] elif isinstance(value, dict): return {key: serialize(val) for key, val in value.items()} + elif isinstance(value, Enum): + return {"__enum__": value.__class__.__name__, "value": value.name} elif hasattr(value, "dict"): return value.dict() elif hasattr(value, "__dict__"): diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 9d8313d0..0416f913 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -9,8 +9,8 @@ ) -def cot_postprocessor(llm_answers: list[str]): - outputs: list[str | None] = [] +def cot_postprocessor(llm_answers: list[str], for_extract: bool = False): + outputs: list[str | dict | None] = [] explanations: list[str | None] = [] for llm_answer in llm_answers: reasoning_idx = llm_answer.find("Reasoning:\n") @@ -20,11 +20,26 @@ def cot_postprocessor(llm_answers: list[str]): reasoning_idx += len("Reasoning:\n") answer_idx = llm_answer.find("Answer:") - reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") - answer = llm_answer[answer_idx + len("Answer:") :] + if answer_idx == -1: + # No "Answer:" found, assume the whole response is the answer + reasoning = "" + answer = llm_answer.strip() + else: + reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") + answer = llm_answer[answer_idx + len("Answer:") :].strip() explanations.append(reasoning) - outputs.append(answer) + + if for_extract: + try: + json_obj = json.loads(answer) + except json.JSONDecodeError: + lotus.logger.info(f"\t Failed to parse: {answer}") + json_obj = {} + json_obj = {key: str(value) for key, value in json_obj.items()} + outputs.append(json_obj) + else: + outputs.append(answer) return outputs, explanations @@ -51,11 +66,14 @@ def deepseek_cot_postprocessor(llm_answers: list[str], for_extract: bool = False if think_start != -1 and think_end != -1: # Extract the reasoning between the tags reasoning = llm_answer[think_start + len("") : think_end].strip() - answer = llm_answer[answer_start + len("Answer:") :].strip() - answer = answer.strip() + if answer_start != -1: + answer = llm_answer[answer_start + len("Answer:") :].strip() + else: + # No "Answer:" found, look for content after + answer = llm_answer[think_end + len("") :].strip() - # If ther is nothing after tag, check if the answer is at the beginning + # If there is nothing after tag, check if the answer is at the beginning if not answer and think_start > 0: answer = llm_answer[:think_start].strip() @@ -67,9 +85,9 @@ def deepseek_cot_postprocessor(llm_answers: list[str], for_extract: bool = False if for_extract: try: - json_obj = json.loads(llm_answer) + json_obj = json.loads(answer) except json.JSONDecodeError: - lotus.logger.info(f"\t Failed to parse: {llm_answer}") + lotus.logger.info(f"\t Failed to parse: {answer}") json_obj = {} json_obj = {key: str(value) for key, value in json_obj.items()} outputs.append(json_obj) @@ -103,7 +121,7 @@ def get_cot_postprocessor(model: lotus.models.LM, for_extract: bool = False) -> base_processor = COT_POSTPROCESSORS[processor_key] return lambda llm_answers: base_processor(llm_answers, for_extract=for_extract) - return cot_postprocessor + return lambda llm_answers: cot_postprocessor(llm_answers, for_extract=for_extract) def map_postprocess( @@ -152,15 +170,15 @@ def extract_postprocess( extract_data = [] explanations = [None] * len(llm_answers) - for llm_answer in llm_answers: - try: - output = json.loads(llm_answer) - except json.JSONDecodeError: - lotus.logger.info(f"\t Failed to parse: {llm_answer}") - output = {} - - output = {key: str(value) for key, value in output.items()} - extract_data.append(output) + for llm_answer in llm_answers: + try: + output = json.loads(llm_answer) + except json.JSONDecodeError: + lotus.logger.info(f"\t Failed to parse: {llm_answer}") + output = {} + + output = {key: str(value) for key, value in output.items()} + extract_data.append(output) return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data, explanations=explanations) diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index ca933a52..ad739d6b 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -43,7 +43,7 @@ def sem_extract( extract_quotes (bool, optional): Whether to extract supporting quotes from the source text for each extracted value. Defaults to False. postprocessor (Callable, optional): A function to post-process the model - outputs. Should take (outputs, model, return_explanations) and return + outputs. Should take (outputs, model, cot_reasoning) and return SemanticExtractPostprocessOutput. Defaults to extract_postprocess. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -88,7 +88,8 @@ def sem_extract( lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc) # post process results - postprocess_output = postprocessor(lm_output.outputs, model, return_explanations) + cot_reasoning = strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations] + postprocess_output = postprocessor(lm_output.outputs, model, cot_reasoning) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") @@ -174,7 +175,7 @@ def __call__( extract_quotes (bool, optional): Whether to extract supporting quotes from the source text for each extracted value. Defaults to False. postprocessor (Callable, optional): A function to post-process the model - outputs. Should take (outputs, model, return_explanations) and return + outputs. Should take (outputs, model, cot_reasoning) and return SemanticExtractPostprocessOutput. Defaults to extract_postprocess. return_raw_outputs (bool, optional): Whether to include raw model outputs in the output DataFrame. Useful for debugging. diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index a2dfa677..60ce2d92 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -279,6 +279,10 @@ def extract_formatter( f"The response should be valid JSON format with the following fields: {fields_str}.\n" ) + # Add CoT instructions for CoT strategy + if strategy == ReasoningStrategy.CoT: + sys_instruction += "\n\nFor your response, first provide your reasoning, then give your final answer in the specified JSON format." + messages = [ {"role": "system", "content": sys_instruction}, user_message_formatter(multimodal_data), diff --git a/tests/deepseek_cot_tests.py b/tests/deepseek_cot_tests.py index 7697c7db..a83d1b06 100644 --- a/tests/deepseek_cot_tests.py +++ b/tests/deepseek_cot_tests.py @@ -5,7 +5,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import DemonstrationConfig, ReasoningStrategy lotus.logger.setLevel("DEBUG") @@ -15,167 +15,237 @@ @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_filter_cot_basic(): - """Test sem_filter using DeepSeek CoT on a simple filtering task.""" +def test_deepseek_demonstrations_only(): + """Test DeepSeek with demonstrations without CoT reasoning.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = { - "Text": ["I had two apples and still have one left", "I gave away all my apples", "I received an apple today"] - } + data = {"Course": ["Linear Algebra", "Creative Writing", "Calculus", "Art History"]} + df = pd.DataFrame(data) + user_instruction = "{Course} requires mathematical skills" + + # Provide examples without reasoning + examples = pd.DataFrame({"Course": ["Statistics", "Poetry", "Physics"], "Answer": [True, False, True]}) + + result = df.sem_filter( + user_instruction, strategy=ReasoningStrategy.Demonstrations, examples=examples, return_all=True + ) + + assert "filter_label" in result.columns + # Should identify math courses correctly based on examples + math_courses = result[result["filter_label"]]["Course"].tolist() + assert any(course in ["Linear Algebra", "Calculus"] for course in math_courses) + + +@pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") +def test_deepseek_cot_demonstrations_combined(): + """Test DeepSeek with combined CoT and demonstrations.""" + lm = LM(model=MODEL_NAME) + lotus.settings.configure(lm=lm) + + data = {"Product": ["Smartphone", "Book", "Laptop", "Pen"]} + df = pd.DataFrame(data) + user_instruction = "{Product} is an electronic device" + + # Provide examples with reasoning + examples = pd.DataFrame( + { + "Product": ["Tablet", "Magazine", "Smart Watch"], + "Answer": [True, False, True], + "Reasoning": [ + "Tablets are electronic devices with screens and processors", + "Magazines are printed materials, not electronic", + "Smart watches are wearable electronic devices with digital displays", + ], + } + ) + + result = df.sem_filter( + user_instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + examples=examples, + return_explanations=True, + return_all=True, + ) + + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Should identify electronic devices correctly + electronic_devices = result[result["filter_label"]]["Product"].tolist() + assert any(device in ["Smartphone", "Laptop"] for device in electronic_devices) + # Check explanations are provided + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 + + +@pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") +def test_deepseek_demonstration_config(): + """Test DeepSeek with DemonstrationConfig.""" + lm = LM(model=MODEL_NAME) + lotus.settings.configure(lm=lm) + + data = {"Animal": ["Dog", "Cat", "Eagle", "Fish"]} df = pd.DataFrame(data) - user_instruction = "{Text} implies I have at least one apple" + user_instruction = "{Animal} can fly" - filtered_df = df.sem_filter(user_instruction, return_explanations=True, return_all=True) + # Provide examples via DemonstrationConfig + examples = pd.DataFrame({"Animal": ["Bird", "Elephant"], "Answer": [True, False]}) - # Check that extra columns are present. - assert "explanation_filter" in filtered_df.columns - assert "filter_label" in filtered_df.columns + demo_config = DemonstrationConfig(examples=examples) - # At least one row should be labeled True. - positive_rows = filtered_df[filtered_df["filter_label"]] - assert len(positive_rows) > 0 + result = df.sem_filter( + user_instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + demonstration_config=demo_config, + return_all=True, + ) - # Each explanation should be nonempty for positive rows. - for exp in positive_rows["explanation_filter"]: - assert exp is not None and exp != "" + assert "filter_label" in result.columns + # Should identify flying animals correctly + flying_animals = result[result["filter_label"]]["Animal"].tolist() + assert "Eagle" in flying_animals @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_map_cot_basic(): - """Test sem_map using DeepSeek CoT on a basic mapping task.""" +def test_deepseek_bootstrapping(): + """Test DeepSeek with automatic demonstration bootstrapping.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = {"Text": ["Paris is the capital of France", "Berlin is the capital of Germany"]} + data = {"City": ["New York", "London", "Tokyo", "Sydney", "Paris"]} df = pd.DataFrame(data) - user_instruction = "Extract the capital city from the sentence: {Text}" - result = df.sem_map(user_instruction, return_explanations=True, strategy=ReasoningStrategy.ZS_COT) + user_instruction = "{City} is in Asia" + + # Configure bootstrapping + demo_config = DemonstrationConfig(bootstrap=True, num_demonstrations=2) + + result = df.sem_filter( + user_instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + demonstration_config=demo_config, + return_explanations=True, + return_all=True, + ) - # Check that the mapping column and explanation column exist. - assert "_map" in result.columns - assert "explanation_map" in result.columns + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns - # Verify that each mapped output is a string and each explanation is nonempty. - for output, exp in zip(result["_map"], result["explanation_map"]): - assert isinstance(output, str) - assert exp is not None and exp != "" + # Should identify Asian cities correctly + asian_cities = result[result["filter_label"]]["City"].tolist() + assert "Tokyo" in asian_cities + + # Should work even without user-provided examples + assert len(result) == len(df) @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_top_k_with_negative_reviews(): - """Test sem_top_k with a dataset containing negative reviews.""" - lm = LM(model=MODEL_NAME, temperature=0.6) +def test_deepseek_extract_with_cot(): + """Test DeepSeek extract operation with CoT reasoning.""" + lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) data = { "Review": [ - "This vacuum cleaner is the best I've ever owned. Highly recommend it!", - "It's okay, not sure I would buy it again.", - "Terrible experience, broke after a few uses.", - "Amazing build quality and customer support. Would absolutely recommend.", - "I would not recommend this to anyone.", - "This product is amazing! I love it.", + "This phone has amazing battery life and great camera quality!", + "The laptop is too slow and overheats frequently.", ] } - df = pd.DataFrame(data) - user_instruction = "{Review} suggests that the user would recommend the product to others" - for method in ["quick", "heap", "naive"]: - sorted_df, stats = df.sem_topk( - user_instruction, - K=2, - method=method, - return_stats=True, - strategy=ReasoningStrategy.ZS_COT, - return_explanations=True, - ) - # Check that the top 2 reviews are positive - top_reviews = sorted_df["Review"].tolist() - assert any( - "recommend" in review.lower() or "best" in review.lower() or "amazing" in review.lower() - for review in top_reviews - ) + output_cols = { + "sentiment": "Overall sentiment (positive/negative)", + "main_feature": "Main feature mentioned in the review", + } - # Check that the stats are correct - assert stats["total_tokens"] > 0 - assert stats["total_llm_calls"] > 0 + input_cols = ["Review"] # Columns to extract from - # Check that each explanation is not empty - for exp in sorted_df["explanation"]: - assert exp is not None and exp != "" + result = df.sem_extract(input_cols, output_cols, strategy=ReasoningStrategy.CoT, return_explanations=True) + + assert "sentiment" in result.columns + assert "main_feature" in result.columns + assert "explanation_extract" in result.columns + + # Check sentiment extraction + sentiments = result["sentiment"].tolist() + assert any("positive" in sent.lower() for sent in sentiments) + assert any("negative" in sent.lower() for sent in sentiments) @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_filter_cot_fewshot(): - """Test sem_filter with few-shot examples to guide filtering decisions.""" +def test_deepseek_backward_compatibility(): + """Test that DeepSeek still works with legacy methods.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = { - "Text": [ - "Sequence: 5, 4, 3", # Not increasing - "Sequence: 1, 2, 3", # Increasing - "Sequence: 8, 7, 6", # Not increasing - ] - } + data = {"Text": ["The weather is sunny today", "It's raining heavily outside"]} df = pd.DataFrame(data) - user_instruction = "{Text} is an increasing sequence" + user_instruction = "{Text} describes good weather" - # Few-shot examples provided as a DataFrame. - examples = pd.DataFrame( - { - "Text": ["Sequence: 1, 2, 3", "Sequence: 3, 2, 1"], - "Answer": [True, False], - "Reasoning": ["Numbers increase steadily", "Numbers decrease"], - } - ) + # Test without explicit strategy (should use default behavior) + result_default = df.sem_filter(user_instruction, return_all=True) - filtered_df = df.sem_filter( - user_instruction, - examples=examples, - return_explanations=True, - return_all=True, - strategy=ReasoningStrategy.COT, - ) + # Test with explicit CoT strategy + result_cot = df.sem_filter(user_instruction, strategy=ReasoningStrategy.CoT, return_all=True) - # Expect that at least the row with "Sequence: 1, 2, 3" is marked positive. - positive_rows = filtered_df[filtered_df["filter_label"]] - assert len(positive_rows) >= 1 - for exp in positive_rows["explanation_filter"]: - assert exp is not None and exp != "" + # Both should work and produce results + assert "filter_label" in result_default.columns + assert "filter_label" in result_cot.columns + assert len(result_default) == len(result_cot) == len(df) @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_map_cot_fewshot(): - """Test sem_map with few-shot examples to guide mapping decisions.""" +def test_deepseek_error_handling(): + """Test error handling with DeepSeek and new reasoning strategies.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = {"Text": ["City: New York", "City: Los Angeles"]} + data = {"Text": ["Sample text"]} df = pd.DataFrame(data) - user_instruction = "Determine the state abbreviation for {Text}" + user_instruction = "{Text} is meaningful" - examples = pd.DataFrame( - { - "Text": ["City: Chicago", "City: Houston"], - "Answer": ["IL", "TX"], - "Reasoning": ["Chicago is in Illinois", "Houston is in Texas"], - } - ) + # Test with empty examples + empty_examples = pd.DataFrame(columns=["Text", "Answer"]) - result = df.sem_map( - user_instruction, - examples=examples, - return_explanations=True, - strategy=ReasoningStrategy.COT, + try: + result = df.sem_filter( + user_instruction, strategy=ReasoningStrategy.Demonstrations, examples=empty_examples, return_all=True + ) + # Should handle gracefully + assert "filter_label" in result.columns + except Exception as e: + # If it raises an error, it should be informative + assert len(str(e)) > 0 + + +@pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") +def test_deepseek_multiple_operations_chaining(): + """Test chaining multiple operations with DeepSeek and different strategies.""" + lm = LM(model=MODEL_NAME) + lotus.settings.configure(lm=lm) + + data = {"Product": ["iPhone", "Novel", "MacBook", "Newspaper", "iPad"]} + df = pd.DataFrame(data) + + # First filter with demonstrations + examples = pd.DataFrame({"Product": ["Laptop", "Book"], "Answer": [True, False]}) + + filtered_df = df.sem_filter( + "{Product} is an electronic device", strategy=ReasoningStrategy.Demonstrations, examples=examples ) - # Check that the new column "State" is added and that explanations are nonempty. - assert "_map" in result.columns - assert "explanation_map" in result.columns - for output, exp in zip(result["_map"], result["explanation_map"]): - assert isinstance(output, str) - assert exp is not None and exp != "" + # Then map with CoT + if len(filtered_df) > 0: + mapped_df = filtered_df.sem_map( + "What category does {Product} belong to?", strategy=ReasoningStrategy.CoT, return_explanations=True + ) + + assert "_map" in mapped_df.columns + assert "explanation_map" in mapped_df.columns + + # Should have reasonable categorizations + for category in mapped_df["_map"]: + assert isinstance(category, str) + assert len(category) > 0 diff --git a/tests/test_reasoning_strategies.py b/tests/test_reasoning_strategies.py new file mode 100644 index 00000000..919fd659 --- /dev/null +++ b/tests/test_reasoning_strategies.py @@ -0,0 +1,503 @@ +import os + +import pandas as pd +import pytest + +import lotus +from lotus.models import LM +from lotus.types import DemonstrationConfig, ReasoningStrategy +from tests.base_test import BaseTest + +# Skip all tests if no OpenAI API key is available +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +pytestmark = pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not available") + + +@pytest.fixture +def sample_courses_df(): + """Sample course data for testing""" + return pd.DataFrame( + { + "Course Name": [ + "Linear Algebra", + "Poetry Writing", + "Calculus II", + "Art History", + "Statistics", + "Creative Writing", + "Machine Learning", + "Literature Analysis", + "Physics", + "Philosophy", + ], + "Department": [ + "Math", + "English", + "Math", + "Art", + "Math", + "English", + "CS", + "English", + "Physics", + "Philosophy", + ], + "Credits": [3, 3, 4, 3, 3, 3, 4, 3, 4, 3], + } + ) + + +@pytest.fixture +def sample_reviews_df(): + """Sample review data for testing""" + return pd.DataFrame( + { + "Review": [ + "This product is amazing! Highly recommend it to everyone.", + "It's okay, nothing special but does the job.", + "Terrible quality, broke after one day. Would not recommend.", + "Great value for money, very satisfied with my purchase.", + "Poor customer service and mediocre product quality.", + "Outstanding performance, exceeded my expectations!", + ], + "Rating": [5, 3, 1, 4, 2, 5], + } + ) + + +@pytest.fixture +def setup_model(): + """Set up a test model""" + lm = LM(model="gpt-4o-mini", temperature=0.1) + lotus.settings.configure(lm=lm) + return lm + + +class TestReasoningStrategies(BaseTest): + """Test suite for reasoning strategies""" + + # ============================================================================= + # Chain-of-Thought (CoT) Tests + # ============================================================================= + + def test_cot_filter_basic(self, sample_courses_df, setup_model): + """Test basic CoT reasoning with sem_filter""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter(instruction, strategy=ReasoningStrategy.CoT, return_explanations=True, return_all=True) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check that explanations are provided + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 + # CoT should contain reasoning + assert any(word in explanation.lower() for word in ["reasoning", "because", "since", "therefore"]) + + def test_cot_map_basic(self, sample_courses_df, setup_model): + """Test basic CoT reasoning with sem_map""" + df = sample_courses_df + instruction = "What is the difficulty level of {Course Name}? Answer: Beginner, Intermediate, or Advanced" + + result = df.sem_map(instruction, strategy=ReasoningStrategy.CoT, return_explanations=True) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check that explanations contain reasoning + for explanation in result["explanation_map"]: + assert explanation is not None + assert len(explanation) > 0 + + def test_cot_topk_basic(self, sample_reviews_df, setup_model): + """Test basic CoT reasoning with sem_topk""" + df = sample_reviews_df + instruction = "{Review} is a positive review" + + result, stats = df.sem_topk( + instruction, K=3, strategy=ReasoningStrategy.CoT, return_explanations=True, return_stats=True + ) + + # Check structure + assert len(result) == 3 + assert "explanation" in result.columns + assert stats["total_llm_calls"] > 0 + + # Check explanations + for explanation in result["explanation"]: + assert explanation is not None + assert len(explanation) > 0 + + # ============================================================================= + # Demonstrations (Few-shot) Tests + # ============================================================================= + + def test_demonstrations_filter_basic(self, sample_courses_df, setup_model): + """Test demonstrations strategy with sem_filter""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Provide examples + examples = pd.DataFrame( + {"Course Name": ["Machine Learning", "Literature", "Physics"], "Answer": [True, False, True]} + ) + + result = df.sem_filter( + instruction, strategy=ReasoningStrategy.Demonstrations, examples=examples, return_all=True + ) + + # Check structure + assert "filter_label" in result.columns + + # Should identify math-heavy courses correctly based on examples + math_courses = result[result["filter_label"]]["Course Name"].tolist() + assert any(course in ["Linear Algebra", "Calculus II", "Statistics"] for course in math_courses) + + def test_demonstrations_map_basic(self, sample_courses_df, setup_model): + """Test demonstrations strategy with sem_map""" + df = sample_courses_df.head(3) # Use fewer rows for faster testing + instruction = "What department is {Course Name} in?" + + # Provide examples + examples = pd.DataFrame({"Course Name": ["Calculus I", "English Literature"], "Answer": ["Math", "English"]}) + + result = df.sem_map(instruction, strategy=ReasoningStrategy.Demonstrations, examples=examples) + + # Check structure + assert "_map" in result.columns + + # Check that mapping results are reasonable + for mapped_value in result["_map"]: + assert isinstance(mapped_value, str) + assert len(mapped_value) > 0 + + # ============================================================================= + # CoT + Demonstrations Tests + # ============================================================================= + + def test_cot_demonstrations_filter(self, sample_courses_df, setup_model): + """Test combined CoT + Demonstrations with sem_filter""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Provide examples with reasoning + examples = pd.DataFrame( + { + "Course Name": ["Machine Learning", "Literature", "Physics"], + "Answer": [True, False, True], + "Reasoning": [ + "Machine Learning requires linear algebra, calculus, and statistics", + "Literature focuses on reading, writing, and analysis - no math required", + "Physics is fundamentally mathematical with equations and calculations", + ], + } + ) + + result = df.sem_filter( + instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + examples=examples, + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check that explanations are provided and contain reasoning + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 + + def test_cot_demonstrations_map(self, sample_courses_df, setup_model): + """Test combined CoT + Demonstrations with sem_map""" + df = sample_courses_df.head(3) + instruction = "What is the difficulty level of {Course Name}?" + + # Provide examples with reasoning + examples = pd.DataFrame( + { + "Course Name": ["Algebra I", "Advanced Calculus"], + "Answer": ["Beginner", "Advanced"], + "Reasoning": [ + "Algebra I is typically an introductory math course", + "Advanced Calculus requires significant mathematical background", + ], + } + ) + + result = df.sem_map( + instruction, strategy=ReasoningStrategy.CoT_Demonstrations, examples=examples, return_explanations=True + ) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check explanations + for explanation in result["explanation_map"]: + assert explanation is not None + assert len(explanation) > 0 + + # ============================================================================= + # DemonstrationConfig and Bootstrapping Tests + # ============================================================================= + + def test_demonstration_config_basic(self, sample_courses_df, setup_model): + """Test DemonstrationConfig with user-provided examples""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Examples provided via DemonstrationConfig + examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + demo_config = DemonstrationConfig(examples=examples) + + result = df.sem_filter( + instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + demonstration_config=demo_config, + return_all=True, + ) + + assert "filter_label" in result.columns + + def test_bootstrapping_basic(self, sample_courses_df, setup_model): + """Test automatic demonstration bootstrapping""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Configure bootstrapping + demo_config = DemonstrationConfig(bootstrap=True, num_demonstrations=2) + + result = df.sem_filter( + instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + demonstration_config=demo_config, + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Should work even without user-provided examples + assert len(result) == len(df) + + def test_bootstrapping_with_oracle_model(self, sample_courses_df, setup_model): + """Test bootstrapping with a different oracle model""" + df = sample_courses_df.head(5) # Use fewer rows for faster testing + instruction = "{Course Name} requires a lot of math" + + demo_config = DemonstrationConfig( + bootstrap=True, + num_demonstrations=1, + oracle_model="gpt-4o-mini", # Use same model for testing + ) + + result = df.sem_filter( + instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + demonstration_config=demo_config, + return_all=True, + ) + + assert "filter_label" in result.columns + + # ============================================================================= + # Backward Compatibility Tests + # ============================================================================= + + def test_backward_compatibility_examples_param(self, sample_courses_df, setup_model): + """Test that old examples parameter still works""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Old way: passing examples directly + examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + result = df.sem_filter( + instruction, + strategy=ReasoningStrategy.Demonstrations, + examples=examples, # Old parameter name + return_all=True, + ) + + assert "filter_label" in result.columns + + def test_no_strategy_specified(self, sample_courses_df, setup_model): + """Test default behavior when no strategy is specified""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter(instruction, return_all=True) + + # Should work with default behavior + assert "filter_label" in result.columns + + # ============================================================================= + # Extract Operation Tests + # ============================================================================= + + def test_cot_extract_basic(self, sample_reviews_df, setup_model): + """Test CoT reasoning with sem_extract""" + df = sample_reviews_df.head(3) # Use fewer rows for faster testing + + input_cols = ["Review"] + output_cols = { + "sentiment": "The sentiment of the review (positive/negative/neutral)", + "key_points": "Main points mentioned in the review", + } + + result = df.sem_extract(input_cols, output_cols, strategy=ReasoningStrategy.CoT, return_explanations=True) + + # Check structure + assert "sentiment" in result.columns + assert "key_points" in result.columns + assert "explanation" in result.columns + + # Check that extractions are reasonable + for sentiment in result["sentiment"]: + assert sentiment.lower() in ["positive", "negative", "neutral"] + + # ============================================================================= + # Error Handling and Edge Cases + # ============================================================================= + + def test_empty_examples(self, sample_courses_df, setup_model): + """Test behavior with empty examples DataFrame""" + df = sample_courses_df.head(3) + instruction = "{Course Name} requires a lot of math" + + # Create properly structured empty examples DataFrame + empty_examples = pd.DataFrame(columns=["Course Name", "Answer"]) + + # Should handle empty examples gracefully + result = df.sem_filter( + instruction, strategy=ReasoningStrategy.Demonstrations, examples=empty_examples, return_all=True + ) + + assert "filter_label" in result.columns + + def test_mismatched_example_columns(self, sample_courses_df, setup_model): + """Test error handling for mismatched example columns""" + df = sample_courses_df.head(3) + instruction = "{Course Name} requires a lot of math" + + # Examples with wrong column names + bad_examples = pd.DataFrame({"WrongColumn": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + # Should handle gracefully or raise informative error + try: + result = df.sem_filter( + instruction, strategy=ReasoningStrategy.Demonstrations, examples=bad_examples, return_all=True + ) + # If it doesn't raise an error, it should still produce results + assert "filter_label" in result.columns + except Exception as e: + # If it raises an error, it should be informative + assert len(str(e)) > 0 + + def test_invalid_strategy_combination(self, sample_courses_df, setup_model): + """Test invalid combinations of parameters""" + df = sample_courses_df.head(3) + instruction = "{Course Name} requires a lot of math" + + # Try to use bootstrapping without CoT_Demonstrations strategy + demo_config = DemonstrationConfig(bootstrap=True) + + try: + result = df.sem_filter( + instruction, + strategy=ReasoningStrategy.CoT, # Wrong strategy for bootstrapping + demonstration_config=demo_config, + return_all=True, + ) + # Should either work or raise informative error + assert "filter_label" in result.columns + except Exception as e: + assert len(str(e)) > 0 + + def test_large_num_demonstrations(self, sample_courses_df, setup_model): + """Test behavior with large number of demonstrations""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Request more demonstrations than available data + demo_config = DemonstrationConfig( + bootstrap=True, + num_demonstrations=20, # More than df length + ) + + result = df.sem_filter( + instruction, + strategy=ReasoningStrategy.CoT_Demonstrations, + demonstration_config=demo_config, + return_all=True, + ) + + # Should handle gracefully + assert "filter_label" in result.columns + + # ============================================================================= + # Performance and Integration Tests + # ============================================================================= + + def test_multiple_operations_with_strategies(self, sample_courses_df, setup_model): + """Test chaining multiple operations with different strategies""" + df = sample_courses_df + + # First filter with demonstrations + examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + filtered_df = df.sem_filter( + "{Course Name} requires a lot of math", strategy=ReasoningStrategy.Demonstrations, examples=examples + ) + + # Then map with CoT + if len(filtered_df) > 0: + mapped_df = filtered_df.sem_map( + "What is the difficulty level of {Course Name}?", + strategy=ReasoningStrategy.CoT, + return_explanations=True, + ) + + assert "_map" in mapped_df.columns + assert "explanation_map" in mapped_df.columns + + def test_strategy_with_return_options(self, sample_courses_df, setup_model): + """Test strategies with various return options""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + # Test return_stats=True (returns tuple) + result, stats = df.sem_filter( + instruction, strategy=ReasoningStrategy.CoT, return_all=True, return_explanations=True, return_stats=True + ) + + # Check all expected columns are present in DataFrame + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check stats is returned + assert isinstance(stats, dict) + + # Test without return_stats (returns DataFrame only) + result_no_stats = df.sem_filter( + instruction, strategy=ReasoningStrategy.CoT, return_all=True, return_explanations=True, return_stats=False + ) + + # Should return DataFrame directly + assert "filter_label" in result_no_stats.columns + assert "explanation_filter" in result_no_stats.columns + + # Test that filtering works correctly + positive_results = result[result["filter_label"]] + assert len(positive_results) >= 0 # Should have some math courses From 7ba876553c3638397b1fbfb7fc41279f79d01eb5 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 21 Aug 2025 18:52:37 -0700 Subject: [PATCH 3/8] remove comments --- examples/op_examples/simple_reasoning.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/op_examples/simple_reasoning.py b/examples/op_examples/simple_reasoning.py index 3921f5b3..01a914bc 100644 --- a/examples/op_examples/simple_reasoning.py +++ b/examples/op_examples/simple_reasoning.py @@ -1,13 +1,3 @@ -""" -Simple Reasoning Strategies Demo - -This example shows the new, simplified reasoning system in Lotus: -1. ReasoningStrategy.CoT - Chain-of-thought reasoning -2. ReasoningStrategy.Demonstrations - Few-shot examples -3. ReasoningStrategy.CoT_Demonstrations - Both combined -4. Automatic demonstration bootstrapping -""" - import pandas as pd import lotus From 0645be3c84b590682b5245c2c1de195b990e5936 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Fri, 22 Aug 2025 21:01:09 -0700 Subject: [PATCH 4/8] change api for reasoning strat --- lotus/__init__.py | 4 +- lotus/sem_ops/demonstration_bootstrap.py | 20 +++-- lotus/sem_ops/postprocessors.py | 9 +- lotus/sem_ops/sem_extract.py | 42 ++++++--- lotus/sem_ops/sem_filter.py | 49 ++++++----- lotus/sem_ops/sem_join.py | 34 ++++---- lotus/sem_ops/sem_map.py | 32 ++++--- lotus/sem_ops/sem_topk.py | 84 ++++++++++-------- lotus/templates/task_instructions.py | 22 ++--- lotus/types.py | 53 +++++++++--- tests/deepseek_cot_tests.py | 28 +++--- tests/test_reasoning_strategies.py | 105 ++++++++++++----------- 12 files changed, 277 insertions(+), 205 deletions(-) diff --git a/lotus/__init__.py b/lotus/__init__.py index 3a4623f5..d1db9a74 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -22,7 +22,7 @@ ) from lotus.web_search import web_search, WebSearchCorpus from lotus.settings import settings # type: ignore[attr-defined] -from lotus.types import ReasoningStrategy, DemonstrationConfig +from lotus.types import PromptStrategy, DemonstrationConfig logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) @@ -52,6 +52,6 @@ "dtype_extensions", "web_search", "WebSearchCorpus", - "ReasoningStrategy", + "PromptStrategy", "DemonstrationConfig", ] diff --git a/lotus/sem_ops/demonstration_bootstrap.py b/lotus/sem_ops/demonstration_bootstrap.py index ba7fce9c..cddca760 100644 --- a/lotus/sem_ops/demonstration_bootstrap.py +++ b/lotus/sem_ops/demonstration_bootstrap.py @@ -4,7 +4,7 @@ import lotus from lotus.models import LM from lotus.templates import task_instructions -from lotus.types import DemonstrationConfig, ReasoningStrategy +from lotus.types import DemonstrationConfig, PromptStrategy def bootstrap_demonstrations_for_filter( @@ -46,10 +46,12 @@ def bootstrap_demonstrations_for_filter( # Generate with CoT reasoning if needed if config.oracle_model or hasattr(config, "include_reasoning"): # Generate with CoT reasoning - prompt = task_instructions.filter_formatter(model, doc, user_instruction, strategy=ReasoningStrategy.CoT) + prompt = task_instructions.filter_formatter( + model, doc, user_instruction, prompt_strategy=PromptStrategy(cot=True) + ) else: # Generate without reasoning - prompt = task_instructions.filter_formatter(model, doc, user_instruction, strategy=None) + prompt = task_instructions.filter_formatter(model, doc, user_instruction, prompt_strategy=None) # Get oracle response response = model([prompt], progress_bar_desc="Bootstrapping demonstrations") @@ -107,10 +109,12 @@ def bootstrap_demonstrations_for_map( # Generate with CoT reasoning if needed if config.oracle_model or hasattr(config, "include_reasoning"): # Generate with CoT reasoning - prompt = task_instructions.map_formatter(model, doc, user_instruction, strategy=ReasoningStrategy.CoT) + prompt = task_instructions.map_formatter( + model, doc, user_instruction, prompt_strategy=PromptStrategy(cot=True) + ) else: # Generate without reasoning - prompt = task_instructions.map_formatter(model, doc, user_instruction, strategy=None) + prompt = task_instructions.map_formatter(model, doc, user_instruction, prompt_strategy=None) # Get oracle response response = model([prompt], progress_bar_desc="Bootstrapping demonstrations") @@ -168,10 +172,12 @@ def bootstrap_demonstrations_for_extract( # Generate with CoT reasoning if needed if config.oracle_model or hasattr(config, "include_reasoning"): # Generate with CoT reasoning - prompt = task_instructions.extract_formatter(model, doc, output_cols, strategy=ReasoningStrategy.CoT) + prompt = task_instructions.extract_formatter( + model, doc, output_cols, prompt_strategy=PromptStrategy(cot=True) + ) else: # Generate without reasoning - prompt = task_instructions.extract_formatter(model, doc, output_cols, strategy=None) + prompt = task_instructions.extract_formatter(model, doc, output_cols, prompt_strategy=None) # Get oracle response response = model([prompt], progress_bar_desc="Bootstrapping demonstrations") diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 0416f913..fe3ea3e7 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -187,6 +187,7 @@ def filter_postprocess( llm_answers: list[str], model: lotus.models.LM, default: bool = True, + cot_reasoning: bool = False, ) -> SemanticFilterPostprocessOutput: """ Postprocess the output of the filter operator. @@ -214,8 +215,12 @@ def process_outputs(answer): lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") return default - postprocessor = get_cot_postprocessor(model) - outputs, explanations = postprocessor(llm_answers) + if cot_reasoning: + postprocessor = get_cot_postprocessor(model) + outputs, explanations = postprocessor(llm_answers) + else: + outputs = llm_answers + explanations = [None] * len(llm_answers) boolean_outputs = [process_outputs(answer) for answer in outputs] diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 422144ab..452234eb 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -6,7 +6,7 @@ from lotus.cache import operator_cache from lotus.models import LM from lotus.templates import task_instructions -from lotus.types import LMOutput, ReasoningStrategy, SemanticExtractOutput, SemanticExtractPostprocessOutput +from lotus.types import LMOutput, PromptStrategy, SemanticExtractOutput, SemanticExtractPostprocessOutput from lotus.utils import show_safe_mode from .postprocessors import extract_postprocess @@ -21,7 +21,7 @@ def sem_extract( safe_mode: bool = False, progress_bar_desc: str = "Extracting", return_explanations: bool = False, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> SemanticExtractOutput: """ Extracts structured attributes and values from a list of documents using a language model. @@ -52,8 +52,9 @@ def sem_extract( return_explanations (bool, optional): Whether to return explanations for the extraction decisions. Useful for debugging and understanding model reasoning. Defaults to False. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to None. + Returns: SemanticExtractOutput: An object containing the extracted outputs, raw @@ -69,11 +70,25 @@ def sem_extract( >>> output_cols = {"sentiment": "positive/negative/neutral", "rating": "1-5 scale"} >>> result = sem_extract(docs, model, output_cols) >>> print(result.outputs) # [{"sentiment": "positive", "rating": "5"}] + + >>> # Using PromptStrategy with chain-of-thought + >>> from lotus.types import PromptStrategy + >>> strat = PromptStrategy(cot=True) + >>> result = sem_extract(docs, model, output_cols, prompt_strategy=strat) + + >>> # Using PromptStrategy with demonstrations + >>> import pandas as pd + >>> examples = pd.DataFrame({ + ... 'text': ['Great product!', 'Terrible service'], + ... 'Answer': [{'sentiment': 'positive'}, {'sentiment': 'negative'}] + ... }) + >>> strat = PromptStrategy(cot=True, dems=examples, max_dems=2) + >>> result = sem_extract(docs, model, output_cols, prompt_strategy=strat) """ # prepare model inputs inputs = [] for doc in docs: - prompt = task_instructions.extract_formatter(model, doc, output_cols, extract_quotes, strategy) + prompt = task_instructions.extract_formatter(model, doc, output_cols, extract_quotes, prompt_strategy) lotus.logger.debug(f"input to model: {prompt}") lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") inputs.append(prompt) @@ -88,7 +103,7 @@ def sem_extract( lm_output: LMOutput = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc) # post process results - cot_reasoning = strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations] + cot_reasoning = prompt_strategy is not None and prompt_strategy.cot postprocess_output = postprocessor(lm_output.outputs, model, cot_reasoning) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") @@ -136,8 +151,9 @@ class SemExtractDataFrame: return_explanations (bool, optional): Whether to include explanations in the output DataFrame. Useful for debugging and understanding model reasoning. Defaults to False. - strategy (ReasoningStrategy | None, optional): The reasoning strategy - to use. Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to None. + Returns: pd.DataFrame: A DataFrame containing the original data plus the @@ -163,10 +179,14 @@ class SemExtractDataFrame: ... ['text'], ... {'sentiment': 'positive/negative/neutral', 'emotion': 'joy/anger/sadness'} ... ) - Extracting: 100%|█████████████████████████████████████████████████████████████████ 2/2 LM calls [00:00<00:00, 2.20it/s] text rating sentiment emotion 0 Great product! 5 positive joy 1 Terrible service 1 negative anger + + >>> # Using PromptStrategy with chain-of-thought + >>> from lotus.types import PromptStrategy + >>> strat = PromptStrategy(cot=True) + >>> df.sem_extract(['text'], {'sentiment': 'positive/negative/neutral'}, prompt_strategy=strat) """ def __init__(self, pandas_obj: pd.DataFrame): @@ -206,7 +226,7 @@ def __call__( safe_mode: bool = False, progress_bar_desc: str = "Extracting", return_explanations: bool = False, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> pd.DataFrame: if lotus.settings.lm is None: raise ValueError( @@ -229,7 +249,7 @@ def __call__( safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, return_explanations=return_explanations, - strategy=strategy, + prompt_strategy=prompt_strategy, ) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 82bf1741..ad1ca458 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -12,8 +12,8 @@ DemonstrationConfig, LMOutput, LogprobsForFilterCascade, + PromptStrategy, ProxyModel, - ReasoningStrategy, SemanticFilterOutput, ) from lotus.utils import show_safe_mode @@ -30,7 +30,7 @@ def sem_filter( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, demonstration_config: DemonstrationConfig | None = None, logprobs: bool = False, safe_mode: bool = False, @@ -63,8 +63,8 @@ def sem_filter( cot_reasoning (list[str] | None, optional): Chain-of-thought reasoning for the example documents. Used when strategy includes COT reasoning. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to None. logprobs (bool, optional): Whether to return log probabilities for the model outputs. Useful for confidence estimation. Defaults to False. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. @@ -99,7 +99,7 @@ def sem_filter( examples_multimodal_data, examples_answers, cot_reasoning, - strategy, + prompt_strategy, reasoning_instructions=additional_cot_instructions, ) lotus.logger.debug(f"input to model: {prompt}") @@ -115,7 +115,9 @@ def sem_filter( inputs, show_progress_bar=show_progress_bar, progress_bar_desc=progress_bar_desc, **kwargs ) - postprocess_output = filter_postprocess(lm_output.outputs, model, default) + postprocess_output = filter_postprocess( + lm_output.outputs, model, default, cot_reasoning=(prompt_strategy is not None and prompt_strategy.cot) + ) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") @@ -142,7 +144,7 @@ def learn_filter_cascade_thresholds( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, additional_cot_instructions: str = "", ) -> tuple[float, float]: """ @@ -172,7 +174,7 @@ def learn_filter_cascade_thresholds( for the example documents. Defaults to None. cot_reasoning (list[str] | None, optional): Chain-of-thought reasoning for the example documents. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. Defaults to None. additional_cot_instructions (str, optional): Additional instructions for chain-of-thought reasoning. Defaults to "". @@ -203,7 +205,7 @@ def learn_filter_cascade_thresholds( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, demonstration_config=None, # No demonstration config for threshold learning safe_mode=False, progress_bar_desc="Running oracle for threshold learning", @@ -255,8 +257,8 @@ class SemFilterDataframe: input DataFrame plus an "Answer" column. Defaults to None. helper_examples (pd.DataFrame | None, optional): Additional helper examples for cascade filtering. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy - to use. Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to None. cascade_args (CascadeArgs | None, optional): Configuration for cascade filtering. Includes parameters like recall_target, precision_target, sampling_percentage, and failure_probability. Defaults to None. @@ -297,8 +299,8 @@ class SemFilterDataframe: 0 Great product! 5 # Example 2: with zero-shot chain-of-thought (ZS-COT) reasoning - >>> from lotus.types import ReasoningStrategy - >>> df.sem_filter("The review {text} and {rating} reflect's a positive sentiment ", strategy=ReasoningStrategy.ZS_COT, return_explanations=True, return_all=True) + >>> from lotus.types import PromptStrategy + >>> df.sem_filter("The review {text} and {rating} reflect's a positive sentiment ", prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_all=True) Filtering: 100%|██████████████████████████████████████████████████████████████████ 4/4 LM calls [00:01<00:00, 3.66it/s] Text filter_label explanation_filter 0 I had two apples, then I gave away one True @@ -344,7 +346,7 @@ def __call__( suffix: str = "_filter", examples: pd.DataFrame | None = None, helper_examples: pd.DataFrame | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, demonstration_config: DemonstrationConfig | None = None, cascade_args: CascadeArgs | None = None, return_stats: bool = False, @@ -362,7 +364,7 @@ def __call__( col_li = lotus.nl_expression.parse_cols(user_instruction) lotus.logger.debug(col_li) - helper_strategy = strategy + helper_strategy = prompt_strategy # check that column exists for column in col_li: @@ -384,10 +386,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if ( - strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations] - and "Reasoning" in examples.columns - ): + if prompt_strategy is not None and prompt_strategy.cot and "Reasoning" in examples.columns: cot_reasoning = examples["Reasoning"].tolist() pos_cascade_threshold, neg_cascade_threshold = None, None @@ -400,7 +399,7 @@ def __call__( assert "Answer" in helper_examples.columns, "Answer must be a column in examples dataframe" helper_examples_multimodal_data = task_instructions.df2multimodal_info(helper_examples, col_li) helper_examples_answers = helper_examples["Answer"].tolist() - if helper_strategy == ReasoningStrategy.CoT and "Reasoning" in helper_examples.columns: + if helper_strategy is not None and helper_strategy.cot and "Reasoning" in helper_examples.columns: helper_cot_reasoning = helper_examples["Reasoning"].tolist() if cascade_args: @@ -419,7 +418,7 @@ def __call__( if not lotus.settings.helper_lm: raise ValueError("Helper LM must be set in settings") - if helper_strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: + if helper_strategy is not None and helper_strategy.cot: raise ValueError("CoT not supported for helper models in cascades.") # Run small LM and get logits @@ -432,7 +431,7 @@ def __call__( examples_answers=helper_examples_answers, cot_reasoning=helper_cot_reasoning, logprobs=True, - strategy=helper_strategy, + prompt_strategy=helper_strategy, demonstration_config=None, # Helper models don't use demonstration config safe_mode=safe_mode, show_progress_bar=True, @@ -469,7 +468,7 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, additional_cot_instructions=additional_cot_instructions, ) @@ -527,7 +526,7 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, demonstration_config=demonstration_config, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", @@ -551,7 +550,7 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, demonstration_config=demonstration_config, safe_mode=safe_mode, show_progress_bar=True, diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index dc3d65d7..377f3f7b 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -6,7 +6,7 @@ import lotus from lotus.cache import operator_cache from lotus.templates import task_instructions -from lotus.types import CascadeArgs, ReasoningStrategy, SemanticJoinOutput +from lotus.types import CascadeArgs, PromptStrategy, SemanticJoinOutput from lotus.utils import show_safe_mode from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds @@ -26,7 +26,7 @@ def sem_join( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Join comparisons", @@ -67,7 +67,7 @@ def sem_join( Defaults to None. default (bool, optional): The default value to use when the model output cannot be parsed as a boolean. Defaults to True. - strategy (ReasoningStrategy | None, optional): The reasoning strategy + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -111,7 +111,7 @@ def sem_join( examples_multimodal_data, examples_answers, cot_reasoning, - strategy, + prompt_strategy, ) ) estimated_total_calls = len(l1) * len(l2) @@ -142,7 +142,7 @@ def sem_join( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, show_progress_bar=False, ) @@ -193,7 +193,7 @@ def sem_join_cascade( map_examples: pd.DataFrame | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, safe_mode: bool = False, ) -> SemanticJoinOutput: """ @@ -253,7 +253,7 @@ def sem_join_cascade( map_examples=map_examples, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, ) num_helper = len(helper_high_conf) @@ -296,7 +296,7 @@ def sem_join_cascade( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, show_progress_bar=True, ) @@ -428,7 +428,7 @@ def join_optimizer( map_examples: pd.DataFrame | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> tuple[pd.DataFrame, pd.DataFrame, int, int]: """ Find most cost-effective join plan between Search-Filter and Map-Search-Filter @@ -472,7 +472,7 @@ def join_optimizer( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, ) sf_high_conf = sf_helper_join[sf_helper_join["_scores"] >= sf_t_pos] sf_high_conf_neg = len(sf_helper_join[sf_helper_join["_scores"] <= sf_t_neg]) @@ -495,7 +495,7 @@ def join_optimizer( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, ) msf_high_conf = msf_helper_join[msf_helper_join["_scores"] >= msf_t_pos] msf_high_conf_neg = len(msf_helper_join[msf_helper_join["_scores"] <= msf_t_neg]) @@ -538,7 +538,7 @@ def learn_join_cascade_threshold( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> tuple[float, float, int]: """ Extract a small sample of the data and find the optimal threshold pair that satisfies the recall and @@ -582,7 +582,7 @@ def learn_join_cascade_threshold( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, progress_bar_desc="Running oracle for threshold learning", ) @@ -625,7 +625,7 @@ def __call__( how: str = "inner", suffix: str = "_join", examples: pd.DataFrame | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, default: bool = True, cascade_args: CascadeArgs | None = None, return_stats: bool = False, @@ -738,7 +738,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, [real_left_on, real_right_on]) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.CoT: + if prompt_strategy is not None and prompt_strategy.cot: return_explanations = True cot_reasoning = examples["Reasoning"].tolist() @@ -769,7 +769,7 @@ def __call__( map_examples=cascade_args.map_examples, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, ) else: @@ -786,7 +786,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, ) diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 76e9eed4..282ccec7 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -8,7 +8,7 @@ from lotus.types import ( DemonstrationConfig, LMOutput, - ReasoningStrategy, + PromptStrategy, SemanticMapOutput, SemanticMapPostprocessOutput, ) @@ -25,7 +25,7 @@ def sem_map( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[str] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, demonstration_config: DemonstrationConfig | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", @@ -57,8 +57,8 @@ def sem_map( cot_reasoning (list[str] | None, optional): Chain-of-thought reasoning for the example documents. Used when strategy includes COT reasoning. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. progress_bar_desc (str, optional): Description for the progress bar. @@ -83,7 +83,13 @@ def sem_map( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.map_formatter( - model, doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy=strategy + model, + doc, + user_instruction, + examples_multimodal_data, + examples_answers, + cot_reasoning, + prompt_strategy=prompt_strategy, ) lotus.logger.debug(f"input to model: {prompt}") lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") @@ -99,9 +105,7 @@ def sem_map( lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc) # post process results - postprocess_output = postprocessor( - lm_output.outputs, model, strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations] - ) + postprocess_output = postprocessor(lm_output.outputs, model, prompt_strategy is not None and prompt_strategy.cot) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") @@ -161,7 +165,7 @@ def __call__( return_raw_outputs: bool = False, suffix: str = "_map", examples: pd.DataFrame | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, demonstration_config: DemonstrationConfig | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", @@ -190,7 +194,7 @@ def __call__( examples (pd.DataFrame | None, optional): Example DataFrame for few-shot learning. Should have the same column structure as the input DataFrame plus an "Answer" column. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -223,8 +227,8 @@ def __call__( 1 Harry is feeling nauseous Negative # Example 2: with zero-shot chain-of-thought (ZS-COT) reasoning - >>> from lotus.types import ReasoningStrategy - >>> df.sem_map("Label the sentiment of Harry in the {document} as positive/negative/neutral. Answer in one word.", return_explanations=True, strategy=ReasoningStrategy.ZS_COT) + >>> from lotus.types import PromptStrategy + >>> df.sem_map("Label the sentiment of Harry in the {document} as positive/negative/neutral. Answer in one word.", return_explanations=True, prompt_strategy=PromptStrategy(cot=True)) Mapping: 100%|████████████████████████████████████████████████████████████████████ 2/2 LM calls [00:02<00:00, 1.04s/it] document _map explanation_map 0 Harry is happy and love cats positive Reasoning: The document states that "Harry is ... @@ -256,7 +260,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: + if prompt_strategy is not None and prompt_strategy.cot: return_explanations = True if "Reasoning" in examples.columns: cot_reasoning = examples["Reasoning"].tolist() @@ -271,7 +275,7 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, demonstration_config=demonstration_config, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 9d3c9ca6..9b1c689a 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -9,7 +9,7 @@ import lotus from lotus.cache import operator_cache from lotus.templates import task_instructions -from lotus.types import LMOutput, ReasoningStrategy, SemanticTopKOutput +from lotus.types import LMOutput, PromptStrategy, SemanticTopKOutput from lotus.utils import show_safe_mode @@ -18,7 +18,7 @@ def get_match_prompt_binary( doc2: dict[str, Any], user_instruction: str, model: lotus.models.LM, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> list[dict[str, Any]]: """ Generate a binary comparison prompt for two documents. @@ -35,7 +35,7 @@ def get_match_prompt_binary( user_instruction (str): The natural language instruction that defines the comparison criteria. model (lotus.models.LM): The language model instance to use for comparison. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, CoT, or Demonstrations. Defaults to None. Returns: @@ -48,7 +48,7 @@ def get_match_prompt_binary( >>> model = LM(model="gpt-4o") >>> prompt = get_match_prompt_binary(doc1, doc2, "Which is more relevant to AI?", model) """ - if strategy == ReasoningStrategy.CoT: + if prompt_strategy is not None and prompt_strategy.cot: sys_prompt = ( "Your job is to to select and return the most relevant document to the user's question.\n" "Carefully read the user's question and the two documents provided below.\n" @@ -69,7 +69,7 @@ def get_match_prompt_binary( content_text, content_image_inputs = task_instructions.context_formatter(doc) prompt += [{"type": "text", "text": f"\nDocument {idx+1}:\n{content_text}"}, *content_image_inputs] - if strategy == ReasoningStrategy.CoT and model.is_deepseek(): + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek(): deepseek_instructions = """Please think through your reasoning step by step, then provide your final answer. You must put your reasoning insdie the tags, then provide your final answer after the tag with the format: Answer: your answer.""" @@ -133,7 +133,7 @@ def compare_batch_binary( pairs: list[tuple[dict[str, Any], dict[str, Any]]], model: lotus.models.LM, user_instruction: str, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> tuple[list[bool], list[str], int]: """ Compare multiple pairs of documents using binary classification. @@ -147,7 +147,7 @@ def compare_batch_binary( model (lotus.models.LM): The language model instance to use for comparison. user_instruction (str): The natural language instruction that defines the comparison criteria. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. Returns: @@ -164,7 +164,9 @@ def compare_batch_binary( match_prompts = [] tokens = 0 for doc1, doc2 in pairs: - match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy, model=model)) + match_prompts.append( + get_match_prompt_binary(doc1, doc2, user_instruction, prompt_strategy=prompt_strategy, model=model) + ) tokens += model.count_tokens(match_prompts[-1]) lm_results: LMOutput = model(match_prompts, show_progress_bar=False) result_explanations = list(map(parse_ans_binary, lm_results.outputs)) @@ -178,7 +180,7 @@ def compare_batch_binary_cascade( model: lotus.models.LM, user_instruction: str, cascade_threshold: float, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> tuple[list[bool], list[str], int, int, int]: """ Compare multiple pairs of documents using a cascade approach. @@ -195,7 +197,7 @@ def compare_batch_binary_cascade( the comparison criteria. cascade_threshold (float): Confidence threshold for using the large model. Cases below this threshold will use the helper model. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. Returns: @@ -219,7 +221,9 @@ def compare_batch_binary_cascade( match_prompts = [] small_tokens = 0 for doc1, doc2 in pairs: - match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy, model=model)) + match_prompts.append( + get_match_prompt_binary(doc1, doc2, user_instruction, prompt_strategy=prompt_strategy, model=model) + ) small_tokens += model.count_tokens(match_prompts[-1]) helper_lm = lotus.settings.helper_lm @@ -277,7 +281,7 @@ def llm_naive_sort( docs: list[dict[str, Any]], model: lotus.models.LM, user_instruction: str, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, safe_mode: bool = False, ) -> SemanticTopKOutput: """ @@ -293,7 +297,7 @@ def llm_naive_sort( model (lotus.models.LM): The language model instance to use for comparisons. user_instruction (str): The natural language instruction that defines the sorting criteria. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -319,7 +323,9 @@ def llm_naive_sort( desc="All-pairs comparisons", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} LM calls [{elapsed}<{remaining}]", ) - comparisons, explanations, tokens = compare_batch_binary(pairs, model, user_instruction, strategy=strategy) + comparisons, explanations, tokens = compare_batch_binary( + pairs, model, user_instruction, prompt_strategy=prompt_strategy + ) pbar.update(len(pairs)) pbar.close() if safe_mode: @@ -350,7 +356,7 @@ def llm_quicksort( user_instruction: str, K: int, embedding: bool = False, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, cascade_threshold: float | None = None, safe_mode: bool = False, ) -> SemanticTopKOutput: @@ -370,7 +376,7 @@ def llm_quicksort( K (int): The number of top documents to return. embedding (bool, optional): Whether to use embedding optimization for pivot selection. Defaults to False. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. cascade_threshold (float | None, optional): Confidence threshold for cascade filtering. If provided, uses a two-stage model approach. Defaults to None. @@ -391,7 +397,9 @@ def llm_quicksort( stats["total_llm_calls"] = 0 stats["explanations"] = {} if safe_mode: - sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy, model=model) + sample_prompt = get_match_prompt_binary( + docs[0], docs[1], user_instruction, prompt_strategy=prompt_strategy, model=model + ) estimated_quickselect_calls = 2 * K estimated_quicksort_calls = 2 * len(docs) * np.log(len(docs)) estimated_total_calls = estimated_quickselect_calls + estimated_quicksort_calls @@ -425,7 +433,9 @@ def partition(indexes: list[int], low: int, high: int, K: int) -> int: pairs = [(docs[indexes[j]], pivot) for j in range(low, high)] if cascade_threshold is None: - comparisons, explanations, tokens = compare_batch_binary(pairs, model, user_instruction, strategy=strategy) + comparisons, explanations, tokens = compare_batch_binary( + pairs, model, user_instruction, prompt_strategy=prompt_strategy + ) stats["total_tokens"] += tokens stats["total_llm_calls"] += len(pairs) @@ -440,7 +450,7 @@ def partition(indexes: list[int], low: int, high: int, K: int) -> int: model, user_instruction, cascade_threshold, - strategy=strategy, + prompt_strategy=prompt_strategy, ) stats["total_small_tokens"] += small_tokens @@ -499,14 +509,14 @@ class HeapDoc: Attributes: num_calls (int): Class variable tracking total number of LM calls. total_tokens (int): Class variable tracking total tokens used. - strategy (ReasoningStrategy | None): Class variable for reasoning strategy. + prompt_strategy (PromptStrategy | None): Class variable for reasoning strategy. model (lotus.models.LM | None): Class variable for the language model. explanations (dict[int, list[str]]): Class variable storing explanations. """ num_calls: int = 0 total_tokens: int = 0 - strategy: ReasoningStrategy | None = None + prompt_strategy: PromptStrategy | None = None model: lotus.models.LM | None = None explanations: dict[int, list[str]] = {} @@ -541,7 +551,7 @@ def __lt__(self, other: "HeapDoc") -> bool: """ assert HeapDoc.model is not None prompt = get_match_prompt_binary( - self.doc, other.doc, self.user_instruction, strategy=self.strategy, model=HeapDoc.model + self.doc, other.doc, self.user_instruction, prompt_strategy=HeapDoc.prompt_strategy, model=HeapDoc.model ) HeapDoc.num_calls += 1 HeapDoc.total_tokens += HeapDoc.model.count_tokens(prompt) @@ -562,7 +572,7 @@ def llm_heapsort( model: lotus.models.LM, user_instruction: str, K: int, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, safe_mode: bool = False, ) -> SemanticTopKOutput: """ @@ -579,7 +589,7 @@ def llm_heapsort( user_instruction (str): The natural language instruction that defines the sorting criteria. K (int): The number of top documents to return. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -595,7 +605,9 @@ def llm_heapsort( """ if safe_mode: - sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy, model=model) + sample_prompt = get_match_prompt_binary( + docs[0], docs[1], user_instruction, prompt_strategy=prompt_strategy, model=model + ) estimated_heap_construction_calls = len(docs) * np.log(len(docs)) estimated_top_k_extraction_calls = K * np.log(len(docs)) estimated_total_calls = estimated_heap_construction_calls + estimated_top_k_extraction_calls @@ -604,7 +616,7 @@ def llm_heapsort( HeapDoc.num_calls = 0 HeapDoc.total_tokens = 0 - HeapDoc.strategy = strategy + HeapDoc.prompt_strategy = prompt_strategy HeapDoc.model = model HeapDoc.explanations = {} N = len(docs) @@ -667,18 +679,18 @@ def process_group(args): Args: args (tuple): A tuple containing (group, user_instruction, K, method, - strategy, group_by, cascade_threshold, return_stats). + prompt_strategy, group_by, cascade_threshold, return_stats). Returns: pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: The top-K results for the group, optionally with statistics. """ - group, user_instruction, K, method, strategy, group_by, cascade_threshold, return_stats = args + group, user_instruction, K, method, prompt_strategy, group_by, cascade_threshold, return_stats = args return group.sem_topk( user_instruction, K, method=method, - strategy=strategy, + prompt_strategy=prompt_strategy, group_by=None, cascade_threshold=cascade_threshold, return_stats=return_stats, @@ -690,7 +702,7 @@ def __call__( user_instruction: str, K: int, method: str = "quick", - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, group_by: list[str] | None = None, cascade_threshold: float | None = None, return_stats: bool = False, @@ -714,7 +726,7 @@ def __call__( - "naive": Naive quadratic approach - "quick-sem": Quicksort with semantic embedding optimization. Requires the passed column to be indexed with sem_index. Defaults to "quick". - strategy (ReasoningStrategy | None, optional): The reasoning strategy + strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. group_by (list[str] | None, optional): Column names to group by before sorting. Each group will be sorted separately. Defaults to None. @@ -773,7 +785,7 @@ def __call__( if group_by: grouped = self._obj.groupby(group_by) group_args = [ - (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) + (group, user_instruction, K, method, prompt_strategy, None, cascade_threshold, return_stats) for _, group in grouped ] @@ -808,7 +820,7 @@ def __call__( formatted_usr_instr, K, embedding=method == "quick-sem", - strategy=strategy, + prompt_strategy=prompt_strategy, cascade_threshold=cascade_threshold, safe_mode=safe_mode, ) @@ -818,7 +830,7 @@ def __call__( model, formatted_usr_instr, K, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, ) elif method == "naive": @@ -826,7 +838,7 @@ def __call__( multimodal_data, model, formatted_usr_instr, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, ) else: @@ -836,7 +848,7 @@ def __call__( new_df = new_df.reindex(output.indexes).reset_index(drop=True) new_df = new_df.head(K) - if return_explanations and strategy == ReasoningStrategy.CoT: + if return_explanations and prompt_strategy is not None and prompt_strategy.cot: explanations = [] for idx in output.indexes[:K]: explanation = "No Comparison Made" diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 60ce2d92..953d689c 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -6,7 +6,7 @@ import lotus from lotus.dtype_extensions import ImageDtype from lotus.types import ( - ReasoningStrategy, + PromptStrategy, SerializationFormat, ) @@ -94,7 +94,7 @@ def filter_formatter( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[bool] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, reasoning_instructions: str = "", ) -> list[dict[str, str]]: answer_instructions = "The answer should be either True or False" @@ -104,7 +104,7 @@ def filter_formatter( """ # Simple strategy checking - if strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: + if prompt_strategy is not None and prompt_strategy.cot: sys_instruction += cot_prompt_formatter( reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions ) @@ -135,7 +135,7 @@ def filter_formatter( # reasoning as filler if the user wants cot reasoning if cot_reasoning: content = cot_formatter(cot_reasoning[idx], str(ex_ans)) - elif strategy in [ReasoningStrategy.CoT, ReasoningStrategy.CoT_Demonstrations]: + elif prompt_strategy is not None and prompt_strategy.cot: content = cot_formatter("Reasoning omitted", str(ex_ans)) else: content = answer_only_formatter(str(ex_ans)) @@ -150,7 +150,7 @@ def filter_formatter( ] ) # Handle DeepSeek CoT formatting (backward compatibility) - if strategy == ReasoningStrategy.CoT and model.is_deepseek() and not examples_multimodal_data: + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek() and not examples_multimodal_data: user_instruction = f"Claim: {user_instruction}\n\n{deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_instruction)) else: @@ -216,7 +216,7 @@ def map_formatter( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[str] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> list[dict[str, str]]: sys_instruction = ( "The user will provide an instruction and some relevant context.\n" @@ -227,7 +227,7 @@ def map_formatter( return map_formatter_cot( multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning ) - elif strategy == ReasoningStrategy.CoT and not examples_multimodal_data: + elif prompt_strategy is not None and prompt_strategy.cot and not examples_multimodal_data: return map_formatter_zs_cot(multimodal_data, user_instruction) messages = [ @@ -245,7 +245,7 @@ def map_formatter( ) # Handle DeepSeek CoT formatting (backward compatibility) - if strategy == ReasoningStrategy.CoT and model.is_deepseek() and not examples_multimodal_data: + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek() and not examples_multimodal_data: user_intructions = f"Instruction: {user_instruction}\n\n{deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_intructions)) else: @@ -258,7 +258,7 @@ def extract_formatter( multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, ) -> list[dict[str, str]]: output_col_names = list(output_cols.keys()) # Set the description to be the key if no value is provided @@ -280,7 +280,7 @@ def extract_formatter( ) # Add CoT instructions for CoT strategy - if strategy == ReasoningStrategy.CoT: + if prompt_strategy is not None and prompt_strategy.cot: sys_instruction += "\n\nFor your response, first provide your reasoning, then give your final answer in the specified JSON format." messages = [ @@ -288,7 +288,7 @@ def extract_formatter( user_message_formatter(multimodal_data), ] - if strategy == ReasoningStrategy.CoT and model.is_deepseek(): + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek(): user_intructions = f"Instruction: {deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_intructions)) diff --git a/lotus/types.py b/lotus/types.py index 656b404b..057206d3 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from enum import Enum, auto +from enum import Enum from typing import Any import pandas as pd @@ -213,25 +213,54 @@ class LotusUsageLimitException(LotusException): ################################################################################ -# Reasoning Strategy +# Prompt Strategy ################################################################################ -class ReasoningStrategy(Enum): +@dataclass +class PromptStrategy: """ - Simple, intuitive reasoning strategies for semantic operations. - - - CoT: Chain-of-thought reasoning with step-by-step explanations - - CoT_Demonstrations: CoT with few-shot examples (user-provided or bootstrapped) - - Demonstrations: Few-shot examples without explicit reasoning + Configurable prompt strategy for semantic operations. + + This class encapsulates various prompting techniques including chain-of-thought + reasoning, demonstrations, and bootstrapping configurations. + + Args: + cot (bool): Whether to use chain-of-thought reasoning. Defaults to False. + dems (pd.DataFrame | str | None): Demonstrations to use. Can be: + - pd.DataFrame: User-provided examples with Answer column + - "auto": Automatically bootstrap demonstrations + - None: No demonstrations + max_dems (int): Maximum number of demonstrations to use. Defaults to 3. + teacher_lm: Language model to use for bootstrapping demonstrations. + If None, uses the main model. Defaults to None. + + Example: + >>> # Chain-of-thought with user-provided demonstrations + >>> strat = PromptStrategy(cot=True, dems=examples_df) + >>> df.sem_filter(user_instruction, prompt_strategy=strat) + + >>> # Auto-bootstrap demonstrations with CoT + >>> strat = PromptStrategy( + ... cot=True, + ... dems="auto", + ... max_dems=2, + ... teacher_lm=lotus.models.LM(model="gpt-4o-mini") + ... ) + >>> df.sem_filter(user_instruction, prompt_strategy=strat) """ - CoT = auto() - CoT_Demonstrations = auto() - Demonstrations = auto() + cot: bool = False + dems: pd.DataFrame | str | None = None + max_dems: int = 3 + teacher_lm: Any = None # lotus.models.LM type, but avoiding circular import @dataclass class DemonstrationConfig: - """Configuration for demonstration-based reasoning""" + """ + DEPRECATED: Use PromptStrategy instead. + + Configuration for demonstration-based reasoning + """ # User-provided examples (alternative to passing examples directly) examples: pd.DataFrame | None = None diff --git a/tests/deepseek_cot_tests.py b/tests/deepseek_cot_tests.py index a83d1b06..2bfe3fe6 100644 --- a/tests/deepseek_cot_tests.py +++ b/tests/deepseek_cot_tests.py @@ -5,7 +5,7 @@ import lotus from lotus.models import LM -from lotus.types import DemonstrationConfig, ReasoningStrategy +from lotus.types import DemonstrationConfig, PromptStrategy lotus.logger.setLevel("DEBUG") @@ -27,9 +27,7 @@ def test_deepseek_demonstrations_only(): # Provide examples without reasoning examples = pd.DataFrame({"Course": ["Statistics", "Poetry", "Physics"], "Answer": [True, False, True]}) - result = df.sem_filter( - user_instruction, strategy=ReasoningStrategy.Demonstrations, examples=examples, return_all=True - ) + result = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(dems=examples), return_all=True) assert "filter_label" in result.columns # Should identify math courses correctly based on examples @@ -62,7 +60,7 @@ def test_deepseek_cot_demonstrations_combined(): result = df.sem_filter( user_instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, + prompt_strategy=PromptStrategy(cot=True, dems=examples), examples=examples, return_explanations=True, return_all=True, @@ -98,7 +96,7 @@ def test_deepseek_demonstration_config(): result = df.sem_filter( user_instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, + prompt_strategy=PromptStrategy(cot=True, dems=examples), demonstration_config=demo_config, return_all=True, ) @@ -124,7 +122,7 @@ def test_deepseek_bootstrapping(): result = df.sem_filter( user_instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, + prompt_strategy=PromptStrategy(cot=True, dems="auto"), demonstration_config=demo_config, return_explanations=True, return_all=True, @@ -162,7 +160,7 @@ def test_deepseek_extract_with_cot(): input_cols = ["Review"] # Columns to extract from - result = df.sem_extract(input_cols, output_cols, strategy=ReasoningStrategy.CoT, return_explanations=True) + result = df.sem_extract(input_cols, output_cols, prompt_strategy=PromptStrategy(cot=True), return_explanations=True) assert "sentiment" in result.columns assert "main_feature" in result.columns @@ -188,7 +186,7 @@ def test_deepseek_backward_compatibility(): result_default = df.sem_filter(user_instruction, return_all=True) # Test with explicit CoT strategy - result_cot = df.sem_filter(user_instruction, strategy=ReasoningStrategy.CoT, return_all=True) + result_cot = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True), return_all=True) # Both should work and produce results assert "filter_label" in result_default.columns @@ -210,9 +208,7 @@ def test_deepseek_error_handling(): empty_examples = pd.DataFrame(columns=["Text", "Answer"]) try: - result = df.sem_filter( - user_instruction, strategy=ReasoningStrategy.Demonstrations, examples=empty_examples, return_all=True - ) + result = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(dems=empty_examples), return_all=True) # Should handle gracefully assert "filter_label" in result.columns except Exception as e: @@ -232,14 +228,14 @@ def test_deepseek_multiple_operations_chaining(): # First filter with demonstrations examples = pd.DataFrame({"Product": ["Laptop", "Book"], "Answer": [True, False]}) - filtered_df = df.sem_filter( - "{Product} is an electronic device", strategy=ReasoningStrategy.Demonstrations, examples=examples - ) + filtered_df = df.sem_filter("{Product} is an electronic device", prompt_strategy=PromptStrategy(dems=examples)) # Then map with CoT if len(filtered_df) > 0: mapped_df = filtered_df.sem_map( - "What category does {Product} belong to?", strategy=ReasoningStrategy.CoT, return_explanations=True + "What category does {Product} belong to?", + prompt_strategy=PromptStrategy(cot=True), + return_explanations=True, ) assert "_map" in mapped_df.columns diff --git a/tests/test_reasoning_strategies.py b/tests/test_reasoning_strategies.py index 919fd659..e29946c7 100644 --- a/tests/test_reasoning_strategies.py +++ b/tests/test_reasoning_strategies.py @@ -5,7 +5,7 @@ import lotus from lotus.models import LM -from lotus.types import DemonstrationConfig, ReasoningStrategy +from lotus.types import PromptStrategy from tests.base_test import BaseTest # Skip all tests if no OpenAI API key is available @@ -85,7 +85,9 @@ def test_cot_filter_basic(self, sample_courses_df, setup_model): df = sample_courses_df instruction = "{Course Name} requires a lot of math" - result = df.sem_filter(instruction, strategy=ReasoningStrategy.CoT, return_explanations=True, return_all=True) + result = df.sem_filter( + instruction, prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_all=True + ) # Check structure assert "filter_label" in result.columns @@ -95,15 +97,33 @@ def test_cot_filter_basic(self, sample_courses_df, setup_model): for explanation in result["explanation_filter"]: assert explanation is not None assert len(explanation) > 0 - # CoT should contain reasoning - assert any(word in explanation.lower() for word in ["reasoning", "because", "since", "therefore"]) + # CoT should contain substantive reasoning (check for common reasoning indicators or sufficient length) + has_reasoning_words = any( + word in explanation.lower() + for word in [ + "reasoning", + "because", + "since", + "therefore", + "requires", + "involves", + "contains", + "needs", + "mathematical", + "math", + "calculus", + "algebra", + ] + ) + is_substantial = len(explanation.split()) > 5 + assert has_reasoning_words or is_substantial, f"Explanation lacks reasoning indicators: '{explanation}'" def test_cot_map_basic(self, sample_courses_df, setup_model): """Test basic CoT reasoning with sem_map""" df = sample_courses_df instruction = "What is the difficulty level of {Course Name}? Answer: Beginner, Intermediate, or Advanced" - result = df.sem_map(instruction, strategy=ReasoningStrategy.CoT, return_explanations=True) + result = df.sem_map(instruction, prompt_strategy=PromptStrategy(cot=True), return_explanations=True) # Check structure assert "_map" in result.columns @@ -120,7 +140,7 @@ def test_cot_topk_basic(self, sample_reviews_df, setup_model): instruction = "{Review} is a positive review" result, stats = df.sem_topk( - instruction, K=3, strategy=ReasoningStrategy.CoT, return_explanations=True, return_stats=True + instruction, K=3, prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_stats=True ) # Check structure @@ -147,9 +167,7 @@ def test_demonstrations_filter_basic(self, sample_courses_df, setup_model): {"Course Name": ["Machine Learning", "Literature", "Physics"], "Answer": [True, False, True]} ) - result = df.sem_filter( - instruction, strategy=ReasoningStrategy.Demonstrations, examples=examples, return_all=True - ) + result = df.sem_filter(instruction, prompt_strategy=PromptStrategy(dems=examples), return_all=True) # Check structure assert "filter_label" in result.columns @@ -166,7 +184,7 @@ def test_demonstrations_map_basic(self, sample_courses_df, setup_model): # Provide examples examples = pd.DataFrame({"Course Name": ["Calculus I", "English Literature"], "Answer": ["Math", "English"]}) - result = df.sem_map(instruction, strategy=ReasoningStrategy.Demonstrations, examples=examples) + result = df.sem_map(instruction, prompt_strategy=PromptStrategy(dems=examples)) # Check structure assert "_map" in result.columns @@ -200,8 +218,7 @@ def test_cot_demonstrations_filter(self, sample_courses_df, setup_model): result = df.sem_filter( instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, - examples=examples, + prompt_strategy=PromptStrategy(cot=True, dems=examples), return_explanations=True, return_all=True, ) @@ -233,7 +250,7 @@ def test_cot_demonstrations_map(self, sample_courses_df, setup_model): ) result = df.sem_map( - instruction, strategy=ReasoningStrategy.CoT_Demonstrations, examples=examples, return_explanations=True + instruction, prompt_strategy=PromptStrategy(cot=True, dems=examples), return_explanations=True ) # Check structure @@ -257,12 +274,9 @@ def test_demonstration_config_basic(self, sample_courses_df, setup_model): # Examples provided via DemonstrationConfig examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) - demo_config = DemonstrationConfig(examples=examples) - result = df.sem_filter( instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, - demonstration_config=demo_config, + prompt_strategy=PromptStrategy(cot=True, dems=examples), return_all=True, ) @@ -274,12 +288,9 @@ def test_bootstrapping_basic(self, sample_courses_df, setup_model): instruction = "{Course Name} requires a lot of math" # Configure bootstrapping - demo_config = DemonstrationConfig(bootstrap=True, num_demonstrations=2) - result = df.sem_filter( instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, - demonstration_config=demo_config, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), return_explanations=True, return_all=True, ) @@ -296,16 +307,9 @@ def test_bootstrapping_with_oracle_model(self, sample_courses_df, setup_model): df = sample_courses_df.head(5) # Use fewer rows for faster testing instruction = "{Course Name} requires a lot of math" - demo_config = DemonstrationConfig( - bootstrap=True, - num_demonstrations=1, - oracle_model="gpt-4o-mini", # Use same model for testing - ) - result = df.sem_filter( instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, - demonstration_config=demo_config, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=1, teacher_lm=LM(model="gpt-4o-mini")), return_all=True, ) @@ -325,7 +329,7 @@ def test_backward_compatibility_examples_param(self, sample_courses_df, setup_mo result = df.sem_filter( instruction, - strategy=ReasoningStrategy.Demonstrations, + prompt_strategy=PromptStrategy(dems=examples), examples=examples, # Old parameter name return_all=True, ) @@ -356,7 +360,9 @@ def test_cot_extract_basic(self, sample_reviews_df, setup_model): "key_points": "Main points mentioned in the review", } - result = df.sem_extract(input_cols, output_cols, strategy=ReasoningStrategy.CoT, return_explanations=True) + result = df.sem_extract( + input_cols, output_cols, prompt_strategy=PromptStrategy(cot=True), return_explanations=True + ) # Check structure assert "sentiment" in result.columns @@ -380,9 +386,7 @@ def test_empty_examples(self, sample_courses_df, setup_model): empty_examples = pd.DataFrame(columns=["Course Name", "Answer"]) # Should handle empty examples gracefully - result = df.sem_filter( - instruction, strategy=ReasoningStrategy.Demonstrations, examples=empty_examples, return_all=True - ) + result = df.sem_filter(instruction, prompt_strategy=PromptStrategy(dems=empty_examples), return_all=True) assert "filter_label" in result.columns @@ -396,9 +400,7 @@ def test_mismatched_example_columns(self, sample_courses_df, setup_model): # Should handle gracefully or raise informative error try: - result = df.sem_filter( - instruction, strategy=ReasoningStrategy.Demonstrations, examples=bad_examples, return_all=True - ) + result = df.sem_filter(instruction, prompt_strategy=PromptStrategy(dems=bad_examples), return_all=True) # If it doesn't raise an error, it should still produce results assert "filter_label" in result.columns except Exception as e: @@ -411,13 +413,10 @@ def test_invalid_strategy_combination(self, sample_courses_df, setup_model): instruction = "{Course Name} requires a lot of math" # Try to use bootstrapping without CoT_Demonstrations strategy - demo_config = DemonstrationConfig(bootstrap=True) - try: result = df.sem_filter( instruction, - strategy=ReasoningStrategy.CoT, # Wrong strategy for bootstrapping - demonstration_config=demo_config, + prompt_strategy=PromptStrategy(cot=True, dems="auto"), # Should use auto for bootstrapping return_all=True, ) # Should either work or raise informative error @@ -431,15 +430,9 @@ def test_large_num_demonstrations(self, sample_courses_df, setup_model): instruction = "{Course Name} requires a lot of math" # Request more demonstrations than available data - demo_config = DemonstrationConfig( - bootstrap=True, - num_demonstrations=20, # More than df length - ) - result = df.sem_filter( instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, - demonstration_config=demo_config, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=20), return_all=True, ) @@ -458,14 +451,14 @@ def test_multiple_operations_with_strategies(self, sample_courses_df, setup_mode examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) filtered_df = df.sem_filter( - "{Course Name} requires a lot of math", strategy=ReasoningStrategy.Demonstrations, examples=examples + "{Course Name} requires a lot of math", prompt_strategy=PromptStrategy(dems=examples) ) # Then map with CoT if len(filtered_df) > 0: mapped_df = filtered_df.sem_map( "What is the difficulty level of {Course Name}?", - strategy=ReasoningStrategy.CoT, + prompt_strategy=PromptStrategy(cot=True), return_explanations=True, ) @@ -479,7 +472,11 @@ def test_strategy_with_return_options(self, sample_courses_df, setup_model): # Test return_stats=True (returns tuple) result, stats = df.sem_filter( - instruction, strategy=ReasoningStrategy.CoT, return_all=True, return_explanations=True, return_stats=True + instruction, + prompt_strategy=PromptStrategy(cot=True), + return_all=True, + return_explanations=True, + return_stats=True, ) # Check all expected columns are present in DataFrame @@ -491,7 +488,11 @@ def test_strategy_with_return_options(self, sample_courses_df, setup_model): # Test without return_stats (returns DataFrame only) result_no_stats = df.sem_filter( - instruction, strategy=ReasoningStrategy.CoT, return_all=True, return_explanations=True, return_stats=False + instruction, + prompt_strategy=PromptStrategy(cot=True), + return_all=True, + return_explanations=True, + return_stats=False, ) # Should return DataFrame directly From 83f7da95cb413455547789368ef285acf935bebd Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Fri, 22 Aug 2025 23:58:25 -0700 Subject: [PATCH 5/8] fix tests + examples --- .github/tests/lm_tests.py | 9 ++++----- examples/model_examples/deepseek.py | 4 ++-- examples/op_examples/filter.py | 3 ++- examples/op_examples/filter_cot.py | 6 +++--- examples/op_examples/map_deepseek_cot.py | 4 ++-- examples/op_examples/multimodal_ops/join.py | 3 ++- examples/op_examples/simple_reasoning.py | 16 +++++++--------- examples/op_examples/top_k_deepseek_cot.py | 4 ++-- 8 files changed, 24 insertions(+), 25 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 7b12fa08..3654960e 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -6,7 +6,7 @@ import lotus from lotus.models import LM, SentenceTransformersRM -from lotus.types import CascadeArgs +from lotus.types import CascadeArgs, PromptStrategy from lotus.vector_store import FaissVS ################################################################################ @@ -269,7 +269,7 @@ def test_filter_operation_cot(setup_models, model): } df = pd.DataFrame(data) user_instruction = "{Text} I have at least one apple" - filtered_df = df.sem_filter(user_instruction, strategy="cot") + filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True)) expected_df = pd.DataFrame({"Text": ["I had two apples, then I gave away one", "My friend gave me an apple"]}) assert filtered_df.equals(expected_df) @@ -302,8 +302,7 @@ def test_filter_operation_cot_fewshot(setup_models, model): user_instruction = "{Sequence} is increasing" filtered_df = df.sem_filter( user_instruction, - strategy="cot", - examples=examples_df, + prompt_strategy=PromptStrategy(cot=True, dems=examples_df), additional_cot_instructions="Assume the most typical or logical case.", ) expected_df = pd.DataFrame( @@ -339,7 +338,7 @@ def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): examples_df = pd.DataFrame(examples) user_instruction = "{Sequence} is increasing" - filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df) + filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True, dems=examples_df)) expected_df = pd.DataFrame( { "Sequence": [ diff --git a/examples/model_examples/deepseek.py b/examples/model_examples/deepseek.py index 74384d75..18ecd555 100644 --- a/examples/model_examples/deepseek.py +++ b/examples/model_examples/deepseek.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy # Set up model lm = LM(model="ollama/deepseek-r1:7b", temperature=0.6) @@ -33,6 +33,6 @@ ) # Run semantic mapping with CoT strategy -df = df.sem_map(user_instruction, return_explanations=True, strategy=ReasoningStrategy.ZS_COT) +df = df.sem_map(user_instruction, return_explanations=True, prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f244051a..a6431764 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -2,6 +2,7 @@ import lotus from lotus.models import LM +from lotus.types import PromptStrategy lm = LM(model="gpt-4o-mini") @@ -17,5 +18,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction, strategy="cot") +df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/filter_cot.py b/examples/op_examples/filter_cot.py index b5859083..872b255c 100644 --- a/examples/op_examples/filter_cot.py +++ b/examples/op_examples/filter_cot.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy lm = LM(model="gpt-4o-mini") @@ -20,9 +20,9 @@ } df = pd.DataFrame(data) user_instruction = "{Text} I have at least one apple" -# filtered_df = df.sem_filter(user_instruction, strategy="cot", return_all=True) +# Old way: filtered_df = df.sem_filter(user_instruction, strategy="cot", return_all=True) filtered_df = df.sem_filter( - user_instruction, strategy=ReasoningStrategy.ZS_COT, return_all=True, return_explanations=True + user_instruction, prompt_strategy=PromptStrategy(cot=True), return_all=True, return_explanations=True ) # uncomment to see reasoning chains print(filtered_df) diff --git a/examples/op_examples/map_deepseek_cot.py b/examples/op_examples/map_deepseek_cot.py index 831463fe..df0deae0 100644 --- a/examples/op_examples/map_deepseek_cot.py +++ b/examples/op_examples/map_deepseek_cot.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy lm = LM(model="ollama/deepseek-r1:7b", temperature=0.5) @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "What is a similar course to {Course Name}. Just give the course name." -df = df.sem_map(user_instruction, return_explanations=True, strategy=ReasoningStrategy.ZS_COT) +df = df.sem_map(user_instruction, return_explanations=True, prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/multimodal_ops/join.py b/examples/op_examples/multimodal_ops/join.py index 9e490ea9..c1fed3ad 100644 --- a/examples/op_examples/multimodal_ops/join.py +++ b/examples/op_examples/multimodal_ops/join.py @@ -5,6 +5,7 @@ import lotus from lotus.dtype_extensions import ImageArray from lotus.models import LM +from lotus.types import PromptStrategy lotus.settings.configure(lm=LM(model="gpt-4o-mini")) @@ -17,6 +18,6 @@ image_df = pd.DataFrame({"image": ImageArray(image_paths), "image_path": image_paths}) labels_df = pd.DataFrame({"label": [0, 1]}) -df = image_df.sem_join(labels_df, "{image} represents the number {label}", strategy="zs-cot") +df = image_df.sem_join(labels_df, "{image} represents the number {label}", prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/simple_reasoning.py b/examples/op_examples/simple_reasoning.py index 01a914bc..4ccd13bd 100644 --- a/examples/op_examples/simple_reasoning.py +++ b/examples/op_examples/simple_reasoning.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import DemonstrationConfig, ReasoningStrategy +from lotus.types import PromptStrategy # Configure the language model lm = LM(model="gpt-4o-mini") @@ -23,7 +23,9 @@ # Example 2: Chain-of-Thought reasoning print("=== 2. Chain-of-Thought Reasoning ===") -cot_df = df.sem_filter(user_instruction, strategy=ReasoningStrategy.CoT, return_explanations=True, return_all=True) +cot_df = df.sem_filter( + user_instruction, prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_all=True +) print(cot_df[["Course Name", "filter_label", "explanation_filter"]]) print() @@ -33,8 +35,7 @@ demo_df = df.sem_filter( user_instruction, - strategy=ReasoningStrategy.Demonstrations, - examples=examples, # Still works for backward compatibility + prompt_strategy=PromptStrategy(dems=examples), return_all=True, ) print(demo_df[["Course Name", "filter_label"]]) @@ -56,8 +57,7 @@ combined_df = df.sem_filter( user_instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, - examples=examples_with_reasoning, + prompt_strategy=PromptStrategy(cot=True, dems=examples_with_reasoning), return_explanations=True, return_all=True, ) @@ -66,12 +66,10 @@ # Example 5: Automatic demonstration bootstrapping print("=== 5. Bootstrapped Demonstrations ===") -bootstrap_config = DemonstrationConfig(bootstrap=True, num_demonstrations=2) bootstrap_df = df.sem_filter( user_instruction, - strategy=ReasoningStrategy.CoT_Demonstrations, - demonstration_config=bootstrap_config, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), return_explanations=True, return_all=True, ) diff --git a/examples/op_examples/top_k_deepseek_cot.py b/examples/op_examples/top_k_deepseek_cot.py index 55113b84..2884def4 100644 --- a/examples/op_examples/top_k_deepseek_cot.py +++ b/examples/op_examples/top_k_deepseek_cot.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy lm = LM(model="ollama/deepseek-r1:7b", temperature=0.6) lotus.settings.configure(lm=lm) @@ -24,7 +24,7 @@ "{Review} suggests that the user would recommend the product to others", K=2, method=method, - strategy=ReasoningStrategy.ZS_COT, + prompt_strategy=PromptStrategy(cot=True), return_stats=True, return_explanations=True, ) From 52c08c64fe48570a87b4108d2a39bc88a9a826d2 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Sat, 23 Aug 2025 00:32:47 -0700 Subject: [PATCH 6/8] fix tests --- .github/tests/lm_tests.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 3654960e..4949e742 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -289,11 +289,12 @@ def test_filter_operation_cot_fewshot(setup_models, model): } df = pd.DataFrame(data) examples = { - "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], - "Answer": [True, True, True], + "Sequence": ["1, 2, 3", "A, B, C", "penny, nickel, dime, quarter", "villiage, town, city"], + "Answer": [True, True, True, True], "Reasoning": [ "1, 2, 3 is an increasing sequence of numbers", - "penny, nickel, dime, quarter is an increasing sequence of coins", + "A, B, C is an increasing sequence of letters in alphabetical order", + "penny, nickel, dime, quarter is an increasing sequence of coins by value", "villiage, town, city is an increasing sequence of settlements", ], } @@ -302,7 +303,8 @@ def test_filter_operation_cot_fewshot(setup_models, model): user_instruction = "{Sequence} is increasing" filtered_df = df.sem_filter( user_instruction, - prompt_strategy=PromptStrategy(cot=True, dems=examples_df), + prompt_strategy=PromptStrategy(cot=True), + examples=examples_df, additional_cot_instructions="Assume the most typical or logical case.", ) expected_df = pd.DataFrame( @@ -332,13 +334,13 @@ def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): } df = pd.DataFrame(data) examples = { - "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], - "Answer": [True, True, True], + "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city", "A, B, C"], + "Answer": [True, True, True, True], } examples_df = pd.DataFrame(examples) user_instruction = "{Sequence} is increasing" - filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True, dems=examples_df)) + filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True), examples=examples_df) expected_df = pd.DataFrame( { "Sequence": [ From ebdada713696694efbbb2ad3b8ba40ed7a8b9eec Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Sat, 23 Aug 2025 00:46:51 -0700 Subject: [PATCH 7/8] fix --- .github/tests/multimodality_tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py index db566a38..ef477073 100644 --- a/.github/tests/multimodality_tests.py +++ b/.github/tests/multimodality_tests.py @@ -134,9 +134,9 @@ def test_topk_operation(setup_models, model): ] ) - strategies = ["quick", "heap", "naive"] - for strategy in strategies: - sorted_df = df.sem_topk(user_instruction, K=3, strategy=strategy) + methods = ["quick", "heap", "naive"] + for method in methods: + sorted_df = df.sem_topk(user_instruction, K=3, method=method) top_2_actual = set(sorted_df["image"].values) assert top_2_expected.issubset(top_2_actual) From 811e6f112d0b8c0b24fa4131f194288e2638ce80 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Tue, 26 Aug 2025 17:15:59 -0700 Subject: [PATCH 8/8] fix api and add comprehensive testing --- .github/tests/lm_tests.py | 103 +++++++++++- lotus/sem_ops/cascade_utils.py | 239 ++++++++++++++++++++++++++- lotus/sem_ops/sem_filter.py | 67 +++++--- lotus/sem_ops/sem_map.py | 38 ++++- lotus/types.py | 3 + tests/test_reasoning_strategies.py | 249 +++++++++++++++++++++++++++++ 6 files changed, 669 insertions(+), 30 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 4949e742..3c3e778e 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -160,7 +160,7 @@ def test_map_fewshot(setup_models, model): examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]} examples_df = pd.DataFrame(examples) user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation." - df = df.sem_map(user_instruction, examples=examples_df, suffix="State") + df = df.sem_map(user_instruction, prompt_strategy=PromptStrategy(dems=examples_df), suffix="State") # clean up the state names to be more robust to free-form text df["State"] = df["State"].str[-2:].str.lower() @@ -303,9 +303,9 @@ def test_filter_operation_cot_fewshot(setup_models, model): user_instruction = "{Sequence} is increasing" filtered_df = df.sem_filter( user_instruction, - prompt_strategy=PromptStrategy(cot=True), - examples=examples_df, - additional_cot_instructions="Assume the most typical or logical case.", + prompt_strategy=PromptStrategy( + cot=True, dems=examples_df, additional_cot_instructions="Assume the most typical or logical case." + ), ) expected_df = pd.DataFrame( { @@ -340,7 +340,7 @@ def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): examples_df = pd.DataFrame(examples) user_instruction = "{Sequence} is increasing" - filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True), examples=examples_df) + filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True, dems=examples_df)) expected_df = pd.DataFrame( { "Sequence": [ @@ -535,3 +535,96 @@ def test_custom_tokenizer(): tokens = custom_lm.count_tokens("Hello, world!") assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens assert tokens < 100 + + +################################################################################ +# Auto-bootstrapping tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_auto_bootstrapping_filter(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test auto-bootstrapping with filter operation + data = { + "Course Name": [ + "Linear Algebra", + "Poetry Writing", + "Calculus II", + "Art History", + "Statistics", + "Creative Writing", + "Machine Learning", + "Philosophy", + ] + } + df = pd.DataFrame(data) + user_instruction = "{Course Name} requires a lot of math" + + # Test auto-bootstrapping + result = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Should have some math courses identified + math_courses = result[result["filter_label"]]["Course Name"].tolist() + expected_math_courses = ["Linear Algebra", "Calculus II", "Statistics", "Machine Learning"] + assert any(course in expected_math_courses for course in math_courses) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_auto_bootstrapping_map(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test auto-bootstrapping with map operation + data = {"Course Name": ["Linear Algebra", "Poetry Writing", "Calculus II", "Art History"]} + df = pd.DataFrame(data) + user_instruction = "What is the difficulty level of {Course Name}? Answer: Beginner, Intermediate, or Advanced" + + # Test auto-bootstrapping + result = df.sem_map( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + ) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check that all difficulty levels are valid + for difficulty in result["_map"]: + assert difficulty.lower() in ["beginner", "intermediate", "advanced"] + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_auto_bootstrapping_with_teacher_model(setup_models, model): + lm = setup_models[model] + teacher_lm = setup_models[model] # Use same model as teacher for testing + lotus.settings.configure(lm=lm) + + data = {"Text": ["I am happy", "I am sad", "I am excited", "I am tired"]} + df = pd.DataFrame(data) + user_instruction = "{Text} expresses a positive emotion" + + # Test auto-bootstrapping with explicit teacher model + result = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2, teacher_lm=teacher_lm), + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + + # Should identify positive emotions + positive_texts = result[result["filter_label"]]["Text"].tolist() + assert any(text in ["I am happy", "I am excited"] for text in positive_texts) diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 6896e4c1..308c36e0 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -1,8 +1,11 @@ +from typing import Any + import numpy as np +import pandas as pd from numpy.typing import NDArray import lotus -from lotus.types import CascadeArgs +from lotus.types import CascadeArgs, PromptStrategy def importance_sampling( @@ -147,3 +150,237 @@ def calculate_tau_neg( def calibrate_sem_sim_join(true_score: list[float]) -> list[float]: true_score = list(np.clip(true_score, 0, 1)) return true_score + + +def bootstrap_demonstrations( + data: pd.DataFrame, + col_li: list[str], + user_instruction: str, + prompt_strategy: PromptStrategy, + operation_type: str = "filter", +) -> tuple[list[dict[str, Any]], list[Any], list[str] | None]: + """ + Bootstrap demonstrations automatically using a teacher model. + + This function samples diverse examples from the input data and uses a teacher + model to generate high-quality answers and reasoning for these examples. + + Args: + data (pd.DataFrame): The input DataFrame to sample from + col_li (list[str]): List of column names to include in the examples + user_instruction (str): The user instruction for the task + prompt_strategy (PromptStrategy): The prompt strategy containing bootstrapping config + operation_type (str): Type of operation ("filter", "map", "extract") + + Returns: + tuple: (examples_multimodal_data, examples_answers, cot_reasoning) + - examples_multimodal_data: List of example documents + - examples_answers: List of answers for the examples + - cot_reasoning: List of reasoning explanations (if CoT enabled) + """ + # Determine teacher model + teacher_lm = prompt_strategy.teacher_lm if prompt_strategy.teacher_lm is not None else lotus.settings.lm + if teacher_lm is None: + raise ValueError("No teacher model available for bootstrapping") + + # Sample diverse examples from the data + max_dems = min(prompt_strategy.max_dems, len(data)) + if max_dems == 0: + return [], [], None + + # Use random sampling + np.random.seed(42) + sample_indices = np.random.choice(len(data), size=max_dems, replace=False) + sample_data = data.iloc[sample_indices] + + lotus.logger.info(f"Bootstrapping {max_dems} demonstrations using teacher model") + + # Convert sampled data to multimodal format + from lotus.templates import task_instructions + + examples_multimodal_data = task_instructions.df2multimodal_info(sample_data, col_li) + + # Generate answers using teacher model + successful_examples_multimodal_data = [] + examples_answers: list[Any] = [] + cot_reasoning: list[str] | None = [] if prompt_strategy.cot else None + + for i, doc in enumerate(examples_multimodal_data): + try: + if operation_type == "filter": + # For filter operations, generate boolean answers + filter_answer, reasoning = _bootstrap_filter_example( + doc, user_instruction, teacher_lm, prompt_strategy.cot + ) + successful_examples_multimodal_data.append(doc) + examples_answers.append(filter_answer) + if cot_reasoning is not None: + cot_reasoning.append(reasoning) + + elif operation_type == "map": + # For map operations, generate string answers + map_answer, reasoning = _bootstrap_map_example(doc, user_instruction, teacher_lm, prompt_strategy.cot) + successful_examples_multimodal_data.append(doc) + examples_answers.append(map_answer) + if cot_reasoning is not None: + cot_reasoning.append(reasoning) + + else: + lotus.logger.warning(f"Bootstrapping not yet implemented for operation type: {operation_type}") + # Fallback to empty examples + return [], [], None + + except Exception as e: + lotus.logger.warning(f"Failed to bootstrap example {i}: {e}") + # Skip this example and continue with the next one + continue + + lotus.logger.info(f"Successfully bootstrapped {len(examples_answers)} demonstrations") + return successful_examples_multimodal_data, examples_answers, cot_reasoning + + +def _bootstrap_filter_example( + doc: dict[str, Any], user_instruction: str, teacher_lm: lotus.models.LM, use_cot: bool +) -> tuple[bool, str]: + """Bootstrap a single filter example using the teacher model.""" + + if use_cot: + # Request reasoning with the answer + messages = [ + { + "role": "system", + "content": "You are a helpful assistant that provides detailed reasoning for classification tasks.", + }, + { + "role": "user", + "content": f"""Please evaluate whether the following claim is true for the given context. + +Claim: {user_instruction} +Context: {doc.get('text', str(doc))} + +First provide your reasoning, then give your final answer as either "True" or "False". + +Format your response as: +Reasoning: [Your detailed reasoning here] +Answer: [True/False]""", + }, + ] + else: + # Just request the answer + messages = [ + { + "role": "system", + "content": "You are a helpful assistant that evaluates claims. Respond with only 'True' or 'False'.", + }, + { + "role": "user", + "content": f"""Claim: {user_instruction} +Context: {doc.get('text', str(doc))} + +Answer (True/False):""", + }, + ] + + # Get response from teacher model + lm_output = teacher_lm([messages]) + response = lm_output.outputs[0] + + if use_cot: + # Parse reasoning and answer + lines = response.strip().split("\n") + reasoning = "" + answer_str = "" + + for line in lines: + if line.startswith("Reasoning:"): + reasoning = line[10:].strip() + elif line.startswith("Answer:"): + answer_str = line[7:].strip() + + # Fallback parsing if format is not followed exactly + if not reasoning or not answer_str: + parts = response.lower().split("answer:") + if len(parts) >= 2: + reasoning = parts[0].strip() + answer_str = parts[1].strip() + else: + reasoning = response + answer_str = "true" if "true" in response.lower() else "false" + + # Convert to boolean + answer = answer_str.lower().strip() in ["true", "yes", "1"] + return answer, reasoning + else: + # Simple boolean conversion + answer = response.lower().strip() in ["true", "yes", "1"] + return answer, "Reasoning omitted" + + +def _bootstrap_map_example( + doc: dict[str, Any], user_instruction: str, teacher_lm: lotus.models.LM, use_cot: bool +) -> tuple[str, str]: + """Bootstrap a single map example using the teacher model.""" + + if use_cot: + # Request reasoning with the answer + messages = [ + { + "role": "system", + "content": "You are a helpful assistant that provides detailed reasoning for transformation tasks.", + }, + { + "role": "user", + "content": f"""Please follow the instruction for the given context. + +Instruction: {user_instruction} +Context: {doc.get('text', str(doc))} + +First provide your reasoning, then give your final answer. + +Format your response as: +Reasoning: [Your detailed reasoning here] +Answer: [Your answer here]""", + }, + ] + else: + # Just request the answer + messages = [ + {"role": "system", "content": "You are a helpful assistant that follows instructions precisely."}, + { + "role": "user", + "content": f"""Instruction: {user_instruction} +Context: {doc.get('text', str(doc))} + +Answer:""", + }, + ] + + # Get response from teacher model + lm_output = teacher_lm([messages]) + response = lm_output.outputs[0] + + if use_cot: + # Parse reasoning and answer + lines = response.strip().split("\n") + reasoning = "" + answer = "" + + for line in lines: + if line.startswith("Reasoning:"): + reasoning = line[10:].strip() + elif line.startswith("Answer:"): + answer = line[7:].strip() + + # Fallback parsing if format is not followed exactly + if not reasoning or not answer: + parts = response.split("Answer:") + if len(parts) >= 2: + reasoning = parts[0].strip() + answer = parts[1].strip() + else: + reasoning = "Reasoning omitted" + answer = response.strip() + + return answer, reasoning + else: + return response.strip(), "Reasoning omitted" diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 50597ce7..4de31503 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -17,7 +17,12 @@ ) from lotus.utils import show_safe_mode -from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds +from .cascade_utils import ( + bootstrap_demonstrations, + calibrate_llm_logprobs, + importance_sampling, + learn_cascade_thresholds, +) from .postprocessors import filter_postprocess @@ -34,7 +39,6 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", - additional_cot_instructions: str = "", ) -> SemanticFilterOutput: """ Filters a list of documents based on a natural language instruction using a language model. @@ -71,8 +75,6 @@ def sem_filter( processing. Defaults to True. progress_bar_desc (str, optional): Description for the progress bar. Defaults to "Filtering". - additional_cot_instructions (str, optional): Additional instructions for - chain-of-thought reasoning. Defaults to "". Returns: SemanticFilterOutput: An object containing the boolean filter outputs, raw @@ -98,7 +100,7 @@ def sem_filter( examples_answers, cot_reasoning, prompt_strategy, - reasoning_instructions=additional_cot_instructions, + reasoning_instructions=prompt_strategy.additional_cot_instructions if prompt_strategy is not None else "", ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) @@ -143,7 +145,6 @@ def learn_filter_cascade_thresholds( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, prompt_strategy: PromptStrategy | None = None, - additional_cot_instructions: str = "", ) -> tuple[float, float]: """ Automatically learns optimal cascade thresholds for filter operations. @@ -174,8 +175,6 @@ def learn_filter_cascade_thresholds( for the example documents. Defaults to None. prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. Defaults to None. - additional_cot_instructions (str, optional): Additional instructions for - chain-of-thought reasoning. Defaults to "". Returns: tuple[float, float]: A tuple containing the learned low and high thresholds @@ -206,7 +205,6 @@ def learn_filter_cascade_thresholds( prompt_strategy=prompt_strategy, safe_mode=False, progress_bar_desc="Running oracle for threshold learning", - additional_cot_instructions=additional_cot_instructions, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -265,8 +263,6 @@ class SemFilterDataframe: estimation. Defaults to False. progress_bar_desc (str, optional): Description for the progress bar. Defaults to "Filtering". - additional_cot_instructions (str, optional): Additional instructions - for chain-of-thought reasoning. Defaults to "". Returns: pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: A DataFrame @@ -348,7 +344,6 @@ def __call__( return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", - additional_cot_instructions: str = "", ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: if lotus.settings.lm is None: raise ValueError( @@ -375,14 +370,43 @@ def __call__( examples_answers = None cot_reasoning = None - # Handle examples - if examples is not None: - assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" - examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) - examples_answers = examples["Answer"].tolist() - - if prompt_strategy is not None and prompt_strategy.cot and "Reasoning" in examples.columns: - cot_reasoning = examples["Reasoning"].tolist() + # Handle examples from PromptStrategy.dems first, then fall back to examples parameter for backward compatibility + if prompt_strategy is not None and prompt_strategy.dems is not None: + if isinstance(prompt_strategy.dems, pd.DataFrame): + # User-provided examples + examples_source = prompt_strategy.dems + assert "Answer" in examples_source.columns, "Answer must be a column in examples dataframe" + examples_multimodal_data = task_instructions.df2multimodal_info(examples_source, col_li) + examples_answers = examples_source["Answer"].tolist() + + if prompt_strategy.cot and "Reasoning" in examples_source.columns: + cot_reasoning = examples_source["Reasoning"].tolist() + + elif prompt_strategy.dems == "auto": + # Auto-bootstrap demonstrations + try: + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=self._obj, + col_li=col_li, + user_instruction=formatted_usr_instr, + prompt_strategy=prompt_strategy, + operation_type="filter", + ) + except Exception as e: + lotus.logger.warning(f"Failed to bootstrap demonstrations: {e}") + # Fall back to no examples + examples_multimodal_data = None + examples_answers = None + cot_reasoning = None + elif examples is not None: + # Backward compatibility: use the old examples parameter + examples_source = examples + assert "Answer" in examples_source.columns, "Answer must be a column in examples dataframe" + examples_multimodal_data = task_instructions.df2multimodal_info(examples_source, col_li) + examples_answers = examples_source["Answer"].tolist() + + if prompt_strategy is not None and prompt_strategy.cot and "Reasoning" in examples_source.columns: + cot_reasoning = examples_source["Reasoning"].tolist() pos_cascade_threshold, neg_cascade_threshold = None, None if cascade_args is not None: @@ -463,7 +487,6 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, prompt_strategy=prompt_strategy, - additional_cot_instructions=additional_cot_instructions, ) stats["pos_cascade_threshold"] = pos_cascade_threshold @@ -523,7 +546,6 @@ def __call__( prompt_strategy=prompt_strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", - additional_cot_instructions=additional_cot_instructions, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -547,7 +569,6 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, - additional_cot_instructions=additional_cot_instructions, ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index a28fcbf3..6930380f 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -13,6 +13,7 @@ ) from lotus.utils import show_safe_mode +from .cascade_utils import bootstrap_demonstrations from .postprocessors import map_postprocess @@ -241,7 +242,42 @@ def __call__( examples_answers = None cot_reasoning = None - if examples is not None: + # Handle examples from PromptStrategy.dems first, then fall back to examples parameter for backward compatibility + if prompt_strategy is not None and prompt_strategy.dems is not None: + if isinstance(prompt_strategy.dems, pd.DataFrame): + # User-provided examples + examples_source = prompt_strategy.dems + assert "Answer" in examples_source.columns, "Answer must be a column in examples dataframe" + examples_multimodal_data = task_instructions.df2multimodal_info(examples_source, col_li) + examples_answers = examples_source["Answer"].tolist() + + if prompt_strategy.cot: + return_explanations = True + if "Reasoning" in examples_source.columns: + cot_reasoning = examples_source["Reasoning"].tolist() + else: + cot_reasoning = ["Reasoning omitted"] * len(examples_answers) + + elif prompt_strategy.dems == "auto": + # Auto-bootstrap demonstrations + try: + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=self._obj, + col_li=col_li, + user_instruction=formatted_usr_instr, + prompt_strategy=prompt_strategy, + operation_type="map", + ) + if prompt_strategy.cot and examples_answers: + return_explanations = True + except Exception as e: + lotus.logger.warning(f"Failed to bootstrap demonstrations: {e}") + # Fall back to no examples + examples_multimodal_data = None + examples_answers = None + cot_reasoning = None + elif examples is not None: + # Backward compatibility: use the old examples parameter assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() diff --git a/lotus/types.py b/lotus/types.py index 43a1613e..14ba1347 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -232,6 +232,8 @@ class PromptStrategy: max_dems (int): Maximum number of demonstrations to use. Defaults to 3. teacher_lm: Language model to use for bootstrapping demonstrations. If None, uses the main model. Defaults to None. + additional_cot_instructions (str): Additional instructions for + chain-of-thought reasoning. Defaults to "". Example: >>> # Chain-of-thought with user-provided demonstrations @@ -252,3 +254,4 @@ class PromptStrategy: dems: pd.DataFrame | str | None = None max_dems: int = 3 teacher_lm: Any = None # lotus.models.LM type, but avoiding circular import + additional_cot_instructions: str = "" diff --git a/tests/test_reasoning_strategies.py b/tests/test_reasoning_strategies.py index b5a69485..bac85650 100644 --- a/tests/test_reasoning_strategies.py +++ b/tests/test_reasoning_strategies.py @@ -315,6 +315,132 @@ def test_bootstrapping_with_oracle_model(self, sample_courses_df, setup_model): assert "filter_label" in result.columns + def test_bootstrapping_map_operation(self, sample_courses_df, setup_model): + """Test auto-bootstrapping with map operation""" + df = sample_courses_df.head(4) # Use fewer rows for faster testing + instruction = "What is the difficulty level of {Course Name}? Answer: Beginner, Intermediate, or Advanced" + + result = df.sem_map( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + ) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check that all difficulty levels are reasonable + for difficulty in result["_map"]: + assert isinstance(difficulty, str) + assert len(difficulty) > 0 + + def test_bootstrapping_without_cot(self, sample_courses_df, setup_model): + """Test auto-bootstrapping without chain-of-thought""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=False, dems="auto", max_dems=2), + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + # Should not have explanations when CoT is disabled + assert "explanation_filter" not in result.columns + + def test_bootstrapping_max_dems_limit(self, sample_courses_df, setup_model): + """Test that max_dems is respected""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Request more demonstrations than available data + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=100), + return_all=True, + ) + + # Should still work and be limited by available data + assert "filter_label" in result.columns + + def test_bootstrapping_with_additional_instructions(self, sample_courses_df, setup_model): + """Test auto-bootstrapping with additional CoT instructions""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy( + cot=True, + dems="auto", + max_dems=2, + additional_cot_instructions="Consider the mathematical content and prerequisites carefully.", + ), + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check that explanations are provided + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 + + def test_bootstrapping_error_handling(self, sample_courses_df, setup_model): + """Test that bootstrapping errors are handled gracefully""" + df = sample_courses_df.head(2) + instruction = "{Course Name} requires a lot of math" + + # This should work even if bootstrapping encounters issues + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=1), + return_all=True, + ) + + # Should still produce results + assert "filter_label" in result.columns + assert len(result) == len(df) + + def test_bootstrapping_empty_data_handling(self, setup_model): + """Test bootstrapping with empty or minimal data""" + empty_df = pd.DataFrame({"Course Name": []}) + instruction = "{Course Name} requires a lot of math" + + # Should handle empty data gracefully + result = empty_df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_all=True, + ) + + assert "filter_label" in result.columns + assert len(result) == 0 + + def test_bootstrapping_integration_with_existing_features(self, sample_courses_df, setup_model): + """Test that auto-bootstrapping works with other features like return_stats""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + result, stats = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + return_all=True, + return_stats=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + assert isinstance(stats, dict) + # ============================================================================= # Backward Compatibility Tests # ============================================================================= @@ -502,3 +628,126 @@ def test_strategy_with_return_options(self, sample_courses_df, setup_model): # Test that filtering works correctly positive_results = result[result["filter_label"]] assert len(positive_results) >= 0 # Should have some math courses + + # ============================================================================= + # Direct Bootstrap Function Tests + # ============================================================================= + + def test_bootstrap_demonstrations_function_direct(self, sample_courses_df, setup_model): + """Test the bootstrap_demonstrations function directly""" + import lotus.nl_expression as nle + from lotus.sem_ops.cascade_utils import bootstrap_demonstrations + + df = sample_courses_df.head(4) + user_instruction = "{Course Name} requires a lot of math" + col_li = nle.parse_cols(user_instruction) + formatted_instruction = nle.nle2str(user_instruction, col_li) + + prompt_strategy = PromptStrategy(cot=True, dems="auto", max_dems=2, teacher_lm=setup_model) + + # Test the function directly + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy, + operation_type="filter", + ) + + # Check outputs + assert isinstance(examples_multimodal_data, list) + assert isinstance(examples_answers, list) + assert len(examples_multimodal_data) == len(examples_answers) + assert len(examples_multimodal_data) <= prompt_strategy.max_dems + + # Check CoT reasoning + if prompt_strategy.cot: + assert cot_reasoning is not None + assert isinstance(cot_reasoning, list) + assert len(cot_reasoning) == len(examples_answers) + for reasoning in cot_reasoning: + assert isinstance(reasoning, str) + assert len(reasoning) > 0 + + # Check answer types for filter operation + for answer in examples_answers: + assert isinstance(answer, bool) + + def test_bootstrap_demonstrations_map_operation_direct(self, sample_courses_df, setup_model): + """Test bootstrap_demonstrations function for map operation""" + import lotus.nl_expression as nle + from lotus.sem_ops.cascade_utils import bootstrap_demonstrations + + df = sample_courses_df.head(3) + user_instruction = "What is the difficulty level of {Course Name}?" + col_li = nle.parse_cols(user_instruction) + formatted_instruction = nle.nle2str(user_instruction, col_li) + + prompt_strategy = PromptStrategy(cot=False, dems="auto", max_dems=2, teacher_lm=setup_model) + + # Test the function for map operation + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy, + operation_type="map", + ) + + # Check outputs + assert isinstance(examples_multimodal_data, list) + assert isinstance(examples_answers, list) + assert len(examples_multimodal_data) == len(examples_answers) + + # Check answer types for map operation + for answer in examples_answers: + assert isinstance(answer, str) + assert len(answer) > 0 + + # No CoT reasoning expected + assert cot_reasoning is None + + def test_bootstrap_demonstrations_error_cases(self, sample_courses_df, setup_model): + """Test bootstrap_demonstrations function error handling""" + import lotus.nl_expression as nle + from lotus.sem_ops.cascade_utils import bootstrap_demonstrations + + df = sample_courses_df.head(2) + user_instruction = "{Course Name} requires a lot of math" + col_li = nle.parse_cols(user_instruction) + formatted_instruction = nle.nle2str(user_instruction, col_li) + + # Test with no teacher model - temporarily clear lotus.settings.lm + original_lm = lotus.settings.lm + lotus.settings.lm = None + + try: + prompt_strategy_no_teacher = PromptStrategy(cot=True, dems="auto", max_dems=2, teacher_lm=None) + + with pytest.raises(ValueError, match="No teacher model available for bootstrapping"): + bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy_no_teacher, + operation_type="filter", + ) + finally: + # Restore the original LM + lotus.settings.lm = original_lm + + # Test with unsupported operation type + prompt_strategy = PromptStrategy(cot=True, dems="auto", max_dems=1, teacher_lm=setup_model) + + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy, + operation_type="unsupported_operation", + ) + + # Should return empty results for unsupported operations + assert examples_multimodal_data == [] + assert examples_answers == [] + assert cot_reasoning is None