Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions assets/ar/QA/PalmQA_Fanar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json

from llmebench.datasets import PaLMEvalDataset
from llmebench.models import OpenAIModel
from llmebench.tasks import MultiNativQATask


def metadata():
return {
"author": "UBC-NLP / Adapted by QCRI",
"model": "OpenAIModel",
"description": "Evaluation on PaLM dataset containing MSA and dialect instructions across 22 Arab countries.",
"scores": {},
}


def config():
return {
"dataset": PaLMEvalDataset,
"task": MultiNativQATask,
"model": OpenAIModel,
"general_args": {"test_split": "default"},
}


def prompt(input_sample):
# Define the question prompt
question_prompt = f"""
Please use your expertise to answer the following Arabic question. Answer in Arabic. Please provide Answer only. No additional text.

Question: {input_sample['question']}

"""

# Define the assistant prompt
assistant_prompt = """
You are an Arabic AI assistant specialized in providing detailed and accurate answers across various fields. Your task is to deliver clear, concise, and relevant information.
"""
return [
{"role": "user", "content": question_prompt},
{"role": "assistant", "content": assistant_prompt},
]


def post_process(response):
content = response["choices"][0]["message"]["content"].strip()
return content
16 changes: 8 additions & 8 deletions llmebench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def run_benchmark(self, dry_run=False):
predictions = []

num_processed = 0
full_summary_fp = open(full_summary_path, "w")
full_summary_fp = open(full_summary_path, "w", encoding="utf-8")

num_failed = 0
failed_summary_fp = open(failed_summary_path, "w")
failed_summary_fp = open(failed_summary_path, "w", encoding="utf-8")

for sample_idx, (input_sample, few_shot_examples) in enumerate(
zip_longest(data, few_shots_data, fillvalue=None)
Expand All @@ -222,7 +222,7 @@ def run_benchmark(self, dry_run=False):
cache_payload["few_shot_examples"] = few_shot_examples

if cache_path.exists() and not self.ignore_cache and not dry_run:
with open(cache_path, "r") as fp:
with open(cache_path, "r", encoding="utf-8") as fp:
cache_payload = json.load(fp)

summarized_payload = {
Expand Down Expand Up @@ -258,7 +258,7 @@ def run_benchmark(self, dry_run=False):
)

# Save the cache payload
with open(cache_path, "w") as fp:
with open(cache_path, "w", encoding="utf-8") as fp:
json.dump(cache_payload, fp, ensure_ascii=False)

full_summary_fp.close()
Expand All @@ -280,7 +280,7 @@ def run_benchmark(self, dry_run=False):

task_result_path = cache_dir / "results.json"

with open(task_result_path, "w") as fp:
with open(task_result_path, "w", encoding="utf-8") as fp:
json.dump(task_results, fp, ensure_ascii=False)

all_task_results[name] = task_results
Expand Down Expand Up @@ -497,10 +497,10 @@ def main():
all_results_path = args.results_dir / "all_results.json"

if not all_results_path.exists():
with open(all_results_path, "w") as fp:
with open(all_results_path, "w", encoding="utf-8") as fp:
json.dump({}, fp)

with open(all_results_path, "r") as fp:
with open(all_results_path, "r", encoding="utf-8") as fp:
all_results = json.load(fp)

for asset in assets:
Expand Down Expand Up @@ -544,5 +544,5 @@ def main():
logging.error(f"{name} failed to run")
traceback.print_exc()

with open(all_results_path, "w") as fp:
with open(all_results_path, "w", encoding="utf-8") as fp:
json.dump(all_results, fp, ensure_ascii=False)
2 changes: 1 addition & 1 deletion llmebench/datasets/ArSAS.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def load_data(self, data_path, no_labels=False):
data_path = self.resolve_path(data_path)

data = []
with open(data_path, "r") as fp:
with open(data_path, "r", encoding="utf-8") as fp:
for line_idx, line in enumerate(fp):
text, label = line.strip().split("\t")
data.append({"input": text, "label": label, "line_number": line_idx})
Expand Down
62 changes: 62 additions & 0 deletions llmebench/datasets/PalmQA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json

from llmebench.datasets.dataset_base import DatasetBase
from llmebench.tasks import TaskType


class PaLMEvalDataset(DatasetBase):
def __init__(self, **kwargs):
super(PaLMEvalDataset, self).__init__(**kwargs)

@staticmethod
def get_data_sample():
return {
"data_id": "1",
"input": {
"question": "من الملك الذي كان يتولى الحكم في الأردن عندما تم بناء مسجد الحسين؟"
},
"label": "بني مسجد الحسين في عهد الملك عبد الله الثاني.",
}

@staticmethod
def metadata():
return {
"language": "ar",
"citation": "Refer to PaLM eval paper",
"link": "https://github.com/UBC-NLP/palm",
"license": "",
"splits": {"default": {"test": "test.jsonl"}},
"task_type": TaskType.Other,
}

def load_data(self, data_path, no_labels=False):
data_path = self.resolve_path(data_path)
data = []

with open(data_path, encoding="utf-8") as f:
for line in f:
obj = json.loads(line)

# Concatenate instruction and input
instruction = obj.get("instruction") or ""
input_text = obj.get("input") or ""

full_prompt = f"{instruction.strip()} {input_text.strip()}".strip()

# Use "output" instead of "ideal"
output = obj.get("output")
if output is None:
print(f"Missing output for ID {obj.get('id')}")
output = ""

label = output

data.append(
{
"data_id": obj.get("id"),
"input": {"question": full_prompt},
"label": label,
}
)

return data
1 change: 1 addition & 0 deletions llmebench/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .OSACT4SubtaskA import OSACT4SubtaskADataset
from .OSACT4SubtaskB import OSACT4SubtaskBDataset
from .PADT import PADTDataset
from .PalmQA import PaLMEvalDataset
from .PIQA import PIQADataset
from .QADI import QADIDataset
from .QCRIDialectalArabicPOS import QCRIDialectalArabicPOSDataset
Expand Down