Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/op_examples/filter_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import lotus
from lotus.models import LM, LiteLLMRM
from lotus.types import CascadeArgs, ProxyModel
from lotus.vector_store import FaissVS

gpt_4o_mini = LM("gpt-4o-mini")
gpt_4o = LM("gpt-4o")
rm = LiteLLMRM(model="text-embedding-3-small")
vs = FaissVS()

lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm)
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm, vs=vs)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to my logic changes. But vs needs to be set here for the example to run, given the recent merge.

data = {
"Course Name": [
"Probability and Random Processes",
Expand Down
48 changes: 33 additions & 15 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def _update_stats(self, response: ModelResponse, is_cached: bool = False):
except Exception as e:
# Handle any other unexpected errors when calculating cost
lotus.logger.debug(f"Unexpected error calculating completion cost: {e}")
raise Warning("Error calculating completion cost - cost metrics will be inaccurate. Enable debug logging for details.")
raise Warning(
"Error calculating completion cost - cost metrics will be inaccurate. Enable debug logging for details."
)
cost = None

# Always update virtual usage
Expand Down Expand Up @@ -234,27 +236,43 @@ def format_logprobs_for_filter_cascade(
) -> LogprobsForFilterCascade:
# Get base cascade format first
base_cascade = self.format_logprobs_for_cascade(logprobs)
all_true_probs = []
all_true_probs: list[float] = []

def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None:
if "True" in token_probs and "False" in token_probs:
true_prob = token_probs["True"]
false_prob = token_probs["False"]
# Normalize keys by converting to lowercase and stripping whitespace
# Take the max probability for each key (e.g. True and true both map to "true" so we take the max of the two)
normalized_probs: dict[str, float] = {}
for k, v in token_probs.items():
normalized_key = k.lower().strip()
normalized_probs[normalized_key] = max(v, normalized_probs.get(normalized_key, float("-inf")))

# Look for true/false in normalized keys
if "true" in normalized_probs and "false" in normalized_probs:
true_prob = normalized_probs["true"]
false_prob = normalized_probs["false"]
return true_prob / (true_prob + false_prob)
return None

# Get true probabilities for filter cascade
for resp_idx, response_logprobs in enumerate(logprobs):
true_prob = None
for logprob in response_logprobs:
for response_logprobs in logprobs:
true_prob: float = 1 # Default if no true/false token found

# Find last true/false token by normalizing all tokens and searching in reverse
cleaned_tokens = [logprob.token.lower().strip() for logprob in response_logprobs]
true_false_indices = [i for i, token in enumerate(cleaned_tokens) if token in ["true", "false"]]

if true_false_indices:
last_true_false_idx = true_false_indices[-1]
logprob = response_logprobs[last_true_false_idx]
cleaned_token = cleaned_tokens[last_true_false_idx]

token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs}
true_prob = get_normalized_true_prob(token_probs)
if true_prob is not None:
break
normalized_prob = get_normalized_true_prob(token_probs)

# Default to 1 if "True" in tokens, 0 if not
if true_prob is None:
true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0
# Fall back to binary true/false if normalization fails
if normalized_prob is None:
true_prob = 1 if cleaned_token == "true" else 0
else:
true_prob = normalized_prob

all_true_probs.append(true_prob)

Expand Down
58 changes: 42 additions & 16 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, ProxyModel, SemanticFilterOutput
from lotus.utils import show_safe_mode
from lotus.utils import get_out_col_name, show_safe_mode

from .cascade_utils import calibrate_llm_logprobs, importance_sampling, learn_cascade_thresholds
from .postprocessors import filter_postprocess
Expand Down Expand Up @@ -162,6 +162,7 @@ def __call__(
strategy: str | None = None,
cascade_args: CascadeArgs | None = None,
return_stats: bool = False,
return_scores: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Filtering",
additional_cot_instructions: str = "",
Expand All @@ -183,6 +184,7 @@ def __call__(
sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1.
failure_probability (float): The failure probability when cascading. Defaults to 0.2.
return_stats (bool): Whether to return statistics. Defaults to False.
return_scores (bool): Whether to return probabilities. Defaults to False.
additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "".

Returns:
Expand All @@ -193,7 +195,7 @@ def __call__(
"The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()"
)

stats: dict[str, float] = {}
stats: dict[str, Any] = {}
lotus.logger.debug(user_instruction)
col_li = lotus.nl_expression.parse_cols(user_instruction)
lotus.logger.debug(col_li)
Expand Down Expand Up @@ -330,9 +332,16 @@ def __call__(
outputs: list[bool] = [False] * len(multimodal_data)
raw_outputs: list[str] = [""] * len(multimodal_data)
explanations: list[str | None] = [None] * len(multimodal_data)
scores: list[float] = [0.0] * len(multimodal_data)
score_methods: list[str] = [""] * len(multimodal_data)

for idx in high_conf_idxs:
outputs[idx] = proxy_outputs[idx]
scores[idx] = proxy_scores[idx]
if proxy_model == ProxyModel.HELPER_LM:
score_methods[idx] = "HELPER_LM_TOKEN_PROB"
else:
score_methods[idx] = "RM_SIM_SCORE"

# If using helper LM, get raw outputs and explanations
if proxy_model == ProxyModel.HELPER_LM:
Expand All @@ -358,13 +367,22 @@ def __call__(
strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc="Running predicate evals with oracle LM",
logprobs=return_scores,
additional_cot_instructions=additional_cot_instructions,
)

if return_scores:
assert large_output.logprobs is not None, "Logprobs must be returned if return_scores is True"
formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(large_output.logprobs)
large_probs = formatted_logprobs.true_probs

for idx, large_idx in enumerate(low_conf_idxs):
outputs[large_idx] = large_output.outputs[idx]
raw_outputs[large_idx] = large_output.raw_outputs[idx]
explanations[large_idx] = large_output.explanations[idx]
if return_scores:
scores[large_idx] = large_probs[idx]
score_methods[large_idx] = "LM_TOKEN_PROB"

stats["filters_resolved_by_helper_model"] += len(high_conf_idxs)
stats["filters_resolved_by_large_model"] += len(low_conf_idxs)
Expand All @@ -382,12 +400,22 @@ def __call__(
safe_mode=safe_mode,
show_progress_bar=True,
progress_bar_desc=progress_bar_desc,
logprobs=return_scores,
additional_cot_instructions=additional_cot_instructions,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
explanations = output.explanations

if return_scores:
assert output.logprobs is not None, "Logprobs must be returned if return_scores is True"
formatted_logprobs = lotus.settings.lm.format_logprobs_for_filter_cascade(output.logprobs)
scores = formatted_logprobs.true_probs
score_methods = ["LM_TOKEN_PROB"] * len(multimodal_data)
else:
scores = [0.0] * len(multimodal_data)
score_methods = [""] * len(multimodal_data)

if not return_all:
# find indices where output is True
ids = [i for i, x in enumerate(outputs) if x]
Expand All @@ -398,34 +426,32 @@ def __call__(
[outputs[i] for i in ids]
filtered_explanations = [explanations[i] for i in ids]
filtered_raw_outputs = [raw_outputs[i] for i in ids]
filtered_scores = [scores[i] for i in ids]
filtered_score_methods = [score_methods[i] for i in ids]
lotus.logger.debug(f"filtered_raw_outputs: {filtered_raw_outputs}")

new_df = self._obj.iloc[ids]
new_df.attrs["index_dirs"] = self._obj.attrs.get("index_dirs", None)
else:

def get_out_col_name(df, col_name):
if col_name in df.columns:
i = 1
while f"{col_name}_{i}" in new_df.columns:
i += 1
return f"{col_name}_{i}"
else:
return col_name

new_df = self._obj.copy()
new_df[get_out_col_name(new_df, "filter_label")] = outputs
filtered_explanations = explanations
filtered_raw_outputs = raw_outputs
filtered_scores = scores
filtered_score_methods = score_methods

# return rows where output is True
if return_explanations and return_raw_outputs:
new_df["explanation" + suffix] = filtered_explanations
new_df["raw_output" + suffix] = filtered_raw_outputs
new_df[get_out_col_name(new_df, "explanation" + suffix)] = filtered_explanations
new_df[get_out_col_name(new_df, "raw_output" + suffix)] = filtered_raw_outputs
elif return_explanations:
new_df["explanation" + suffix] = filtered_explanations
new_df[get_out_col_name(new_df, "explanation" + suffix)] = filtered_explanations
elif return_raw_outputs:
new_df["raw_output" + suffix] = filtered_raw_outputs
new_df[get_out_col_name(new_df, "raw_output" + suffix)] = filtered_raw_outputs

if return_scores:
new_df[get_out_col_name(new_df, "score")] = filtered_scores
new_df[get_out_col_name(new_df, "score_method")] = filtered_score_methods

if return_stats:
return new_df, stats
Expand Down
20 changes: 20 additions & 0 deletions lotus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,23 @@ def show_safe_mode(estimated_cost, estimated_LM_calls):
except KeyboardInterrupt:
print("\nExecution cancelled by user")
exit(0)


def get_out_col_name(df: pd.DataFrame, col_name: str) -> str:
"""
Gets a unique column name by appending an index if the column already exists.

Args:
df (pd.DataFrame): The DataFrame to check for existing columns.
col_name (str): The base column name.

Returns:
str: A unique column name.
"""
if col_name in df.columns:
i = 1
while f"{col_name}_{i}" in df.columns:
i += 1
return f"{col_name}_{i}"
else:
return col_name
73 changes: 73 additions & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import pytest

import lotus
from lotus.models import LM
from tests.base_test import BaseTest


Expand Down Expand Up @@ -99,3 +101,74 @@ def test_filtered_search_with_scores(self, sample_df):
assert len(result["vec_scores_sim_score"]) == 2
# Scores should be between 0 and 1
assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"])


class TestFilterWithScores(BaseTest):
def test_filter_with_scores(self, sample_df):
"""Test semantic filter with scores returned to the user"""
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)
result = sample_df.sem_filter("{Course Name} will be fun", return_scores=True)
print(result)
assert "score" in result.columns
assert "score_method" in result.columns

def test_filter_with_scores_and_return_all(self, sample_df):
"""Test semantic filter with scores returned to the user"""
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)
result = sample_df.sem_filter("{Course Name} will be fun", return_scores=True, return_all=True)
print(result)
assert "score" in result.columns
assert "score_method" in result.columns

for idx, row in result.iterrows():
if row["filter_label"]:
assert row["score"] >= 0.5
else:
assert row["score"] <= 0.5

def test_filter_twice(self, sample_df):
"""Test filtering twice to verify column name indexing works correctly"""
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

# First filter
result = sample_df.sem_filter(
"{Course Name} is related to programming",
return_all=True,
return_explanations=True,
return_raw_outputs=True,
return_scores=True,
)
print(f"First filter result: {result}")

# Verify first filter columns
assert "filter_label" in result.columns
assert "explanation_filter" in result.columns
assert "raw_output_filter" in result.columns
assert "score" in result.columns
assert "score_method" in result.columns

# Second filter
result = result.sem_filter(
"{Course Name} is related to programming",
return_all=True,
return_explanations=True,
return_raw_outputs=True,
return_scores=True,
)
print(f"Second filter result: {result}")
# Verify second filter columns have indices
assert "filter_label_1" in result.columns
assert "explanation_filter_1" in result.columns
assert "raw_output_filter_1" in result.columns
assert "score_1" in result.columns
assert "score_method_1" in result.columns

# Verify original columns still exist
assert "filter_label" in result.columns
assert "explanation_filter" in result.columns
assert "raw_output_filter" in result.columns
assert "score" in result.columns
assert "score_method" in result.columns