diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index bc9b490b..0015f777 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -7,7 +7,7 @@ import lotus from lotus.models import LM, SentenceTransformersRM -from lotus.types import CascadeArgs +from lotus.types import CascadeArgs, ProxyModel from lotus.vector_store import FaissVS ################################################################################ @@ -712,3 +712,23 @@ def test_pairwise_judge(setup_models, model): ) assert list(df["_judge_0"].values) == ["A", "B"] assert list(df["_judge_1"].values) == ["A", "B"] + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_sem_filter_cascade_rejects_non_single_token_output_tokens(setup_models, model): + """Cascade filtering requires each output token string to encode to exactly one token id.""" + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + df = pd.DataFrame({"Text": ["hello"]}) + cascade = CascadeArgs( + proxy_model=ProxyModel.EMBEDDING_MODEL, + filter_pos_cascade_threshold=0.9, + filter_neg_cascade_threshold=0.1, + ) + with pytest.raises(ValueError, match="single token"): + df.sem_filter( + "{Text} is positive", + cascade_args=cascade, + output_tokens=("Column A", "Column B"), + ) diff --git a/lotus/evals/pairwise_judge.py b/lotus/evals/pairwise_judge.py index 57796ad1..8a3638fe 100644 --- a/lotus/evals/pairwise_judge.py +++ b/lotus/evals/pairwise_judge.py @@ -11,14 +11,14 @@ def _unique_col_names(existing_columns: pd.Index) -> tuple[str, str]: - """Pick col_A / col_B names that don't collide with existing columns.""" - base_a, base_b = "col_A", "col_B" + """Pick A / B names that don't collide with existing columns.""" + base_a, base_b = "A", "B" if base_a not in existing_columns and base_b not in existing_columns: return base_a, base_b i = 1 while True: - candidate_a = f"{base_a}_{i}" - candidate_b = f"{base_b}_{i}" + candidate_a = f"{base_a}{i}" + candidate_b = f"{base_b}{i}" if candidate_a not in existing_columns and candidate_b not in existing_columns: return candidate_a, candidate_b i += 1 @@ -240,6 +240,8 @@ def _run_trial(i: int): output_df = output output_df = output_df.drop(columns=[c for c in renamed_columns if c in output_df.columns]) for col_name in output_df.columns: + if col_name.startswith("raw_output") or col_name.startswith("explanation"): + continue output_df[col_name] = output_df[col_name].map({True: "A", False: "B"}) all_output_df.append(output_df) new_df = self._obj.copy() diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index c5ac4839..850bb03f 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -433,12 +433,28 @@ def __call__( proxy_model: ProxyModel | None = None helper_output: SemanticFilterOutput | None = None if cascade_args: + for token_str in output_tokens: + token_ids = lotus.settings.lm.encode_text(token_str) + if len(token_ids) != 1: + raise ValueError( + f"Output token '{token_str}' encodes to {len(token_ids)} tokens with the main LM. " + f"Cascade requires each output token to be a single token." + ) + proxy_model = cascade_args.proxy_model # Get the proxy scores if proxy_model == ProxyModel.HELPER_LM: if not lotus.settings.helper_lm: raise ValueError("Helper LM must be set in settings") + for token_str in output_tokens: + token_ids = lotus.settings.helper_lm.encode_text(token_str) + if len(token_ids) != 1: + raise ValueError( + f"Output token '{token_str}' encodes to {len(token_ids)} tokens with the helper LM. " + f"Cascade requires each output token to be a single token." + ) + if helper_strategy == ReasoningStrategy.COT or helper_strategy == ReasoningStrategy.ZS_COT: raise ValueError("CoT not supported for helper models in cascades.") diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 994657bc..4b6cdff1 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -104,16 +104,20 @@ def filter_formatter( """ sys_instruction = system_prompt or default_sys_instruction - if strategy == ReasoningStrategy.COT: - sys_instruction += cot_prompt_formatter( - reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions - ) - elif strategy == ReasoningStrategy.ZS_COT: - sys_instruction += cot_prompt_formatter( - reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions - ) + if system_prompt: + sys_instruction = system_prompt else: - sys_instruction += non_cot_prompt_formatter(answer_instructions=answer_instructions) + sys_instruction = default_sys_instruction + if strategy == ReasoningStrategy.COT: + sys_instruction += cot_prompt_formatter( + reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions + ) + elif strategy == ReasoningStrategy.ZS_COT: + sys_instruction += cot_prompt_formatter( + reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions + ) + else: + sys_instruction += non_cot_prompt_formatter(answer_instructions=answer_instructions) messages = [ {"role": "system", "content": sys_instruction},