Skip to content
Merged
22 changes: 21 additions & 1 deletion .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import CascadeArgs
from lotus.types import CascadeArgs, ProxyModel
from lotus.vector_store import FaissVS

################################################################################
Expand Down Expand Up @@ -712,3 +712,23 @@ def test_pairwise_judge(setup_models, model):
)
assert list(df["_judge_0"].values) == ["A", "B"]
assert list(df["_judge_1"].values) == ["A", "B"]


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a start, but we need more than one CI test for cascade testing. We should cover edge cases (e.g., when there's an existing col A and col B)

def test_sem_filter_cascade_rejects_non_single_token_output_tokens(setup_models, model):
"""Cascade filtering requires each output token string to encode to exactly one token id."""
lm = setup_models[model]
lotus.settings.configure(lm=lm)

df = pd.DataFrame({"Text": ["hello"]})
cascade = CascadeArgs(
proxy_model=ProxyModel.EMBEDDING_MODEL,
filter_pos_cascade_threshold=0.9,
filter_neg_cascade_threshold=0.1,
)
with pytest.raises(ValueError, match="single token"):
df.sem_filter(
"{Text} is positive",
cascade_args=cascade,
output_tokens=("Column A", "Column B"),
)
10 changes: 6 additions & 4 deletions lotus/evals/pairwise_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@


def _unique_col_names(existing_columns: pd.Index) -> tuple[str, str]:
"""Pick col_A / col_B names that don't collide with existing columns."""
base_a, base_b = "col_A", "col_B"
"""Pick A / B names that don't collide with existing columns."""
base_a, base_b = "A", "B"
if base_a not in existing_columns and base_b not in existing_columns:
return base_a, base_b
i = 1
while True:
candidate_a = f"{base_a}_{i}"
candidate_b = f"{base_b}_{i}"
candidate_a = f"{base_a}{i}"
candidate_b = f"{base_b}{i}"
if candidate_a not in existing_columns and candidate_b not in existing_columns:
return candidate_a, candidate_b
i += 1
Expand Down Expand Up @@ -240,6 +240,8 @@ def _run_trial(i: int):
output_df = output
output_df = output_df.drop(columns=[c for c in renamed_columns if c in output_df.columns])
for col_name in output_df.columns:
if col_name.startswith("raw_output") or col_name.startswith("explanation"):
continue
output_df[col_name] = output_df[col_name].map({True: "A", False: "B"})
all_output_df.append(output_df)
new_df = self._obj.copy()
Expand Down
16 changes: 16 additions & 0 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,28 @@ def __call__(
proxy_model: ProxyModel | None = None
helper_output: SemanticFilterOutput | None = None
if cascade_args:
for token_str in output_tokens:
token_ids = lotus.settings.lm.encode_text(token_str)
if len(token_ids) != 1:
raise ValueError(
f"Output token '{token_str}' encodes to {len(token_ids)} tokens with the main LM. "
f"Cascade requires each output token to be a single token."
)

proxy_model = cascade_args.proxy_model
# Get the proxy scores
if proxy_model == ProxyModel.HELPER_LM:
if not lotus.settings.helper_lm:
raise ValueError("Helper LM must be set in settings")

for token_str in output_tokens:
token_ids = lotus.settings.helper_lm.encode_text(token_str)
if len(token_ids) != 1:
raise ValueError(
f"Output token '{token_str}' encodes to {len(token_ids)} tokens with the helper LM. "
f"Cascade requires each output token to be a single token."
)

if helper_strategy == ReasoningStrategy.COT or helper_strategy == ReasoningStrategy.ZS_COT:
raise ValueError("CoT not supported for helper models in cascades.")

Expand Down
22 changes: 13 additions & 9 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,20 @@ def filter_formatter(
"""
sys_instruction = system_prompt or default_sys_instruction

if strategy == ReasoningStrategy.COT:
sys_instruction += cot_prompt_formatter(
reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions
)
elif strategy == ReasoningStrategy.ZS_COT:
sys_instruction += cot_prompt_formatter(
reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions
)
if system_prompt:
sys_instruction = system_prompt
else:
sys_instruction += non_cot_prompt_formatter(answer_instructions=answer_instructions)
sys_instruction = default_sys_instruction
if strategy == ReasoningStrategy.COT:
sys_instruction += cot_prompt_formatter(
reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions
)
elif strategy == ReasoningStrategy.ZS_COT:
sys_instruction += cot_prompt_formatter(
reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions
)
else:
sys_instruction += non_cot_prompt_formatter(answer_instructions=answer_instructions)

messages = [
{"role": "system", "content": sys_instruction},
Expand Down
Loading