forked from ducngg/tijepa
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmake_mvsa.py
More file actions
executable file
·69 lines (61 loc) · 1.81 KB
/
make_mvsa.py
File metadata and controls
executable file
·69 lines (61 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from eval_on_mvsa import encode_dataset, train_simple_linear_module, MVSA
import os
import sys
TRAIN = "SMALL-A100-448-600-10k-OBS-SCHEDULER-1744373585"
EP = 300
def to_tensor():
import torchvision.transforms.functional as F
from torchvision import transforms
# Tensorize by set tensor_folder=None
MVSA(
batch_size=1,
img_size=448,
device='cuda:0',
transform=transforms.Compose(
[
transforms.ToTensor()
]
),
tensor_folder=None
)
def encode_all():
# os.makedirs(f"trains/{TRAIN}/tensors")
# os.makedirs(f"trains/{TRAIN}/tensors/{EP}-epoch-target")
# os.makedirs(f"trains/{TRAIN}/tensors/{EP}-epoch-context")
encode_dataset(
checkpoint_path=f"trains/{TRAIN}/epoch-{EP}.pt",
batch_size=200,
device='cuda:0',
save_path=f"trains/{TRAIN}/tensors/{EP}-epoch-target",
crosser_type='target',
tensor_folder='src/datasets/mvsa-tensor-448-new',
)
# encode_dataset(
# checkpoint_path=f"trains/{TRAIN}/epoch-{EP}.pt",
# batch_size=200,
# device='cuda:0',
# save_path=f"trains/{TRAIN}/{EP}-epoch-context",
# crosser_type='context',
# tensor_folder='src/datasets/mvsa-tensor-448',
# )
def train():
train_simple_linear_module(
save_path=f"trains/{TRAIN}/tensors/{EP}-epoch-target",
hidden_size=1024,
batch_size=128,
epochs=50,
device='cuda:0',
seed=200, # 100
lr=1e-3
)
# train_simple_linear_module(
# save_path=f"trains/{TRAIN}/tensors/{EP}-epoch-context",
# hidden_size=768,
# batch_size=512,
# epochs=50,
# device='cuda:0',
# seed=100
# )
if __name__ == "__main__":
# encode_all()
train()