diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 7fe73f18..b97f2871 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -7,7 +7,7 @@ import lotus from lotus.models import LM, SentenceTransformersRM -from lotus.types import CascadeArgs +from lotus.types import CascadeArgs, PromptStrategy from lotus.vector_store import FaissVS ################################################################################ @@ -161,7 +161,7 @@ def test_map_fewshot(setup_models, model): examples = {"School": ["Stanford", "MIT"], "Answer": ["CA", "MA"]} examples_df = pd.DataFrame(examples) user_instruction = "What state is {School} in? Respond only with the two-letter abbreviation." - df = df.sem_map(user_instruction, examples=examples_df, suffix="State") + df = df.sem_map(user_instruction, prompt_strategy=PromptStrategy(dems=examples_df), suffix="State") # clean up the state names to be more robust to free-form text df["State"] = df["State"].str[-2:].str.lower() @@ -283,7 +283,7 @@ def test_filter_operation_cot(setup_models, model): } df = pd.DataFrame(data) user_instruction = "{Text} I have at least one apple" - filtered_df = df.sem_filter(user_instruction, strategy="cot") + filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True)) expected_df = pd.DataFrame({"Text": ["I had two apples, then I gave away one", "My friend gave me an apple"]}) assert filtered_df.equals(expected_df) @@ -303,11 +303,12 @@ def test_filter_operation_cot_fewshot(setup_models, model): } df = pd.DataFrame(data) examples = { - "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], - "Answer": [True, True, True], + "Sequence": ["1, 2, 3", "A, B, C", "penny, nickel, dime, quarter", "villiage, town, city"], + "Answer": [True, True, True, True], "Reasoning": [ "1, 2, 3 is an increasing sequence of numbers", - "penny, nickel, dime, quarter is an increasing sequence of coins", + "A, B, C is an increasing sequence of letters in alphabetical order", + "penny, nickel, dime, quarter is an increasing sequence of coins by value", "villiage, town, city is an increasing sequence of settlements", ], } @@ -316,9 +317,9 @@ def test_filter_operation_cot_fewshot(setup_models, model): user_instruction = "{Sequence} is increasing" filtered_df = df.sem_filter( user_instruction, - strategy="cot", - examples=examples_df, - additional_cot_instructions="Assume the most typical or logical case.", + prompt_strategy=PromptStrategy( + cot=True, dems=examples_df, additional_cot_instructions="Assume the most typical or logical case." + ), ) expected_df = pd.DataFrame( { @@ -347,13 +348,13 @@ def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): } df = pd.DataFrame(data) examples = { - "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], - "Answer": [True, True, True], + "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city", "A, B, C"], + "Answer": [True, True, True, True], } examples_df = pd.DataFrame(examples) user_instruction = "{Sequence} is increasing" - filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df) + filtered_df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True, dems=examples_df)) expected_df = pd.DataFrame( { "Sequence": [ @@ -550,6 +551,99 @@ def test_custom_tokenizer(): assert tokens < 100 +################################################################################ +# Auto-bootstrapping tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_auto_bootstrapping_filter(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test auto-bootstrapping with filter operation + data = { + "Course Name": [ + "Linear Algebra", + "Poetry Writing", + "Calculus II", + "Art History", + "Statistics", + "Creative Writing", + "Machine Learning", + "Philosophy", + ] + } + df = pd.DataFrame(data) + user_instruction = "{Course Name} requires a lot of math" + + # Test auto-bootstrapping + result = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Should have some math courses identified + math_courses = result[result["filter_label"]]["Course Name"].tolist() + expected_math_courses = ["Linear Algebra", "Calculus II", "Statistics", "Machine Learning"] + assert any(course in expected_math_courses for course in math_courses) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_auto_bootstrapping_map(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test auto-bootstrapping with map operation + data = {"Course Name": ["Linear Algebra", "Poetry Writing", "Calculus II", "Art History"]} + df = pd.DataFrame(data) + user_instruction = "What is the difficulty level of {Course Name}? Answer: Beginner, Intermediate, or Advanced" + + # Test auto-bootstrapping + result = df.sem_map( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + ) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check that all difficulty levels are valid + for difficulty in result["_map"]: + assert difficulty.lower() in ["beginner", "intermediate", "advanced"] + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_auto_bootstrapping_with_teacher_model(setup_models, model): + lm = setup_models[model] + teacher_lm = setup_models[model] # Use same model as teacher for testing + lotus.settings.configure(lm=lm) + + data = {"Text": ["I am happy", "I am sad", "I am excited", "I am tired"]} + df = pd.DataFrame(data) + user_instruction = "{Text} expresses a positive emotion" + + # Test auto-bootstrapping with explicit teacher model + result = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2, teacher_lm=teacher_lm), + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + + # Should identify positive emotions + positive_texts = result[result["filter_label"]]["Text"].tolist() + assert any(text in ["I am happy", "I am excited"] for text in positive_texts) + + ################################################################################ # Eval tests ################################################################################ diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py index db566a38..ef477073 100644 --- a/.github/tests/multimodality_tests.py +++ b/.github/tests/multimodality_tests.py @@ -134,9 +134,9 @@ def test_topk_operation(setup_models, model): ] ) - strategies = ["quick", "heap", "naive"] - for strategy in strategies: - sorted_df = df.sem_topk(user_instruction, K=3, strategy=strategy) + methods = ["quick", "heap", "naive"] + for method in methods: + sorted_df = df.sem_topk(user_instruction, K=3, method=method) top_2_actual = set(sorted_df["image"].values) assert top_2_expected.issubset(top_2_actual) diff --git a/examples/model_examples/deepseek.py b/examples/model_examples/deepseek.py index 74384d75..18ecd555 100644 --- a/examples/model_examples/deepseek.py +++ b/examples/model_examples/deepseek.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy # Set up model lm = LM(model="ollama/deepseek-r1:7b", temperature=0.6) @@ -33,6 +33,6 @@ ) # Run semantic mapping with CoT strategy -df = df.sem_map(user_instruction, return_explanations=True, strategy=ReasoningStrategy.ZS_COT) +df = df.sem_map(user_instruction, return_explanations=True, prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f244051a..a6431764 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -2,6 +2,7 @@ import lotus from lotus.models import LM +from lotus.types import PromptStrategy lm = LM(model="gpt-4o-mini") @@ -17,5 +18,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction, strategy="cot") +df = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/filter_cot.py b/examples/op_examples/filter_cot.py index b5859083..872b255c 100644 --- a/examples/op_examples/filter_cot.py +++ b/examples/op_examples/filter_cot.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy lm = LM(model="gpt-4o-mini") @@ -20,9 +20,9 @@ } df = pd.DataFrame(data) user_instruction = "{Text} I have at least one apple" -# filtered_df = df.sem_filter(user_instruction, strategy="cot", return_all=True) +# Old way: filtered_df = df.sem_filter(user_instruction, strategy="cot", return_all=True) filtered_df = df.sem_filter( - user_instruction, strategy=ReasoningStrategy.ZS_COT, return_all=True, return_explanations=True + user_instruction, prompt_strategy=PromptStrategy(cot=True), return_all=True, return_explanations=True ) # uncomment to see reasoning chains print(filtered_df) diff --git a/examples/op_examples/map_deepseek_cot.py b/examples/op_examples/map_deepseek_cot.py index 831463fe..df0deae0 100644 --- a/examples/op_examples/map_deepseek_cot.py +++ b/examples/op_examples/map_deepseek_cot.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy lm = LM(model="ollama/deepseek-r1:7b", temperature=0.5) @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "What is a similar course to {Course Name}. Just give the course name." -df = df.sem_map(user_instruction, return_explanations=True, strategy=ReasoningStrategy.ZS_COT) +df = df.sem_map(user_instruction, return_explanations=True, prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/multimodal_ops/join.py b/examples/op_examples/multimodal_ops/join.py index 9e490ea9..c1fed3ad 100644 --- a/examples/op_examples/multimodal_ops/join.py +++ b/examples/op_examples/multimodal_ops/join.py @@ -5,6 +5,7 @@ import lotus from lotus.dtype_extensions import ImageArray from lotus.models import LM +from lotus.types import PromptStrategy lotus.settings.configure(lm=LM(model="gpt-4o-mini")) @@ -17,6 +18,6 @@ image_df = pd.DataFrame({"image": ImageArray(image_paths), "image_path": image_paths}) labels_df = pd.DataFrame({"label": [0, 1]}) -df = image_df.sem_join(labels_df, "{image} represents the number {label}", strategy="zs-cot") +df = image_df.sem_join(labels_df, "{image} represents the number {label}", prompt_strategy=PromptStrategy(cot=True)) print(df) diff --git a/examples/op_examples/simple_reasoning.py b/examples/op_examples/simple_reasoning.py new file mode 100644 index 00000000..4ccd13bd --- /dev/null +++ b/examples/op_examples/simple_reasoning.py @@ -0,0 +1,78 @@ +import pandas as pd + +import lotus +from lotus.models import LM +from lotus.types import PromptStrategy + +# Configure the language model +lm = LM(model="gpt-4o-mini") +lotus.settings.configure(lm=lm) + +# Sample data +data = { + "Course Name": ["Linear Algebra", "Poetry Writing", "Calculus II", "Art History", "Statistics", "Creative Writing"] +} +df = pd.DataFrame(data) +user_instruction = "{Course Name} requires a lot of math" + +# Example 1: Basic filtering (no reasoning) +print("=== 1. Basic Filtering ===") +basic_df = df.sem_filter(user_instruction, return_all=True) +print(basic_df[["Course Name", "filter_label"]]) +print() + +# Example 2: Chain-of-Thought reasoning +print("=== 2. Chain-of-Thought Reasoning ===") +cot_df = df.sem_filter( + user_instruction, prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_all=True +) +print(cot_df[["Course Name", "filter_label", "explanation_filter"]]) +print() + +# Example 3: Few-shot examples (demonstrations) +print("=== 3. Few-shot Examples ===") +examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature", "Physics"], "Answer": [True, False, True]}) + +demo_df = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(dems=examples), + return_all=True, +) +print(demo_df[["Course Name", "filter_label"]]) +print() + +# Example 4: CoT + Demonstrations (the powerful combination) +print("=== 4. CoT + Demonstrations ===") +examples_with_reasoning = pd.DataFrame( + { + "Course Name": ["Machine Learning", "Literature", "Physics"], + "Answer": [True, False, True], + "Reasoning": [ + "Machine Learning requires linear algebra, calculus, and statistics", + "Literature focuses on reading, writing, and analysis - no math required", + "Physics is fundamentally mathematical with equations and calculations", + ], + } +) + +combined_df = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems=examples_with_reasoning), + return_explanations=True, + return_all=True, +) +print(combined_df[["Course Name", "filter_label", "explanation_filter"]]) +print() + +# Example 5: Automatic demonstration bootstrapping +print("=== 5. Bootstrapped Demonstrations ===") + +bootstrap_df = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + return_all=True, +) +print("Automatically generated demonstrations:") +print(bootstrap_df[["Course Name", "filter_label", "explanation_filter"]]) +print() diff --git a/examples/op_examples/top_k_deepseek_cot.py b/examples/op_examples/top_k_deepseek_cot.py index 55113b84..2884def4 100644 --- a/examples/op_examples/top_k_deepseek_cot.py +++ b/examples/op_examples/top_k_deepseek_cot.py @@ -2,7 +2,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy lm = LM(model="ollama/deepseek-r1:7b", temperature=0.6) lotus.settings.configure(lm=lm) @@ -24,7 +24,7 @@ "{Review} suggests that the user would recommend the product to others", K=2, method=method, - strategy=ReasoningStrategy.ZS_COT, + prompt_strategy=PromptStrategy(cot=True), return_stats=True, return_explanations=True, ) diff --git a/lotus/__init__.py b/lotus/__init__.py index 858fed48..f55c8907 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -23,6 +23,7 @@ from lotus.evals import llm_as_judge, pairwise_judge from lotus.web_search import web_search, WebSearchCorpus from lotus.settings import settings # type: ignore[attr-defined] +from lotus.types import PromptStrategy logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO) @@ -52,6 +53,7 @@ "dtype_extensions", "web_search", "WebSearchCorpus", + "PromptStrategy", "llm_as_judge", "pairwise_judge", ] diff --git a/lotus/cache.py b/lotus/cache.py index a14fceed..35bf002d 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -43,7 +43,7 @@ def wrapper(self, *args, **kwargs): def serialize(value: Any) -> Any: """ Serialize a value into a JSON-serializable format. - Supports basic types, pandas DataFrames, and objects with a `dict` or `__dict__` method. + Supports basic types, pandas DataFrames, Enums, and objects with a `dict` or `__dict__` method. """ if value is None or isinstance(value, (str, int, float, bool)): return value @@ -57,6 +57,8 @@ def serialize(value: Any) -> Any: return [serialize(item) for item in value] elif isinstance(value, dict): return {key: serialize(val) for key, val in value.items()} + elif isinstance(value, Enum): + return {"__enum__": value.__class__.__name__, "value": value.name} elif hasattr(value, "dict"): return value.dict() elif hasattr(value, "__dict__"): diff --git a/lotus/evals/llm_as_judge.py b/lotus/evals/llm_as_judge.py index 930d52e6..bd7e4610 100644 --- a/lotus/evals/llm_as_judge.py +++ b/lotus/evals/llm_as_judge.py @@ -10,7 +10,7 @@ from lotus.sem_ops.postprocessors import map_postprocess from lotus.sem_ops.sem_map import sem_map from lotus.templates import task_instructions -from lotus.types import ReasoningStrategy, SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import PromptStrategy, SemanticMapOutput, SemanticMapPostprocessOutput def llm_as_judge( @@ -24,7 +24,7 @@ def llm_as_judge( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[str] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), safe_mode: bool = False, progress_bar_desc: str = "Evaluating", **model_kwargs: Any, @@ -56,7 +56,7 @@ def llm_as_judge( cot_reasoning (list[str] | None, optional): Chain-of-thought reasoning for the example documents. Used when strategy includes COT reasoning. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -73,7 +73,7 @@ def llm_as_judge( "Your job is to judge the output given the criteria, context and grading scale." ) - if response_format is not None and strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT]: + if response_format is not None and prompt_strategy.cot: raise ValueError( "Response format is not supported for COT or ZS_COT strategies. Use a non-COT strategy instead with reasoning field in the response format." ) @@ -92,7 +92,7 @@ def llm_as_judge( examples_multimodal_data, examples_answers, cot_reasoning, - strategy, + prompt_strategy, safe_mode, progress_bar_desc, response_format=response_format, @@ -139,7 +139,7 @@ class LLMAsJudgeDataframe: examples (pd.DataFrame | None, optional): Example DataFrame for few-shot learning. Should have the same column structure as the input DataFrame plus an "Answer" column. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. extra_cols_to_include (list[str] | None, optional): Extra columns to include in the input for judge. Defaults to None. @@ -197,7 +197,7 @@ def __call__( suffix: str = "_judge", examples: pd.DataFrame | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, extra_cols_to_include: list[str] | None = None, safe_mode: bool = False, progress_bar_desc: str = "Evaluating", @@ -208,7 +208,7 @@ def __call__( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" ) - if response_format is not None and strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT]: + if response_format is not None and prompt_strategy is not None and prompt_strategy.cot: raise ValueError( "Response format is not supported for COT or ZS_COT strategies. Use a non-COT strategy instead with reasoning field in the response format." ) @@ -239,7 +239,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT or strategy == ReasoningStrategy.ZS_COT: + if prompt_strategy is not None and prompt_strategy.cot: cot_reasoning = examples["Reasoning"].tolist() output = llm_as_judge( @@ -253,7 +253,7 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=PromptStrategy(), safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, **model_kwargs, diff --git a/lotus/evals/pairwise_judge.py b/lotus/evals/pairwise_judge.py index 3f2a43c9..c538169b 100644 --- a/lotus/evals/pairwise_judge.py +++ b/lotus/evals/pairwise_judge.py @@ -7,7 +7,7 @@ import lotus.models from lotus.cache import operator_cache from lotus.sem_ops.postprocessors import map_postprocess -from lotus.types import ReasoningStrategy, SemanticMapPostprocessOutput +from lotus.types import PromptStrategy, SemanticMapPostprocessOutput @pd.api.extensions.register_dataframe_accessor("pairwise_judge") @@ -38,7 +38,7 @@ class PairwiseJudgeDataframe: examples (pd.DataFrame | None, optional): Example DataFrame for few-shot learning. Should have the same column structure as the input DataFrame plus an "Answer" column. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -82,7 +82,7 @@ def __call__( suffix: str = "_judge", examples: pd.DataFrame | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, safe_mode: bool = False, progress_bar_desc: str = "Evaluating", **model_kwargs: Any, @@ -92,7 +92,7 @@ def __call__( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" ) - if response_format is not None and strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT]: + if response_format is not None and prompt_strategy is not None and prompt_strategy.cot: raise ValueError( "Response format is not supported for COT or ZS_COT strategies. Use a non-COT strategy instead with reasoning field in the response format." ) @@ -120,7 +120,7 @@ def __call__( suffix=suffix + "_" + c1 + "_" + c2, examples=examples, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, **model_kwargs, @@ -150,7 +150,7 @@ def __call__( suffix=suffix, examples=examples, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, extra_cols_to_include=[col1, col2], safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, diff --git a/lotus/sem_ops/cascade_utils.py b/lotus/sem_ops/cascade_utils.py index 6896e4c1..720a15a3 100644 --- a/lotus/sem_ops/cascade_utils.py +++ b/lotus/sem_ops/cascade_utils.py @@ -1,8 +1,11 @@ +from typing import Any + import numpy as np +import pandas as pd from numpy.typing import NDArray import lotus -from lotus.types import CascadeArgs +from lotus.types import CascadeArgs, PromptStrategy def importance_sampling( @@ -147,3 +150,113 @@ def calculate_tau_neg( def calibrate_sem_sim_join(true_score: list[float]) -> list[float]: true_score = list(np.clip(true_score, 0, 1)) return true_score + + +def bootstrap_demonstrations( + data: pd.DataFrame, + col_li: list[str], + user_instruction: str, + prompt_strategy: PromptStrategy = PromptStrategy(), + operation_type: str = "filter", +) -> tuple[list[dict[str, Any]], list[Any], list[str] | None]: + """ + Bootstrap demonstrations automatically using a teacher model and sem_ops. + + This function samples diverse examples from the input data and uses a teacher + model with the appropriate sem_op to generate high-quality answers and reasoning. + + Args: + data (pd.DataFrame): The input DataFrame to sample from + col_li (list[str]): List of column names to include in the examples + user_instruction (str): The user instruction for the task + prompt_strategy (PromptStrategy): The prompt strategy containing bootstrapping config + operation_type (str): Type of operation ("filter", "map", "extract") + + Returns: + tuple: (examples_multimodal_data, examples_answers, cot_reasoning) + - examples_multimodal_data: List of example documents + - examples_answers: List of answers for the examples + - cot_reasoning: List of reasoning explanations (if CoT enabled) + """ + # Determine teacher model + teacher_lm = prompt_strategy.teacher_lm if prompt_strategy.teacher_lm is not None else lotus.settings.lm + if teacher_lm is None: + raise ValueError("No teacher model available for bootstrapping") + + # Sample diverse examples from the data + max_dems = min(prompt_strategy.max_dems, len(data)) + if max_dems == 0: + return [], [], None + + # Use random sampling + np.random.seed(42) + sample_indices = np.random.choice(len(data), size=max_dems, replace=False) + sample_data = data.iloc[sample_indices] + + lotus.logger.info(f"Bootstrapping {max_dems} demonstrations using teacher model with sem_ops") + + # Convert sampled data to multimodal format + from lotus.templates import task_instructions + + examples_multimodal_data = task_instructions.df2multimodal_info(sample_data, col_li) + + # Create teacher prompt strategy for bootstrapping (without auto-demos to avoid recursion) + teacher_prompt_strategy = PromptStrategy( + cot=prompt_strategy.cot, dems=None, additional_cot_instructions=prompt_strategy.additional_cot_instructions + ) + + examples_answers: list[Any] = [] + cot_reasoning: list[str] | None = None + + try: + if operation_type == "filter": + # Use sem_filter with teacher model + from lotus.sem_ops.sem_filter import sem_filter + + # For filter, we need to call the function directly since DataFrame method has different signature + filter_result = sem_filter( + docs=examples_multimodal_data, + model=teacher_lm, + user_instruction=user_instruction, + prompt_strategy=teacher_prompt_strategy, + show_progress_bar=False, + ) + examples_answers = filter_result.outputs + if prompt_strategy.cot and filter_result.explanations is not None: + filtered_explanations = [exp for exp in filter_result.explanations if exp is not None] + cot_reasoning = filtered_explanations if filtered_explanations else None + else: + cot_reasoning = None + + elif operation_type == "map": + # Use sem_map with teacher model + from lotus.sem_ops.sem_map import sem_map + + # For map, we need to call the function directly since DataFrame method has different signature + map_result = sem_map( + docs=examples_multimodal_data, + model=teacher_lm, + user_instruction=user_instruction, + prompt_strategy=teacher_prompt_strategy, + ) + examples_answers = map_result.outputs + if prompt_strategy.cot and map_result.explanations is not None: + filtered_explanations = [exp for exp in map_result.explanations if exp is not None] + cot_reasoning = filtered_explanations if filtered_explanations else None + else: + cot_reasoning = None + + elif operation_type == "extract": + lotus.logger.warning("Bootstrapping for extract operation requires output_cols parameter") + return [], [], None + + else: + lotus.logger.warning(f"Bootstrapping not yet implemented for operation type: {operation_type}") + return [], [], None + + except Exception as e: + lotus.logger.warning(f"Failed to bootstrap demonstrations: {e}") + return [], [], None + + lotus.logger.info(f"Successfully bootstrapped {len(examples_answers)} demonstrations") + return examples_multimodal_data, examples_answers, cot_reasoning diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 34a95bb0..3d6e53c6 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -10,7 +10,7 @@ def cot_postprocessor(llm_answers: list[str], for_extract: bool = False): - outputs: list[Union[str, Dict[str, Any], None]] = [] + outputs: list[Union[str, Dict[Any, Any], None]] = [] explanations: list[str | None] = [] for llm_answer in llm_answers: reasoning_idx = llm_answer.find("Reasoning:\n") @@ -53,7 +53,7 @@ def deepseek_cot_postprocessor(llm_answers: list[str], for_extract: bool = False Returns: Tuple: (outputs, explanations) """ - outputs: list[Union[str, Dict[str, Any], None]] = [] + outputs: list[Union[str, Dict[Any, Any], None]] = [] explanations: list[str | None] = [] for llm_answer in llm_answers: @@ -65,11 +65,14 @@ def deepseek_cot_postprocessor(llm_answers: list[str], for_extract: bool = False if think_start != -1 and think_end != -1: # Extract the reasoning between the tags reasoning = llm_answer[think_start + len("") : think_end].strip() - answer = llm_answer[answer_start + len("Answer:") :].strip() - answer = answer.strip() + if answer_start != -1: + answer = llm_answer[answer_start + len("Answer:") :].strip() + else: + # No "Answer:" found, look for content after + answer = llm_answer[think_end + len("") :].strip() - # If ther is nothing after tag, check if the answer is at the beginning + # If there is nothing after tag, check if the answer is at the beginning if not answer and think_start > 0: answer = llm_answer[:think_start].strip() @@ -183,6 +186,7 @@ def filter_postprocess( llm_answers: list[str], model: lotus.models.LM, default: bool = True, + cot_reasoning: bool = False, ) -> SemanticFilterPostprocessOutput: """ Postprocess the output of the filter operator. @@ -210,8 +214,12 @@ def process_outputs(answer): lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") return default - postprocessor = get_cot_postprocessor(model) - outputs, explanations = postprocessor(llm_answers) + if cot_reasoning: + postprocessor = get_cot_postprocessor(model) + outputs, explanations = postprocessor(llm_answers) + else: + outputs = llm_answers + explanations = [None] * len(llm_answers) boolean_outputs = [process_outputs(answer) for answer in outputs] diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 0e60bbb6..dd5eb1e8 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -6,7 +6,7 @@ from lotus.cache import operator_cache from lotus.models import LM from lotus.templates import task_instructions -from lotus.types import LMOutput, ReasoningStrategy, SemanticExtractOutput, SemanticExtractPostprocessOutput +from lotus.types import LMOutput, PromptStrategy, SemanticExtractOutput, SemanticExtractPostprocessOutput from lotus.utils import show_safe_mode from .postprocessors import extract_postprocess @@ -21,7 +21,7 @@ def sem_extract( safe_mode: bool = False, progress_bar_desc: str = "Extracting", return_explanations: bool = False, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> SemanticExtractOutput: """ Extracts structured attributes and values from a list of documents using a language model. @@ -43,7 +43,7 @@ def sem_extract( extract_quotes (bool, optional): Whether to extract supporting quotes from the source text for each extracted value. Defaults to False. postprocessor (Callable, optional): A function to post-process the model - outputs. Should take (outputs, model, return_explanations) and return + outputs. Should take (outputs, model, cot_reasoning) and return SemanticExtractPostprocessOutput. Defaults to extract_postprocess. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -52,8 +52,9 @@ def sem_extract( return_explanations (bool, optional): Whether to return explanations for the extraction decisions. Useful for debugging and understanding model reasoning. Defaults to False. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to PromptStrategy(). + Returns: SemanticExtractOutput: An object containing the extracted outputs, raw @@ -69,11 +70,25 @@ def sem_extract( >>> output_cols = {"sentiment": "positive/negative/neutral", "rating": "1-5 scale"} >>> result = sem_extract(docs, model, output_cols) >>> print(result.outputs) # [{"sentiment": "positive", "rating": "5"}] + + >>> # Using PromptStrategy with chain-of-thought + >>> from lotus.types import PromptStrategy + >>> strat = PromptStrategy(cot=True) + >>> result = sem_extract(docs, model, output_cols, prompt_strategy=strat) + + >>> # Using PromptStrategy with demonstrations + >>> import pandas as pd + >>> examples = pd.DataFrame({ + ... 'text': ['Great product!', 'Terrible service'], + ... 'Answer': [{'sentiment': 'positive'}, {'sentiment': 'negative'}] + ... }) + >>> strat = PromptStrategy(cot=True, dems=examples, max_dems=2) + >>> result = sem_extract(docs, model, output_cols, prompt_strategy=strat) """ # prepare model inputs inputs = [] for doc in docs: - prompt = task_instructions.extract_formatter(model, doc, output_cols, extract_quotes, strategy) + prompt = task_instructions.extract_formatter(model, doc, output_cols, extract_quotes, prompt_strategy) lotus.logger.debug(f"input to model: {prompt}") lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") inputs.append(prompt) @@ -86,14 +101,13 @@ def sem_extract( # call model # Don't use JSON response format when CoT reasoning is enabled, as it prevents reasoning text - if strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT]: + if prompt_strategy.cot: lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc) else: lm_output = model(inputs, response_format={"type": "json_object"}, progress_bar_desc=progress_bar_desc) # post process results - # Check if CoT reasoning is being used - cot_reasoning = strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] + cot_reasoning = prompt_strategy.cot postprocess_output = postprocessor(lm_output.outputs, model, cot_reasoning) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") @@ -141,8 +155,9 @@ class SemExtractDataFrame: return_explanations (bool, optional): Whether to include explanations in the output DataFrame. Useful for debugging and understanding model reasoning. Defaults to False. - strategy (ReasoningStrategy | None, optional): The reasoning strategy - to use. Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to PromptStrategy(). + Returns: pd.DataFrame: A DataFrame containing the original data plus the @@ -168,10 +183,14 @@ class SemExtractDataFrame: ... ['text'], ... {'sentiment': 'positive/negative/neutral', 'emotion': 'joy/anger/sadness'} ... ) - Extracting: 100%|█████████████████████████████████████████████████████████████████ 2/2 LM calls [00:00<00:00, 2.20it/s] text rating sentiment emotion 0 Great product! 5 positive joy 1 Terrible service 1 negative anger + + >>> # Using PromptStrategy with chain-of-thought + >>> from lotus.types import PromptStrategy + >>> strat = PromptStrategy(cot=True) + >>> df.sem_extract(['text'], {'sentiment': 'positive/negative/neutral'}, prompt_strategy=strat) """ def __init__(self, pandas_obj: pd.DataFrame): @@ -211,7 +230,7 @@ def __call__( safe_mode: bool = False, progress_bar_desc: str = "Extracting", return_explanations: bool = False, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> pd.DataFrame: if lotus.settings.lm is None: raise ValueError( @@ -234,7 +253,7 @@ def __call__( safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, return_explanations=return_explanations, - strategy=strategy, + prompt_strategy=prompt_strategy, ) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 3a26506d..e97401bb 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -11,13 +11,18 @@ CascadeArgs, LMOutput, LogprobsForFilterCascade, + PromptStrategy, ProxyModel, - ReasoningStrategy, SemanticFilterOutput, ) from lotus.utils import show_safe_mode -from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds +from .cascade_utils import ( + bootstrap_demonstrations, + calibrate_llm_logprobs, + importance_sampling, + learn_cascade_thresholds, +) from .postprocessors import filter_postprocess @@ -29,12 +34,11 @@ def sem_filter( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), logprobs: bool = False, safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", - additional_cot_instructions: str = "", ) -> SemanticFilterOutput: """ Filters a list of documents based on a natural language instruction using a language model. @@ -61,8 +65,8 @@ def sem_filter( cot_reasoning (list[str] | None, optional): Chain-of-thought reasoning for the example documents. Used when strategy includes COT reasoning. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to None. logprobs (bool, optional): Whether to return log probabilities for the model outputs. Useful for confidence estimation. Defaults to False. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. @@ -71,8 +75,6 @@ def sem_filter( processing. Defaults to True. progress_bar_desc (str, optional): Description for the progress bar. Defaults to "Filtering". - additional_cot_instructions (str, optional): Additional instructions for - chain-of-thought reasoning. Defaults to "". Returns: SemanticFilterOutput: An object containing the boolean filter outputs, raw @@ -97,8 +99,8 @@ def sem_filter( examples_multimodal_data, examples_answers, cot_reasoning, - strategy, - reasoning_instructions=additional_cot_instructions, + prompt_strategy, + reasoning_instructions=prompt_strategy.additional_cot_instructions, ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) @@ -113,7 +115,7 @@ def sem_filter( inputs, show_progress_bar=show_progress_bar, progress_bar_desc=progress_bar_desc, **kwargs ) - postprocess_output = filter_postprocess(lm_output.outputs, model, default) + postprocess_output = filter_postprocess(lm_output.outputs, model, default, cot_reasoning=prompt_strategy.cot) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") @@ -140,8 +142,7 @@ def learn_filter_cascade_thresholds( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, - additional_cot_instructions: str = "", + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> tuple[float, float]: """ Automatically learns optimal cascade thresholds for filter operations. @@ -170,10 +171,8 @@ def learn_filter_cascade_thresholds( for the example documents. Defaults to None. cot_reasoning (list[str] | None, optional): Chain-of-thought reasoning for the example documents. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. Defaults to None. - additional_cot_instructions (str, optional): Additional instructions for - chain-of-thought reasoning. Defaults to "". Returns: tuple[float, float]: A tuple containing the learned low and high thresholds @@ -201,10 +200,9 @@ def learn_filter_cascade_thresholds( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=False, progress_bar_desc="Running oracle for threshold learning", - additional_cot_instructions=additional_cot_instructions, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -252,8 +250,8 @@ class SemFilterDataframe: input DataFrame plus an "Answer" column. Defaults to None. helper_examples (pd.DataFrame | None, optional): Additional helper examples for cascade filtering. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy - to use. Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to None. cascade_args (CascadeArgs | None, optional): Configuration for cascade filtering. Includes parameters like recall_target, precision_target, sampling_percentage, and failure_probability. Defaults to None. @@ -263,8 +261,6 @@ class SemFilterDataframe: estimation. Defaults to False. progress_bar_desc (str, optional): Description for the progress bar. Defaults to "Filtering". - additional_cot_instructions (str, optional): Additional instructions - for chain-of-thought reasoning. Defaults to "". Returns: pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: A DataFrame @@ -294,8 +290,8 @@ class SemFilterDataframe: 0 Great product! 5 # Example 2: with zero-shot chain-of-thought (ZS-COT) reasoning - >>> from lotus.types import ReasoningStrategy - >>> df.sem_filter("The review {text} and {rating} reflect's a positive sentiment ", strategy=ReasoningStrategy.ZS_COT, return_explanations=True, return_all=True) + >>> from lotus.types import PromptStrategy + >>> df.sem_filter("The review {text} and {rating} reflect's a positive sentiment ", prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_all=True) Filtering: 100%|██████████████████████████████████████████████████████████████████ 4/4 LM calls [00:01<00:00, 3.66it/s] Text filter_label explanation_filter 0 I had two apples, then I gave away one True @@ -341,12 +337,11 @@ def __call__( suffix: str = "_filter", examples: pd.DataFrame | None = None, helper_examples: pd.DataFrame | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), cascade_args: CascadeArgs | None = None, return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", - additional_cot_instructions: str = "", ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: if lotus.settings.lm is None: raise ValueError( @@ -357,7 +352,8 @@ def __call__( lotus.logger.debug(user_instruction) col_li = lotus.nl_expression.parse_cols(user_instruction) lotus.logger.debug(col_li) - helper_strategy = strategy + + helper_strategy = prompt_strategy # check that column exists for column in col_li: @@ -371,13 +367,44 @@ def __call__( examples_multimodal_data = None examples_answers = None cot_reasoning = None - if examples is not None: - assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" - examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) - examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT and "Reasoning" in examples.columns: - cot_reasoning = examples["Reasoning"].tolist() + # Handle examples from PromptStrategy.dems first, then fall back to examples parameter for backward compatibility + if prompt_strategy.dems is not None: + if isinstance(prompt_strategy.dems, pd.DataFrame): + # User-provided examples + examples_source = prompt_strategy.dems + assert "Answer" in examples_source.columns, "Answer must be a column in examples dataframe" + examples_multimodal_data = task_instructions.df2multimodal_info(examples_source, col_li) + examples_answers = examples_source["Answer"].tolist() + + if prompt_strategy.cot and "Reasoning" in examples_source.columns: + cot_reasoning = examples_source["Reasoning"].tolist() + + elif prompt_strategy.dems == "auto": + # Auto-bootstrap demonstrations + try: + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=self._obj, + col_li=col_li, + user_instruction=formatted_usr_instr, + prompt_strategy=prompt_strategy, + operation_type="filter", + ) + except Exception as e: + lotus.logger.warning(f"Failed to bootstrap demonstrations: {e}") + # Fall back to no examples + examples_multimodal_data = None + examples_answers = None + cot_reasoning = None + elif examples is not None: + # Backward compatibility: use the old examples parameter + examples_source = examples + assert "Answer" in examples_source.columns, "Answer must be a column in examples dataframe" + examples_multimodal_data = task_instructions.df2multimodal_info(examples_source, col_li) + examples_answers = examples_source["Answer"].tolist() + + if prompt_strategy.cot and "Reasoning" in examples_source.columns: + cot_reasoning = examples_source["Reasoning"].tolist() pos_cascade_threshold, neg_cascade_threshold = None, None if cascade_args is not None: @@ -389,7 +416,7 @@ def __call__( assert "Answer" in helper_examples.columns, "Answer must be a column in examples dataframe" helper_examples_multimodal_data = task_instructions.df2multimodal_info(helper_examples, col_li) helper_examples_answers = helper_examples["Answer"].tolist() - if helper_strategy == ReasoningStrategy.COT and "Reasoning" in helper_examples.columns: + if helper_strategy is not None and helper_strategy.cot and "Reasoning" in helper_examples.columns: helper_cot_reasoning = helper_examples["Reasoning"].tolist() if cascade_args: @@ -408,7 +435,7 @@ def __call__( if not lotus.settings.helper_lm: raise ValueError("Helper LM must be set in settings") - if helper_strategy == ReasoningStrategy.COT: + if helper_strategy is not None and helper_strategy.cot: raise ValueError("CoT not supported for helper models in cascades.") # Run small LM and get logits @@ -421,7 +448,7 @@ def __call__( examples_answers=helper_examples_answers, cot_reasoning=helper_cot_reasoning, logprobs=True, - strategy=helper_strategy, + prompt_strategy=helper_strategy, safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", @@ -457,8 +484,7 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, - additional_cot_instructions=additional_cot_instructions, + prompt_strategy=prompt_strategy, ) stats["pos_cascade_threshold"] = pos_cascade_threshold @@ -515,10 +541,9 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", - additional_cot_instructions=additional_cot_instructions, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -538,11 +563,10 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, - additional_cot_instructions=additional_cot_instructions, ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 14a4881a..1cc5421d 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -6,7 +6,7 @@ import lotus from lotus.cache import operator_cache from lotus.templates import task_instructions -from lotus.types import CascadeArgs, ReasoningStrategy, SemanticJoinOutput +from lotus.types import CascadeArgs, PromptStrategy, SemanticJoinOutput from lotus.utils import show_safe_mode from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds @@ -26,7 +26,7 @@ def sem_join( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Join comparisons", @@ -67,7 +67,7 @@ def sem_join( Defaults to None. default (bool, optional): The default value to use when the model output cannot be parsed as a boolean. Defaults to True. - strategy (ReasoningStrategy | None, optional): The reasoning strategy + prompt_strategy (PromptStrategy | None, optional): The reasoning strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -111,7 +111,7 @@ def sem_join( examples_multimodal_data, examples_answers, cot_reasoning, - strategy, + prompt_strategy, ) ) estimated_total_calls = len(l1) * len(l2) @@ -142,7 +142,7 @@ def sem_join( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, show_progress_bar=False, ) @@ -193,7 +193,7 @@ def sem_join_cascade( map_examples: pd.DataFrame | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), safe_mode: bool = False, ) -> SemanticJoinOutput: """ @@ -253,7 +253,7 @@ def sem_join_cascade( map_examples=map_examples, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, ) num_helper = len(helper_high_conf) @@ -296,7 +296,7 @@ def sem_join_cascade( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, show_progress_bar=True, ) @@ -428,7 +428,7 @@ def join_optimizer( map_examples: pd.DataFrame | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> tuple[pd.DataFrame, pd.DataFrame, int, int]: """ Find most cost-effective join plan between Search-Filter and Map-Search-Filter @@ -472,7 +472,7 @@ def join_optimizer( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, ) sf_high_conf = sf_helper_join[sf_helper_join["_scores"] >= sf_t_pos] sf_high_conf_neg = len(sf_helper_join[sf_helper_join["_scores"] <= sf_t_neg]) @@ -495,7 +495,7 @@ def join_optimizer( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, ) msf_high_conf = msf_helper_join[msf_helper_join["_scores"] >= msf_t_pos] msf_high_conf_neg = len(msf_helper_join[msf_helper_join["_scores"] <= msf_t_neg]) @@ -538,7 +538,7 @@ def learn_join_cascade_threshold( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, default: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> tuple[float, float, int]: """ Extract a small sample of the data and find the optimal threshold pair that satisfies the recall and @@ -582,7 +582,7 @@ def learn_join_cascade_threshold( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, progress_bar_desc="Running oracle for threshold learning", ) @@ -675,7 +675,7 @@ def __call__( how: str = "inner", suffix: str = "_join", examples: pd.DataFrame | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), default: bool = True, cascade_args: CascadeArgs | None = None, return_stats: bool = False, @@ -737,7 +737,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, [real_left_on, real_right_on]) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT: + if prompt_strategy.cot: return_explanations = True cot_reasoning = examples["Reasoning"].tolist() @@ -768,7 +768,7 @@ def __call__( map_examples=cascade_args.map_examples, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, ) else: @@ -785,7 +785,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, ) diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 2c8b7074..e7a73e38 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -5,9 +5,15 @@ import lotus from lotus.cache import operator_cache from lotus.templates import task_instructions -from lotus.types import LMOutput, ReasoningStrategy, SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import ( + LMOutput, + PromptStrategy, + SemanticMapOutput, + SemanticMapPostprocessOutput, +) from lotus.utils import show_safe_mode +from .cascade_utils import bootstrap_demonstrations from .postprocessors import map_postprocess @@ -20,7 +26,7 @@ def sem_map( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[str] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), safe_mode: bool = False, progress_bar_desc: str = "Mapping", **model_kwargs: Any, @@ -53,8 +59,8 @@ def sem_map( cot_reasoning (list[str] | None, optional): Chain-of-thought reasoning for the example documents. Used when strategy includes COT reasoning. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The prompt strategy to use. + Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to PromptStrategy(). safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. progress_bar_desc (str, optional): Description for the progress bar. @@ -85,7 +91,7 @@ def sem_map( examples_multimodal_data, examples_answers, cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, system_prompt=system_prompt, ) lotus.logger.debug(f"input to model: {prompt}") @@ -102,9 +108,7 @@ def sem_map( lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc, **model_kwargs) # post process results - postprocess_output = postprocessor( - lm_output.outputs, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] - ) + postprocess_output = postprocessor(lm_output.outputs, model, prompt_strategy.cot) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") @@ -145,7 +149,7 @@ class SemMapDataframe: examples (pd.DataFrame | None, optional): Example DataFrame for few-shot learning. Should have the same column structure as the input DataFrame plus an "Answer" column. Defaults to None. - strategy (ReasoningStrategy | None, optional): The reasoning strategy + prompt_strategy (PromptStrategy | None, optional): The prompt strategy to use. Can be None, COT, or ZS_COT. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -179,8 +183,8 @@ class SemMapDataframe: 1 Harry is feeling nauseous Negative # Example 2: with zero-shot chain-of-thought (ZS-COT) reasoning - >>> from lotus.types import ReasoningStrategy - >>> df.sem_map("Label the sentiment of Harry in the {document} as positive/negative/neutral. Answer in one word.", return_explanations=True, strategy=ReasoningStrategy.ZS_COT) + >>> from lotus.types import PromptStrategy + >>> df.sem_map("Label the sentiment of Harry in the {document} as positive/negative/neutral. Answer in one word.", return_explanations=True, prompt_strategy=PromptStrategy(cot=True)) Mapping: 100%|████████████████████████████████████████████████████████████████████ 2/2 LM calls [00:02<00:00, 1.04s/it] document _map explanation_map 0 Harry is happy and love cats positive Reasoning: The document states that "Harry is ... @@ -221,7 +225,7 @@ def __call__( return_raw_outputs: bool = False, suffix: str = "_map", examples: pd.DataFrame | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), safe_mode: bool = False, progress_bar_desc: str = "Mapping", **model_kwargs: Any, @@ -241,18 +245,57 @@ def __call__( multimodal_data = task_instructions.df2multimodal_info(self._obj, col_li) formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) + # Handle examples and demonstrations examples_multimodal_data = None examples_answers = None cot_reasoning = None - if examples is not None: + # Handle examples from PromptStrategy.dems first, then fall back to examples parameter for backward compatibility + if prompt_strategy.dems is not None: + if isinstance(prompt_strategy.dems, pd.DataFrame): + # User-provided examples + examples_source = prompt_strategy.dems + assert "Answer" in examples_source.columns, "Answer must be a column in examples dataframe" + examples_multimodal_data = task_instructions.df2multimodal_info(examples_source, col_li) + examples_answers = examples_source["Answer"].tolist() + + if prompt_strategy.cot: + return_explanations = True + if "Reasoning" in examples_source.columns: + cot_reasoning = examples_source["Reasoning"].tolist() + else: + cot_reasoning = ["Reasoning omitted"] * len(examples_answers) + + elif prompt_strategy.dems == "auto": + # Auto-bootstrap demonstrations + try: + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=self._obj, + col_li=col_li, + user_instruction=formatted_usr_instr, + prompt_strategy=prompt_strategy, + operation_type="map", + ) + if prompt_strategy.cot and examples_answers: + return_explanations = True + except Exception as e: + lotus.logger.warning(f"Failed to bootstrap demonstrations: {e}") + # Fall back to no examples + examples_multimodal_data = None + examples_answers = None + cot_reasoning = None + elif examples is not None: + # Backward compatibility: use the old examples parameter assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT or strategy == ReasoningStrategy.ZS_COT: + if prompt_strategy.cot: return_explanations = True - cot_reasoning = examples["Reasoning"].tolist() + if "Reasoning" in examples.columns: + cot_reasoning = examples["Reasoning"].tolist() + else: + cot_reasoning = ["Reasoning omitted"] * len(examples_answers) output = sem_map( multimodal_data, @@ -263,7 +306,7 @@ def __call__( examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, **model_kwargs, diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 3da9291e..58de67bc 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -9,7 +9,7 @@ import lotus from lotus.cache import operator_cache from lotus.templates import task_instructions -from lotus.types import LMOutput, ReasoningStrategy, SemanticTopKOutput +from lotus.types import LMOutput, PromptStrategy, SemanticTopKOutput from lotus.utils import show_safe_mode @@ -18,7 +18,7 @@ def get_match_prompt_binary( doc2: dict[str, Any], user_instruction: str, model: lotus.models.LM, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> list[dict[str, Any]]: """ Generate a binary comparison prompt for two documents. @@ -35,8 +35,8 @@ def get_match_prompt_binary( user_instruction (str): The natural language instruction that defines the comparison criteria. model (lotus.models.LM): The language model instance to use for comparison. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The reasoning strategy to use. + Can be None, CoT, or Demonstrations. Defaults to PromptStrategy(). Returns: list[dict[str, Any]]: A list of message dictionaries formatted for the @@ -48,7 +48,7 @@ def get_match_prompt_binary( >>> model = LM(model="gpt-4o") >>> prompt = get_match_prompt_binary(doc1, doc2, "Which is more relevant to AI?", model) """ - if strategy == ReasoningStrategy.ZS_COT: + if prompt_strategy.cot: sys_prompt = ( "Your job is to to select and return the most relevant document to the user's question.\n" "Carefully read the user's question and the two documents provided below.\n" @@ -69,7 +69,7 @@ def get_match_prompt_binary( content_text, content_image_inputs = task_instructions.context_formatter(doc) prompt += [{"type": "text", "text": f"\nDocument {idx+1}:\n{content_text}"}, *content_image_inputs] - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek(): deepseek_instructions = """Please think through your reasoning step by step, then provide your final answer. You must put your reasoning insdie the tags, then provide your final answer after the tag with the format: Answer: your answer.""" @@ -133,7 +133,7 @@ def compare_batch_binary( pairs: list[tuple[dict[str, Any], dict[str, Any]]], model: lotus.models.LM, user_instruction: str, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> tuple[list[bool], list[str], int]: """ Compare multiple pairs of documents using binary classification. @@ -147,8 +147,8 @@ def compare_batch_binary( model (lotus.models.LM): The language model instance to use for comparison. user_instruction (str): The natural language instruction that defines the comparison criteria. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The reasoning strategy to use. + Can be None, COT, or ZS_COT. Defaults to PromptStrategy(). Returns: tuple[list[bool], list[str], int]: A tuple containing: @@ -164,7 +164,9 @@ def compare_batch_binary( match_prompts = [] tokens = 0 for doc1, doc2 in pairs: - match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy, model=model)) + match_prompts.append( + get_match_prompt_binary(doc1, doc2, user_instruction, prompt_strategy=prompt_strategy, model=model) + ) tokens += model.count_tokens(match_prompts[-1]) lm_results: LMOutput = model(match_prompts, show_progress_bar=False) result_explanations = list(map(parse_ans_binary, lm_results.outputs)) @@ -178,7 +180,7 @@ def compare_batch_binary_cascade( model: lotus.models.LM, user_instruction: str, cascade_threshold: float, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> tuple[list[bool], list[str], int, int, int]: """ Compare multiple pairs of documents using a cascade approach. @@ -195,8 +197,8 @@ def compare_batch_binary_cascade( the comparison criteria. cascade_threshold (float): Confidence threshold for using the large model. Cases below this threshold will use the helper model. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The reasoning strategy to use. + Can be None, COT, or ZS_COT. Defaults to PromptStrategy(). Returns: tuple[list[bool], list[str], int, int, int]: A tuple containing: @@ -219,7 +221,9 @@ def compare_batch_binary_cascade( match_prompts = [] small_tokens = 0 for doc1, doc2 in pairs: - match_prompts.append(get_match_prompt_binary(doc1, doc2, user_instruction, strategy=strategy, model=model)) + match_prompts.append( + get_match_prompt_binary(doc1, doc2, user_instruction, prompt_strategy=prompt_strategy, model=model) + ) small_tokens += model.count_tokens(match_prompts[-1]) helper_lm = lotus.settings.helper_lm @@ -277,7 +281,7 @@ def llm_naive_sort( docs: list[dict[str, Any]], model: lotus.models.LM, user_instruction: str, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), safe_mode: bool = False, ) -> SemanticTopKOutput: """ @@ -293,8 +297,8 @@ def llm_naive_sort( model (lotus.models.LM): The language model instance to use for comparisons. user_instruction (str): The natural language instruction that defines the sorting criteria. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The reasoning strategy to use. + Can be None, COT, or ZS_COT. Defaults to PromptStrategy(). safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -319,7 +323,9 @@ def llm_naive_sort( desc="All-pairs comparisons", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} LM calls [{elapsed}<{remaining}]", ) - comparisons, explanations, tokens = compare_batch_binary(pairs, model, user_instruction, strategy=strategy) + comparisons, explanations, tokens = compare_batch_binary( + pairs, model, user_instruction, prompt_strategy=prompt_strategy + ) pbar.update(len(pairs)) pbar.close() if safe_mode: @@ -350,7 +356,7 @@ def llm_quicksort( user_instruction: str, K: int, embedding: bool = False, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), cascade_threshold: float | None = None, safe_mode: bool = False, ) -> SemanticTopKOutput: @@ -370,8 +376,8 @@ def llm_quicksort( K (int): The number of top documents to return. embedding (bool, optional): Whether to use embedding optimization for pivot selection. Defaults to False. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The reasoning strategy to use. + Can be None, COT, or ZS_COT. Defaults to PromptStrategy(). cascade_threshold (float | None, optional): Confidence threshold for cascade filtering. If provided, uses a two-stage model approach. Defaults to None. safe_mode (bool, optional): Whether to enable safe mode with cost estimation. @@ -391,7 +397,9 @@ def llm_quicksort( stats["total_llm_calls"] = 0 stats["explanations"] = {} if safe_mode: - sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy, model=model) + sample_prompt = get_match_prompt_binary( + docs[0], docs[1], user_instruction, prompt_strategy=prompt_strategy, model=model + ) estimated_quickselect_calls = 2 * K estimated_quicksort_calls = 2 * len(docs) * np.log(len(docs)) estimated_total_calls = estimated_quickselect_calls + estimated_quicksort_calls @@ -425,7 +433,9 @@ def partition(indexes: list[int], low: int, high: int, K: int) -> int: pairs = [(docs[indexes[j]], pivot) for j in range(low, high)] if cascade_threshold is None: - comparisons, explanations, tokens = compare_batch_binary(pairs, model, user_instruction, strategy=strategy) + comparisons, explanations, tokens = compare_batch_binary( + pairs, model, user_instruction, prompt_strategy=prompt_strategy + ) stats["total_tokens"] += tokens stats["total_llm_calls"] += len(pairs) @@ -440,7 +450,7 @@ def partition(indexes: list[int], low: int, high: int, K: int) -> int: model, user_instruction, cascade_threshold, - strategy=strategy, + prompt_strategy=prompt_strategy, ) stats["total_small_tokens"] += small_tokens @@ -499,14 +509,14 @@ class HeapDoc: Attributes: num_calls (int): Class variable tracking total number of LM calls. total_tokens (int): Class variable tracking total tokens used. - strategy (ReasoningStrategy | None): Class variable for reasoning strategy. + prompt_strategy (PromptStrategy): Class variable for reasoning strategy. model (lotus.models.LM | None): Class variable for the language model. explanations (dict[int, list[str]]): Class variable storing explanations. """ num_calls: int = 0 total_tokens: int = 0 - strategy: ReasoningStrategy | None = None + prompt_strategy: PromptStrategy = PromptStrategy() model: lotus.models.LM | None = None explanations: dict[int, list[str]] = {} @@ -541,7 +551,7 @@ def __lt__(self, other: "HeapDoc") -> bool: """ assert HeapDoc.model is not None prompt = get_match_prompt_binary( - self.doc, other.doc, self.user_instruction, strategy=self.strategy, model=HeapDoc.model + self.doc, other.doc, self.user_instruction, prompt_strategy=HeapDoc.prompt_strategy, model=HeapDoc.model ) HeapDoc.num_calls += 1 HeapDoc.total_tokens += HeapDoc.model.count_tokens(prompt) @@ -562,7 +572,7 @@ def llm_heapsort( model: lotus.models.LM, user_instruction: str, K: int, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), safe_mode: bool = False, ) -> SemanticTopKOutput: """ @@ -579,8 +589,8 @@ def llm_heapsort( user_instruction (str): The natural language instruction that defines the sorting criteria. K (int): The number of top documents to return. - strategy (ReasoningStrategy | None, optional): The reasoning strategy to use. - Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The reasoning strategy to use. + Can be None, COT, or ZS_COT. Defaults to PromptStrategy(). safe_mode (bool, optional): Whether to enable safe mode with cost estimation. Defaults to False. @@ -595,7 +605,9 @@ def llm_heapsort( """ if safe_mode: - sample_prompt = get_match_prompt_binary(docs[0], docs[1], user_instruction, strategy=strategy, model=model) + sample_prompt = get_match_prompt_binary( + docs[0], docs[1], user_instruction, prompt_strategy=prompt_strategy, model=model + ) estimated_heap_construction_calls = len(docs) * np.log(len(docs)) estimated_top_k_extraction_calls = K * np.log(len(docs)) estimated_total_calls = estimated_heap_construction_calls + estimated_top_k_extraction_calls @@ -604,7 +616,7 @@ def llm_heapsort( HeapDoc.num_calls = 0 HeapDoc.total_tokens = 0 - HeapDoc.strategy = strategy + HeapDoc.prompt_strategy = prompt_strategy HeapDoc.model = model HeapDoc.explanations = {} N = len(docs) @@ -640,8 +652,8 @@ class SemTopKDataframe: - "naive": Naive quadratic approach - "quick-sem": Quicksort with semantic embedding optimization. Requires the passed column to be indexed with sem_index. Defaults to "quick". - strategy (ReasoningStrategy | None, optional): The reasoning strategy - to use. Can be None, COT, or ZS_COT. Defaults to None. + prompt_strategy (PromptStrategy, optional): The prompt strategy + to use. Configures chain-of-thought, demonstrations, and bootstrapping. Defaults to PromptStrategy(). group_by (list[str] | None, optional): Column names to group by before sorting. Each group will be sorted separately. Defaults to None. cascade_threshold (float | None, optional): Confidence threshold for @@ -714,18 +726,18 @@ def process_group(args): Args: args (tuple): A tuple containing (group, user_instruction, K, method, - strategy, group_by, cascade_threshold, return_stats). + prompt_strategy, group_by, cascade_threshold, return_stats). Returns: pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: The top-K results for the group, optionally with statistics. """ - group, user_instruction, K, method, strategy, group_by, cascade_threshold, return_stats = args + group, user_instruction, K, method, prompt_strategy, group_by, cascade_threshold, return_stats = args return group.sem_topk( user_instruction, K, method=method, - strategy=strategy, + prompt_strategy=prompt_strategy, group_by=None, cascade_threshold=cascade_threshold, return_stats=return_stats, @@ -737,7 +749,7 @@ def __call__( user_instruction: str, K: int, method: str = "quick", - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), group_by: list[str] | None = None, cascade_threshold: float | None = None, return_stats: bool = False, @@ -763,7 +775,7 @@ def __call__( if group_by: grouped = self._obj.groupby(group_by) group_args = [ - (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) + (group, user_instruction, K, method, prompt_strategy, None, cascade_threshold, return_stats) for _, group in grouped ] @@ -798,7 +810,7 @@ def __call__( formatted_usr_instr, K, embedding=method == "quick-sem", - strategy=strategy, + prompt_strategy=prompt_strategy, cascade_threshold=cascade_threshold, safe_mode=safe_mode, ) @@ -808,7 +820,7 @@ def __call__( model, formatted_usr_instr, K, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, ) elif method == "naive": @@ -816,7 +828,7 @@ def __call__( multimodal_data, model, formatted_usr_instr, - strategy=strategy, + prompt_strategy=prompt_strategy, safe_mode=safe_mode, ) else: @@ -826,7 +838,7 @@ def __call__( new_df = new_df.reindex(output.indexes).reset_index(drop=True) new_df = new_df.head(K) - if return_explanations and strategy == ReasoningStrategy.ZS_COT: + if return_explanations and prompt_strategy.cot: explanations = [] for idx in output.indexes[:K]: explanation = "No Comparison Made" diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 6ab05139..f998f6b1 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -5,7 +5,10 @@ import lotus from lotus.dtype_extensions import ImageDtype -from lotus.types import ReasoningStrategy, SerializationFormat +from lotus.types import ( + PromptStrategy, + SerializationFormat, +) def cot_formatter(reasoning, answer): @@ -91,7 +94,7 @@ def filter_formatter( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[bool] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy | None = None, reasoning_instructions: str = "", ) -> list[dict[str, str]]: answer_instructions = "The answer should be either True or False" @@ -100,11 +103,8 @@ def filter_formatter( Your job is to determine whether the claim is true for the given context. """ - if strategy == ReasoningStrategy.COT: - sys_instruction += cot_prompt_formatter( - reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions - ) - elif strategy == ReasoningStrategy.ZS_COT: + # Simple strategy checking + if prompt_strategy is not None and prompt_strategy.cot: sys_instruction += cot_prompt_formatter( reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions ) @@ -135,7 +135,7 @@ def filter_formatter( # reasoning as filler if the user wants cot reasoning if cot_reasoning: content = cot_formatter(cot_reasoning[idx], str(ex_ans)) - elif strategy == ReasoningStrategy.COT: + elif prompt_strategy is not None and prompt_strategy.cot: content = cot_formatter("Reasoning omitted", str(ex_ans)) else: content = answer_only_formatter(str(ex_ans)) @@ -149,7 +149,8 @@ def filter_formatter( }, ] ) - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + # Handle DeepSeek CoT formatting (backward compatibility) + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek() and not examples_multimodal_data: user_instruction = f"Claim: {user_instruction}\n\n{deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_instruction)) else: @@ -217,7 +218,7 @@ def map_formatter( examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[str] | None = None, cot_reasoning: list[str] | None = None, - strategy: ReasoningStrategy | str | None = None, + prompt_strategy: PromptStrategy | None = None, system_prompt: str | None = None, ) -> list[dict[str, str]]: sys_instruction = system_prompt or ( @@ -229,7 +230,7 @@ def map_formatter( return map_formatter_cot( multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning, system_prompt ) - elif strategy == ReasoningStrategy.ZS_COT: + elif prompt_strategy is not None and prompt_strategy.cot: return map_formatter_zs_cot(multimodal_data, user_instruction, system_prompt) messages = [ @@ -246,7 +247,8 @@ def map_formatter( ] ) - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + # Handle DeepSeek CoT formatting (backward compatibility) + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek() and not examples_multimodal_data: user_intructions = f"Instruction: {user_instruction}\n\n{deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_intructions)) else: @@ -259,7 +261,7 @@ def extract_formatter( multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True, - strategy: ReasoningStrategy | None = None, + prompt_strategy: PromptStrategy = PromptStrategy(), ) -> list[dict[str, str]]: output_col_names = list(output_cols.keys()) # Set the description to be the key if no value is provided @@ -274,14 +276,10 @@ def extract_formatter( fields_str = ", ".join(all_fields) # Add CoT reasoning instructions to the system prompt if needed - if strategy == ReasoningStrategy.COT: - reasoning_instructions = "Think through each extraction step by step." - answer_instructions = f"Provide the JSON response with fields: {fields_str}" - cot_instruction = cot_prompt_formatter( - reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions - ) - elif strategy == ReasoningStrategy.ZS_COT: + if prompt_strategy.cot: reasoning_instructions = "Think through each extraction step by step." + if prompt_strategy.additional_cot_instructions: + reasoning_instructions += f" {prompt_strategy.additional_cot_instructions}" answer_instructions = f"Provide the JSON response with fields: {fields_str}" cot_instruction = cot_prompt_formatter( reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions @@ -309,12 +307,16 @@ def extract_formatter( if cot_instruction: sys_instruction += "\n" + cot_instruction + # Add CoT instructions for CoT strategy + if prompt_strategy is not None and prompt_strategy.cot: + sys_instruction += "\n\nFor your response, first provide your reasoning, then give your final answer in the specified JSON format." + messages = [ {"role": "system", "content": sys_instruction}, user_message_formatter(multimodal_data), ] - if strategy == ReasoningStrategy.ZS_COT and model.is_deepseek(): + if prompt_strategy is not None and prompt_strategy.cot and model.is_deepseek(): user_intructions = f"Instruction: {deepseek_cot_formatter()}" messages.append(user_message_formatter(multimodal_data, user_intructions)) diff --git a/lotus/types.py b/lotus/types.py index 08519729..14ba1347 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from enum import Enum, auto +from enum import Enum from typing import Any import pandas as pd @@ -213,10 +213,45 @@ class LotusUsageLimitException(LotusException): ################################################################################ -# Reasoning Strategy +# Prompt Strategy ################################################################################ -class ReasoningStrategy(Enum): - DEFAULT = auto() - COT = auto() - ZS_COT = auto() - FEW_SHOT = auto() +@dataclass +class PromptStrategy: + """ + Configurable prompt strategy for semantic operations. + + This class encapsulates various prompting techniques including chain-of-thought + reasoning, demonstrations, and bootstrapping configurations. + + Args: + cot (bool): Whether to use chain-of-thought reasoning. Defaults to False. + dems (pd.DataFrame | str | None): Demonstrations to use. Can be: + - pd.DataFrame: User-provided examples with Answer column + - "auto": Automatically bootstrap demonstrations + - None: No demonstrations + max_dems (int): Maximum number of demonstrations to use. Defaults to 3. + teacher_lm: Language model to use for bootstrapping demonstrations. + If None, uses the main model. Defaults to None. + additional_cot_instructions (str): Additional instructions for + chain-of-thought reasoning. Defaults to "". + + Example: + >>> # Chain-of-thought with user-provided demonstrations + >>> strat = PromptStrategy(cot=True, dems=examples_df) + >>> df.sem_filter(user_instruction, prompt_strategy=strat) + + >>> # Auto-bootstrap demonstrations with CoT + >>> strat = PromptStrategy( + ... cot=True, + ... dems="auto", + ... max_dems=2, + ... teacher_lm=lotus.models.LM(model="gpt-4o-mini") + ... ) + >>> df.sem_filter(user_instruction, prompt_strategy=strat) + """ + + cot: bool = False + dems: pd.DataFrame | str | None = None + max_dems: int = 3 + teacher_lm: Any = None # lotus.models.LM type, but avoiding circular import + additional_cot_instructions: str = "" diff --git a/tests/deepseek_cot_tests.py b/tests/deepseek_cot_tests.py index 7697c7db..96391f9c 100644 --- a/tests/deepseek_cot_tests.py +++ b/tests/deepseek_cot_tests.py @@ -5,7 +5,7 @@ import lotus from lotus.models import LM -from lotus.types import ReasoningStrategy +from lotus.types import PromptStrategy lotus.logger.setLevel("DEBUG") @@ -15,167 +15,227 @@ @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_filter_cot_basic(): - """Test sem_filter using DeepSeek CoT on a simple filtering task.""" +def test_deepseek_demonstrations_only(): + """Test DeepSeek with demonstrations without CoT reasoning.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = { - "Text": ["I had two apples and still have one left", "I gave away all my apples", "I received an apple today"] - } + data = {"Course": ["Linear Algebra", "Creative Writing", "Calculus", "Art History"]} + df = pd.DataFrame(data) + user_instruction = "{Course} requires mathematical skills" + + # Provide examples without reasoning + examples = pd.DataFrame({"Course": ["Statistics", "Poetry", "Physics"], "Answer": [True, False, True]}) + + result = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(dems=examples), return_all=True) + + assert "filter_label" in result.columns + # Should identify math courses correctly based on examples + math_courses = result[result["filter_label"]]["Course"].tolist() + assert any(course in ["Linear Algebra", "Calculus"] for course in math_courses) + + +@pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") +def test_deepseek_cot_demonstrations_combined(): + """Test DeepSeek with combined CoT and demonstrations.""" + lm = LM(model=MODEL_NAME) + lotus.settings.configure(lm=lm) + data = {"Product": ["Smartphone", "Book", "Laptop", "Pen"]} df = pd.DataFrame(data) - user_instruction = "{Text} implies I have at least one apple" + user_instruction = "{Product} is an electronic device" - filtered_df = df.sem_filter(user_instruction, return_explanations=True, return_all=True) + # Provide examples with reasoning + examples = pd.DataFrame( + { + "Product": ["Tablet", "Magazine", "Smart Watch"], + "Answer": [True, False, True], + "Reasoning": [ + "Tablets are electronic devices with screens and processors", + "Magazines are printed materials, not electronic", + "Smart watches are wearable electronic devices with digital displays", + ], + } + ) + + result = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems=examples), + examples=examples, + return_explanations=True, + return_all=True, + ) - # Check that extra columns are present. - assert "explanation_filter" in filtered_df.columns - assert "filter_label" in filtered_df.columns + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns - # At least one row should be labeled True. - positive_rows = filtered_df[filtered_df["filter_label"]] - assert len(positive_rows) > 0 + # Should identify electronic devices correctly + electronic_devices = result[result["filter_label"]]["Product"].tolist() + assert any(device in ["Smartphone", "Laptop"] for device in electronic_devices) - # Each explanation should be nonempty for positive rows. - for exp in positive_rows["explanation_filter"]: - assert exp is not None and exp != "" + # Check explanations are provided + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_map_cot_basic(): - """Test sem_map using DeepSeek CoT on a basic mapping task.""" +def test_deepseek_demonstration_config(): + """Test DeepSeek with examples.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = {"Text": ["Paris is the capital of France", "Berlin is the capital of Germany"]} + data = {"Animal": ["Dog", "Cat", "Eagle", "Fish"]} df = pd.DataFrame(data) - user_instruction = "Extract the capital city from the sentence: {Text}" - result = df.sem_map(user_instruction, return_explanations=True, strategy=ReasoningStrategy.ZS_COT) + user_instruction = "{Animal} can fly" - # Check that the mapping column and explanation column exist. - assert "_map" in result.columns - assert "explanation_map" in result.columns + # Provide examples + examples = pd.DataFrame({"Animal": ["Bird", "Elephant"], "Answer": [True, False]}) - # Verify that each mapped output is a string and each explanation is nonempty. - for output, exp in zip(result["_map"], result["explanation_map"]): - assert isinstance(output, str) - assert exp is not None and exp != "" + result = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems=examples), + return_all=True, + ) + + assert "filter_label" in result.columns + # Should identify flying animals correctly + flying_animals = result[result["filter_label"]]["Animal"].tolist() + assert "Eagle" in flying_animals + + +@pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") +def test_deepseek_bootstrapping(): + """Test DeepSeek with automatic demonstration bootstrapping.""" + lm = LM(model=MODEL_NAME) + lotus.settings.configure(lm=lm) + + data = {"City": ["New York", "London", "Tokyo", "Sydney", "Paris"]} + df = pd.DataFrame(data) + user_instruction = "{City} is in Asia" + + # Configure bootstrapping + result = df.sem_filter( + user_instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + return_all=True, + ) + + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Should identify Asian cities correctly + asian_cities = result[result["filter_label"]]["City"].tolist() + assert "Tokyo" in asian_cities + + # Should work even without user-provided examples + assert len(result) == len(df) @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_top_k_with_negative_reviews(): - """Test sem_top_k with a dataset containing negative reviews.""" - lm = LM(model=MODEL_NAME, temperature=0.6) +def test_deepseek_extract_with_cot(): + """Test DeepSeek extract operation with CoT reasoning.""" + lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) data = { "Review": [ - "This vacuum cleaner is the best I've ever owned. Highly recommend it!", - "It's okay, not sure I would buy it again.", - "Terrible experience, broke after a few uses.", - "Amazing build quality and customer support. Would absolutely recommend.", - "I would not recommend this to anyone.", - "This product is amazing! I love it.", + "This phone has amazing battery life and great camera quality!", + "The laptop is too slow and overheats frequently.", ] } - df = pd.DataFrame(data) - user_instruction = "{Review} suggests that the user would recommend the product to others" - for method in ["quick", "heap", "naive"]: - sorted_df, stats = df.sem_topk( - user_instruction, - K=2, - method=method, - return_stats=True, - strategy=ReasoningStrategy.ZS_COT, - return_explanations=True, - ) - # Check that the top 2 reviews are positive - top_reviews = sorted_df["Review"].tolist() - assert any( - "recommend" in review.lower() or "best" in review.lower() or "amazing" in review.lower() - for review in top_reviews - ) + output_cols = { + "sentiment": "Overall sentiment (positive/negative)", + "main_feature": "Main feature mentioned in the review", + } - # Check that the stats are correct - assert stats["total_tokens"] > 0 - assert stats["total_llm_calls"] > 0 + input_cols = ["Review"] # Columns to extract from - # Check that each explanation is not empty - for exp in sorted_df["explanation"]: - assert exp is not None and exp != "" + result = df.sem_extract(input_cols, output_cols, prompt_strategy=PromptStrategy(cot=True), return_explanations=True) + + assert "sentiment" in result.columns + assert "main_feature" in result.columns + assert "explanation_extract" in result.columns + + # Check sentiment extraction + sentiments = result["sentiment"].tolist() + assert any("positive" in sent.lower() for sent in sentiments) + assert any("negative" in sent.lower() for sent in sentiments) @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_filter_cot_fewshot(): - """Test sem_filter with few-shot examples to guide filtering decisions.""" +def test_deepseek_backward_compatibility(): + """Test that DeepSeek still works with legacy methods.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = { - "Text": [ - "Sequence: 5, 4, 3", # Not increasing - "Sequence: 1, 2, 3", # Increasing - "Sequence: 8, 7, 6", # Not increasing - ] - } + data = {"Text": ["The weather is sunny today", "It's raining heavily outside"]} df = pd.DataFrame(data) - user_instruction = "{Text} is an increasing sequence" + user_instruction = "{Text} describes good weather" - # Few-shot examples provided as a DataFrame. - examples = pd.DataFrame( - { - "Text": ["Sequence: 1, 2, 3", "Sequence: 3, 2, 1"], - "Answer": [True, False], - "Reasoning": ["Numbers increase steadily", "Numbers decrease"], - } - ) + # Test without explicit strategy (should use default behavior) + result_default = df.sem_filter(user_instruction, return_all=True) - filtered_df = df.sem_filter( - user_instruction, - examples=examples, - return_explanations=True, - return_all=True, - strategy=ReasoningStrategy.COT, - ) + # Test with explicit CoT strategy + result_cot = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(cot=True), return_all=True) - # Expect that at least the row with "Sequence: 1, 2, 3" is marked positive. - positive_rows = filtered_df[filtered_df["filter_label"]] - assert len(positive_rows) >= 1 - for exp in positive_rows["explanation_filter"]: - assert exp is not None and exp != "" + # Both should work and produce results + assert "filter_label" in result_default.columns + assert "filter_label" in result_cot.columns + assert len(result_default) == len(result_cot) == len(df) @pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") -def test_deepseek_map_cot_fewshot(): - """Test sem_map with few-shot examples to guide mapping decisions.""" +def test_deepseek_error_handling(): + """Test error handling with DeepSeek and new reasoning strategies.""" lm = LM(model=MODEL_NAME) lotus.settings.configure(lm=lm) - data = {"Text": ["City: New York", "City: Los Angeles"]} + data = {"Text": ["Sample text"]} df = pd.DataFrame(data) - user_instruction = "Determine the state abbreviation for {Text}" + user_instruction = "{Text} is meaningful" - examples = pd.DataFrame( - { - "Text": ["City: Chicago", "City: Houston"], - "Answer": ["IL", "TX"], - "Reasoning": ["Chicago is in Illinois", "Houston is in Texas"], - } - ) + # Test with empty examples + empty_examples = pd.DataFrame(columns=["Text", "Answer"]) - result = df.sem_map( - user_instruction, - examples=examples, - return_explanations=True, - strategy=ReasoningStrategy.COT, - ) + try: + result = df.sem_filter(user_instruction, prompt_strategy=PromptStrategy(dems=empty_examples), return_all=True) + # Should handle gracefully + assert "filter_label" in result.columns + except Exception as e: + # If it raises an error, it should be informative + assert len(str(e)) > 0 + + +@pytest.mark.skipif(not ENABLE_OLLAMA_TESTS, reason="Skipping test because Ollama tests are not enabled") +def test_deepseek_multiple_operations_chaining(): + """Test chaining multiple operations with DeepSeek and different strategies.""" + lm = LM(model=MODEL_NAME) + lotus.settings.configure(lm=lm) + + data = {"Product": ["iPhone", "Novel", "MacBook", "Newspaper", "iPad"]} + df = pd.DataFrame(data) + + # First filter with demonstrations + examples = pd.DataFrame({"Product": ["Laptop", "Book"], "Answer": [True, False]}) + + filtered_df = df.sem_filter("{Product} is an electronic device", prompt_strategy=PromptStrategy(dems=examples)) + + # Then map with CoT + if len(filtered_df) > 0: + mapped_df = filtered_df.sem_map( + "What category does {Product} belong to?", + prompt_strategy=PromptStrategy(cot=True), + return_explanations=True, + ) + + assert "_map" in mapped_df.columns + assert "explanation_map" in mapped_df.columns - # Check that the new column "State" is added and that explanations are nonempty. - assert "_map" in result.columns - assert "explanation_map" in result.columns - for output, exp in zip(result["_map"], result["explanation_map"]): - assert isinstance(output, str) - assert exp is not None and exp != "" + # Should have reasonable categorizations + for category in mapped_df["_map"]: + assert isinstance(category, str) + assert len(category) > 0 diff --git a/tests/test_reasoning_strategies.py b/tests/test_reasoning_strategies.py new file mode 100644 index 00000000..0eaaa2d8 --- /dev/null +++ b/tests/test_reasoning_strategies.py @@ -0,0 +1,787 @@ +import os + +import pandas as pd +import pytest + +import lotus +import lotus.nl_expression as nle +from lotus.models import LM +from lotus.sem_ops.cascade_utils import bootstrap_demonstrations +from lotus.types import PromptStrategy +from tests.base_test import BaseTest + +# Skip all tests if no OpenAI API key is available +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +pytestmark = pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not available") + + +@pytest.fixture +def sample_courses_df(): + """Sample course data for testing""" + return pd.DataFrame( + { + "Course Name": [ + "Linear Algebra", + "Poetry Writing", + "Calculus II", + "Art History", + "Statistics", + "Creative Writing", + "Machine Learning", + "Literature Analysis", + "Physics", + "Philosophy", + ], + "Department": [ + "Math", + "English", + "Math", + "Art", + "Math", + "English", + "CS", + "English", + "Physics", + "Philosophy", + ], + "Credits": [3, 3, 4, 3, 3, 3, 4, 3, 4, 3], + } + ) + + +@pytest.fixture +def sample_reviews_df(): + """Sample review data for testing""" + return pd.DataFrame( + { + "Review": [ + "This product is amazing! Highly recommend it to everyone.", + "It's okay, nothing special but does the job.", + "Terrible quality, broke after one day. Would not recommend.", + "Great value for money, very satisfied with my purchase.", + "Poor customer service and mediocre product quality.", + "Outstanding performance, exceeded my expectations!", + ], + "Rating": [5, 3, 1, 4, 2, 5], + } + ) + + +@pytest.fixture +def setup_model(): + """Set up a test model""" + lm = LM(model="gpt-4o-mini", temperature=0.1) + lotus.settings.configure(lm=lm) + return lm + + +class TestReasoningStrategies(BaseTest): + """Test suite for reasoning strategies""" + + # ============================================================================= + # Chain-of-Thought (CoT) Tests + # ============================================================================= + + def test_cot_filter_basic(self, sample_courses_df, setup_model): + """Test basic CoT reasoning with sem_filter""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter( + instruction, prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_all=True + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check that explanations are provided + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 + # CoT should contain substantive reasoning (check for common reasoning indicators or sufficient length) + has_reasoning_words = any( + word in explanation.lower() + for word in [ + "reasoning", + "because", + "since", + "therefore", + "requires", + "involves", + "contains", + "needs", + "mathematical", + "math", + "calculus", + "algebra", + ] + ) + is_substantial = len(explanation.split()) > 5 + assert has_reasoning_words or is_substantial, f"Explanation lacks reasoning indicators: '{explanation}'" + + def test_cot_map_basic(self, sample_courses_df, setup_model): + """Test basic CoT reasoning with sem_map""" + df = sample_courses_df + instruction = "What is the difficulty level of {Course Name}? Answer: Beginner, Intermediate, or Advanced" + + result = df.sem_map(instruction, prompt_strategy=PromptStrategy(cot=True), return_explanations=True) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check that explanations contain reasoning + for explanation in result["explanation_map"]: + assert explanation is not None + assert len(explanation) > 0 + + def test_cot_topk_basic(self, sample_reviews_df, setup_model): + """Test basic CoT reasoning with sem_topk""" + df = sample_reviews_df + instruction = "{Review} is a positive review" + + result, stats = df.sem_topk( + instruction, K=3, prompt_strategy=PromptStrategy(cot=True), return_explanations=True, return_stats=True + ) + + # Check structure + assert len(result) == 3 + assert "explanation" in result.columns + assert stats["total_llm_calls"] > 0 + + # Check explanations + for explanation in result["explanation"]: + assert explanation is not None + assert len(explanation) > 0 + + # ============================================================================= + # Demonstrations (Few-shot) Tests + # ============================================================================= + + def test_demonstrations_filter_basic(self, sample_courses_df, setup_model): + """Test demonstrations strategy with sem_filter""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Provide examples + examples = pd.DataFrame( + {"Course Name": ["Machine Learning", "Literature", "Physics"], "Answer": [True, False, True]} + ) + + result = df.sem_filter(instruction, prompt_strategy=PromptStrategy(dems=examples), return_all=True) + + # Check structure + assert "filter_label" in result.columns + + # Should identify math-heavy courses correctly based on examples + math_courses = result[result["filter_label"]]["Course Name"].tolist() + assert any(course in ["Linear Algebra", "Calculus II", "Statistics"] for course in math_courses) + + def test_demonstrations_map_basic(self, sample_courses_df, setup_model): + """Test demonstrations strategy with sem_map""" + df = sample_courses_df.head(3) # Use fewer rows for faster testing + instruction = "What department is {Course Name} in?" + + # Provide examples + examples = pd.DataFrame({"Course Name": ["Calculus I", "English Literature"], "Answer": ["Math", "English"]}) + + result = df.sem_map(instruction, prompt_strategy=PromptStrategy(dems=examples)) + + # Check structure + assert "_map" in result.columns + + # Check that mapping results are reasonable + for mapped_value in result["_map"]: + assert isinstance(mapped_value, str) + assert len(mapped_value) > 0 + + # ============================================================================= + # CoT + Demonstrations Tests + # ============================================================================= + + def test_cot_demonstrations_filter(self, sample_courses_df, setup_model): + """Test combined CoT + Demonstrations with sem_filter""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Provide examples with reasoning + examples = pd.DataFrame( + { + "Course Name": ["Machine Learning", "Literature", "Physics"], + "Answer": [True, False, True], + "Reasoning": [ + "Machine Learning requires linear algebra, calculus, and statistics", + "Literature focuses on reading, writing, and analysis - no math required", + "Physics is fundamentally mathematical with equations and calculations", + ], + } + ) + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems=examples), + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check that explanations are provided and contain reasoning + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 + + def test_cot_demonstrations_map(self, sample_courses_df, setup_model): + """Test combined CoT + Demonstrations with sem_map""" + df = sample_courses_df.head(3) + instruction = "What is the difficulty level of {Course Name}?" + + # Provide examples with reasoning + examples = pd.DataFrame( + { + "Course Name": ["Algebra I", "Advanced Calculus"], + "Answer": ["Beginner", "Advanced"], + "Reasoning": [ + "Algebra I is typically an introductory math course", + "Advanced Calculus requires significant mathematical background", + ], + } + ) + + result = df.sem_map( + instruction, prompt_strategy=PromptStrategy(cot=True, dems=examples), return_explanations=True + ) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check explanations + for explanation in result["explanation_map"]: + assert explanation is not None + assert len(explanation) > 0 + + # ============================================================================= + # Examples and Bootstrapping Tests + # ============================================================================= + + def test_demonstration_basic(self, sample_courses_df, setup_model): + """Test with user-provided examples""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Examples provided + examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems=examples), + return_all=True, + ) + + assert "filter_label" in result.columns + + def test_bootstrapping_basic(self, sample_courses_df, setup_model): + """Test automatic demonstration bootstrapping""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Configure bootstrapping + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Should work even without user-provided examples + assert len(result) == len(df) + + def test_bootstrapping_with_oracle_model(self, sample_courses_df, setup_model): + """Test bootstrapping with a different oracle model""" + df = sample_courses_df.head(5) # Use fewer rows for faster testing + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=1, teacher_lm=LM(model="gpt-4o-mini")), + return_all=True, + ) + + assert "filter_label" in result.columns + + def test_bootstrapping_map_operation(self, sample_courses_df, setup_model): + """Test auto-bootstrapping with map operation""" + df = sample_courses_df.head(4) # Use fewer rows for faster testing + instruction = "What is the difficulty level of {Course Name}? Answer: Beginner, Intermediate, or Advanced" + + result = df.sem_map( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + ) + + # Check structure + assert "_map" in result.columns + assert "explanation_map" in result.columns + + # Check that all difficulty levels are reasonable + for difficulty in result["_map"]: + assert isinstance(difficulty, str) + assert len(difficulty) > 0 + + def test_bootstrapping_without_cot(self, sample_courses_df, setup_model): + """Test auto-bootstrapping without chain-of-thought""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=False, dems="auto", max_dems=2), + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + # Should not have explanations when CoT is disabled + assert "explanation_filter" not in result.columns + + def test_bootstrapping_max_dems_limit(self, sample_courses_df, setup_model): + """Test that max_dems is respected""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Request more demonstrations than available data + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=100), + return_all=True, + ) + + # Should still work and be limited by available data + assert "filter_label" in result.columns + + def test_bootstrapping_with_additional_instructions(self, sample_courses_df, setup_model): + """Test auto-bootstrapping with additional CoT instructions""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy( + cot=True, + dems="auto", + max_dems=2, + additional_cot_instructions="Consider the mathematical content and prerequisites carefully.", + ), + return_explanations=True, + return_all=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check that explanations are provided + for explanation in result["explanation_filter"]: + assert explanation is not None + assert len(explanation) > 0 + + def test_bootstrapping_error_handling(self, sample_courses_df, setup_model): + """Test that bootstrapping errors are handled gracefully""" + df = sample_courses_df.head(2) + instruction = "{Course Name} requires a lot of math" + + # This should work even if bootstrapping encounters issues + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=1), + return_all=True, + ) + + # Should still produce results + assert "filter_label" in result.columns + assert len(result) == len(df) + + def test_bootstrapping_empty_data_handling(self, setup_model): + """Test bootstrapping with empty or minimal data""" + empty_df = pd.DataFrame({"Course Name": []}) + instruction = "{Course Name} requires a lot of math" + + # Should handle empty data gracefully + result = empty_df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_all=True, + ) + + assert "filter_label" in result.columns + assert len(result) == 0 + + def test_bootstrapping_integration_with_existing_features(self, sample_courses_df, setup_model): + """Test that auto-bootstrapping works with other features like return_stats""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + result, stats = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=2), + return_explanations=True, + return_all=True, + return_stats=True, + ) + + # Check structure + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + assert isinstance(stats, dict) + + # ============================================================================= + # Backward Compatibility Tests + # ============================================================================= + + def test_backward_compatibility_examples_param(self, sample_courses_df, setup_model): + """Test that old examples parameter still works""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Old way: passing examples directly + examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(dems=examples), + examples=examples, # Old parameter name + return_all=True, + ) + + assert "filter_label" in result.columns + + def test_no_strategy_specified(self, sample_courses_df, setup_model): + """Test default behavior when no strategy is specified""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + result = df.sem_filter(instruction, return_all=True) + + # Should work with default behavior + assert "filter_label" in result.columns + + # ============================================================================= + # Extract Operation Tests + # ============================================================================= + + def test_cot_extract_basic(self, sample_reviews_df, setup_model): + """Test CoT reasoning with sem_extract""" + df = sample_reviews_df.head(3) # Use fewer rows for faster testing + + input_cols = ["Review"] + output_cols = { + "sentiment": "The sentiment of the review (positive/negative/neutral)", + "key_points": "Main points mentioned in the review", + } + + result = df.sem_extract( + input_cols, output_cols, prompt_strategy=PromptStrategy(cot=True), return_explanations=True + ) + + # Check structure + assert "sentiment" in result.columns + assert "key_points" in result.columns + assert "explanation" in result.columns + + # Check that extractions are reasonable + for sentiment in result["sentiment"]: + assert sentiment.lower() in ["positive", "negative", "neutral"] + + # ============================================================================= + # Error Handling and Edge Cases + # ============================================================================= + + def test_empty_examples(self, sample_courses_df, setup_model): + """Test behavior with empty examples DataFrame""" + df = sample_courses_df.head(3) + instruction = "{Course Name} requires a lot of math" + + # Create properly structured empty examples DataFrame + empty_examples = pd.DataFrame(columns=["Course Name", "Answer"]) + + # Should handle empty examples gracefully + result = df.sem_filter(instruction, prompt_strategy=PromptStrategy(dems=empty_examples), return_all=True) + + assert "filter_label" in result.columns + + def test_mismatched_example_columns(self, sample_courses_df, setup_model): + """Test error handling for mismatched example columns""" + df = sample_courses_df.head(3) + instruction = "{Course Name} requires a lot of math" + + # Examples with wrong column names + bad_examples = pd.DataFrame({"WrongColumn": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + # Should handle gracefully or raise informative error + try: + result = df.sem_filter(instruction, prompt_strategy=PromptStrategy(dems=bad_examples), return_all=True) + # If it doesn't raise an error, it should still produce results + assert "filter_label" in result.columns + except Exception as e: + # If it raises an error, it should be informative + assert len(str(e)) > 0 + + def test_invalid_strategy_combination(self, sample_courses_df, setup_model): + """Test invalid combinations of parameters""" + df = sample_courses_df.head(3) + instruction = "{Course Name} requires a lot of math" + + # Try to use bootstrapping without CoT_Demonstrations strategy + try: + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto"), # Should use auto for bootstrapping + return_all=True, + ) + # Should either work or raise informative error + assert "filter_label" in result.columns + except Exception as e: + assert len(str(e)) > 0 + + def test_large_num_demonstrations(self, sample_courses_df, setup_model): + """Test behavior with large number of demonstrations""" + df = sample_courses_df + instruction = "{Course Name} requires a lot of math" + + # Request more demonstrations than available data + result = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True, dems="auto", max_dems=20), + return_all=True, + ) + + # Should handle gracefully + assert "filter_label" in result.columns + + # ============================================================================= + # Performance and Integration Tests + # ============================================================================= + + def test_multiple_operations_with_strategies(self, sample_courses_df, setup_model): + """Test chaining multiple operations with different strategies""" + df = sample_courses_df + + # First filter with demonstrations + examples = pd.DataFrame({"Course Name": ["Machine Learning", "Literature"], "Answer": [True, False]}) + + filtered_df = df.sem_filter( + "{Course Name} requires a lot of math", prompt_strategy=PromptStrategy(dems=examples) + ) + + # Then map with CoT + if len(filtered_df) > 0: + mapped_df = filtered_df.sem_map( + "What is the difficulty level of {Course Name}?", + prompt_strategy=PromptStrategy(cot=True), + return_explanations=True, + ) + + assert "_map" in mapped_df.columns + assert "explanation_map" in mapped_df.columns + + def test_strategy_with_return_options(self, sample_courses_df, setup_model): + """Test strategies with various return options""" + df = sample_courses_df.head(4) + instruction = "{Course Name} requires a lot of math" + + # Test return_stats=True (returns tuple) + result, stats = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True), + return_all=True, + return_explanations=True, + return_stats=True, + ) + + # Check all expected columns are present in DataFrame + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + + # Check stats is returned + assert isinstance(stats, dict) + + # Test without return_stats (returns DataFrame only) + result_no_stats = df.sem_filter( + instruction, + prompt_strategy=PromptStrategy(cot=True), + return_all=True, + return_explanations=True, + return_stats=False, + ) + + # Should return DataFrame directly + assert "filter_label" in result_no_stats.columns + assert "explanation_filter" in result_no_stats.columns + + # Test that filtering works correctly + positive_results = result[result["filter_label"]] + assert len(positive_results) >= 0 # Should have some math courses + + # ============================================================================= + # Direct Bootstrap Function Tests + # ============================================================================= + + def test_bootstrap_demonstrations_function_direct(self, sample_courses_df, setup_model): + """Test the bootstrap_demonstrations function directly (now uses sem_ops internally)""" + import lotus.nl_expression as nle + from lotus.sem_ops.cascade_utils import bootstrap_demonstrations + + df = sample_courses_df.head(4) + user_instruction = "{Course Name} requires a lot of math" + col_li = nle.parse_cols(user_instruction) + formatted_instruction = nle.nle2str(user_instruction, col_li) + + prompt_strategy = PromptStrategy(cot=True, dems="auto", max_dems=2, teacher_lm=setup_model) + + # Test the function directly + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy, + operation_type="filter", + ) + + # Check outputs + assert isinstance(examples_multimodal_data, list) + assert isinstance(examples_answers, list) + assert len(examples_multimodal_data) == len(examples_answers) + assert len(examples_multimodal_data) <= prompt_strategy.max_dems + + # Check CoT reasoning + if prompt_strategy.cot: + # Note: CoT reasoning might be None if the model output doesn't follow expected format + if cot_reasoning is not None: + assert isinstance(cot_reasoning, list) + assert len(cot_reasoning) == len(examples_answers) + for reasoning in cot_reasoning: + if reasoning is not None: # Individual reasoning items might be None + assert isinstance(reasoning, str) + + # Check answer types for filter operation + for answer in examples_answers: + assert isinstance(answer, bool) + + def test_bootstrap_demonstrations_map_operation_direct(self, sample_courses_df, setup_model): + """Test bootstrap_demonstrations function for map operation (now uses sem_map internally)""" + import lotus.nl_expression as nle + from lotus.sem_ops.cascade_utils import bootstrap_demonstrations + + df = sample_courses_df.head(3) + user_instruction = "What is the difficulty level of {Course Name}?" + col_li = nle.parse_cols(user_instruction) + formatted_instruction = nle.nle2str(user_instruction, col_li) + + prompt_strategy = PromptStrategy(cot=False, dems="auto", max_dems=2, teacher_lm=setup_model) + + # Test the function for map operation + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy, + operation_type="map", + ) + + # Check outputs + assert isinstance(examples_multimodal_data, list) + assert isinstance(examples_answers, list) + assert len(examples_multimodal_data) == len(examples_answers) + + # Check answer types for map operation + for answer in examples_answers: + assert isinstance(answer, str) + assert len(answer) > 0 + + # No CoT reasoning expected + assert cot_reasoning is None + + def test_bootstrap_demonstrations_default_prompt_strategy(self, sample_courses_df, setup_model): + """Test bootstrap_demonstrations function with default PromptStrategy""" + df = sample_courses_df.head(2) + user_instruction = "{Course Name} requires a lot of math" + col_li = nle.parse_cols(user_instruction) + formatted_instruction = nle.nle2str(user_instruction, col_li) + + # Set up teacher model in settings since default PromptStrategy has teacher_lm=None + original_lm = lotus.settings.lm + lotus.settings.lm = setup_model + + try: + # Test with default PromptStrategy (no prompt_strategy parameter) + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + # prompt_strategy parameter omitted to test default + operation_type="filter", + ) + + # Check that it works with default strategy + assert isinstance(examples_multimodal_data, list) + assert isinstance(examples_answers, list) + # Default PromptStrategy has cot=False, so no reasoning expected + assert cot_reasoning is None + + finally: + # Restore the original LM + lotus.settings.lm = original_lm + + def test_bootstrap_demonstrations_error_cases(self, sample_courses_df, setup_model): + """Test bootstrap_demonstrations function error handling""" + import lotus.nl_expression as nle + from lotus.sem_ops.cascade_utils import bootstrap_demonstrations + + df = sample_courses_df.head(2) + user_instruction = "{Course Name} requires a lot of math" + col_li = nle.parse_cols(user_instruction) + formatted_instruction = nle.nle2str(user_instruction, col_li) + + # Test with no teacher model - temporarily clear lotus.settings.lm + original_lm = lotus.settings.lm + lotus.settings.lm = None + + try: + prompt_strategy_no_teacher = PromptStrategy(cot=True, dems="auto", max_dems=2, teacher_lm=None) + + with pytest.raises(ValueError, match="No teacher model available for bootstrapping"): + bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy_no_teacher, + operation_type="filter", + ) + finally: + # Restore the original LM + lotus.settings.lm = original_lm + + # Test with unsupported operation type + prompt_strategy = PromptStrategy(cot=True, dems="auto", max_dems=1, teacher_lm=setup_model) + + examples_multimodal_data, examples_answers, cot_reasoning = bootstrap_demonstrations( + data=df, + col_li=col_li, + user_instruction=formatted_instruction, + prompt_strategy=prompt_strategy, + operation_type="unsupported_operation", + ) + + # Should return empty results for unsupported operations + assert examples_multimodal_data == [] + assert examples_answers == [] + assert cot_reasoning is None