-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathutils.py
More file actions
97 lines (88 loc) · 3.68 KB
/
utils.py
File metadata and controls
97 lines (88 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import re
import ast
import json
import datasets
import importlib
from tqdm import tqdm
from dotenv import load_dotenv
from typing import List, Dict, Literal, Tuple
os.environ.setdefault("TQDM_ASCII", "1")
os.environ.setdefault("TQDM_DYNAMIC_NCOLS", "0")
os.environ.setdefault("TQDM_NCOLS", "80")
if hasattr(datasets, "disable_progress_bars"):
datasets.disable_progress_bars()
elif hasattr(datasets, "disable_progress_bar"):
datasets.disable_progress_bar()
def change_dialsim_conversation_to_locomo_form(raw_text) -> Tuple[Dict, int]:
"""
Change DialSim conversation corpus to Locomo conversation corpus format.
Args:
raw_text: original raw text of DialSim conversation
Returns:
conversation: the converted conversation dict
session_cnt: the number of sessions in the conversation
"""
conversation = {}
session_pattern = re.compile(r"\[Date: (.*?), Session #(\d+)\]\n\n(.*?)(?=(?:\[Date:)|$)", re.S)
sessions = session_pattern.findall(raw_text)
for sid, session in enumerate(sessions, start=1):
date_str, session_num, session_text = session
session_date_time = f"{date_str}, Session #{session_num}"
conversation[f"session_{sid}_date_time"] = session_date_time
sess = []
lines = session_text.strip().split("\n")
for idx, line in enumerate(lines, start=1):
# 匹配 "Speaker: text"
match = re.match(r"^(.*?):\s*(.*)$", line)
if match:
speaker, text = match.groups()
sess.append({
"speaker": speaker.strip(),
"dia_id": f"D{session_num}:{idx}",
"text": text.strip()
})
conversation[f"session_{sid}"] = sess
return conversation, len(sessions)
def convert_str_to_obj(example):
for col in example.keys():
if col.startswith("dialog") or col.startswith("implicit_feedback") or col in ["input_chat_messages", "info"]:
if isinstance(example[col], str):
try:
example[col] = ast.literal_eval(example[col])
except (ValueError, SyntaxError):
try:
example[col] = json.loads(example[col])
except Exception:
pass
if "Locomo" in example["dataset_name"]:
if example["info"]["category"] == 5:
example["info"]["golden_answer"] = json.dumps(example["info"]["golden_answer"])
else:
example["info"]["golden_answer"] = str(example["info"]["golden_answer"])
return example
def load_from_hf(dataset_name: str):
hf_datasets_path = os.getenv("MEMORY_BENCH_PATH", "THUIR/MemoryBench")
dataset = datasets.load_dataset(hf_datasets_path, dataset_name)
dataset = dataset.map(convert_str_to_obj)
if "Locomo" in dataset_name or "DialSim" in dataset_name:
corpus = datasets.load_dataset(hf_datasets_path, data_files=f"corpus/{dataset_name}.jsonl")
corpus_text = corpus["train"][0]['text']
if "Locomo" in dataset_name:
corpus = json.loads(corpus_text)["conversation"]
for session_idx in range(1, len(corpus.keys())):
session_key = f"session_{session_idx}"
if not session_key in corpus:
session_cnt = session_idx - 1
break
elif "DialSim" in dataset_name:
corpus, session_cnt = change_dialsim_conversation_to_locomo_form(corpus_text)
else:
corpus = None
session_cnt = None
return {
"dataset_name": dataset_name,
"dataset": dataset,
"corpus": corpus,
"session_cnt": session_cnt
}