From ebcc4438d51d3110b43b85eb39e7353a190ed012 Mon Sep 17 00:00:00 2001 From: Lucas Resck <41991486+lucasresck@users.noreply.github.com> Date: Wed, 20 Nov 2024 18:26:15 +0000 Subject: [PATCH] Fix `MMLUEval` answer extraction regex Find all occurrences of the regular expression pattern, then take the last one; allow overlapping. --- common.py | 2 +- mmlu_eval.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common.py b/common.py index e035f5d4..79fe1d9b 100644 --- a/common.py +++ b/common.py @@ -23,7 +23,7 @@ ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( - "(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" + "(?i)(?=(?:{}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])))" ) # All the different ways "Answer" is written in different languages MULTILINGUAL_ANSWER_REGEXES = [ diff --git a/mmlu_eval.py b/mmlu_eval.py index 90b83287..6fe1fcc7 100644 --- a/mmlu_eval.py +++ b/mmlu_eval.py @@ -105,9 +105,10 @@ def fn(row: dict): extracted_answer = None for answer_regex in MULTILINGUAL_ANSWER_REGEXES: regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) - match = re.search(regex, response_text) - if match: - extracted_answer = normalize_extracted_answer(match.group(1)) + matches = re.findall(regex, response_text) + if matches: + match = matches[-1] + extracted_answer = normalize_extracted_answer(match) break score = 1.0 if extracted_answer == row["Answer"] else 0.0 html = common.jinja_env.from_string(HTML_JINJA).render(