diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 7b12fa08..0e48970c 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -6,7 +6,7 @@ import lotus from lotus.models import LM, SentenceTransformersRM -from lotus.types import CascadeArgs +from lotus.types import CascadeArgs, ReasoningStrategy from lotus.vector_store import FaissVS ################################################################################ @@ -269,7 +269,7 @@ def test_filter_operation_cot(setup_models, model): } df = pd.DataFrame(data) user_instruction = "{Text} I have at least one apple" - filtered_df = df.sem_filter(user_instruction, strategy="cot") + filtered_df = df.sem_filter(user_instruction, strategy=ReasoningStrategy.ZS_COT) expected_df = pd.DataFrame({"Text": ["I had two apples, then I gave away one", "My friend gave me an apple"]}) assert filtered_df.equals(expected_df) @@ -302,7 +302,7 @@ def test_filter_operation_cot_fewshot(setup_models, model): user_instruction = "{Sequence} is increasing" filtered_df = df.sem_filter( user_instruction, - strategy="cot", + strategy=ReasoningStrategy.COT, examples=examples_df, additional_cot_instructions="Assume the most typical or logical case.", ) @@ -339,7 +339,7 @@ def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): examples_df = pd.DataFrame(examples) user_instruction = "{Sequence} is increasing" - filtered_df = df.sem_filter(user_instruction, strategy="cot", examples=examples_df) + filtered_df = df.sem_filter(user_instruction, strategy=ReasoningStrategy.ZS_COT, examples=examples_df) expected_df = pd.DataFrame( { "Sequence": [ @@ -352,6 +352,178 @@ def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): assert filtered_df.equals(expected_df) +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Text": [ + "I had two apples, then I gave away one", + "My friend gave me an apple", + "I gave away both of my apples", + "I gave away my apple, then a friend gave me his apple, then I threw my apple away", + ] + } + df = pd.DataFrame(data) + user_instruction = "{Text} I have at least one apple" + filtered_df = df.sem_filter(user_instruction, strategy=ReasoningStrategy.ZS_COT) + expected_df = pd.DataFrame({"Text": ["I had two apples, then I gave away one", "My friend gave me an apple"]}) + assert filtered_df.equals(expected_df) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot_fewshot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Five, Four, Three", + "A, B, C", + "Pond, Lake, Ocean", + ] + } + df = pd.DataFrame(data) + examples = { + "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], + "Answer": [True, True, True], + "Reasoning": [ + "1, 2, 3 is an increasing sequence of numbers", + "penny, nickel, dime, quarter is an increasing sequence of coins", + "villiage, town, city is an increasing sequence of settlements", + ], + } + examples_df = pd.DataFrame(examples) + + user_instruction = "{Sequence} is increasing" + filtered_df = df.sem_filter( + user_instruction, + strategy=ReasoningStrategy.COT, + examples=examples_df, + additional_cot_instructions="Assume the most typical or logical case.", + ) + expected_df = pd.DataFrame( + { + "Sequence": [ + "A, B, C", + "Pond, Lake, Ocean", + ] + }, + index=[1, 2], + ) + assert filtered_df.equals(expected_df) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_filter_operation_cot_fewshot_no_reasoning(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Five, Four, Three", + "A, B, C", + "Pond, Lake, Ocean", + ] + } + df = pd.DataFrame(data) + examples = { + "Sequence": ["1, 2, 3", "penny, nickel, dime, quarter", "villiage, town, city"], + "Answer": [True, True, True], + } + examples_df = pd.DataFrame(examples) + + user_instruction = "{Sequence} is increasing" + filtered_df = df.sem_filter(user_instruction, strategy=ReasoningStrategy.ZS_COT, examples=examples_df) + expected_df = pd.DataFrame( + { + "Sequence": [ + "A, B, C", + "Pond, Lake, Ocean", + ] + }, + index=[1, 2], + ) + assert filtered_df.equals(expected_df) + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_map_operation_cot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Alpha, Bravo, Charlie", + "One, Two, Three", + "Triangle, Square, Pentagon", + ] + } + df = pd.DataFrame(data) + user_instruction = "What should be the next item in the sequence: {Sequence}" + mapped_df = df.sem_map(user_instruction, strategy=ReasoningStrategy.ZS_COT) + expected_df = pd.DataFrame({"_map": ["Delta", "Four", "Hexagon"]}) + assert mapped_df["_map"].equals(expected_df["_map"]) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_map_operation_cot_fewshot(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Alpha, Bravo, Charlie", + "One, Two, Three", + "Triangle, Square, Pentagon", + ] + } + df = pd.DataFrame(data) + examples = { + "Sequence": ["A, B, C", "Kindergarten, First Grade, Second Grade"], + "Answer": ["D", "Third Grade"], + "Reasoning": [ + "D is the next letter in the alphabet after C", + "Third Grade is the next grade after Second Grade", + ], + } + examples_df = pd.DataFrame(examples) + user_instruction = "What should be the next item in the sequence: {Sequence}" + mapped_df = df.sem_map(user_instruction, strategy=ReasoningStrategy.COT, examples=examples_df) + expected_df = pd.DataFrame({"_map": ["Delta", "Four", "Hexagon"]}) + assert mapped_df["_map"].equals(expected_df["_map"]) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini", "ollama/llama3.1")) +def test_map_operation_cot_fewshot_no_reasoning(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + data = { + "Sequence": [ + "Alpha, Bravo, Charlie", + "One, Two, Three", + "Triangle, Square, Pentagon", + ] + } + df = pd.DataFrame(data) + examples = { + "Sequence": ["A, B, C", "Kindergarten, First Grade, Second Grade"], + "Answer": ["D", "Third Grade"], + } + examples_df = pd.DataFrame(examples) + user_instruction = "What should be the next item in the sequence: {Sequence}" + mapped_df = df.sem_map(user_instruction, strategy=ReasoningStrategy.ZS_COT, examples=examples_df) + expected_df = pd.DataFrame({"_map": ["Delta", "Four", "Hexagon"]}) + assert mapped_df["_map"].equals(expected_df["_map"]) + + ################################################################################ # Cascade tests ################################################################################ diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 3981e65f..364cf594 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -10,8 +10,8 @@ def cot_postprocessor(llm_answers: list[str]): - outputs: list[str | None] = [] - explanations: list[str | None] = [] + outputs: list[str] = [] + explanations: list[str] = [] for llm_answer in llm_answers: reasoning_idx = llm_answer.find("Reasoning:\n") if reasoning_idx == -1: @@ -19,13 +19,19 @@ def cot_postprocessor(llm_answers: list[str]): else: reasoning_idx += len("Reasoning:\n") - answer_idx = llm_answer.find("Answer:") + answer_idx = llm_answer.find("Answer: ") + if answer_idx == -1: + answer_idx = 0 + else: + answer_idx += len("Answer: ") + + reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") - answer = llm_answer[answer_idx + len("Answer:") :] + answer = llm_answer[answer_idx:].rstrip("\n").lstrip("\n") explanations.append(reasoning) outputs.append(answer) - + return outputs, explanations @@ -106,57 +112,21 @@ def get_cot_postprocessor(model: lotus.models.LM, for_extract: bool = False) -> return cot_postprocessor -def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: - """ - Postprocess the output of the map operator with CoT reasoning. - - Args: - llm_answers (list[str]): The list of llm answers. - - Returns: - SemanticMapPostprocessOutput - """ - outputs: list[str] = [] - explanations: list[str | None] = [] - - for llm_answer in llm_answers: - reasoning_idx = llm_answer.find("Reasoning:\n") - if reasoning_idx == -1: - reasoning_idx = 0 - else: - reasoning_idx += len("Reasoning:\n") - - answer_idx = llm_answer.find("Answer:") - reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") - answer = llm_answer[answer_idx + len("Answer:") :] - outputs.append(answer) - explanations.append(reasoning) - - return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) - - -def map_postprocess( - llm_answers: list[str], - model: lotus.models.LM, - cot_reasoning: bool = False, -) -> SemanticMapPostprocessOutput: +def map_postprocess(llm_answers: list[str], model: lotus.models.LM, cot_reasoning: bool = False, default: str = "") -> SemanticMapPostprocessOutput: """ Postprocess the output of the map operator. Args: llm_answers (list[str]): The list of llm answers. - cot_reasoning (bool): Whether there is CoT reasoning. + default (str): The default value to use if we fail to parse the answer. Returns: SemanticMapPostprocessOutput """ - if cot_reasoning: - postprocessor = get_cot_postprocessor(model) - outputs, explanations = postprocessor(llm_answers) - else: - outputs = llm_answers - explanations = [None] * len(llm_answers) + postprocessor = get_cot_postprocessor(model) + outputs, explanations = postprocessor(llm_answers) + outputs = [output if output is not None else default for output in outputs] return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index ca8908be..03cc47be 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -15,13 +15,15 @@ def sem_map( docs: list[dict[str, Any]], model: lotus.models.LM, user_instruction: str, - postprocessor: Callable[[list[str], lotus.models.LM, bool], SemanticMapPostprocessOutput] = map_postprocess, + postprocessor: Callable[[list[str], lotus.models.LM, bool, str], SemanticMapPostprocessOutput] = map_postprocess, + default: str = "", examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[str] | None = None, cot_reasoning: list[str] | None = None, strategy: ReasoningStrategy | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", + additional_cot_instructions: str = "", ) -> SemanticMapOutput: """ Maps a list of documents to a list of outputs using a model. @@ -34,6 +36,11 @@ def sem_map( examples_multimodal_data (list[dict[str, Any]] | None): The text for examples. Defaults to None. examples_answers (list[str] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. + additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". + strategy (str | None): The reasoning strategy. Defaults to None. + safe_mode (bool): Whether to use safe mode. Defaults to False. + progress_bar_desc (str): The description for the progress bar. Defaults to "Mapping". + default (str): The default value to use if we fail to parse the answer. Returns: SemanticMapOutput: The outputs, raw outputs, and explanations. @@ -43,7 +50,13 @@ def sem_map( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.map_formatter( - model, doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy=strategy + model, doc, + user_instruction, + examples_multimodal_data, + examples_answers, + cot_reasoning, + strategy=strategy, + reasoning_instructions=additional_cot_instructions, ) lotus.logger.debug(f"input to model: {prompt}") lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") @@ -60,7 +73,7 @@ def sem_map( # post process results postprocess_output = postprocessor( - lm_output.outputs, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT] + lm_output.outputs, model, strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT], default ) lotus.logger.debug(f"raw_outputs: {lm_output.outputs}") lotus.logger.debug(f"outputs: {postprocess_output.outputs}") @@ -92,7 +105,7 @@ def _validate(obj: pd.DataFrame) -> None: def __call__( self, user_instruction: str, - postprocessor: Callable[[list[str], lotus.models.LM, bool], SemanticMapPostprocessOutput] = map_postprocess, + postprocessor: Callable[[list[str], lotus.models.LM, bool, str], SemanticMapPostprocessOutput] = map_postprocess, return_explanations: bool = False, return_raw_outputs: bool = False, suffix: str = "_map", @@ -100,6 +113,8 @@ def __call__( strategy: ReasoningStrategy | None = None, safe_mode: bool = False, progress_bar_desc: str = "Mapping", + additional_cot_instructions: str = "", + default: str = "", ) -> pd.DataFrame: """ Applies semantic map over a dataframe. @@ -112,6 +127,10 @@ def __call__( suffix (str): The suffix for the new columns. Defaults to "_map". examples (pd.DataFrame | None): The examples dataframe. Defaults to None. strategy (str | None): The reasoning strategy. Defaults to None. + safe_mode (bool): Whether to use safe mode. Defaults to False. + progress_bar_desc (str): The description for the progress bar. Defaults to "Mapping". + additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". + default (str): The default value to use if we fail to parse the answer. Returns: pd.DataFrame: The dataframe with the new mapped columns. @@ -140,8 +159,7 @@ def __call__( examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() - if strategy == ReasoningStrategy.COT or strategy == ReasoningStrategy.ZS_COT: - return_explanations = True + if strategy == ReasoningStrategy.COT or strategy == ReasoningStrategy.ZS_COT and "Reasoning" in examples.columns: cot_reasoning = examples["Reasoning"].tolist() output = sem_map( @@ -155,6 +173,8 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc=progress_bar_desc, + additional_cot_instructions=additional_cot_instructions, + default=default, ) new_df = self._obj.copy() diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index bcc94119..30d4f156 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -9,11 +9,11 @@ def cot_formatter(reasoning, answer): - return f"""Reasoning:\n{reasoning}\n\nAnswer: {answer}""" + return f"""\n\nReasoning:\n{reasoning}\n\nAnswer: {answer}""" def answer_only_formatter(answer): - return f"""Answer: {answer}""" + return f"""\n\nAnswer: {answer}""" def deepseek_cot_formatter(): @@ -23,17 +23,17 @@ def deepseek_cot_formatter(): def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str: - reasoning_instructions = f"" - answer_instructions = f"" - return f"""Let's think step by step. Use the following format to provide your answer: - {cot_formatter(reasoning_instructions, answer_instructions)} + reasoning_placeholder = f"" + answer_placeholder = f"" + return f"""\n\nLet's think step by step. Use the following format to provide your answer: + {cot_formatter(reasoning_placeholder, answer_placeholder)} """ def non_cot_prompt_formatter(answer_instructions: str = "") -> str: - answer_instructions = f"" - return f"""Use the following format to provide your answer: - {answer_only_formatter(answer_instructions)} + answer_placeholder = f"" + return f"""\n\nUse the following format to provide your answer: + {answer_only_formatter(answer_placeholder)} """ @@ -153,57 +153,6 @@ def filter_formatter( return messages -def map_formatter_cot( - multimodal_data: dict[str, Any], - user_instruction: str, - examples_multimodal_data: list[dict[str, Any]], - examples_answer: list[str], - cot_reasoning: list[str], -) -> list[dict[str, str]]: - sys_instruction = ( - "The user will provide an instruction and some relevant context.\n" - "Your job is to answer the user's instruction given the context." - "You must give your reasoning and then your final answer" - ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] - - for idx in range(len(examples_multimodal_data)): - ex_df_txt = examples_multimodal_data[idx] - ex_ans = examples_answer[idx] - cot = cot_reasoning[idx] - messages.extend( - [ - user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), - { - "role": "assistant", - "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", - }, - ] - ) - - messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) - return messages - - -def map_formatter_zs_cot( - multimodal_data: dict[str, Any], - user_instruction: str, -) -> list[dict[str, str]]: - sys_instruction = ( - "The user will provide an instruction and some relevant context.\n" - "Your job is to answer the user's instruction given the context." - 'First give your reasoning. Then you MUST end your output with "Answer: your answer"' - ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] - - messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) - return messages - - def map_formatter( model: lotus.models.LM, multimodal_data: dict[str, Any], @@ -212,18 +161,22 @@ def map_formatter( examples_answer: list[str] | None = None, cot_reasoning: list[str] | None = None, strategy: ReasoningStrategy | str | None = None, + reasoning_instructions: str = "", ) -> list[dict[str, str]]: - sys_instruction = ( - "The user will provide an instruction and some relevant context.\n" - "Your job is to answer the user's instruction given the context." - ) - if cot_reasoning: - assert examples_multimodal_data is not None and examples_answer is not None - return map_formatter_cot( - multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning + """ + Creates a map formatter with chain-of-thought reasoning. + Supports both few-shot CoT (with examples) and zero-shot CoT. + """ + sys_instruction = """The user will provide an instruction and some relevant context. + Your job is to answer the user's instruction given the context. + """ + + if strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT]: + sys_instruction += cot_prompt_formatter( + reasoning_instructions=reasoning_instructions, ) - elif strategy == ReasoningStrategy.ZS_COT: - return map_formatter_zs_cot(multimodal_data, user_instruction) + else: + sys_instruction += non_cot_prompt_formatter() messages = [ {"role": "system", "content": sys_instruction}, @@ -231,11 +184,34 @@ def map_formatter( if examples_multimodal_data: assert examples_answer is not None - for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer): + assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) + assert len(examples_multimodal_data) == len(examples_answer) + + if cot_reasoning: + # If CoT reasoning examples are provided, use them + assert isinstance(cot_reasoning, list) + assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) + + for idx in range(len(examples_multimodal_data)): + ex_df_txt = examples_multimodal_data[idx] + ex_ans = examples_answer[idx] + content = "" + + # If CoT reasoning is provided, use it. Otherwise, supply a default reasoning + if cot_reasoning: + content = cot_formatter(cot_reasoning[idx], str(ex_ans)) + elif strategy in [ReasoningStrategy.COT, ReasoningStrategy.ZS_COT]: + content = cot_formatter("Reasoning omitted", str(ex_ans)) + else: + content = answer_only_formatter(str(ex_ans)) + messages.extend( [ user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), - {"role": "assistant", "content": str(ex_ans)}, + { + "role": "assistant", + "content": content, + }, ] )