-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathdata.py
More file actions
127 lines (103 loc) · 6.78 KB
/
data.py
File metadata and controls
127 lines (103 loc) · 6.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
from datasets import load_dataset
from typing import Tuple
import pandas as pd
class RarePrompt:
def __init__(self) -> None:
self.system_prompt = "You are a specialist in the field of rare diseases."
self.diagnosis_system_prompt = self.system_prompt + \
" You will be provided and asked about a complicated clinical case; \
read it carefully and then provide a diverse and comprehensive differential diagnosis. \
Also, you will be provided some knowledge about the patient's phenotype and online diagnosis suggestions as reference, please read it carefully."
def diagnosis_prompt(self, patient_info: str) -> Tuple[str, str]:
info_type = "phenotype"
prompt = ""
prompt += f"Patient's {info_type}: {patient_info}\n"
prompt += "Enumerate the top 5 most likely diagnoses. Be precise, and try to cover many unique possibilities. "
prompt += "Each diagnosis should be a rare disease. "
prompt += "Use ## to tag the disease name. "
prompt += "Make sure to reorder the diagnoses from most likely to least likely. "
prompt += "The top 5 diagnoses are:"
return (self.diagnosis_system_prompt, prompt)
class RareDataset():
def __init__(self, args) -> None:
self.dataset_name = args.dataset_name
self.dataset_path = args.dataset_path
self.phenotype_mapping = json.load(open(args.phenotype_mapping, "r", encoding="utf-8-sig"))
self.disease_mapping = json.load(open(args.disease_mapping, "r", encoding="utf-8-sig"))
if self.dataset_name in ["RAMEDIS", "MME", "HMS", "LIRICAL"]:
self.data = load_dataset(self.dataset_path, self.dataset_name, split='test', trust_remote_code=True)
elif self.dataset_name == 'Xinhua':
self.data = pd.read_csv('dataset/xinhua_test_0331.csv')
elif self.dataset_name == 'MIMIC':
self.data = pd.read_csv('dataset/mimic_test.csv', sep='|')
elif self.dataset_name == 'mygene':
self.data = pd.read_csv('dataset/mygene_test.csv')
elif self.dataset_name == 'DDD':
self.data = pd.read_csv('dataset/ddd_test.csv')
elif self.dataset_name == 'case':
self.data = pd.read_csv('dataset/cases.csv')
elif self.dataset_name == 'hunan':
self.data = pd.read_csv('dataset/hunan_cases.csv')
self.patient = self.load_ehr_phenotype_data()
def load_ehr_phenotype_data(self):
patient = []
if self.dataset_name in ["RAMEDIS", "MME", "HMS", "LIRICAL"]:
for p in self.data:
phenotype_list = p['Phenotype']
disease_list = p['RareDisease']
phenotype_list_ = [self.phenotype_mapping[phenotype] for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
disease_list = [self.disease_mapping[disease] for disease in disease_list if disease in self.disease_mapping]
phenotype_id = [phenotype for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
phenotype = ", ".join(phenotype_list_)
disease = ", ".join(disease_list)
patient.append((phenotype, disease, phenotype_list_, phenotype_id))
elif self.dataset_name in ['MIMIC']:
for p in self.data.iterrows():
phenotype_list = eval(p[1]['HPO'])
disease_list = eval(p[1]['orpha'])
phenotype_list_ = [self.phenotype_mapping[phenotype] for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
disease_list = [self.disease_mapping[disease] for disease in disease_list if disease in self.disease_mapping]
phenotype_id = [phenotype for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
phenotype = ", ".join(phenotype_list_)
disease = ", ".join(disease_list)
patient.append((phenotype, disease, phenotype_list_, phenotype_id))
elif self.dataset_name in ['Xinhua']:
for p in self.data.iterrows():
phenotype_list = eval(p[1]['hpo'])
disease_list = eval(p[1]['orpha'])
phenotype_list_ = [self.phenotype_mapping[phenotype] for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
disease_list = [self.disease_mapping[disease] for disease in disease_list if disease in self.disease_mapping]
phenotype_id = [phenotype for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
phenotype = ", ".join(phenotype_list_)
disease = ", ".join(disease_list)
if 'vcf_path' in p[1]:
vcf_path = p[1]['vcf_path']
patient.append((phenotype, disease, phenotype_list_, phenotype_id, vcf_path))
else:
patient.append((phenotype, disease, phenotype_list_, phenotype_id))
elif self.dataset_name in ['mygene', 'DDD']:
for p in self.data.iterrows():
phenotype_list = eval(p[1]['phenotype'])
disease_list = eval(p[1]['rare_disease'])
phenotype_list_ = [self.phenotype_mapping[phenotype] for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
disease_list = [self.disease_mapping[disease] for disease in disease_list if disease in self.disease_mapping]
phenotype_id = [phenotype for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
phenotype = ", ".join(phenotype_list_)
disease = ", ".join(disease_list)
patient.append((phenotype, disease, phenotype_list_, phenotype_id))
elif self.dataset_name in ['hunan']:
for p in self.data.iterrows():
phenotype_list = p[1]['hpo'].split('|')
disease_list = eval(p[1]['disease'])
phenotype_list_ = [self.phenotype_mapping[phenotype] for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
disease_list = disease_list # Assuming diseases are already in readable format
phenotype_id = [phenotype for phenotype in phenotype_list if phenotype in self.phenotype_mapping]
phenotype = ", ".join(phenotype_list_)
disease = ", ".join(disease_list)
if 'vcf_path' in p[1]:
vcf_path = p[1]['vcf_path']
patient.append((phenotype, disease, phenotype_list_, phenotype_id, vcf_path))
else:
patient.append((phenotype, disease, phenotype_list_, phenotype_id))
return patient