-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_pseudo_triplets_task.py
More file actions
73 lines (61 loc) · 2.23 KB
/
create_pseudo_triplets_task.py
File metadata and controls
73 lines (61 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import pandas as pd
import json
import networkx as nx
from collections import defaultdict
import numpy.random as rng
def main(n_samples_within=5000, n_samples_across=5000):
G = nx.read_gpickle("../generated-data/ppc_yeast")
nodes = sorted(G.nodes())
node_ix = dict(zip(nodes, range(len(nodes))))
with open('../generated-data/yeast_complexes.json', 'r') as f:
genes_to_groups = json.load(f)
genes_to_group = { g: list(v)[0] for g,v in genes_to_groups.items() if len(v) == 1}
group_to_genes = defaultdict(set)
eligible_genes = set()
for gene, group in genes_to_group.items():
if gene in node_ix:
group_to_genes[group].add(gene)
eligible_genes.add(gene)
eligible_genes = list(eligible_genes)
triplets_in_same_complex = set()
groups = list(group_to_genes)
groups = [g for g in groups if len(group_to_genes[g]) >= 3]
while len(triplets_in_same_complex) < n_samples_within:
# pick a random complex
group = rng.choice(groups)
# sample three genes
genes = list(group_to_genes[group])
triplet = tuple(sorted(rng.choice(genes, size=3, replace=False)))
triplets_in_same_complex.add(triplet)
triplets_in_diff_complexes = set()
while len(triplets_in_diff_complexes) < n_samples_across:
triplet = tuple(sorted(rng.choice(eligible_genes, size=3, replace=False)))
a,b,c = triplet
if (genes_to_group[a] != genes_to_group[b]) \
or (genes_to_group[a] != genes_to_group[c]) \
or (genes_to_group[b] != genes_to_group[c]):
triplets_in_diff_complexes.add(triplet)
rows = [{
"a" : a,
"b" : b,
"c" : c,
"bin" : 0,
"a_id" : node_ix[a],
"b_id" : node_ix[b],
"c_id" : node_ix[c] }
for a,b,c in triplets_in_same_complex
] + [ {
"a" : a,
"b" : b,
"c" : c,
"bin" : 1,
"a_id" : node_ix[a],
"b_id" : node_ix[b],
"c_id" : node_ix[c] }
for a,b,c in triplets_in_diff_complexes
]
df = pd.DataFrame(rows)
df.to_csv('../generated-data/pseudo_triplets')
if __name__ == "__main__":
main()