diff --git a/.github/tests/test_sem_filter.py b/.github/tests/test_sem_filter.py new file mode 100644 index 00000000..20a2bc32 --- /dev/null +++ b/.github/tests/test_sem_filter.py @@ -0,0 +1,257 @@ +import os + +import pandas as pd +import pytest + +import lotus +from lotus.models import LM +from lotus.types import EnsembleStrategy, ReasoningStrategy + +# Gate the whole module behind an env var so devs/CI without Ollama skip cleanly. +ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" + +pytestmark = pytest.mark.skipif( + not ENABLE_OLLAMA_TESTS, + reason="Set ENABLE_OLLAMA_TESTS=true to run Ollama-backed tests", +) + +MODEL_NAME = "ollama/llama3.1" + + +@pytest.fixture(scope="session") +def setup_models(): + return {MODEL_NAME: LM(model=MODEL_NAME)} + + +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_models): + yield + for _, m in setup_models.items(): + m.print_total_usage() + m.reset_stats() + m.reset_cache() + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_basic(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame( + { + "Text": [ + "I am really excited to go to class today!", + "I am very sad", + ] + } + ) + user_instruction = "{Text} is a positive sentiment" + + filtered = df.sem_filter(user_instruction) + + assert isinstance(filtered, pd.DataFrame) + assert "Text" in filtered.columns + assert len(filtered) >= 1 + # Tolerant: ensure obviously negative text is unlikely to pass + assert "I am very sad" not in filtered["Text"].tolist() + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_with_sampling_pick_first(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Today is fantastic!", "This is terrible.", "Pretty good overall."]}) + user_instruction = "{Text} is a positive sentiment" + + # Exercise resampling path (n_sample>1) but keep ensemble as PICK_FIRST. + filtered = df.sem_filter( + user_instruction, + n_sample=2, + temperature=0.9, + # explicit for clarity (default is PICK_FIRST) + ensemble=EnsembleStrategy.PICK_FIRST, + ) + + assert isinstance(filtered, pd.DataFrame) + assert "Text" in filtered.columns + assert len(filtered) >= 1 + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_with_majority_ensemble_runs(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Great job!", "Awful service.", "Fine, I guess."]}) + user_instruction = "{Text} is a positive sentiment" + + # Exercise the MAJORITY voting path (boolean labels after postprocess). + # We keep assertions tolerant because model outputs can vary. + filtered = df.sem_filter( + user_instruction, + n_sample=3, + temperature=0.9, + ensemble=EnsembleStrategy.MAJORITY, + ) + + assert isinstance(filtered, pd.DataFrame) + assert "Text" in filtered.columns + # Should not error; size can vary depending on votes + assert len(filtered) >= 0 + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_invalid_n_sample_raises(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Neutral thing"]}) + user_instruction = "{Text} is a positive sentiment" + + with pytest.raises(ValueError): + df.sem_filter(user_instruction, n_sample=0) + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_return_all_and_explanations(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["I love sunshine", "I hate rain"]}) + user_instruction = "{Text} is a positive sentiment" + + # Ask for all rows + explanations (ZS_COT forces explanation path on) + full_df = df.sem_filter( + user_instruction, + return_all=True, + return_explanations=True, + strategy=ReasoningStrategy.ZS_COT, + ) + + assert isinstance(full_df, pd.DataFrame) + # When return_all=True, an extra 'filter_label' column is added (no suffix) + assert "filter_label" in full_df.columns + # Explanation column uses the default suffix "_filter" + assert "explanation_filter" in full_df.columns + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_return_stats_tuple(setup_models, model): + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["I love this", "I dislike that"]}) + user_instruction = "{Text} is a positive sentiment" + + result = df.sem_filter( + user_instruction, + return_stats=True, + ) + + # When return_stats=True, we get (DataFrame, stats_dict) + assert isinstance(result, tuple) and len(result) == 2 + out_df, stats = result + assert isinstance(out_df, pd.DataFrame) + assert isinstance(stats, dict) + + +# NEW TESTS FOR MULTI-RUN ROLLOUT FUNCTIONALITY + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_with_rollout_columns(setup_models, model): + """Test that per-run rollout columns are correctly added when n_sample > 1.""" + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Great product!", "Terrible service.", "It's okay."]}) + user_instruction = "{Text} is a positive sentiment" + + # Test with return_all=True to see all rollout data + full_df = df.sem_filter( + user_instruction, + n_sample=3, + ensemble=EnsembleStrategy.MAJORITY, + return_all=True, + return_explanations=True, + temperature=0.9, + ) + + assert isinstance(full_df, pd.DataFrame) + + # Check that per-run columns exist + assert "raw_output_1_filter" in full_df.columns + assert "raw_output_2_filter" in full_df.columns + assert "raw_output_3_filter" in full_df.columns + + assert "parsed_output_1_filter" in full_df.columns + assert "parsed_output_2_filter" in full_df.columns + assert "parsed_output_3_filter" in full_df.columns + + assert "explanation_1_filter" in full_df.columns + assert "explanation_2_filter" in full_df.columns + assert "explanation_3_filter" in full_df.columns + + # Check that ensemble answer column exists + assert "ensemble_answer_filter" in full_df.columns + + # Verify that all rows are present (return_all=True) + assert len(full_df) == 3 + + # Verify data types + assert full_df["parsed_output_1_filter"].dtype == bool + assert full_df["parsed_output_2_filter"].dtype == bool + assert full_df["parsed_output_3_filter"].dtype == bool + assert full_df["ensemble_answer_filter"].dtype == bool + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_rollout_columns_filtered_rows(setup_models, model): + """Test that per-run columns work correctly when return_all=False (filtered rows only).""" + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Amazing experience!", "Worst ever.", "Pretty good."]}) + user_instruction = "{Text} is a positive sentiment" + + # Test with return_all=False (default) - only rows that pass the filter + filtered_df = df.sem_filter( + user_instruction, + n_sample=3, + ensemble=EnsembleStrategy.MAJORITY, + return_all=False, # explicit for clarity + temperature=0.9, + ) + + assert isinstance(filtered_df, pd.DataFrame) + + # Check that per-run columns exist + assert "raw_output_1_filter" in filtered_df.columns + assert "parsed_output_1_filter" in filtered_df.columns + + # Check ensemble answer column + assert "ensemble_answer_filter" in filtered_df.columns + + # All rows in filtered result should have ensemble_answer=True + assert all(filtered_df["ensemble_answer_filter"]) + + # Verify we got at least some positive samples + assert len(filtered_df) >= 1 + + +@pytest.mark.parametrize("model", [MODEL_NAME]) +def test_df_sem_filter_no_rollout_columns_when_n_sample_1(setup_models, model): + """Test that per-run columns are NOT added when n_sample=1 (default behavior).""" + lotus.settings.configure(lm=setup_models[model]) + + df = pd.DataFrame({"Text": ["Great!", "Bad."]}) + user_instruction = "{Text} is a positive sentiment" + + # Default n_sample=1, so no rollout columns should appear + full_df = df.sem_filter( + user_instruction, + return_all=True, + ) + + assert isinstance(full_df, pd.DataFrame) + + # These columns should NOT exist when n_sample=1 + assert "raw_output_1_filter" not in full_df.columns + assert "parsed_output_1_filter" not in full_df.columns + assert "ensemble_answer_filter" not in full_df.columns + + # But the regular filter_label should exist + assert "filter_label" in full_df.columns diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..a41de757 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "docs" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0ed63351..64122ee3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -288,6 +288,10 @@ lm = LM( - **Discussions**: https://github.com/lotus-data/lotus/discussions - **Issues**: https://github.com/lotus-data/lotus/issues +## Code of Conduct + +We are committed to providing a welcoming and inclusive environment for all contributors. Please be respectful and constructive in all interactions. + --- diff --git a/README.md b/README.md index 37932060..4d96537f 100644 --- a/README.md +++ b/README.md @@ -168,4 +168,4 @@ If you find LOTUS or semantic operators useful, we'd appreciate if you can pleas eprint={2407.11418}, url={https://arxiv.org/abs/2407.11418}, } -``` +``` \ No newline at end of file diff --git a/lotus/sampling_utils.py b/lotus/sampling_utils.py new file mode 100644 index 00000000..806bebdf --- /dev/null +++ b/lotus/sampling_utils.py @@ -0,0 +1,188 @@ +# lotus/sampling_utils.py + +from collections import Counter +from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, Union, cast + +from litellm.types.utils import ChatCompletionTokenLogprob + +from lotus.types import EnsembleStrategy + +T = TypeVar("T") # item type for ensembling, e.g., bool for filter, str for map + + +def _majority_vote_one(samples: Sequence[T], *, default_yes: Optional[bool]) -> T: + """ + Generic majority vote over hashable samples. + - If there is a tie and samples are booleans {True, False}, use default_yes when provided. + - Otherwise, break ties deterministically by string order of the label. + + Assumes samples are already canonical (e.g., booleans for filter). + """ + if not samples: + raise ValueError("majority vote received an empty sample list") + + c = Counter(samples) + top = c.most_common() + + # Single unique label + if len(top) == 1: + return top[0][0] + + # Tie in counts + if len(top) >= 2 and top[0][1] == top[1][1]: + # If exactly boolean tie, allow default_yes to decide + if set(c.keys()) == {True, False} and default_yes is not None: + return bool(default_yes) # type: ignore[return-value] + + # Otherwise, deterministic tie-break among tied labels by string representation + tied_labels = sorted([lab for lab, cnt in top if cnt == top[0][1]], key=lambda x: str(x)) + return tied_labels[0] + + # Clear winner + return top[0][0] + + +def _mean_bool_one(samples: Sequence[bool]) -> bool: + """ + Average a list of booleans; True if mean >= 0.5 else False. + """ + if not samples: + raise ValueError("mean_bool received an empty sample list") + s = sum(1 for x in samples if x) + return (s / len(samples)) >= 0.5 + + +def apply_ensemble( + strategy: EnsembleStrategy, + all_outputs: List[List[T]], # shape: [n_sample][batch] + *, + default_yes: Optional[bool] = None, + return_indices: bool = False, +) -> Union[List[T], Tuple[List[T], List[int]]]: + """ + Collapse shape [n_sample][batch] -> [batch] according to strategy. + + Assumptions: + - Inputs in `all_outputs` are already canonical (e.g., booleans for filter). + - MAJORITY is generic for any hashable type; MEAN_BOOL requires boolean labels. + + Returns: + - If return_indices=False: List[T] (chosen output per item) + - If return_indices=True: (List[T], List[int]) (plus chosen run index per item) + """ + if not all_outputs: + return [] if not return_indices else ([], []) + + batch = len(all_outputs[0]) + for run in all_outputs: + if len(run) != batch: + raise ValueError("Inconsistent batch sizes across runs") + + n_sample = len(all_outputs) + + if strategy == EnsembleStrategy.PICK_FIRST or n_sample == 1: + chosen_labels: List[T] = list(all_outputs[0]) + return chosen_labels if not return_indices else (chosen_labels, [0] * batch) + + per_item: List[List[T]] = [[all_outputs[k][i] for k in range(n_sample)] for i in range(batch)] + + final_labels: List[T] = [] + chosen_indices: List[int] = [] + + if strategy == EnsembleStrategy.MAJORITY: + for samples in per_item: + label = _majority_vote_one(samples, default_yes=default_yes) + final_labels.append(label) + winner_idx = next(idx for idx, v in enumerate(samples) if v == label) + chosen_indices.append(winner_idx) + + elif strategy == EnsembleStrategy.MEAN_BOOL: + if not all(isinstance(x, bool) for run in all_outputs for x in run): + raise ValueError("MEAN_BOOL can only be applied to boolean outputs.") + for samples in per_item: + label_bool = _mean_bool_one(cast(Sequence[bool], samples)) + # cast back to T to satisfy the generic return type + final_labels.append(cast(T, label_bool)) + # compare as bools to find the first matching run + winner_idx = next(idx for idx, v in enumerate(samples) if bool(v) == label_bool) + chosen_indices.append(winner_idx) + else: + raise ValueError(f"Unknown EnsembleStrategy: {strategy}") + + return final_labels if not return_indices else (final_labels, chosen_indices) + + +def pick_logprobs_for_choices( + all_logprobs: Optional[List[Optional[List[List[ChatCompletionTokenLogprob]]]]], # [n_sample][batch][tokens] + chosen_indices: List[int], # [batch] +) -> Optional[List[List[ChatCompletionTokenLogprob]]]: + """ + Given logprobs for each run and the chosen run index per item, + return the per-item logprobs of the finally chosen outputs. + + all_logprobs is provider-specific; we only route to the chosen run/item. + """ + if all_logprobs is None or not all_logprobs: + return None + + # Determine batch size from first non-None run and validate all runs + batch = None + for run_logs in all_logprobs: + if run_logs is not None: + batch = len(run_logs) + break + if batch is None: + return None + for run_logs in all_logprobs: + if run_logs is not None and len(run_logs) != batch: + raise ValueError("Inconsistent batch sizes in all_logprobs runs") + + if len(chosen_indices) != batch: + raise ValueError("chosen_indices length does not match batch size") + + chosen_per_item: List[List[ChatCompletionTokenLogprob]] = [] + for i, winner_run in enumerate(chosen_indices): + run_logs = all_logprobs[winner_run] if winner_run < len(all_logprobs) else None + chosen_per_item.append(run_logs[i] if (run_logs is not None) else []) + return chosen_per_item + + +def resample_batch( + call_once: Callable[..., Any], + n_sample: int, + *args: Any, + **kwargs: Any, +) -> Tuple[List[List[str]], Optional[List[Optional[List[List[ChatCompletionTokenLogprob]]]]]]: + """ + Run the same batch multiple times and collect outputs (+logprobs if produced). + + We do not prescribe call_once signature; we pass through *args/**kwargs. + Expected from call_once: + - returns an object with `.outputs: List[str]` + - may have `.logprobs` (provider-specific shape), or None/absent. + """ + if n_sample <= 0: + raise ValueError("n_sample must be >= 1") + + all_outputs: List[List[str]] = [] + all_logs: List[Optional[List[List[ChatCompletionTokenLogprob]]]] = [] + + for _ in range(n_sample): + lm_out = call_once(*args, **kwargs) + + if not hasattr(lm_out, "outputs") or not isinstance(lm_out.outputs, list): + raise ValueError("LM call did not return a list-like 'outputs'.") + + all_outputs.append(lm_out.outputs) + + logs = getattr(lm_out, "logprobs", None) + # Accept None for runs with no logprobs + all_logs.append(logs if logs is not None else None) + + # Validate consistent batch + batch = len(all_outputs[0]) + for run in all_outputs[1:]: + if len(run) != batch: + raise ValueError("Inconsistent batch sizes across runs") + + return all_outputs, all_logs diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 3a26506d..7f1195e6 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -6,12 +6,14 @@ import lotus from lotus.cache import operator_cache +from lotus.sampling_utils import apply_ensemble, pick_logprobs_for_choices, resample_batch from lotus.templates import task_instructions from lotus.types import ( CascadeArgs, - LMOutput, + EnsembleStrategy, LogprobsForFilterCascade, ProxyModel, + RawOutputs, ReasoningStrategy, SemanticFilterOutput, ) @@ -35,6 +37,9 @@ def sem_filter( show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", + n_sample: int = 1, # number of samples per item + ensemble: EnsembleStrategy = EnsembleStrategy.PICK_FIRST, # "majority_vote", "mean_prob", + temperature: float | None = None, # if None, use model default ) -> SemanticFilterOutput: """ Filters a list of documents based on a natural language instruction using a language model. @@ -88,6 +93,8 @@ def sem_filter( >>> result = sem_filter(docs, model, "Is this a positive sentiment?") >>> print(result.outputs) # [True, False] """ + if n_sample <= 0: + raise ValueError("n_sample must be >= 1 for semantic filtering.") inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( @@ -102,30 +109,79 @@ def sem_filter( ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) - kwargs: dict[str, Any] = {"logprobs": logprobs} if safe_mode: - estimated_total_calls = len(docs) - estimated_total_cost = sum(model.count_tokens(input) for input in inputs) + estimated_total_calls = len(docs) * n_sample + estimated_total_cost = sum(model.count_tokens(input) for input in inputs) * n_sample show_safe_mode(estimated_total_cost, estimated_total_calls) - lm_output: LMOutput = model( - inputs, show_progress_bar=show_progress_bar, progress_bar_desc=progress_bar_desc, **kwargs + # Define a single-call function using closure to capture outer variables + def _call_once(): + call_kwargs: dict[str, Any] = { + "show_progress_bar": show_progress_bar, + "progress_bar_desc": (progress_bar_desc if n_sample == 1 else f"{progress_bar_desc} (x{n_sample})"), + } + if logprobs: + call_kwargs["logprobs"] = True + if temperature is not None: + call_kwargs["temperature"] = temperature + + return model(inputs, **call_kwargs) + + # Run the model n_sample times (sampling) + all_runs_texts, all_runs_logprobs = resample_batch(_call_once, n_sample=n_sample) + + # Postprocess each run independently to get canonical booleans + postprocessed_runs = [filter_postprocess(run_texts, model, default) for run_texts in all_runs_texts] + + # NEW: package all runs into RawOutputs objects + all_runs_packaged = [] + for run_idx, pp in enumerate(postprocessed_runs): + all_runs_packaged.append( + RawOutputs( + preds=pp.raw_outputs, + logprobs=all_runs_logprobs[run_idx] if all_runs_logprobs else None, + parsed_outputs=pp.outputs, + explanations=pp.explanations, + ) + ) + + canonical_runs_bool: list[list[bool]] = [pp.outputs for pp in postprocessed_runs] # [n_sample][batch] + + final_labels, chosen_run_idx = apply_ensemble( + ensemble, + canonical_runs_bool, # shape [n_sample][batch], booleans only + default_yes=default, # tie-break only for exact boolean ties + return_indices=True, ) - postprocess_output = filter_postprocess(lm_output.outputs, model, default) - lotus.logger.debug(f"outputs: {postprocess_output.outputs}") - lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") - lotus.logger.debug(f"explanations: {postprocess_output.explanations}") + # Type assertion to help mypy understand the tuple unpacking + assert isinstance(final_labels, list) and isinstance(chosen_run_idx, list) + + batch = len(final_labels) + + # Pick raw outputs & explanations from the winning run per item + raw_outputs: list[str] = [postprocessed_runs[chosen_run_idx[i]].raw_outputs[i] for i in range(batch)] + explanations: list[str | None] = [postprocessed_runs[chosen_run_idx[i]].explanations[i] for i in range(batch)] + + # Pick logprobs from the winning run per item, if requested + chosen_logprobs = pick_logprobs_for_choices(all_runs_logprobs, chosen_run_idx) if logprobs else None + + # Debug logging + lotus.logger.debug(f"final_labels: {final_labels}") + lotus.logger.debug(f"chosen_run_idx: {chosen_run_idx}") + lotus.logger.debug(f"raw_outputs (winners): {raw_outputs}") + lotus.logger.debug(f"explanations (winners): {explanations}") if safe_mode: model.print_total_usage() return SemanticFilterOutput( - raw_outputs=postprocess_output.raw_outputs, - outputs=postprocess_output.outputs, - explanations=postprocess_output.explanations, - logprobs=lm_output.logprobs if logprobs else None, + raw_outputs=raw_outputs, # from winners + outputs=final_labels, # canonical booleans after ensembling + explanations=explanations, # from winners + logprobs=chosen_logprobs, # from winners (or None) + raw_outputs_all_runs=all_runs_packaged, # all runs packaged ) @@ -347,11 +403,16 @@ def __call__( safe_mode: bool = False, progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", + n_sample: int = 1, # number of samples per item + ensemble: EnsembleStrategy = EnsembleStrategy.PICK_FIRST, # Changed from str | None + temperature: float | None = None, # if None, use model default ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: if lotus.settings.lm is None: raise ValueError( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" ) + if n_sample <= 0: + raise ValueError("n_sample must be >= 1 for semantic filtering.") stats: dict[str, float] = {} lotus.logger.debug(user_instruction) @@ -425,6 +486,9 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", + n_sample=n_sample, # NEW + ensemble=ensemble, # NEW + temperature=temperature, # NEW ) _, helper_logprobs = helper_output.outputs, helper_output.logprobs assert helper_logprobs is not None @@ -493,6 +557,7 @@ def __call__( for idx in high_conf_idxs: outputs[idx] = proxy_outputs[idx] + all_runs: list[RawOutputs] = [] # If using helper LM, get raw outputs and explanations if proxy_model == ProxyModel.HELPER_LM: @@ -519,6 +584,9 @@ def __call__( safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", additional_cot_instructions=additional_cot_instructions, + n_sample=n_sample, # NEW + ensemble=ensemble, # NEW + temperature=temperature, # NEW ) for idx, large_idx in enumerate(low_conf_idxs): @@ -543,10 +611,14 @@ def __call__( show_progress_bar=True, progress_bar_desc=progress_bar_desc, additional_cot_instructions=additional_cot_instructions, + n_sample=n_sample, # NEW + ensemble=ensemble, # NEW + temperature=temperature, # NEW ) outputs = output.outputs raw_outputs = output.raw_outputs explanations = output.explanations + all_runs = output.raw_outputs_all_runs if output.raw_outputs_all_runs else [] if not return_all: # find indices where output is True @@ -578,6 +650,30 @@ def get_out_col_name(df, col_name): filtered_explanations = explanations filtered_raw_outputs = raw_outputs + # NEW — Add per-run rollout columns if multiple runs exist + + if all_runs and n_sample > 1: + for run_idx, run_data in enumerate(all_runs, start=1): + # CASE 1 — return_all=True → show all rows + if return_all: + new_df[f"raw_output_{run_idx}{suffix}"] = run_data.preds + new_df[f"parsed_output_{run_idx}{suffix}"] = run_data.parsed_outputs + + if return_explanations: + new_df[f"explanation_{run_idx}{suffix}"] = run_data.explanations + + # CASE 2 — return_all=False → only rows that passed the filter + else: + new_df[f"raw_output_{run_idx}{suffix}"] = [run_data.preds[i] for i in ids] + new_df[f"parsed_output_{run_idx}{suffix}"] = [run_data.parsed_outputs[i] for i in ids] + + if return_explanations: + new_df[f"explanation_{run_idx}{suffix}"] = [run_data.explanations[i] for i in ids] + + # Finally add the ensemble answer as a column + ensemble_col = [outputs[i] for i in ids] if not return_all else outputs + new_df[f"ensemble_answer{suffix}"] = ensemble_col + # return rows where output is True if return_explanations and return_raw_outputs: new_df["explanation" + suffix] = filtered_explanations diff --git a/lotus/types.py b/lotus/types.py index 08519729..6568832c 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -63,6 +63,15 @@ class LogprobsForFilterCascade: confidences: list[list[float]] +# Raw outputs +@dataclass +class RawOutputs: + preds: list[str] + logprobs: list[list[ChatCompletionTokenLogprob]] | None + parsed_outputs: list[bool] + explanations: list[str | None] + + ################################################################################ # Semantic operation outputs ################################################################################ @@ -108,6 +117,7 @@ class SemanticFilterOutput: explanations: list[str | None] stats: dict[str, Any] | None = None logprobs: list[list[ChatCompletionTokenLogprob]] | None = None + raw_outputs_all_runs: list[RawOutputs] | None = None @dataclass @@ -220,3 +230,12 @@ class ReasoningStrategy(Enum): COT = auto() ZS_COT = auto() FEW_SHOT = auto() + + +################################################################################ +# Ensemble strategy +################################################################################ +class EnsembleStrategy(Enum): + PICK_FIRST = "pick_first" # Always choose run 0 + MAJORITY = "majority" # For canonical boolean labels + MEAN_BOOL = "mean_bool" # Average booleans >= .5 -> True diff --git a/pyproject.toml b/pyproject.toml index ae80a86d..bda68504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lotus-ai" -version = "1.1.4" +version = "1.1.3" description = "lotus" readme = "README.md" authors = [