-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmask.py
More file actions
104 lines (98 loc) · 4.85 KB
/
mask.py
File metadata and controls
104 lines (98 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import dgl
import torch
import numpy as np
import os
import random
import pandas
# import bidict
from dgl.data import FraudAmazonDataset, FraudYelpDataset
from sklearn.model_selection import train_test_split
def set_seed(seed=3407):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
class Dataset:
def __init__(self, name='tfinance', homo=True, add_self_loop=True, to_bidirectional=False, to_simple=True):
if name == 'yelp':
dataset = FraudYelpDataset()
graph = dataset[0]
graph.ndata['train_mask'] = graph.ndata['train_mask'].bool()
graph.ndata['val_mask'] = graph.ndata['val_mask'].bool()
graph.ndata['test_mask'] = graph.ndata['test_mask'].bool()
if homo:
graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])
elif name == 'amazon':
dataset = FraudAmazonDataset()
graph = dataset[0]
graph.ndata['train_mask'] = graph.ndata['train_mask'].bool()
graph.ndata['val_mask'] = graph.ndata['val_mask'].bool()
graph.ndata['test_mask'] = graph.ndata['test_mask'].bool()
graph.ndata['mark'] = graph.ndata['train_mask']+graph.ndata['val_mask']+graph.ndata['test_mask']
if homo:
graph = dgl.to_homogeneous(dataset[0], ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask', 'mark'])
else:
graph = dgl.load_graphs('/home/yuhanli/wangpeisong/GADBench/datasets/horizon')[0][0]
graph.ndata['feature'] = graph.ndata['features'].float()
graph.ndata['label'] = graph.ndata['label'].long()
self.name = name
self.graph = graph
if add_self_loop:
self.graph = dgl.add_self_loop(self.graph)
if to_bidirectional:
self.graph = dgl.to_bidirected(self.graph, copy_ndata=True)
if to_simple:
self.graph = dgl.to_simple(self.graph)
def split(self, samples=2):
labels = self.graph.ndata['label']
n = self.graph.num_nodes()
if 'mark' in self.graph.ndata:
index = self.graph.ndata['mark'].nonzero()[:,0].numpy().tolist()
else:
index = list(range(n))
train_masks = torch.zeros([n,20]).bool()
val_masks = torch.zeros([n,20]).bool()
test_masks = torch.zeros([n,20]).bool()
if self.name in ['tolokers', 'questions']:
train_ratio, val_ratio = 0.5, 0.25
if self.name in ['tsocial', 'tfinance', 'reddit', 'weibo']:
train_ratio, val_ratio = 0.4, 0.2
if self.name in ['amazon', 'yelp', 'elliptic', 'dgraphfin']: # official split
train_masks[:,:10] = self.graph.ndata['train_mask'].repeat(10,1).T
val_masks[:,:10] = self.graph.ndata['val_mask'].repeat(10,1).T
test_masks[:,:10] = self.graph.ndata['test_mask'].repeat(10,1).T
else:
train_ratio, val_ratio = 0.7, 0.1
for i in range(10):
seed = 3407+10*i
set_seed(seed)
idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels[index], stratify=labels[index], train_size=train_ratio, random_state=seed, shuffle=True)
idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest, train_size=int(len(index)*val_ratio), random_state=seed, shuffle=True)
train_masks[idx_train,i] = 1
val_masks[idx_valid,i] = 1
test_masks[idx_test,i] = 1
# for i in range(10):
# pos_index = np.where(labels == 1)[0]
# neg_index = list(set(index) - set(pos_index))
# pos_train_idx = np.random.choice(pos_index, size=2*samples, replace=False)
# neg_train_idx = np.random.choice(neg_index, size=8*samples, replace=False)
# train_idx = np.concatenate([pos_train_idx[:samples], neg_train_idx[:4*samples]])
# train_masks[train_idx, 10+i] = 1
# val_idx = np.concatenate([pos_train_idx[samples:], neg_train_idx[4*samples:]])
# val_masks[val_idx, 10+i] = 1
# test_masks[index, 10+i] = 1
# test_masks[train_idx, 10+i] = 0
# test_masks[val_idx, 10+i] = 0
self.graph.ndata['train_masks'] = train_masks
self.graph.ndata['val_masks'] = val_masks
self.graph.ndata['test_masks'] = test_masks
for data_name in ['horizon']:
data = Dataset(data_name)
data.split()
print(data.graph)
print(data.graph.ndata['train_masks'].sum(0), data.graph.ndata['val_masks'].sum(0), data.graph.ndata['test_masks'].sum(0))
dgl.save_graphs(data_name, [data.graph])