Skip to content

Commit 0c63099

Browse files
authored
Update train.py
1 parent e87a11a commit 0c63099

1 file changed

Lines changed: 4 additions & 18 deletions

File tree

Experiments/rl_nand_network/train.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torchvision import datasets, transforms
66
from torch.utils.data import DataLoader
77

8-
# 1️⃣ 데이터 전처리 및 로더
98
transform = transforms.Compose([
109
transforms.ToTensor(),
1110
transforms.Lambda(lambda x: (x > 0).float()) # Binarize
@@ -14,19 +13,18 @@
1413
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
1514
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
1615

17-
# 2️⃣ SparseBtnn_Selector 정의
1816
def check_binary_tensor(tensor, tensor_name="tensor"):
1917
assert torch.all((tensor == 0) | (tensor == 1)), f"{tensor_name} contains values other than -1 or 1! as {tensor}"
2018

2119
def my_hard_mask(logits):
2220
softmax_probs = torch.softmax(logits, dim=1)
2321
hard_mask = torch.zeros_like(softmax_probs).scatter_(1, softmax_probs.argmax(dim=1, keepdim=True), 1.0)
24-
return hard_mask - softmax_probs.detach() + softmax_probs # STE 적용
22+
return hard_mask - softmax_probs.detach() + softmax_probs # STE
2523

2624
def my_action_mask(logits, action):
2725
softmax_probs = torch.softmax(logits, dim=1)
2826
hard_mask = torch.zeros_like(softmax_probs).scatter_(1, action, 1.0)
29-
return hard_mask - softmax_probs.detach() + softmax_probs # STE 적용
27+
return hard_mask - softmax_probs.detach() + softmax_probs # STE
3028

3129
class SparseBtnn_Selector(nn.Module):
3230
def __init__(self, x, y):
@@ -41,18 +39,14 @@ def forward(self, x, reward=None, train_mode=False):
4139
if train_mode and reward is not None:
4240
action = torch.multinomial(action_probs, num_samples=1)
4341
policy_loss = -10 * torch.log(action_probs.gather(1, action)) * reward
44-
policy_loss.mean().backward(retain_graph=True) # retain_graph 사용
42+
policy_loss.mean().backward(retain_graph=True) # retain_graph
4543
mask = my_action_mask(selector_logits, action)
4644
else:
47-
mask = my_hard_mask(selector_logits) # STE 적용 마스크
45+
mask = my_hard_mask(selector_logits) # had mask
4846
out = F.linear(x, mask)
4947
check_binary_tensor(out, "selector")
5048
return out
5149

52-
53-
54-
55-
# 3️⃣ Compacted_Nand 정의
5650
class Compacted_Nand(nn.Module):
5751
def __init__(self, x, y):
5852
super().__init__()
@@ -66,12 +60,9 @@ def forward(self, x, reward=None, train_mode=False):
6660
check_binary_tensor(out, "nand")
6761
return out
6862

69-
# 4️⃣ 보상 함수 정의
7063
def get_reward(prediction, target):
71-
"""정확한 예측에는 +1 보상, 틀린 예측에는 -1 처벌"""
7264
return torch.where(prediction == target, torch.tensor(1.0, device=prediction.device), torch.tensor(-1.0, device=prediction.device))
7365

74-
# 5️⃣ 학습 루프
7566
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7667
model = Compacted_Nand(28*28, 10).to(device)
7768
optimizer = optim.Adam(model.parameters(), lr=0.001)
@@ -85,13 +76,8 @@ def get_reward(prediction, target):
8576
for batch_idx, (data, target) in enumerate(train_loader):
8677
data, target = data.to(device).view(data.size(0), -1), target.to(device)
8778

88-
# 1️⃣ 순전파 (logits 생성)
8979
logits = model(data, train_mode=True)
90-
91-
# 2️⃣ 보상 계산 (logits가 정의된 후)
9280
reward = get_reward(logits.argmax(dim=1), target)
93-
94-
# 3️⃣ 손실 계산 및 역전파
9581
loss = criterion(logits, target)
9682
optimizer.zero_grad()
9783
loss.backward()

0 commit comments

Comments
 (0)