From 0be21c7f37096b9fffaf723265089f62a4f4a20e Mon Sep 17 00:00:00 2001 From: IsmaelKabore Date: Thu, 16 Oct 2025 17:07:06 -0700 Subject: [PATCH 1/3] Implement semantic map + filter --- .github/tests/test_sem_filter.py | 69 +++++++++++++ .github/tests/test_sem_map.py | 60 ++++++++++++ .vscode/settings.json | 7 ++ CONTRIBUTING.md | 4 + README.md | 11 +-- lotus/models/lm.py | 10 +- lotus/sampling_utils.py | 161 +++++++++++++++++++++++++++++++ lotus/sem_ops/sem_filter.py | 47 +++++++-- lotus/sem_ops/sem_map.py | 49 ++++++++-- pyproject.toml | 2 +- 10 files changed, 395 insertions(+), 25 deletions(-) create mode 100644 .github/tests/test_sem_filter.py create mode 100644 .github/tests/test_sem_map.py create mode 100644 .vscode/settings.json create mode 100644 lotus/sampling_utils.py diff --git a/.github/tests/test_sem_filter.py b/.github/tests/test_sem_filter.py new file mode 100644 index 00000000..408950bd --- /dev/null +++ b/.github/tests/test_sem_filter.py @@ -0,0 +1,69 @@ +import os + +import pandas as pd +import pytest + +import lotus +from lotus.models import LM + +# Gate the whole module behind an env var so devs/CI without Ollama skip cleanly. +ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" + +pytestmark = pytest.mark.skipif( + not ENABLE_OLLAMA_TESTS, + reason="Set ENABLE_OLLAMA_TESTS=true to run Ollama-backed tests", +) + +MODEL_NAME = "ollama/llama3.1" + + +@pytest.fixture(scope="session") +def setup_models(): + return {MODEL_NAME: LM(model=MODEL_NAME)} + + +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_models): + yield + for _, m in setup_models.items(): + m.print_total_usage() + m.reset_stats() + m.reset_cache() + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_basic(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame( + { + "Text": [ + "I am really excited to go to class today!", + "I am very sad", + ] + } + ) + user_instruction = "{Text} is a positive sentiment" + + filtered = df.sem_filter(user_instruction) + + # + assert isinstance(filtered, pd.DataFrame) + assert "Text" in filtered.columns + assert len(filtered) >= 1 + assert "I am very sad" not in filtered["Text"].tolist() + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_with_sampling(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Today is fantastic!", "This is terrible.", "Pretty good overall."]}) + user_instruction = "{Text} is a positive sentiment" + + # Exercise sampling/temperature path; keep assertions tolerant + filtered = df.sem_filter(user_instruction, n_sample=2, temperature=0.9) + + assert isinstance(filtered, pd.DataFrame) + assert "Text" in filtered.columns + assert len(filtered) >= 1 diff --git a/.github/tests/test_sem_map.py b/.github/tests/test_sem_map.py new file mode 100644 index 00000000..430ef144 --- /dev/null +++ b/.github/tests/test_sem_map.py @@ -0,0 +1,60 @@ +import os + +import pandas as pd +import pytest + +import lotus +from lotus.models import LM + +ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" + +pytestmark = pytest.mark.skipif( + not ENABLE_OLLAMA_TESTS, + reason="Set ENABLE_OLLAMA_TESTS=true to run Ollama-backed tests", +) + +MODEL_NAME = "ollama/llama3.1" + + +@pytest.fixture(scope="session") +def setup_models(): + return {MODEL_NAME: LM(model=MODEL_NAME)} + + +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_models): + yield + for _, m in setup_models.items(): + m.print_total_usage() + m.reset_stats() + m.reset_cache() + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_map_basic(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"document": ["Alice likes cats.", "Bob likes dogs."]}) + mapped = df.sem_map("Summarize {document} in three words.", suffix="_map") + + assert isinstance(mapped, pd.DataFrame) + assert "_map" in mapped.columns + assert len(mapped) == 2 + assert all(isinstance(x, str) and len(x) > 0 for x in mapped["_map"].tolist()) + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_map_with_sampling(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"document": ["The sky is blue."]}) + mapped = df.sem_map( + "Paraphrase {document} compactly.", + n_sample=2, + temperature=0.8, + suffix="_map", + ) + + assert "_map" in mapped.columns + assert len(mapped) == 1 + assert isinstance(mapped.iloc[0]["_map"], str) and len(mapped.iloc[0]["_map"]) > 0 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..a41de757 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "docs" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0ed63351..64122ee3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -288,6 +288,10 @@ lm = LM( - **Discussions**: https://github.com/lotus-data/lotus/discussions - **Issues**: https://github.com/lotus-data/lotus/issues +## Code of Conduct + +We are committed to providing a welcoming and inclusive environment for all contributors. Please be respectful and constructive in all interactions. + --- diff --git a/README.md b/README.md index 37932060..75788e0a 100644 --- a/README.md +++ b/README.md @@ -154,18 +154,13 @@ For recent updates related to LOTUS, follow [@lianapatel_](https://x.com/lianapa If you find LOTUS or semantic operators useful, we'd appreciate if you can please cite this work as follows: ```bibtex -@article{patel2025semanticoptimization, - title = {Semantic Operators and Their Optimization: Enabling LLM-Based Data Processing with Accuracy Guarantees in LOTUS}, - author = {Patel, Liana and Jha, Siddharth and Pan, Melissa and Gupta, Harshit and Asawa, Parth and Guestrin, Carlos and Zaharia, Matei}, - year = {2025}, - journal = {Proc. VLDB Endow.}, - url = {https://doi.org/10.14778/3749646.3749685}, -} -@article{patel2024semanticoperators, +@misc{patel2024semanticoperators, title={Semantic Operators: A Declarative Model for Rich, AI-based Analytics Over Text Data}, author={Liana Patel and Siddharth Jha and Parth Asawa and Melissa Pan and Carlos Guestrin and Matei Zaharia}, year={2024}, eprint={2407.11418}, + archivePrefix={arXiv}, + primaryClass={cs.DB}, url={https://arxiv.org/abs/2407.11418}, } ``` diff --git a/lotus/models/lm.py b/lotus/models/lm.py index b59414d3..953f0ba4 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -282,6 +282,10 @@ def _process_with_rate_limiting( # Each request should be spaced by min_interval_per_request required_time_for_batch = len(sub_batch) * min_interval_per_request + # Only sleep if the batch was faster than the required time + # Each request should be spaced by min_interval_per_request + required_time_for_batch = len(sub_batch) * min_interval_per_request + # Only sleep if the batch was faster than the required time if i < num_batches - 1: # Don't sleep after the last batch to_sleep = required_time_for_batch - elapsed @@ -382,8 +386,10 @@ def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompleti choice = response.choices[0] assert isinstance(choice, Choices) - assert choice.logprobs is not None and isinstance(choice.logprobs, ChoiceLogprobs) - logprobs = choice.logprobs["content"] + if choice.logprobs is not None and isinstance(choice.logprobs, ChoiceLogprobs): + logprobs = choice.logprobs["content"] + else: + logprobs = [] return logprobs def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: diff --git a/lotus/sampling_utils.py b/lotus/sampling_utils.py new file mode 100644 index 00000000..847b41b4 --- /dev/null +++ b/lotus/sampling_utils.py @@ -0,0 +1,161 @@ +# lotus/sampling_utils.py + +from collections import Counter +from typing import Any, Callable, List, Optional, Sequence, Tuple + + +def _norm(x: Any) -> str: + """ + Normalize assorted truthy/falsy strings to 'yes'/'no' when possible; + otherwise return the lowercased string. + + Examples: + " YES " -> "yes" + "False" -> "no" + "maybe" -> "maybe" + """ + t = str(x).strip().lower() + if t in {"yes", "y", "true", "t", "1"}: + return "yes" + if t in {"no", "n", "false", "f", "0"}: + return "no" + return t + + +def ensemble_majority_vote(samples_for_one_item: Sequence[str], *, default_yes: bool) -> str: + """ + Majority vote over arbitrary label strings (after normalization). + If the vote is tied: + - If it's exactly 'yes' == 'no', fall back to default_yes. + - Otherwise, deterministically break ties by lexicographic order. + """ + c = Counter(_norm(s) for s in samples_for_one_item) + if not c: + return "yes" if default_yes else "no" + + top = c.most_common() + + # single unique label + if len(top) == 1: + return top[0][0] + + # tie on counts + if top[0][1] == top[1][1]: + # special case: exactly yes==no + if "yes" in c and "no" in c and c["yes"] == c["no"]: + return "yes" if default_yes else "no" + # otherwise deterministic tie-break among the tied labels + tied_labels = sorted([lab for lab, cnt in top if cnt == top[0][1]]) + return tied_labels[0] + + # clear winner + return top[0][0] + + +def ensemble_mean_bool(samples_for_one_item: Sequence[str]) -> str: + """ + Average yes/no votes; returns 'yes' if mean >= 0.5 else 'no'. + Only meaningful when outputs semantically map to yes/no. + """ + vals = [1 if _norm(s) == "yes" else 0 for s in samples_for_one_item] + mean = sum(vals) / max(1, len(vals)) + return "yes" if mean >= 0.5 else "no" + + +def apply_ensemble(strategy: Optional[str], all_outputs: List[List[str]], *, default_yes: bool) -> List[str]: + """ + Collapse shape [n_sample][batch] -> [batch]. + + Behavior: + - If strategy is None or n_sample==1, return the first run unchanged. + - 'majority_vote' / 'majority': mode over strings (works for general labels). + - 'mean_prob' / 'average_prob' / 'avg_prob': treat strings as yes/no and average. + + Args: + strategy: name of ensemble rule (or None). + all_outputs: list of runs, each run is list[str] of length 'batch'. + default_yes: used to break exact yes/no ties in majority voting. + + Returns: + List[str]: one output per batch item. + """ + if not all_outputs: + return [] + if not strategy or len(all_outputs) == 1: + return list(all_outputs[0]) + + # Sanity check: all runs must have same batch size + batch = len(all_outputs[0]) + for run in all_outputs[1:]: + if len(run) != batch: + raise ValueError(f"Inconsistent batch sizes across runs: expected {batch}, got {len(run)}") + + n_sample = len(all_outputs) + per_item = [[all_outputs[k][i] for k in range(n_sample)] for i in range(batch)] + s = strategy.lower() + + out: List[str] = [] + for samples in per_item: + if s in {"majority_vote", "majority"}: + out.append(ensemble_majority_vote(samples, default_yes=default_yes)) + elif s in {"mean_prob", "average_prob", "avg_prob"}: + out.append(ensemble_mean_bool(samples)) + else: + raise ValueError( + f"Unknown ensemble strategy: {strategy}. " + "Use 'majority_vote' for general strings; 'average_prob' only for yes/no." + ) + return out + + +def resample_batch( + call_once: Callable[[bool, Optional[float], bool, str], Any], + *, + n_sample: int, + want_logprobs: bool, + show_progress_bar: bool, + progress_bar_desc: str, + temperature: Optional[float], +) -> Tuple[List[List[str]], Optional[List[Any]]]: + """ + Run the same batch through the model multiple times and collect outputs. + + Expected signature for `call_once`: + call_once(want_logprobs, temperature, show_progress_bar, progress_bar_desc) -> lm_out + where lm_out must have: + - lm_out.outputs: List[str] # batch-sized list of strings + - lm_out.logprobs: Optional[Any] # present if want_logprobs=True (provider-specific shape) + + Returns: + (all_outputs, chosen_logprobs) + - all_outputs: List of runs; each run is outputs[List[str]] of length 'batch' + - chosen_logprobs: logprobs from the FIRST run if requested, else None. + + Notes: + - We validate consistent batch sizes across runs. + - We intentionally return ONLY one set of logprobs to keep the API simple. + """ + all_outputs: List[List[str]] = [] + logs: List[Any] = [] # non-optional; we append only if available + + for _ in range(max(1, n_sample)): + lm_out = call_once(want_logprobs, temperature, show_progress_bar, progress_bar_desc) + + if not hasattr(lm_out, "outputs"): + raise ValueError("LM call did not return an object with an 'outputs' attribute.") + if not isinstance(lm_out.outputs, list): + raise ValueError("LM call returned outputs that are not a list.") + + all_outputs.append(lm_out.outputs) + + if want_logprobs and hasattr(lm_out, "logprobs") and lm_out.logprobs is not None: + logs.append(lm_out.logprobs) + + # Ensure consistent batch size + batch = len(all_outputs[0]) + for run in all_outputs[1:]: + if len(run) != batch: + raise ValueError(f"Inconsistent batch sizes across runs: expected {batch}, got {len(run)}") + + chosen: Optional[List[Any]] = logs[0] if (want_logprobs and len(logs) > 0) else None + return all_outputs, chosen diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 3a26506d..45bae3f5 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -6,10 +6,10 @@ import lotus from lotus.cache import operator_cache +from lotus.sampling_utils import apply_ensemble, resample_batch from lotus.templates import task_instructions from lotus.types import ( CascadeArgs, - LMOutput, LogprobsForFilterCascade, ProxyModel, ReasoningStrategy, @@ -35,6 +35,9 @@ def sem_filter( show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", + n_sample: int = 1, # number of samples per item + ensemble: str | None = None, # "majority_vote", "mean_prob", None + temperature: float | None = None, # if None, use model default ) -> SemanticFilterOutput: """ Filters a list of documents based on a natural language instruction using a language model. @@ -102,18 +105,34 @@ def sem_filter( ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) - kwargs: dict[str, Any] = {"logprobs": logprobs} if safe_mode: - estimated_total_calls = len(docs) - estimated_total_cost = sum(model.count_tokens(input) for input in inputs) + estimated_total_calls = len(docs) * max(1, n_sample) + estimated_total_cost = sum(model.count_tokens(input) for input in inputs) * max(1, n_sample) show_safe_mode(estimated_total_cost, estimated_total_calls) - lm_output: LMOutput = model( - inputs, show_progress_bar=show_progress_bar, progress_bar_desc=progress_bar_desc, **kwargs + def _call_once(want_logprobs, temp, spb, desc): + return model( + inputs, + logprobs=want_logprobs, + temperature=temp, + show_progress_bar=spb, + progress_bar_desc=desc if n_sample == 1 else f"{desc} (x{n_sample})", + ) + + all_runs, chosen_logprobs = resample_batch( + _call_once, + n_sample=n_sample, + want_logprobs=logprobs, + temperature=temperature, + show_progress_bar=show_progress_bar, + progress_bar_desc=progress_bar_desc, ) - postprocess_output = filter_postprocess(lm_output.outputs, model, default) + # Collapse [n_sample][batch] -> [batch] + final_texts = apply_ensemble(ensemble, all_runs, default_yes=default) + + postprocess_output = filter_postprocess(final_texts, model, default) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") @@ -125,7 +144,7 @@ def sem_filter( raw_outputs=postprocess_output.raw_outputs, outputs=postprocess_output.outputs, explanations=postprocess_output.explanations, - logprobs=lm_output.logprobs if logprobs else None, + logprobs=chosen_logprobs if logprobs else None, ) @@ -347,6 +366,9 @@ def __call__( safe_mode: bool = False, progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", + n_sample: int = 1, # number of samples per item + ensemble: str | None = None, # "majority_vote", "mean_prob", None + temperature: float | None = None, # if None, use model default ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: if lotus.settings.lm is None: raise ValueError( @@ -425,6 +447,9 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", + n_sample=n_sample, # NEW + ensemble=ensemble, # NEW + temperature=temperature, # NEW ) _, helper_logprobs = helper_output.outputs, helper_output.logprobs assert helper_logprobs is not None @@ -519,6 +544,9 @@ def __call__( safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", additional_cot_instructions=additional_cot_instructions, + n_sample=n_sample, # NEW + ensemble=ensemble, # NEW + temperature=temperature, # NEW ) for idx, large_idx in enumerate(low_conf_idxs): @@ -543,6 +571,9 @@ def __call__( show_progress_bar=True, progress_bar_desc=progress_bar_desc, additional_cot_instructions=additional_cot_instructions, + n_sample=n_sample, # NEW + ensemble=ensemble, # NEW + temperature=temperature, # NEW ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 2c8b7074..9d66c580 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -4,8 +4,9 @@ import lotus from lotus.cache import operator_cache +from lotus.sampling_utils import apply_ensemble, resample_batch from lotus.templates import task_instructions -from lotus.types import LMOutput, ReasoningStrategy, SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import ReasoningStrategy, SemanticMapOutput, SemanticMapPostprocessOutput from lotus.utils import show_safe_mode from .postprocessors import map_postprocess @@ -23,6 +24,9 @@ def sem_map( strategy: ReasoningStrategy | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", + n_sample: int = 1, # NEW + ensemble: str | None = None, # NEW + temperature: float | None = None, # NEW **model_kwargs: Any, ) -> SemanticMapOutput: """ @@ -59,6 +63,13 @@ def sem_map( Defaults to False. progress_bar_desc (str, optional): Description for the progress bar. Defaults to "Mapping". + n_sample (int): number of repeated LM calls per item. If >1, outputs are combined + using the specified ensemble strategy. + ensemble (str|None): "majority_vote", "average_prob"/"mean_prob"/"avg_prob", or None. + Strategy to combine multiple samples per item. If None, no ensembling is done. + Only used if n_sample > 1. + temperature (float|None): test-time temperature; None uses model default. + Defaults to None. **model_kwargs: Any: Additional keyword arguments to pass to the model. Returns: SemanticMapOutput: An object containing the processed outputs, raw outputs, @@ -94,18 +105,38 @@ def sem_map( # check if safe_mode is enabled if safe_mode: - estimated_cost = sum(model.count_tokens(input) for input in inputs) - estimated_LM_calls = len(docs) + estimated_cost = sum(model.count_tokens(input) for input in inputs) * max(1, n_sample) + estimated_LM_calls = len(docs) * max(1, n_sample) show_safe_mode(estimated_cost, estimated_LM_calls) # call model - lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc, **model_kwargs) + def _call_once(_want_logprobs, temp, spb, desc): + # sem_map doesn’t use logprobs, so _want_logprobs is ignored + return model( + inputs, + temperature=temp, + show_progress_bar=spb, + progress_bar_desc=desc if n_sample == 1 else f"{desc} (x{n_sample})", + **model_kwargs, + ) + + all_runs, _ = resample_batch( + _call_once, + n_sample=n_sample, + want_logprobs=False, + temperature=temperature, + show_progress_bar=True, + progress_bar_desc=progress_bar_desc, + ) + + # Collapse [n_sample][batch] -> [batch] using the chosen ensemble strategy + final_texts = apply_ensemble(ensemble, all_runs, default_yes=True) # default_yes only matters for ties # post process results postprocess_output = postprocessor( - lm_output.outputs, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] + final_texts, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] ) - lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") + lotus.logger.debug(f"raw_outputs: {final_texts}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") if safe_mode: @@ -224,6 +255,9 @@ def __call__( strategy: ReasoningStrategy | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", + n_sample: int = 1, # NEW + ensemble: str | None = None, # NEW + temperature: float | None = None, # NEW **model_kwargs: Any, ) -> pd.DataFrame: if lotus.settings.lm is None: @@ -266,6 +300,9 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, + n_sample=n_sample, # NEW + ensemble=ensemble, # NEW + temperature=temperature, # NEW **model_kwargs, ) diff --git a/pyproject.toml b/pyproject.toml index ae80a86d..bda68504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lotus-ai" -version = "1.1.4" +version = "1.1.3" description = "lotus" readme = "README.md" authors = [ From ca92e0a4b811121f131b4caf6e20c1874a9e3cee Mon Sep 17 00:00:00 2001 From: IsmaelKabore Date: Sun, 26 Oct 2025 01:33:47 -0700 Subject: [PATCH 2/3] sem_filter: sampling + ensembling; sampling_utils: enums & chosen-run logprobs; tests --- .github/tests/test_sem_filter.py | 89 +++++++++++- .github/tests/test_sem_map.py | 60 -------- README.md | 13 +- lotus/models/lm.py | 10 +- lotus/sampling_utils.py | 241 +++++++++++++++++-------------- lotus/sem_ops/sem_filter.py | 86 +++++++---- lotus/sem_ops/sem_map.py | 49 +------ lotus/types.py | 9 ++ 8 files changed, 300 insertions(+), 257 deletions(-) delete mode 100644 .github/tests/test_sem_map.py diff --git a/.github/tests/test_sem_filter.py b/.github/tests/test_sem_filter.py index 408950bd..b56653e2 100644 --- a/.github/tests/test_sem_filter.py +++ b/.github/tests/test_sem_filter.py @@ -5,6 +5,7 @@ import lotus from lotus.models import LM +from lotus.types import EnsembleStrategy, ReasoningStrategy # Gate the whole module behind an env var so devs/CI without Ollama skip cleanly. ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" @@ -47,23 +48,103 @@ def test_df_sem_filter_basic(setup_models, model): filtered = df.sem_filter(user_instruction) - # assert isinstance(filtered, pd.DataFrame) assert "Text" in filtered.columns assert len(filtered) >= 1 + # Tolerant: ensure obviously negative text is unlikely to pass assert "I am very sad" not in filtered["Text"].tolist() @pytest.mark.parametrize("model", [MODEL_NAME]) -def test_df_sem_filter_with_sampling(setup_models, model): +def test_df_sem_filter_with_sampling_pick_first(setup_models, model): lotus.settings.configure(lm=setup_models[model]) df = pd.DataFrame({"Text": ["Today is fantastic!", "This is terrible.", "Pretty good overall."]}) user_instruction = "{Text} is a positive sentiment" - # Exercise sampling/temperature path; keep assertions tolerant - filtered = df.sem_filter(user_instruction, n_sample=2, temperature=0.9) + # Exercise resampling path (n_sample>1) but keep ensemble as PICK_FIRST. + filtered = df.sem_filter( + user_instruction, + n_sample=2, + temperature=0.9, + # explicit for clarity (default is PICK_FIRST) + ensemble=EnsembleStrategy.PICK_FIRST, + ) assert isinstance(filtered, pd.DataFrame) assert "Text" in filtered.columns assert len(filtered) >= 1 + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_with_majority_ensemble_runs(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Great job!", "Awful service.", "Fine, I guess."]}) + user_instruction = "{Text} is a positive sentiment" + + # Exercise the MAJORITY voting path (boolean labels after postprocess). + # We keep assertions tolerant because model outputs can vary. + filtered = df.sem_filter( + user_instruction, + n_sample=3, + temperature=0.9, + ensemble=EnsembleStrategy.MAJORITY, + ) + + assert isinstance(filtered, pd.DataFrame) + assert "Text" in filtered.columns + # Should not error; size can vary depending on votes + assert len(filtered) >= 0 + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_invalid_n_sample_raises(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Neutral thing"]}) + user_instruction = "{Text} is a positive sentiment" + + with pytest.raises(ValueError): + df.sem_filter(user_instruction, n_sample=0) + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_return_all_and_explanations(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["I love sunshine", "I hate rain"]}) + user_instruction = "{Text} is a positive sentiment" + + # Ask for all rows + explanations (ZS_COT forces explanation path on) + full_df = df.sem_filter( + user_instruction, + return_all=True, + return_explanations=True, + strategy=ReasoningStrategy.ZS_COT, + ) + + assert isinstance(full_df, pd.DataFrame) + # When return_all=True, an extra 'filter_label' column is added (no suffix) + assert "filter_label" in full_df.columns + # Explanation column uses the default suffix "_filter" + assert "explanation_filter" in full_df.columns + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_return_stats_tuple(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["I love this", "I dislike that"]}) + user_instruction = "{Text} is a positive sentiment" + + result = df.sem_filter( + user_instruction, + return_stats=True, + ) + + # When return_stats=True, we get (DataFrame, stats_dict) + assert isinstance(result, tuple) and len(result) == 2 + out_df, stats = result + assert isinstance(out_df, pd.DataFrame) + assert isinstance(stats, dict) diff --git a/.github/tests/test_sem_map.py b/.github/tests/test_sem_map.py deleted file mode 100644 index 430ef144..00000000 --- a/.github/tests/test_sem_map.py +++ /dev/null @@ -1,60 +0,0 @@ -import os - -import pandas as pd -import pytest - -import lotus -from lotus.models import LM - -ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" - -pytestmark = pytest.mark.skipif( - not ENABLE_OLLAMA_TESTS, - reason="Set ENABLE_OLLAMA_TESTS=true to run Ollama-backed tests", -) - -MODEL_NAME = "ollama/llama3.1" - - -@pytest.fixture(scope="session") -def setup_models(): - return {MODEL_NAME: LM(model=MODEL_NAME)} - - -@pytest.fixture(autouse=True) -def print_usage_after_each_test(setup_models): - yield - for _, m in setup_models.items(): - m.print_total_usage() - m.reset_stats() - m.reset_cache() - - -@pytest.mark.parametrize("model", [MODEL_NAME]) -def test_df_sem_map_basic(setup_models, model): - lotus.settings.configure(lm=setup_models[model]) - - df = pd.DataFrame({"document": ["Alice likes cats.", "Bob likes dogs."]}) - mapped = df.sem_map("Summarize {document} in three words.", suffix="_map") - - assert isinstance(mapped, pd.DataFrame) - assert "_map" in mapped.columns - assert len(mapped) == 2 - assert all(isinstance(x, str) and len(x) > 0 for x in mapped["_map"].tolist()) - - -@pytest.mark.parametrize("model", [MODEL_NAME]) -def test_df_sem_map_with_sampling(setup_models, model): - lotus.settings.configure(lm=setup_models[model]) - - df = pd.DataFrame({"document": ["The sky is blue."]}) - mapped = df.sem_map( - "Paraphrase {document} compactly.", - n_sample=2, - temperature=0.8, - suffix="_map", - ) - - assert "_map" in mapped.columns - assert len(mapped) == 1 - assert isinstance(mapped.iloc[0]["_map"], str) and len(mapped.iloc[0]["_map"]) > 0 diff --git a/README.md b/README.md index 75788e0a..4d96537f 100644 --- a/README.md +++ b/README.md @@ -154,13 +154,18 @@ For recent updates related to LOTUS, follow [@lianapatel_](https://x.com/lianapa If you find LOTUS or semantic operators useful, we'd appreciate if you can please cite this work as follows: ```bibtex -@misc{patel2024semanticoperators, +@article{patel2025semanticoptimization, + title = {Semantic Operators and Their Optimization: Enabling LLM-Based Data Processing with Accuracy Guarantees in LOTUS}, + author = {Patel, Liana and Jha, Siddharth and Pan, Melissa and Gupta, Harshit and Asawa, Parth and Guestrin, Carlos and Zaharia, Matei}, + year = {2025}, + journal = {Proc. VLDB Endow.}, + url = {https://doi.org/10.14778/3749646.3749685}, +} +@article{patel2024semanticoperators, title={Semantic Operators: A Declarative Model for Rich, AI-based Analytics Over Text Data}, author={Liana Patel and Siddharth Jha and Parth Asawa and Melissa Pan and Carlos Guestrin and Matei Zaharia}, year={2024}, eprint={2407.11418}, - archivePrefix={arXiv}, - primaryClass={cs.DB}, url={https://arxiv.org/abs/2407.11418}, } -``` +``` \ No newline at end of file diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 953f0ba4..b59414d3 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -282,10 +282,6 @@ def _process_with_rate_limiting( # Each request should be spaced by min_interval_per_request required_time_for_batch = len(sub_batch) * min_interval_per_request - # Only sleep if the batch was faster than the required time - # Each request should be spaced by min_interval_per_request - required_time_for_batch = len(sub_batch) * min_interval_per_request - # Only sleep if the batch was faster than the required time if i < num_batches - 1: # Don't sleep after the last batch to_sleep = required_time_for_batch - elapsed @@ -386,10 +382,8 @@ def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompleti choice = response.choices[0] assert isinstance(choice, Choices) - if choice.logprobs is not None and isinstance(choice.logprobs, ChoiceLogprobs): - logprobs = choice.logprobs["content"] - else: - logprobs = [] + assert choice.logprobs is not None and isinstance(choice.logprobs, ChoiceLogprobs) + logprobs = choice.logprobs["content"] return logprobs def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade: diff --git a/lotus/sampling_utils.py b/lotus/sampling_utils.py index 847b41b4..05aea636 100644 --- a/lotus/sampling_utils.py +++ b/lotus/sampling_utils.py @@ -1,161 +1,188 @@ # lotus/sampling_utils.py from collections import Counter -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, Union, cast +from litellm.types.utils import ChatCompletionTokenLogprob -def _norm(x: Any) -> str: - """ - Normalize assorted truthy/falsy strings to 'yes'/'no' when possible; - otherwise return the lowercased string. +from lotus.types import EnsembleStrategy - Examples: - " YES " -> "yes" - "False" -> "no" - "maybe" -> "maybe" - """ - t = str(x).strip().lower() - if t in {"yes", "y", "true", "t", "1"}: - return "yes" - if t in {"no", "n", "false", "f", "0"}: - return "no" - return t +T = TypeVar("T") # item type for ensembling, e.g., bool for filter, str for map -def ensemble_majority_vote(samples_for_one_item: Sequence[str], *, default_yes: bool) -> str: +def _majority_vote_one(samples: Sequence[T], *, default_yes: Optional[bool]) -> T: """ - Majority vote over arbitrary label strings (after normalization). - If the vote is tied: - - If it's exactly 'yes' == 'no', fall back to default_yes. - - Otherwise, deterministically break ties by lexicographic order. + Generic majority vote over hashable samples. + - If there is a tie and samples are booleans {True, False}, use default_yes when provided. + - Otherwise, break ties deterministically by string order of the label. + + Assumes samples are already canonical (e.g., booleans for filter). """ - c = Counter(_norm(s) for s in samples_for_one_item) - if not c: - return "yes" if default_yes else "no" + if not samples: + raise ValueError("majority vote received an empty sample list") + c = Counter(samples) top = c.most_common() - # single unique label + # Single unique label if len(top) == 1: return top[0][0] - # tie on counts - if top[0][1] == top[1][1]: - # special case: exactly yes==no - if "yes" in c and "no" in c and c["yes"] == c["no"]: - return "yes" if default_yes else "no" - # otherwise deterministic tie-break among the tied labels - tied_labels = sorted([lab for lab, cnt in top if cnt == top[0][1]]) + # Tie in counts + if len(top) >= 2 and top[0][1] == top[1][1]: + # If exactly boolean tie, allow default_yes to decide + if set(c.keys()) == {True, False} and default_yes is not None: + return bool(default_yes) # type: ignore[return-value] + + # Otherwise, deterministic tie-break among tied labels by string representation + tied_labels = sorted([lab for lab, cnt in top if cnt == top[0][1]], key=lambda x: str(x)) return tied_labels[0] - # clear winner + # Clear winner return top[0][0] -def ensemble_mean_bool(samples_for_one_item: Sequence[str]) -> str: +def _mean_bool_one(samples: Sequence[bool]) -> bool: """ - Average yes/no votes; returns 'yes' if mean >= 0.5 else 'no'. - Only meaningful when outputs semantically map to yes/no. + Average a list of booleans; True if mean >= 0.5 else False. """ - vals = [1 if _norm(s) == "yes" else 0 for s in samples_for_one_item] - mean = sum(vals) / max(1, len(vals)) - return "yes" if mean >= 0.5 else "no" + if not samples: + raise ValueError("mean_bool received an empty sample list") + s = sum(1 for x in samples if x) + return (s / len(samples)) >= 0.5 -def apply_ensemble(strategy: Optional[str], all_outputs: List[List[str]], *, default_yes: bool) -> List[str]: +def apply_ensemble( + strategy: EnsembleStrategy, + all_outputs: List[List[T]], # shape: [n_sample][batch] + *, + default_yes: Optional[bool] = None, + return_indices: bool = False, +) -> Union[List[T], Tuple[List[T], List[int]]]: """ - Collapse shape [n_sample][batch] -> [batch]. - - Behavior: - - If strategy is None or n_sample==1, return the first run unchanged. - - 'majority_vote' / 'majority': mode over strings (works for general labels). - - 'mean_prob' / 'average_prob' / 'avg_prob': treat strings as yes/no and average. + Collapse shape [n_sample][batch] -> [batch] according to strategy. - Args: - strategy: name of ensemble rule (or None). - all_outputs: list of runs, each run is list[str] of length 'batch'. - default_yes: used to break exact yes/no ties in majority voting. + Assumptions: + - Inputs in `all_outputs` are already canonical (e.g., booleans for filter). + - MAJORITY is generic for any hashable type; MEAN_BOOL requires boolean labels. Returns: - List[str]: one output per batch item. + - If return_indices=False: List[T] (chosen output per item) + - If return_indices=True: (List[T], List[int]) (plus chosen run index per item) """ if not all_outputs: - return [] - if not strategy or len(all_outputs) == 1: - return list(all_outputs[0]) + return [] if not return_indices else ([], []) # type: ignore[return-value] - # Sanity check: all runs must have same batch size batch = len(all_outputs[0]) - for run in all_outputs[1:]: + for run in all_outputs: if len(run) != batch: - raise ValueError(f"Inconsistent batch sizes across runs: expected {batch}, got {len(run)}") + raise ValueError("Inconsistent batch sizes across runs") n_sample = len(all_outputs) - per_item = [[all_outputs[k][i] for k in range(n_sample)] for i in range(batch)] - s = strategy.lower() - - out: List[str] = [] - for samples in per_item: - if s in {"majority_vote", "majority"}: - out.append(ensemble_majority_vote(samples, default_yes=default_yes)) - elif s in {"mean_prob", "average_prob", "avg_prob"}: - out.append(ensemble_mean_bool(samples)) - else: - raise ValueError( - f"Unknown ensemble strategy: {strategy}. " - "Use 'majority_vote' for general strings; 'average_prob' only for yes/no." - ) - return out + + if strategy == EnsembleStrategy.PICK_FIRST or n_sample == 1: + chosen_labels: List[T] = list(all_outputs[0]) + return chosen_labels if not return_indices else (chosen_labels, [0] * batch) + + per_item: List[List[T]] = [[all_outputs[k][i] for k in range(n_sample)] for i in range(batch)] + + final_labels: List[T] = [] + chosen_indices: List[int] = [] + + if strategy == EnsembleStrategy.MAJORITY: + for samples in per_item: + label = _majority_vote_one(samples, default_yes=default_yes) + final_labels.append(label) + winner_idx = next(idx for idx, v in enumerate(samples) if v == label) + chosen_indices.append(winner_idx) + + elif strategy == EnsembleStrategy.MEAN_BOOL: + if not all(isinstance(x, bool) for run in all_outputs for x in run): + raise ValueError("MEAN_BOOL can only be applied to boolean outputs.") + for samples in per_item: + label_bool = _mean_bool_one(cast(Sequence[bool], samples)) + # cast back to T to satisfy the generic return type + final_labels.append(cast(T, label_bool)) + # compare as bools to find the first matching run + winner_idx = next(idx for idx, v in enumerate(samples) if bool(v) == label_bool) + chosen_indices.append(winner_idx) + else: + raise ValueError(f"Unknown EnsembleStrategy: {strategy}") + + return final_labels if not return_indices else (final_labels, chosen_indices) + + +def pick_logprobs_for_choices( + all_logprobs: Optional[List[Optional[List[List[ChatCompletionTokenLogprob]]]]], # [n_sample][batch][tokens] + chosen_indices: List[int], # [batch] +) -> Optional[List[List[ChatCompletionTokenLogprob]]]: + """ + Given logprobs for each run and the chosen run index per item, + return the per-item logprobs of the finally chosen outputs. + + all_logprobs is provider-specific; we only route to the chosen run/item. + """ + if all_logprobs is None or not all_logprobs: + return None + + # Determine batch size from first non-None run and validate all runs + batch = None + for run_logs in all_logprobs: + if run_logs is not None: + batch = len(run_logs) + break + if batch is None: + return None + for run_logs in all_logprobs: + if run_logs is not None and len(run_logs) != batch: + raise ValueError("Inconsistent batch sizes in all_logprobs runs") + + if len(chosen_indices) != batch: + raise ValueError("chosen_indices length does not match batch size") + + chosen_per_item: List[List[ChatCompletionTokenLogprob]] = [] + for i, winner_run in enumerate(chosen_indices): + run_logs = all_logprobs[winner_run] if winner_run < len(all_logprobs) else None + chosen_per_item.append(run_logs[i] if (run_logs is not None) else []) + return chosen_per_item def resample_batch( - call_once: Callable[[bool, Optional[float], bool, str], Any], - *, + call_once: Callable[..., Any], n_sample: int, - want_logprobs: bool, - show_progress_bar: bool, - progress_bar_desc: str, - temperature: Optional[float], -) -> Tuple[List[List[str]], Optional[List[Any]]]: + *args: Any, + **kwargs: Any, +) -> Tuple[List[List[str]], Optional[List[Optional[List[List[ChatCompletionTokenLogprob]]]]]]: """ - Run the same batch through the model multiple times and collect outputs. + Run the same batch multiple times and collect outputs (+logprobs if produced). - Expected signature for `call_once`: - call_once(want_logprobs, temperature, show_progress_bar, progress_bar_desc) -> lm_out - where lm_out must have: - - lm_out.outputs: List[str] # batch-sized list of strings - - lm_out.logprobs: Optional[Any] # present if want_logprobs=True (provider-specific shape) - - Returns: - (all_outputs, chosen_logprobs) - - all_outputs: List of runs; each run is outputs[List[str]] of length 'batch' - - chosen_logprobs: logprobs from the FIRST run if requested, else None. - - Notes: - - We validate consistent batch sizes across runs. - - We intentionally return ONLY one set of logprobs to keep the API simple. + We do not prescribe call_once signature; we pass through *args/**kwargs. + Expected from call_once: + - returns an object with `.outputs: List[str]` + - may have `.logprobs` (provider-specific shape), or None/absent. """ + if n_sample <= 0: + raise ValueError("n_sample must be >= 1") + all_outputs: List[List[str]] = [] - logs: List[Any] = [] # non-optional; we append only if available + all_logs: List[Optional[List[List[ChatCompletionTokenLogprob]]]] = [] - for _ in range(max(1, n_sample)): - lm_out = call_once(want_logprobs, temperature, show_progress_bar, progress_bar_desc) + for _ in range(n_sample): + lm_out = call_once(*args, **kwargs) - if not hasattr(lm_out, "outputs"): - raise ValueError("LM call did not return an object with an 'outputs' attribute.") - if not isinstance(lm_out.outputs, list): - raise ValueError("LM call returned outputs that are not a list.") + if not hasattr(lm_out, "outputs") or not isinstance(lm_out.outputs, list): + raise ValueError("LM call did not return a list-like 'outputs'.") all_outputs.append(lm_out.outputs) - if want_logprobs and hasattr(lm_out, "logprobs") and lm_out.logprobs is not None: - logs.append(lm_out.logprobs) + logs = getattr(lm_out, "logprobs", None) + # Accept None for runs with no logprobs + all_logs.append(logs if logs is not None else None) - # Ensure consistent batch size + # Validate consistent batch batch = len(all_outputs[0]) for run in all_outputs[1:]: if len(run) != batch: - raise ValueError(f"Inconsistent batch sizes across runs: expected {batch}, got {len(run)}") + raise ValueError("Inconsistent batch sizes across runs") - chosen: Optional[List[Any]] = logs[0] if (want_logprobs and len(logs) > 0) else None - return all_outputs, chosen + return all_outputs, all_logs diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 45bae3f5..1e04c86d 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -6,10 +6,11 @@ import lotus from lotus.cache import operator_cache -from lotus.sampling_utils import apply_ensemble, resample_batch +from lotus.sampling_utils import apply_ensemble, pick_logprobs_for_choices, resample_batch from lotus.templates import task_instructions from lotus.types import ( CascadeArgs, + EnsembleStrategy, LogprobsForFilterCascade, ProxyModel, ReasoningStrategy, @@ -36,7 +37,7 @@ def sem_filter( progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", n_sample: int = 1, # number of samples per item - ensemble: str | None = None, # "majority_vote", "mean_prob", None + ensemble: EnsembleStrategy = EnsembleStrategy.PICK_FIRST, # "majority_vote", "mean_prob", temperature: float | None = None, # if None, use model default ) -> SemanticFilterOutput: """ @@ -91,6 +92,8 @@ def sem_filter( >>> result = sem_filter(docs, model, "Is this a positive sentiment?") >>> print(result.outputs) # [True, False] """ + if n_sample <= 0: + raise ValueError("n_sample must be >= 1 for semantic filtering.") inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( @@ -107,44 +110,63 @@ def sem_filter( inputs.append(prompt) if safe_mode: - estimated_total_calls = len(docs) * max(1, n_sample) - estimated_total_cost = sum(model.count_tokens(input) for input in inputs) * max(1, n_sample) + estimated_total_calls = len(docs) * n_sample + estimated_total_cost = sum(model.count_tokens(input) for input in inputs) * n_sample show_safe_mode(estimated_total_cost, estimated_total_calls) - def _call_once(want_logprobs, temp, spb, desc): - return model( - inputs, - logprobs=want_logprobs, - temperature=temp, - show_progress_bar=spb, - progress_bar_desc=desc if n_sample == 1 else f"{desc} (x{n_sample})", - ) - - all_runs, chosen_logprobs = resample_batch( - _call_once, - n_sample=n_sample, - want_logprobs=logprobs, - temperature=temperature, - show_progress_bar=show_progress_bar, - progress_bar_desc=progress_bar_desc, + # Define a single-call function using closure to capture outer variables + def _call_once(): + call_kwargs: dict[str, Any] = { + "show_progress_bar": show_progress_bar, + "progress_bar_desc": (progress_bar_desc if n_sample == 1 else f"{progress_bar_desc} (x{n_sample})"), + } + if logprobs: + call_kwargs["logprobs"] = True + if temperature is not None: + call_kwargs["temperature"] = temperature + + return model(inputs, **call_kwargs) + + # Run the model n_sample times (sampling) + all_runs_texts, all_runs_logprobs = resample_batch(_call_once, n_sample=n_sample) + + # Postprocess each run independently to get canonical booleans + postprocessed_runs = [filter_postprocess(run_texts, model, default) for run_texts in all_runs_texts] + canonical_runs_bool: list[list[bool]] = [pp.outputs for pp in postprocessed_runs] # [n_sample][batch] + + final_labels, chosen_run_idx = apply_ensemble( + ensemble, + canonical_runs_bool, # shape [n_sample][batch], booleans only + default_yes=default, # tie-break only for exact boolean ties + return_indices=True, ) - # Collapse [n_sample][batch] -> [batch] - final_texts = apply_ensemble(ensemble, all_runs, default_yes=default) + # Type assertion to help mypy understand the tuple unpacking + assert isinstance(final_labels, list) and isinstance(chosen_run_idx, list) + + batch = len(final_labels) + + # Pick raw outputs & explanations from the winning run per item + raw_outputs: list[str] = [postprocessed_runs[chosen_run_idx[i]].raw_outputs[i] for i in range(batch)] + explanations: list[str | None] = [postprocessed_runs[chosen_run_idx[i]].explanations[i] for i in range(batch)] + + # Pick logprobs from the winning run per item, if requested + chosen_logprobs = pick_logprobs_for_choices(all_runs_logprobs, chosen_run_idx) if logprobs else None - postprocess_output = filter_postprocess(final_texts, model, default) - lotus.logger.debug(f"outputs: {postprocess_output.outputs}") - lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") - lotus.logger.debug(f"explanations: {postprocess_output.explanations}") + # Debug logging + lotus.logger.debug(f"final_labels: {final_labels}") + lotus.logger.debug(f"chosen_run_idx: {chosen_run_idx}") + lotus.logger.debug(f"raw_outputs (winners): {raw_outputs}") + lotus.logger.debug(f"explanations (winners): {explanations}") if safe_mode: model.print_total_usage() return SemanticFilterOutput( - raw_outputs=postprocess_output.raw_outputs, - outputs=postprocess_output.outputs, - explanations=postprocess_output.explanations, - logprobs=chosen_logprobs if logprobs else None, + raw_outputs=raw_outputs, # from winners + outputs=final_labels, # canonical booleans after ensembling + explanations=explanations, # from winners + logprobs=chosen_logprobs, # from winners (or None) ) @@ -367,13 +389,15 @@ def __call__( progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", n_sample: int = 1, # number of samples per item - ensemble: str | None = None, # "majority_vote", "mean_prob", None + ensemble: EnsembleStrategy = EnsembleStrategy.PICK_FIRST, # Changed from str | None temperature: float | None = None, # if None, use model default ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: if lotus.settings.lm is None: raise ValueError( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" ) + if n_sample <= 0: + raise ValueError("n_sample must be >= 1 for semantic filtering.") stats: dict[str, float] = {} lotus.logger.debug(user_instruction) diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 9d66c580..2c8b7074 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -4,9 +4,8 @@ import lotus from lotus.cache import operator_cache -from lotus.sampling_utils import apply_ensemble, resample_batch from lotus.templates import task_instructions -from lotus.types import ReasoningStrategy, SemanticMapOutput, SemanticMapPostprocessOutput +from lotus.types import LMOutput, ReasoningStrategy, SemanticMapOutput, SemanticMapPostprocessOutput from lotus.utils import show_safe_mode from .postprocessors import map_postprocess @@ -24,9 +23,6 @@ def sem_map( strategy: ReasoningStrategy | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", - n_sample: int = 1, # NEW - ensemble: str | None = None, # NEW - temperature: float | None = None, # NEW **model_kwargs: Any, ) -> SemanticMapOutput: """ @@ -63,13 +59,6 @@ def sem_map( Defaults to False. progress_bar_desc (str, optional): Description for the progress bar. Defaults to "Mapping". - n_sample (int): number of repeated LM calls per item. If >1, outputs are combined - using the specified ensemble strategy. - ensemble (str|None): "majority_vote", "average_prob"/"mean_prob"/"avg_prob", or None. - Strategy to combine multiple samples per item. If None, no ensembling is done. - Only used if n_sample > 1. - temperature (float|None): test-time temperature; None uses model default. - Defaults to None. **model_kwargs: Any: Additional keyword arguments to pass to the model. Returns: SemanticMapOutput: An object containing the processed outputs, raw outputs, @@ -105,38 +94,18 @@ def sem_map( # check if safe_mode is enabled if safe_mode: - estimated_cost = sum(model.count_tokens(input) for input in inputs) * max(1, n_sample) - estimated_LM_calls = len(docs) * max(1, n_sample) + estimated_cost = sum(model.count_tokens(input) for input in inputs) + estimated_LM_calls = len(docs) show_safe_mode(estimated_cost, estimated_LM_calls) # call model - def _call_once(_want_logprobs, temp, spb, desc): - # sem_map doesn’t use logprobs, so _want_logprobs is ignored - return model( - inputs, - temperature=temp, - show_progress_bar=spb, - progress_bar_desc=desc if n_sample == 1 else f"{desc} (x{n_sample})", - **model_kwargs, - ) - - all_runs, _ = resample_batch( - _call_once, - n_sample=n_sample, - want_logprobs=False, - temperature=temperature, - show_progress_bar=True, - progress_bar_desc=progress_bar_desc, - ) - - # Collapse [n_sample][batch] -> [batch] using the chosen ensemble strategy - final_texts = apply_ensemble(ensemble, all_runs, default_yes=True) # default_yes only matters for ties + lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc, **model_kwargs) # post process results postprocess_output = postprocessor( - final_texts, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] + lm_output.outputs, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] ) - lotus.logger.debug(f"raw_outputs: {final_texts}") + lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") if safe_mode: @@ -255,9 +224,6 @@ def __call__( strategy: ReasoningStrategy | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", - n_sample: int = 1, # NEW - ensemble: str | None = None, # NEW - temperature: float | None = None, # NEW **model_kwargs: Any, ) -> pd.DataFrame: if lotus.settings.lm is None: @@ -300,9 +266,6 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, - n_sample=n_sample, # NEW - ensemble=ensemble, # NEW - temperature=temperature, # NEW **model_kwargs, ) diff --git a/lotus/types.py b/lotus/types.py index 08519729..6cbaab70 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -220,3 +220,12 @@ class ReasoningStrategy(Enum): COT = auto() ZS_COT = auto() FEW_SHOT = auto() + + +################################################################################ +# Ensemble strategy +################################################################################ +class EnsembleStrategy(Enum): + PICK_FIRST = "pick_first" # Always choose run 0 + MAJORITY = "majority" # For canonical boolean labels + MEAN_BOOL = "mean_bool" # Average booleans >= .5 -> True From f6b78555228658b185da5fd3d4460a7cd8162646 Mon Sep 17 00:00:00 2001 From: IsmaelKabore Date: Tue, 25 Nov 2025 21:43:52 -0800 Subject: [PATCH 3/3] add per-run rollout tracking to sem_filter --- .github/tests/test_sem_filter.py | 107 +++++++++++++++++++++++++++++++ lotus/sampling_utils.py | 2 +- lotus/sem_ops/sem_filter.py | 41 ++++++++++++ lotus/types.py | 10 +++ 4 files changed, 159 insertions(+), 1 deletion(-) diff --git a/.github/tests/test_sem_filter.py b/.github/tests/test_sem_filter.py index b56653e2..20a2bc32 100644 --- a/.github/tests/test_sem_filter.py +++ b/.github/tests/test_sem_filter.py @@ -148,3 +148,110 @@ def test_df_sem_filter_return_stats_tuple(setup_models, model): out_df, stats = result assert isinstance(out_df, pd.DataFrame) assert isinstance(stats, dict) + + +# NEW TESTS FOR MULTI-RUN ROLLOUT FUNCTIONALITY + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_with_rollout_columns(setup_models, model): + """Test that per-run rollout columns are correctly added when n_sample > 1.""" + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Great product!", "Terrible service.", "It's okay."]}) + user_instruction = "{Text} is a positive sentiment" + + # Test with return_all=True to see all rollout data + full_df = df.sem_filter( + user_instruction, + n_sample=3, + ensemble=EnsembleStrategy.MAJORITY, + return_all=True, + return_explanations=True, + temperature=0.9, + ) + + assert isinstance(full_df, pd.DataFrame) + + # Check that per-run columns exist + assert "raw_output_1_filter" in full_df.columns + assert "raw_output_2_filter" in full_df.columns + assert "raw_output_3_filter" in full_df.columns + + assert "parsed_output_1_filter" in full_df.columns + assert "parsed_output_2_filter" in full_df.columns + assert "parsed_output_3_filter" in full_df.columns + + assert "explanation_1_filter" in full_df.columns + assert "explanation_2_filter" in full_df.columns + assert "explanation_3_filter" in full_df.columns + + # Check that ensemble answer column exists + assert "ensemble_answer_filter" in full_df.columns + + # Verify that all rows are present (return_all=True) + assert len(full_df) == 3 + + # Verify data types + assert full_df["parsed_output_1_filter"].dtype == bool + assert full_df["parsed_output_2_filter"].dtype == bool + assert full_df["parsed_output_3_filter"].dtype == bool + assert full_df["ensemble_answer_filter"].dtype == bool + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_rollout_columns_filtered_rows(setup_models, model): + """Test that per-run columns work correctly when return_all=False (filtered rows only).""" + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Amazing experience!", "Worst ever.", "Pretty good."]}) + user_instruction = "{Text} is a positive sentiment" + + # Test with return_all=False (default) - only rows that pass the filter + filtered_df = df.sem_filter( + user_instruction, + n_sample=3, + ensemble=EnsembleStrategy.MAJORITY, + return_all=False, # explicit for clarity + temperature=0.9, + ) + + assert isinstance(filtered_df, pd.DataFrame) + + # Check that per-run columns exist + assert "raw_output_1_filter" in filtered_df.columns + assert "parsed_output_1_filter" in filtered_df.columns + + # Check ensemble answer column + assert "ensemble_answer_filter" in filtered_df.columns + + # All rows in filtered result should have ensemble_answer=True + assert all(filtered_df["ensemble_answer_filter"]) + + # Verify we got at least some positive samples + assert len(filtered_df) >= 1 + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_no_rollout_columns_when_n_sample_1(setup_models, model): + """Test that per-run columns are NOT added when n_sample=1 (default behavior).""" + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Great!", "Bad."]}) + user_instruction = "{Text} is a positive sentiment" + + # Default n_sample=1, so no rollout columns should appear + full_df = df.sem_filter( + user_instruction, + return_all=True, + ) + + assert isinstance(full_df, pd.DataFrame) + + # These columns should NOT exist when n_sample=1 + assert "raw_output_1_filter" not in full_df.columns + assert "parsed_output_1_filter" not in full_df.columns + assert "ensemble_answer_filter" not in full_df.columns + + # But the regular filter_label should exist + assert "filter_label" in full_df.columns diff --git a/lotus/sampling_utils.py b/lotus/sampling_utils.py index 05aea636..806bebdf 100644 --- a/lotus/sampling_utils.py +++ b/lotus/sampling_utils.py @@ -71,7 +71,7 @@ def apply_ensemble( - If return_indices=True: (List[T], List[int]) (plus chosen run index per item) """ if not all_outputs: - return [] if not return_indices else ([], []) # type: ignore[return-value] + return [] if not return_indices else ([], []) batch = len(all_outputs[0]) for run in all_outputs: diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 1e04c86d..7f1195e6 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -13,6 +13,7 @@ EnsembleStrategy, LogprobsForFilterCascade, ProxyModel, + RawOutputs, ReasoningStrategy, SemanticFilterOutput, ) @@ -132,6 +133,19 @@ def _call_once(): # Postprocess each run independently to get canonical booleans postprocessed_runs = [filter_postprocess(run_texts, model, default) for run_texts in all_runs_texts] + + # NEW: package all runs into RawOutputs objects + all_runs_packaged = [] + for run_idx, pp in enumerate(postprocessed_runs): + all_runs_packaged.append( + RawOutputs( + preds=pp.raw_outputs, + logprobs=all_runs_logprobs[run_idx] if all_runs_logprobs else None, + parsed_outputs=pp.outputs, + explanations=pp.explanations, + ) + ) + canonical_runs_bool: list[list[bool]] = [pp.outputs for pp in postprocessed_runs] # [n_sample][batch] final_labels, chosen_run_idx = apply_ensemble( @@ -167,6 +181,7 @@ def _call_once(): outputs=final_labels, # canonical booleans after ensembling explanations=explanations, # from winners logprobs=chosen_logprobs, # from winners (or None) + raw_outputs_all_runs=all_runs_packaged, # all runs packaged ) @@ -542,6 +557,7 @@ def __call__( for idx in high_conf_idxs: outputs[idx] = proxy_outputs[idx] + all_runs: list[RawOutputs] = [] # If using helper LM, get raw outputs and explanations if proxy_model == ProxyModel.HELPER_LM: @@ -602,6 +618,7 @@ def __call__( outputs = output.outputs raw_outputs = output.raw_outputs explanations = output.explanations + all_runs = output.raw_outputs_all_runs if output.raw_outputs_all_runs else [] if not return_all: # find indices where output is True @@ -633,6 +650,30 @@ def get_out_col_name(df, col_name): filtered_explanations = explanations filtered_raw_outputs = raw_outputs + # NEW — Add per-run rollout columns if multiple runs exist + + if all_runs and n_sample > 1: + for run_idx, run_data in enumerate(all_runs, start=1): + # CASE 1 — return_all=True → show all rows + if return_all: + new_df[f"raw_output_{run_idx}{suffix}"] = run_data.preds + new_df[f"parsed_output_{run_idx}{suffix}"] = run_data.parsed_outputs + + if return_explanations: + new_df[f"explanation_{run_idx}{suffix}"] = run_data.explanations + + # CASE 2 — return_all=False → only rows that passed the filter + else: + new_df[f"raw_output_{run_idx}{suffix}"] = [run_data.preds[i] for i in ids] + new_df[f"parsed_output_{run_idx}{suffix}"] = [run_data.parsed_outputs[i] for i in ids] + + if return_explanations: + new_df[f"explanation_{run_idx}{suffix}"] = [run_data.explanations[i] for i in ids] + + # Finally add the ensemble answer as a column + ensemble_col = [outputs[i] for i in ids] if not return_all else outputs + new_df[f"ensemble_answer{suffix}"] = ensemble_col + # return rows where output is True if return_explanations and return_raw_outputs: new_df["explanation" + suffix] = filtered_explanations diff --git a/lotus/types.py b/lotus/types.py index 6cbaab70..6568832c 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -63,6 +63,15 @@ class LogprobsForFilterCascade: confidences: list[list[float]] +# Raw outputs +@dataclass +class RawOutputs: + preds: list[str] + logprobs: list[list[ChatCompletionTokenLogprob]] | None + parsed_outputs: list[bool] + explanations: list[str | None] + + ################################################################################ # Semantic operation outputs ################################################################################ @@ -108,6 +117,7 @@ class SemanticFilterOutput: explanations: list[str | None] stats: dict[str, Any] | None = None logprobs: list[list[ChatCompletionTokenLogprob]] | None = None + raw_outputs_all_runs: list[RawOutputs] | None = None @dataclass