55from torchvision import datasets , transforms
66from torch .utils .data import DataLoader
77
8- # 1️⃣ 데이터 전처리 및 로더
98transform = transforms .Compose ([
109 transforms .ToTensor (),
1110 transforms .Lambda (lambda x : (x > 0 ).float ()) # Binarize
1413train_dataset = datasets .MNIST ("./data" , train = True , download = True , transform = transform )
1514train_loader = DataLoader (train_dataset , batch_size = 64 , shuffle = True )
1615
17- # 2️⃣ SparseBtnn_Selector 정의
1816def check_binary_tensor (tensor , tensor_name = "tensor" ):
1917 assert torch .all ((tensor == 0 ) | (tensor == 1 )), f"{ tensor_name } contains values other than -1 or 1! as { tensor } "
2018
2119def my_hard_mask (logits ):
2220 softmax_probs = torch .softmax (logits , dim = 1 )
2321 hard_mask = torch .zeros_like (softmax_probs ).scatter_ (1 , softmax_probs .argmax (dim = 1 , keepdim = True ), 1.0 )
24- return hard_mask - softmax_probs .detach () + softmax_probs # STE 적용
22+ return hard_mask - softmax_probs .detach () + softmax_probs # STE
2523
2624def my_action_mask (logits , action ):
2725 softmax_probs = torch .softmax (logits , dim = 1 )
2826 hard_mask = torch .zeros_like (softmax_probs ).scatter_ (1 , action , 1.0 )
29- return hard_mask - softmax_probs .detach () + softmax_probs # STE 적용
27+ return hard_mask - softmax_probs .detach () + softmax_probs # STE
3028
3129class SparseBtnn_Selector (nn .Module ):
3230 def __init__ (self , x , y ):
@@ -41,18 +39,14 @@ def forward(self, x, reward=None, train_mode=False):
4139 if train_mode and reward is not None :
4240 action = torch .multinomial (action_probs , num_samples = 1 )
4341 policy_loss = - 10 * torch .log (action_probs .gather (1 , action )) * reward
44- policy_loss .mean ().backward (retain_graph = True ) # retain_graph 사용
42+ policy_loss .mean ().backward (retain_graph = True ) # retain_graph
4543 mask = my_action_mask (selector_logits , action )
4644 else :
47- mask = my_hard_mask (selector_logits ) # STE 적용 마스크
45+ mask = my_hard_mask (selector_logits ) # had mask
4846 out = F .linear (x , mask )
4947 check_binary_tensor (out , "selector" )
5048 return out
5149
52-
53-
54-
55- # 3️⃣ Compacted_Nand 정의
5650class Compacted_Nand (nn .Module ):
5751 def __init__ (self , x , y ):
5852 super ().__init__ ()
@@ -66,12 +60,9 @@ def forward(self, x, reward=None, train_mode=False):
6660 check_binary_tensor (out , "nand" )
6761 return out
6862
69- # 4️⃣ 보상 함수 정의
7063def get_reward (prediction , target ):
71- """정확한 예측에는 +1 보상, 틀린 예측에는 -1 처벌"""
7264 return torch .where (prediction == target , torch .tensor (1.0 , device = prediction .device ), torch .tensor (- 1.0 , device = prediction .device ))
7365
74- # 5️⃣ 학습 루프
7566device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
7667model = Compacted_Nand (28 * 28 , 10 ).to (device )
7768optimizer = optim .Adam (model .parameters (), lr = 0.001 )
@@ -85,13 +76,8 @@ def get_reward(prediction, target):
8576 for batch_idx , (data , target ) in enumerate (train_loader ):
8677 data , target = data .to (device ).view (data .size (0 ), - 1 ), target .to (device )
8778
88- # 1️⃣ 순전파 (logits 생성)
8979 logits = model (data , train_mode = True )
90-
91- # 2️⃣ 보상 계산 (logits가 정의된 후)
9280 reward = get_reward (logits .argmax (dim = 1 ), target )
93-
94- # 3️⃣ 손실 계산 및 역전파
9581 loss = criterion (logits , target )
9682 optimizer .zero_grad ()
9783 loss .backward ()
0 commit comments