diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index 2ff61106..52b261d6 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -3,12 +3,14 @@ import lotus from lotus.models import LM, LiteLLMRM from lotus.types import CascadeArgs, ProxyModel +from lotus.vector_store import FaissVS gpt_4o_mini = LM("gpt-4o-mini") gpt_4o = LM("gpt-4o") rm = LiteLLMRM(model="text-embedding-3-small") +vs = FaissVS() -lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm) +lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/models/lm.py b/lotus/models/lm.py index e8c4d087..9d26e5d7 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -194,7 +194,9 @@ def _update_stats(self, response: ModelResponse, is_cached: bool = False): except Exception as e: # Handle any other unexpected errors when calculating cost lotus.logger.debug(f"Unexpected error calculating completion cost: {e}") - raise Warning("Error calculating completion cost - cost metrics will be inaccurate. Enable debug logging for details.") + raise Warning( + "Error calculating completion cost - cost metrics will be inaccurate. Enable debug logging for details." + ) cost = None # Always update virtual usage @@ -234,27 +236,43 @@ def format_logprobs_for_filter_cascade( ) -> LogprobsForFilterCascade: # Get base cascade format first base_cascade = self.format_logprobs_for_cascade(logprobs) - all_true_probs = [] + all_true_probs: list[float] = [] def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: - if "True" in token_probs and "False" in token_probs: - true_prob = token_probs["True"] - false_prob = token_probs["False"] + # Normalize keys by converting to lowercase and stripping whitespace + # Take the max probability for each key (e.g. True and true both map to "true" so we take the max of the two) + normalized_probs: dict[str, float] = {} + for k, v in token_probs.items(): + normalized_key = k.lower().strip() + normalized_probs[normalized_key] = max(v, normalized_probs.get(normalized_key, float("-inf"))) + + # Look for true/false in normalized keys + if "true" in normalized_probs and "false" in normalized_probs: + true_prob = normalized_probs["true"] + false_prob = normalized_probs["false"] return true_prob / (true_prob + false_prob) return None - # Get true probabilities for filter cascade - for resp_idx, response_logprobs in enumerate(logprobs): - true_prob = None - for logprob in response_logprobs: + for response_logprobs in logprobs: + true_prob: float = 1 # Default if no true/false token found + + # Find last true/false token by normalizing all tokens and searching in reverse + cleaned_tokens = [logprob.token.lower().strip() for logprob in response_logprobs] + true_false_indices = [i for i, token in enumerate(cleaned_tokens) if token in ["true", "false"]] + + if true_false_indices: + last_true_false_idx = true_false_indices[-1] + logprob = response_logprobs[last_true_false_idx] + cleaned_token = cleaned_tokens[last_true_false_idx] + token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs} - true_prob = get_normalized_true_prob(token_probs) - if true_prob is not None: - break + normalized_prob = get_normalized_true_prob(token_probs) - # Default to 1 if "True" in tokens, 0 if not - if true_prob is None: - true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0 + # Fall back to binary true/false if normalization fails + if normalized_prob is None: + true_prob = 1 if cleaned_token == "true" else 0 + else: + true_prob = normalized_prob all_true_probs.append(true_prob) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index c0dd4be3..34fde2fb 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -8,7 +8,7 @@ from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, ProxyModel, SemanticFilterOutput -from lotus.utils import show_safe_mode +from lotus.utils import get_out_col_name, show_safe_mode from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -162,6 +162,7 @@ def __call__( strategy: str | None = None, cascade_args: CascadeArgs | None = None, return_stats: bool = False, + return_scores: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", @@ -183,6 +184,7 @@ def __call__( sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1. failure_probability (float): The failure probability when cascading. Defaults to 0.2. return_stats (bool): Whether to return statistics. Defaults to False. + return_scores (bool): Whether to return probabilities. Defaults to False. additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: @@ -193,7 +195,7 @@ def __call__( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" ) - stats: dict[str, float] = {} + stats: dict[str, Any] = {} lotus.logger.debug(user_instruction) col_li = lotus.nl_expression.parse_cols(user_instruction) lotus.logger.debug(col_li) @@ -330,9 +332,16 @@ def __call__( outputs: list[bool] = [False] * len(multimodal_data) raw_outputs: list[str] = [""] * len(multimodal_data) explanations: list[str | None] = [None] * len(multimodal_data) + scores: list[float] = [0.0] * len(multimodal_data) + score_methods: list[str] = [""] * len(multimodal_data) for idx in high_conf_idxs: outputs[idx] = proxy_outputs[idx] + scores[idx] = proxy_scores[idx] + if proxy_model == ProxyModel.HELPER_LM: + score_methods[idx] = "HELPER_LM_TOKEN_PROB" + else: + score_methods[idx] = "RM_SIM_SCORE" # If using helper LM, get raw outputs and explanations if proxy_model == ProxyModel.HELPER_LM: @@ -358,13 +367,22 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", + logprobs=return_scores, additional_cot_instructions=additional_cot_instructions, ) + if return_scores: + assert large_output.logprobs is not None, "Logprobs must be returned if return_scores is True" + formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(large_output.logprobs) + large_probs = formatted_logprobs.true_probs + for idx, large_idx in enumerate(low_conf_idxs): outputs[large_idx] = large_output.outputs[idx] raw_outputs[large_idx] = large_output.raw_outputs[idx] explanations[large_idx] = large_output.explanations[idx] + if return_scores: + scores[large_idx] = large_probs[idx] + score_methods[large_idx] = "LM_TOKEN_PROB" stats["filters_resolved_by_helper_model"] += len(high_conf_idxs) stats["filters_resolved_by_large_model"] += len(low_conf_idxs) @@ -382,12 +400,22 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, + logprobs=return_scores, additional_cot_instructions=additional_cot_instructions, ) outputs = output.outputs raw_outputs = output.raw_outputs explanations = output.explanations + if return_scores: + assert output.logprobs is not None, "Logprobs must be returned if return_scores is True" + formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(output.logprobs) + scores = formatted_logprobs.true_probs + score_methods = ["LM_TOKEN_PROB"] * len(multimodal_data) + else: + scores = [0.0] * len(multimodal_data) + score_methods = [""] * len(multimodal_data) + if not return_all: # find indices where output is True ids = [i for i, x in enumerate(outputs) if x] @@ -398,34 +426,32 @@ def __call__( [outputs[i] for i in ids] filtered_explanations = [explanations[i] for i in ids] filtered_raw_outputs = [raw_outputs[i] for i in ids] + filtered_scores = [scores[i] for i in ids] + filtered_score_methods = [score_methods[i] for i in ids] lotus.logger.debug(f"filtered_raw_outputs: {filtered_raw_outputs}") new_df = self._obj.iloc[ids] new_df.attrs["index_dirs"] = self._obj.attrs.get("index_dirs", None) else: - - def get_out_col_name(df, col_name): - if col_name in df.columns: - i = 1 - while f"{col_name}_{i}" in new_df.columns: - i += 1 - return f"{col_name}_{i}" - else: - return col_name - new_df = self._obj.copy() new_df[get_out_col_name(new_df, "filter_label")] = outputs filtered_explanations = explanations filtered_raw_outputs = raw_outputs + filtered_scores = scores + filtered_score_methods = score_methods # return rows where output is True if return_explanations and return_raw_outputs: - new_df["explanation" + suffix] = filtered_explanations - new_df["raw_output" + suffix] = filtered_raw_outputs + new_df[get_out_col_name(new_df, "explanation" + suffix)] = filtered_explanations + new_df[get_out_col_name(new_df, "raw_output" + suffix)] = filtered_raw_outputs elif return_explanations: - new_df["explanation" + suffix] = filtered_explanations + new_df[get_out_col_name(new_df, "explanation" + suffix)] = filtered_explanations elif return_raw_outputs: - new_df["raw_output" + suffix] = filtered_raw_outputs + new_df[get_out_col_name(new_df, "raw_output" + suffix)] = filtered_raw_outputs + + if return_scores: + new_df[get_out_col_name(new_df, "score")] = filtered_scores + new_df[get_out_col_name(new_df, "score_method")] = filtered_score_methods if return_stats: return new_df, stats diff --git a/lotus/utils.py b/lotus/utils.py index bf75748b..b1038744 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -132,3 +132,23 @@ def show_safe_mode(estimated_cost, estimated_LM_calls): except KeyboardInterrupt: print("\nExecution cancelled by user") exit(0) + + +def get_out_col_name(df: pd.DataFrame, col_name: str) -> str: + """ + Gets a unique column name by appending an index if the column already exists. + + Args: + df (pd.DataFrame): The DataFrame to check for existing columns. + col_name (str): The base column name. + + Returns: + str: A unique column name. + """ + if col_name in df.columns: + i = 1 + while f"{col_name}_{i}" in df.columns: + i += 1 + return f"{col_name}_{i}" + else: + return col_name diff --git a/tests/test_filter.py b/tests/test_filter.py index b51319ca..75822b47 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,6 +1,8 @@ import pandas as pd import pytest +import lotus +from lotus.models import LM from tests.base_test import BaseTest @@ -99,3 +101,74 @@ def test_filtered_search_with_scores(self, sample_df): assert len(result["vec_scores_sim_score"]) == 2 # Scores should be between 0 and 1 assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"]) + + +class TestFilterWithScores(BaseTest): + def test_filter_with_scores(self, sample_df): + """Test semantic filter with scores returned to the user""" + lm = LM(model="gpt-4o-mini") + lotus.settings.configure(lm=lm) + result = sample_df.sem_filter("{Course Name} will be fun", return_scores=True) + print(result) + assert "score" in result.columns + assert "score_method" in result.columns + + def test_filter_with_scores_and_return_all(self, sample_df): + """Test semantic filter with scores returned to the user""" + lm = LM(model="gpt-4o-mini") + lotus.settings.configure(lm=lm) + result = sample_df.sem_filter("{Course Name} will be fun", return_scores=True, return_all=True) + print(result) + assert "score" in result.columns + assert "score_method" in result.columns + + for idx, row in result.iterrows(): + if row["filter_label"]: + assert row["score"] >= 0.5 + else: + assert row["score"] <= 0.5 + + def test_filter_twice(self, sample_df): + """Test filtering twice to verify column name indexing works correctly""" + lm = LM(model="gpt-4o-mini") + lotus.settings.configure(lm=lm) + + # First filter + result = sample_df.sem_filter( + "{Course Name} is related to programming", + return_all=True, + return_explanations=True, + return_raw_outputs=True, + return_scores=True, + ) + print(f"First filter result: {result}") + + # Verify first filter columns + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + assert "raw_output_filter" in result.columns + assert "score" in result.columns + assert "score_method" in result.columns + + # Second filter + result = result.sem_filter( + "{Course Name} is related to programming", + return_all=True, + return_explanations=True, + return_raw_outputs=True, + return_scores=True, + ) + print(f"Second filter result: {result}") + # Verify second filter columns have indices + assert "filter_label_1" in result.columns + assert "explanation_filter_1" in result.columns + assert "raw_output_filter_1" in result.columns + assert "score_1" in result.columns + assert "score_method_1" in result.columns + + # Verify original columns still exist + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + assert "raw_output_filter" in result.columns + assert "score" in result.columns + assert "score_method" in result.columns