From 618b913296cf9f75c54256c40a55d948d0f28ed1 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Wed, 19 Feb 2025 17:40:01 -0800 Subject: [PATCH 1/6] Initial work --- examples/op_examples/filter_cascade.py | 4 +- lotus/models/lm.py | 1 + lotus/sem_ops/sem_filter.py | 21 ++++++- tests/test_filter.py | 78 ++++++++++++++------------ 4 files changed, 64 insertions(+), 40 deletions(-) diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index 2ff61106..52b261d6 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -3,12 +3,14 @@ import lotus from lotus.models import LM, LiteLLMRM from lotus.types import CascadeArgs, ProxyModel +from lotus.vector_store import FaissVS gpt_4o_mini = LM("gpt-4o-mini") gpt_4o = LM("gpt-4o") rm = LiteLLMRM(model="text-embedding-3-small") +vs = FaissVS() -lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm) +lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 616dc143..644a05ac 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -178,6 +178,7 @@ def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: if "True" in token_probs and "False" in token_probs: true_prob = token_probs["True"] false_prob = token_probs["False"] + print(f"True probability: {true_prob}, False probability: {false_prob}") return true_prob / (true_prob + false_prob) return None diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 0f30874f..2405098e 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -173,7 +173,6 @@ def __call__( sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1. failure_probability (float): The failure probability when cascading. Defaults to 0.2. return_stats (bool): Whether to return statistics. Defaults to False. - Returns: pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: The filtered dataframe or a tuple containing the filtered dataframe and statistics. """ @@ -182,7 +181,7 @@ def __call__( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" ) - stats = {} + stats: dict[str, Any] = {} lotus.logger.debug(user_instruction) col_li = lotus.nl_expression.parse_cols(user_instruction) lotus.logger.debug(col_li) @@ -321,8 +320,13 @@ def __call__( raw_outputs: list[str] = [""] * len(multimodal_data) explanations: list[str | None] = [None] * len(multimodal_data) + if return_stats: + stats["probs"] = [0.0] * len(multimodal_data) + for idx in high_conf_idxs: outputs[idx] = proxy_outputs[idx] + if return_stats: + stats["probs"][idx] = proxy_scores[idx] # If using helper LM, get raw outputs and explanations if proxy_model == ProxyModel.HELPER_LM: @@ -348,12 +352,19 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", + logprobs=return_stats, ) + if return_stats and large_output.logprobs: + formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(large_output.logprobs) + large_probs = formatted_logprobs.true_probs + for idx, large_idx in enumerate(low_conf_idxs): outputs[large_idx] = large_output.outputs[idx] raw_outputs[large_idx] = large_output.raw_outputs[idx] explanations[large_idx] = large_output.explanations[idx] + if return_stats: + stats["probs"][large_idx] = large_probs[idx] stats["filters_resolved_by_helper_model"] += len(high_conf_idxs) stats["filters_resolved_by_large_model"] += len(low_conf_idxs) @@ -371,11 +382,17 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, + logprobs=return_stats, # stats includes logprobs ) outputs = output.outputs raw_outputs = output.raw_outputs explanations = output.explanations + if return_stats: + assert output.logprobs is not None, "logprobs must be returned to get stats" + formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(output.logprobs) + stats["probs"] = formatted_logprobs.true_probs + # find indices where output is True ids = [i for i, x in enumerate(outputs) if x] idx_ids = [self._obj.index[i] for i, x in enumerate(outputs) if x] diff --git a/tests/test_filter.py b/tests/test_filter.py index 1611340a..47e3e6dd 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,31 +1,29 @@ import pandas as pd import pytest +import lotus +from lotus.models import LM from tests.base_test import BaseTest @pytest.fixture def sample_df(): - return pd.DataFrame({ - "Course Name": [ - "Introduction to Programming", - "Advanced Programming", - "Cooking Basics", - "Advanced Culinary Arts", - "Data Structures", - "Algorithms", - "French Cuisine", - "Italian Cooking" - ], - "Department": [ - "CS", "CS", "Culinary", "Culinary", - "CS", "CS", "Culinary", "Culinary" - ], - "Level": [ - 100, 200, 100, 200, - 300, 300, 200, 200 - ] - }) + return pd.DataFrame( + { + "Course Name": [ + "Introduction to Programming", + "Advanced Programming", + "Cooking Basics", + "Advanced Culinary Arts", + "Data Structures", + "Algorithms", + "French Cuisine", + "Italian Cooking", + ], + "Department": ["CS", "CS", "Culinary", "Culinary", "CS", "CS", "Culinary", "Culinary"], + "Level": [100, 200, 100, 200, 300, 300, 200, 200], + } + ) class TestSearch(BaseTest): @@ -41,11 +39,11 @@ def test_filtered_search_relational(self, sample_df): """Test semantic search with relational filter""" # Index the dataframe df = sample_df.sem_index("Course Name", "course_index") - + # Apply relational filter and search filtered_df = df[df["Department"] == "CS"] result = filtered_df.sem_search("Course Name", "advanced courses", K=2) - + assert len(result) == 2 # Should only return CS courses assert all(dept == "CS" for dept in result["Department"]) @@ -55,11 +53,11 @@ def test_filtered_search_semantic(self, sample_df): """Test semantic search after semantic filter""" # Index the dataframe df = sample_df.sem_index("Course Name", "course_index") - + # Apply semantic filter and search filtered_df = df.sem_filter("{Course Name} is related to cooking") result = filtered_df.sem_search("Course Name", "advanced level courses", K=2) - + assert len(result) == 2 # Should only return cooking-related courses assert all(dept == "Culinary" for dept in result["Department"]) @@ -69,12 +67,12 @@ def test_filtered_search_combined(self, sample_df): """Test semantic search with both relational and semantic filters""" # Index the dataframe df = sample_df.sem_index("Course Name", "course_index") - + # Apply both filters and search filtered_df = df[df["Level"] >= 200] # relational filter filtered_df = filtered_df.sem_filter("{Course Name} is related to computer science") # semantic filter result = filtered_df.sem_search("Course Name", "data structures and algorithms", K=2) - + assert len(result) == 2 # Should only return advanced CS courses assert all(dept == "CS" for dept in result["Department"]) @@ -85,26 +83,32 @@ def test_filtered_search_combined(self, sample_df): def test_filtered_search_empty_result(self, sample_df): """Test semantic search when filter returns empty result""" df = sample_df.sem_index("Course Name", "course_index") - + # Apply filter that should return no results filtered_df = df[df["Level"] > 1000] result = filtered_df.sem_search("Course Name", "any course", K=2) - + assert len(result) == 0 def test_filtered_search_with_scores(self, sample_df): """Test filtered semantic search with similarity scores""" df = sample_df.sem_index("Course Name", "course_index") - + filtered_df = df[df["Department"] == "CS"] - result = filtered_df.sem_search( - "Course Name", - "programming courses", - K=2, - return_scores=True - ) - + result = filtered_df.sem_search("Course Name", "programming courses", K=2, return_scores=True) + assert "vec_scores_sim_score" in result.columns assert len(result["vec_scores_sim_score"]) == 2 # Scores should be between 0 and 1 - assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"]) \ No newline at end of file + assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"]) + + +class TestFilterWithProbs(BaseTest): + def test_filter_with_probs(self, sample_df): + """Test semantic filter with probabilities returned to the user""" + lm = LM(model="gpt-4o-mini") + lotus.settings.configure(lm=lm) + result, stats = sample_df.sem_filter("{Course Name} is useful for life", return_stats=True) + print(result) + print(stats) + assert "probs" in stats From 9000328fedd9d351904129acb3488a9f759c88ee Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Wed, 19 Feb 2025 20:36:59 -0800 Subject: [PATCH 2/6] More work --- lotus/models/lm.py | 44 +++++++++++++++++++++++-------------- lotus/sem_ops/sem_filter.py | 33 +++++++++++++++++----------- tests/test_filter.py | 5 ++--- 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 644a05ac..07cd1804 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -172,28 +172,40 @@ def format_logprobs_for_filter_cascade( ) -> LogprobsForFilterCascade: # Get base cascade format first base_cascade = self.format_logprobs_for_cascade(logprobs) - all_true_probs = [] + all_true_probs: list[float] = [] def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: - if "True" in token_probs and "False" in token_probs: - true_prob = token_probs["True"] - false_prob = token_probs["False"] - print(f"True probability: {true_prob}, False probability: {false_prob}") + # Normalize keys by converting to lowercase and stripping whitespace + # Take the max probability for each key (e.g. True and true both map to "true" so we take the max of the two) + normalized_probs: dict[str, float] = {} + for k, v in token_probs.items(): + normalized_key = k.lower().strip() + normalized_probs[normalized_key] = max(v, normalized_probs.get(normalized_key, float("-inf"))) + + # Look for true/false in normalized keys + if "true" in normalized_probs and "false" in normalized_probs: + true_prob = normalized_probs["true"] + false_prob = normalized_probs["false"] return true_prob / (true_prob + false_prob) return None - # Get true probabilities for filter cascade - for resp_idx, response_logprobs in enumerate(logprobs): - true_prob = None - for logprob in response_logprobs: - token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs} - true_prob = get_normalized_true_prob(token_probs) - if true_prob is not None: - break + for response_logprobs in logprobs: + true_prob: float = 1 # Default if no true/false token found + + for logprob in reversed(response_logprobs): + cleaned_token = logprob.token.lower().strip() + if cleaned_token not in ["true", "false"]: + continue - # Default to 1 if "True" in tokens, 0 if not - if true_prob is None: - true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0 + token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs} + normalized_prob = get_normalized_true_prob(token_probs) + + # Fall back to binary true/false if normalization fails + if normalized_prob is None: + true_prob = 1 if cleaned_token == "true" else 0 + else: + true_prob = normalized_prob + break all_true_probs.append(true_prob) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 5a049184..ca5fe774 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -162,6 +162,7 @@ def __call__( strategy: str | None = None, cascade_args: CascadeArgs | None = None, return_stats: bool = False, + return_probs: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", @@ -183,6 +184,7 @@ def __call__( sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1. failure_probability (float): The failure probability when cascading. Defaults to 0.2. return_stats (bool): Whether to return statistics. Defaults to False. + return_probs (bool): Whether to return probabilities. Defaults to False. additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: @@ -330,14 +332,11 @@ def __call__( outputs: list[bool] = [False] * len(multimodal_data) raw_outputs: list[str] = [""] * len(multimodal_data) explanations: list[str | None] = [None] * len(multimodal_data) - - if return_stats: - stats["probs"] = [0.0] * len(multimodal_data) + probs: list[float] = [0.0] * len(multimodal_data) for idx in high_conf_idxs: outputs[idx] = proxy_outputs[idx] - if return_stats: - stats["probs"][idx] = proxy_scores[idx] + probs[idx] = proxy_scores[idx] # If using helper LM, get raw outputs and explanations if proxy_model == ProxyModel.HELPER_LM: @@ -363,11 +362,12 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", - logprobs=return_stats, + logprobs=return_probs, additional_cot_instructions=additional_cot_instructions, ) - if return_stats and large_output.logprobs: + if return_probs: + assert large_output.logprobs is not None, "Logprobs must be returned if return_probs is True" formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(large_output.logprobs) large_probs = formatted_logprobs.true_probs @@ -375,8 +375,8 @@ def __call__( outputs[large_idx] = large_output.outputs[idx] raw_outputs[large_idx] = large_output.raw_outputs[idx] explanations[large_idx] = large_output.explanations[idx] - if return_stats: - stats["probs"][large_idx] = large_probs[idx] + if return_probs: + probs[large_idx] = large_probs[idx] stats["filters_resolved_by_helper_model"] += len(high_conf_idxs) stats["filters_resolved_by_large_model"] += len(low_conf_idxs) @@ -394,17 +394,19 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, - logprobs=return_stats, # stats includes logprobs + logprobs=return_probs, additional_cot_instructions=additional_cot_instructions, ) outputs = output.outputs raw_outputs = output.raw_outputs explanations = output.explanations - if return_stats: - assert output.logprobs is not None, "logprobs must be returned to get stats" + if return_probs: + assert output.logprobs is not None, "Logprobs must be returned if return_probs is True" formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(output.logprobs) - stats["probs"] = formatted_logprobs.true_probs + probs = formatted_logprobs.true_probs + else: + probs = [0.0] * len(multimodal_data) if not return_all: # find indices where output is True @@ -416,6 +418,7 @@ def __call__( [outputs[i] for i in ids] filtered_explanations = [explanations[i] for i in ids] filtered_raw_outputs = [raw_outputs[i] for i in ids] + filtered_probs = [probs[i] for i in ids] lotus.logger.debug(f"filtered_raw_outputs: {filtered_raw_outputs}") new_df = self._obj.iloc[ids] @@ -435,6 +438,7 @@ def get_out_col_name(df, col_name): new_df[get_out_col_name(new_df, "filter_label")] = outputs filtered_explanations = explanations filtered_raw_outputs = raw_outputs + filtered_probs = probs # return rows where output is True if return_explanations and return_raw_outputs: @@ -445,6 +449,9 @@ def get_out_col_name(df, col_name): elif return_raw_outputs: new_df["raw_output" + suffix] = filtered_raw_outputs + if return_probs: + new_df["probs" + suffix] = filtered_probs + if return_stats: return new_df, stats diff --git a/tests/test_filter.py b/tests/test_filter.py index 47e3e6dd..228b986a 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -108,7 +108,6 @@ def test_filter_with_probs(self, sample_df): """Test semantic filter with probabilities returned to the user""" lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - result, stats = sample_df.sem_filter("{Course Name} is useful for life", return_stats=True) + result = sample_df.sem_filter("{Course Name} will be fun", return_probs=True) print(result) - print(stats) - assert "probs" in stats + assert "probs_filter" in result.columns From 1619657bed7b2014e4be3051a725143931d13c3b Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Wed, 19 Feb 2025 20:42:09 -0800 Subject: [PATCH 3/6] Add another test --- tests/test_filter.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_filter.py b/tests/test_filter.py index 228b986a..1c518aac 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -111,3 +111,17 @@ def test_filter_with_probs(self, sample_df): result = sample_df.sem_filter("{Course Name} will be fun", return_probs=True) print(result) assert "probs_filter" in result.columns + + def test_filter_with_probs_and_return_all(self, sample_df): + """Test semantic filter with probabilities returned to the user""" + lm = LM(model="gpt-4o-mini") + lotus.settings.configure(lm=lm) + result = sample_df.sem_filter("{Course Name} will be fun", return_probs=True, return_all=True) + print(result) + assert "probs_filter" in result.columns + + for idx, row in result.iterrows(): + if row["filter_label"]: + assert row["probs_filter"] > 0.5 + else: + assert row["probs_filter"] <= 0.5 From 956db5466e70c478b06eef19f44fe09e1e7dfc67 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Wed, 19 Feb 2025 20:59:55 -0800 Subject: [PATCH 4/6] Refactor --- lotus/models/lm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lotus/models/lm.py b/lotus/models/lm.py index 07cd1804..cfd1f0c9 100644 --- a/lotus/models/lm.py +++ b/lotus/models/lm.py @@ -192,10 +192,14 @@ def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: for response_logprobs in logprobs: true_prob: float = 1 # Default if no true/false token found - for logprob in reversed(response_logprobs): - cleaned_token = logprob.token.lower().strip() - if cleaned_token not in ["true", "false"]: - continue + # Find last true/false token by normalizing all tokens and searching in reverse + cleaned_tokens = [logprob.token.lower().strip() for logprob in response_logprobs] + true_false_indices = [i for i, token in enumerate(cleaned_tokens) if token in ["true", "false"]] + + if true_false_indices: + last_true_false_idx = true_false_indices[-1] + logprob = response_logprobs[last_true_false_idx] + cleaned_token = cleaned_tokens[last_true_false_idx] token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs} normalized_prob = get_normalized_true_prob(token_probs) @@ -205,7 +209,6 @@ def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None: true_prob = 1 if cleaned_token == "true" else 0 else: true_prob = normalized_prob - break all_true_probs.append(true_prob) From 5587074e783bf98e8ab647b80a90ebaef9bb2bbb Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Wed, 19 Feb 2025 22:53:00 -0800 Subject: [PATCH 5/6] Switch to scores and add score_method --- lotus/sem_ops/sem_filter.py | 47 +++++++++++++++++++++++-------------- tests/test_filter.py | 24 ++++++++++--------- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index ca5fe774..f22f7ee5 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -162,7 +162,7 @@ def __call__( strategy: str | None = None, cascade_args: CascadeArgs | None = None, return_stats: bool = False, - return_probs: bool = False, + return_scores: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", additional_cot_instructions: str = "", @@ -184,7 +184,7 @@ def __call__( sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1. failure_probability (float): The failure probability when cascading. Defaults to 0.2. return_stats (bool): Whether to return statistics. Defaults to False. - return_probs (bool): Whether to return probabilities. Defaults to False. + return_scores (bool): Whether to return probabilities. Defaults to False. additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: @@ -332,11 +332,16 @@ def __call__( outputs: list[bool] = [False] * len(multimodal_data) raw_outputs: list[str] = [""] * len(multimodal_data) explanations: list[str | None] = [None] * len(multimodal_data) - probs: list[float] = [0.0] * len(multimodal_data) + scores: list[float] = [0.0] * len(multimodal_data) + score_methods: list[str] = [""] * len(multimodal_data) for idx in high_conf_idxs: outputs[idx] = proxy_outputs[idx] - probs[idx] = proxy_scores[idx] + scores[idx] = proxy_scores[idx] + if proxy_model == ProxyModel.HELPER_LM: + score_methods[idx] = "HELPER_LM" + else: + score_methods[idx] = "EMBEDDING_MODEL" # If using helper LM, get raw outputs and explanations if proxy_model == ProxyModel.HELPER_LM: @@ -362,12 +367,12 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", - logprobs=return_probs, + logprobs=return_scores, additional_cot_instructions=additional_cot_instructions, ) - if return_probs: - assert large_output.logprobs is not None, "Logprobs must be returned if return_probs is True" + if return_scores: + assert large_output.logprobs is not None, "Logprobs must be returned if return_scores is True" formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(large_output.logprobs) large_probs = formatted_logprobs.true_probs @@ -375,8 +380,9 @@ def __call__( outputs[large_idx] = large_output.outputs[idx] raw_outputs[large_idx] = large_output.raw_outputs[idx] explanations[large_idx] = large_output.explanations[idx] - if return_probs: - probs[large_idx] = large_probs[idx] + if return_scores: + scores[large_idx] = large_probs[idx] + score_methods[large_idx] = "LM" stats["filters_resolved_by_helper_model"] += len(high_conf_idxs) stats["filters_resolved_by_large_model"] += len(low_conf_idxs) @@ -394,19 +400,21 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, - logprobs=return_probs, + logprobs=return_scores, additional_cot_instructions=additional_cot_instructions, ) outputs = output.outputs raw_outputs = output.raw_outputs explanations = output.explanations - if return_probs: - assert output.logprobs is not None, "Logprobs must be returned if return_probs is True" + if return_scores: + assert output.logprobs is not None, "Logprobs must be returned if return_scores is True" formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(output.logprobs) - probs = formatted_logprobs.true_probs + scores = formatted_logprobs.true_probs + score_methods = ["LM"] * len(multimodal_data) else: - probs = [0.0] * len(multimodal_data) + scores = [0.0] * len(multimodal_data) + score_methods = [""] * len(multimodal_data) if not return_all: # find indices where output is True @@ -418,7 +426,8 @@ def __call__( [outputs[i] for i in ids] filtered_explanations = [explanations[i] for i in ids] filtered_raw_outputs = [raw_outputs[i] for i in ids] - filtered_probs = [probs[i] for i in ids] + filtered_scores = [scores[i] for i in ids] + filtered_score_methods = [score_methods[i] for i in ids] lotus.logger.debug(f"filtered_raw_outputs: {filtered_raw_outputs}") new_df = self._obj.iloc[ids] @@ -438,7 +447,8 @@ def get_out_col_name(df, col_name): new_df[get_out_col_name(new_df, "filter_label")] = outputs filtered_explanations = explanations filtered_raw_outputs = raw_outputs - filtered_probs = probs + filtered_scores = scores + filtered_score_methods = score_methods # return rows where output is True if return_explanations and return_raw_outputs: @@ -449,8 +459,9 @@ def get_out_col_name(df, col_name): elif return_raw_outputs: new_df["raw_output" + suffix] = filtered_raw_outputs - if return_probs: - new_df["probs" + suffix] = filtered_probs + if return_scores: + new_df["score"] = filtered_scores + new_df["score_method"] = filtered_score_methods if return_stats: return new_df, stats diff --git a/tests/test_filter.py b/tests/test_filter.py index 1c518aac..03e97ae7 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -103,25 +103,27 @@ def test_filtered_search_with_scores(self, sample_df): assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"]) -class TestFilterWithProbs(BaseTest): - def test_filter_with_probs(self, sample_df): - """Test semantic filter with probabilities returned to the user""" +class TestFilterWithScores(BaseTest): + def test_filter_with_scores(self, sample_df): + """Test semantic filter with scores returned to the user""" lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - result = sample_df.sem_filter("{Course Name} will be fun", return_probs=True) + result = sample_df.sem_filter("{Course Name} will be fun", return_scores=True) print(result) - assert "probs_filter" in result.columns + assert "score" in result.columns + assert "score_method" in result.columns - def test_filter_with_probs_and_return_all(self, sample_df): - """Test semantic filter with probabilities returned to the user""" + def test_filter_with_scores_and_return_all(self, sample_df): + """Test semantic filter with scores returned to the user""" lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - result = sample_df.sem_filter("{Course Name} will be fun", return_probs=True, return_all=True) + result = sample_df.sem_filter("{Course Name} will be fun", return_scores=True, return_all=True) print(result) - assert "probs_filter" in result.columns + assert "score" in result.columns + assert "score_method" in result.columns for idx, row in result.iterrows(): if row["filter_label"]: - assert row["probs_filter"] > 0.5 + assert row["score"] > 0.5 else: - assert row["probs_filter"] <= 0.5 + assert row["score"] <= 0.5 From 8fee42387b00c5954aff085ab992128d0703fc17 Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Tue, 1 Apr 2025 09:12:57 -0700 Subject: [PATCH 6/6] Address comments --- lotus/sem_ops/sem_filter.py | 32 +++++++++---------------- lotus/utils.py | 20 ++++++++++++++++ tests/test_filter.py | 47 ++++++++++++++++++++++++++++++++++++- 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index f22f7ee5..34fde2fb 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -8,7 +8,7 @@ from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, ProxyModel, SemanticFilterOutput -from lotus.utils import show_safe_mode +from lotus.utils import get_out_col_name, show_safe_mode from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds from .postprocessors import filter_postprocess @@ -339,9 +339,9 @@ def __call__( outputs[idx] = proxy_outputs[idx] scores[idx] = proxy_scores[idx] if proxy_model == ProxyModel.HELPER_LM: - score_methods[idx] = "HELPER_LM" + score_methods[idx] = "HELPER_LM_TOKEN_PROB" else: - score_methods[idx] = "EMBEDDING_MODEL" + score_methods[idx] = "RM_SIM_SCORE" # If using helper LM, get raw outputs and explanations if proxy_model == ProxyModel.HELPER_LM: @@ -382,7 +382,7 @@ def __call__( explanations[large_idx] = large_output.explanations[idx] if return_scores: scores[large_idx] = large_probs[idx] - score_methods[large_idx] = "LM" + score_methods[large_idx] = "LM_TOKEN_PROB" stats["filters_resolved_by_helper_model"] += len(high_conf_idxs) stats["filters_resolved_by_large_model"] += len(low_conf_idxs) @@ -411,7 +411,7 @@ def __call__( assert output.logprobs is not None, "Logprobs must be returned if return_scores is True" formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(output.logprobs) scores = formatted_logprobs.true_probs - score_methods = ["LM"] * len(multimodal_data) + score_methods = ["LM_TOKEN_PROB"] * len(multimodal_data) else: scores = [0.0] * len(multimodal_data) score_methods = [""] * len(multimodal_data) @@ -433,16 +433,6 @@ def __call__( new_df = self._obj.iloc[ids] new_df.attrs["index_dirs"] = self._obj.attrs.get("index_dirs", None) else: - - def get_out_col_name(df, col_name): - if col_name in df.columns: - i = 1 - while f"{col_name}_{i}" in new_df.columns: - i += 1 - return f"{col_name}_{i}" - else: - return col_name - new_df = self._obj.copy() new_df[get_out_col_name(new_df, "filter_label")] = outputs filtered_explanations = explanations @@ -452,16 +442,16 @@ def get_out_col_name(df, col_name): # return rows where output is True if return_explanations and return_raw_outputs: - new_df["explanation" + suffix] = filtered_explanations - new_df["raw_output" + suffix] = filtered_raw_outputs + new_df[get_out_col_name(new_df, "explanation" + suffix)] = filtered_explanations + new_df[get_out_col_name(new_df, "raw_output" + suffix)] = filtered_raw_outputs elif return_explanations: - new_df["explanation" + suffix] = filtered_explanations + new_df[get_out_col_name(new_df, "explanation" + suffix)] = filtered_explanations elif return_raw_outputs: - new_df["raw_output" + suffix] = filtered_raw_outputs + new_df[get_out_col_name(new_df, "raw_output" + suffix)] = filtered_raw_outputs if return_scores: - new_df["score"] = filtered_scores - new_df["score_method"] = filtered_score_methods + new_df[get_out_col_name(new_df, "score")] = filtered_scores + new_df[get_out_col_name(new_df, "score_method")] = filtered_score_methods if return_stats: return new_df, stats diff --git a/lotus/utils.py b/lotus/utils.py index bf75748b..b1038744 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -132,3 +132,23 @@ def show_safe_mode(estimated_cost, estimated_LM_calls): except KeyboardInterrupt: print("\nExecution cancelled by user") exit(0) + + +def get_out_col_name(df: pd.DataFrame, col_name: str) -> str: + """ + Gets a unique column name by appending an index if the column already exists. + + Args: + df (pd.DataFrame): The DataFrame to check for existing columns. + col_name (str): The base column name. + + Returns: + str: A unique column name. + """ + if col_name in df.columns: + i = 1 + while f"{col_name}_{i}" in df.columns: + i += 1 + return f"{col_name}_{i}" + else: + return col_name diff --git a/tests/test_filter.py b/tests/test_filter.py index 03e97ae7..75822b47 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -124,6 +124,51 @@ def test_filter_with_scores_and_return_all(self, sample_df): for idx, row in result.iterrows(): if row["filter_label"]: - assert row["score"] > 0.5 + assert row["score"] >= 0.5 else: assert row["score"] <= 0.5 + + def test_filter_twice(self, sample_df): + """Test filtering twice to verify column name indexing works correctly""" + lm = LM(model="gpt-4o-mini") + lotus.settings.configure(lm=lm) + + # First filter + result = sample_df.sem_filter( + "{Course Name} is related to programming", + return_all=True, + return_explanations=True, + return_raw_outputs=True, + return_scores=True, + ) + print(f"First filter result: {result}") + + # Verify first filter columns + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + assert "raw_output_filter" in result.columns + assert "score" in result.columns + assert "score_method" in result.columns + + # Second filter + result = result.sem_filter( + "{Course Name} is related to programming", + return_all=True, + return_explanations=True, + return_raw_outputs=True, + return_scores=True, + ) + print(f"Second filter result: {result}") + # Verify second filter columns have indices + assert "filter_label_1" in result.columns + assert "explanation_filter_1" in result.columns + assert "raw_output_filter_1" in result.columns + assert "score_1" in result.columns + assert "score_method_1" in result.columns + + # Verify original columns still exist + assert "filter_label" in result.columns + assert "explanation_filter" in result.columns + assert "raw_output_filter" in result.columns + assert "score" in result.columns + assert "score_method" in result.columns