-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_selector.py
More file actions
95 lines (79 loc) · 2.98 KB
/
batch_selector.py
File metadata and controls
95 lines (79 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import math
from typing import Callable
def selectBatchSize(N: int,
model_fn: Callable,
train_fn: Callable,
widen: int = 2) -> int:
"""
Select optimal batch size via actual training efficiency.
Compares powers-of-two around sqrt(N) and selects the one
with highest (accuracy / training time).
Parameters:
N (int): Number of training samples
model_fn (Callable): Function that returns a new model
train_fn (Callable): Function(model, batch_size) -> (accuracy, time)
widen (int): Number of powers-of-two to test on each side
Returns:
int: Best-performing batch size
"""
assert N > 0, "N must be a positive integer"
center_exp = round(math.log2(math.sqrt(N)))
candidates = [2 ** i for i in range(max(4, center_exp - widen), center_exp + widen + 1)]
results = []
for B in candidates:
acc, t = train_fn(model_fn(), B)
efficiency = acc / t
results.append((efficiency, B))
_, best_B = max(results, key=lambda x: x[0])
return best_B
def getBatchSizeMNIST(N: int, widen: int = 2) -> int:
"""
Returns best batch size for MNIST using selectBatchSize with fixed model and train function.
"""
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import time
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
return self.fc(x)
def model_fn():
return SimpleNet()
def train_fn(model, batch_size):
transform = transforms.ToTensor()
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="./data", train=False, transform=transform)
subset = Subset(train_data, list(range(min(N, 60000))))
train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=256)
model.train()
opt = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
start = time.time()
for xb, yb in train_loader:
out = model(xb)
loss = loss_fn(out, yb)
opt.zero_grad()
loss.backward()
opt.step()
elapsed = time.time() - start
model.eval()
correct = 0
total = 0
with torch.no_grad():
for xb, yb in test_loader:
preds = model(xb).argmax(dim=1)
correct += (preds == yb).sum().item()
total += yb.size(0)
accuracy = correct / total
return accuracy, elapsed
return selectBatchSize(N, model_fn, train_fn, widen)