1+ import torch
2+ from torch import nn
3+ from torch .optim import Adam
4+ from torch .utils .data import DataLoader
5+ from torchvision import transforms
6+ import pytorch_lightning as pl
7+
8+ from plinear import layers
9+
10+ # LightningModule 정의
11+ class LitModel (pl .LightningModule ):
12+ def __init__ (self , linear = nn .Linear , conv2d = nn .Conv2d ):
13+ super ().__init__ ()
14+
15+ self .model = rcvit_xs (num_classes = 100 , linear = linear , conv2d = conv2d )
16+ self .criterion = nn .CrossEntropyLoss ()
17+
18+ def forward (self , x ):
19+ return self .model (x )
20+
21+ def training_step (self , batch , batch_idx ):
22+ x , y = batch
23+ logits = self (x )
24+ loss = self .criterion (logits , y )
25+
26+ preds = torch .argmax (logits , dim = 1 )
27+ acc = (preds == y ).float ().mean ()
28+
29+ # Batch 단위 결과 로깅
30+ self .log ("train_loss" , loss , prog_bar = True )
31+ self .log ("train_acc" , acc , prog_bar = True )
32+ return {"loss" : loss , "accuracy" : acc }
33+
34+ def train_epoch_end (self , outputs ):
35+ # 에포크 단위 결과 요약
36+ avg_loss = torch .stack ([x ["loss" ] for x in outputs ]).mean ()
37+ avg_acc = torch .stack ([x ["accuracy" ] for x in outputs ]).mean ()
38+ print (f"\n Epoch End - Train Loss: { avg_loss :.4f} , Train Accuracy: { avg_acc :.4f} " )
39+
40+ def validation_step (self , batch , batch_idx ):
41+ x , y = batch
42+ logits = self (x )
43+ loss = self .criterion (logits , y )
44+
45+ preds = torch .argmax (logits , dim = 1 )
46+ acc = (preds == y ).float ().mean ()
47+
48+ # Batch 단위 결과 로깅
49+ self .log ("val_acc" , acc , prog_bar = True )
50+ return {"val_accuracy" : acc }
51+
52+ def validation_step_epoch_end (self , outputs ):
53+ # 에포크 단위 결과 요약
54+ avg_acc = torch .stack ([x ["val_accuracy" ] for x in outputs ]).mean ()
55+ print (f"Validation Accuracy: { avg_acc :.4f} " )
56+
57+ def test_step (self , batch , batch_idx ):
58+ x , y = batch
59+ logits = self (x )
60+ loss = self .criterion (logits , y )
61+ preds = torch .argmax (logits , dim = 1 )
62+ acc = (preds == y ).float ().mean ()
63+ self .log ("test_accuracy" , acc )
64+ return {"test_accuracy" : acc }
65+
66+ def configure_optimizers (self ):
67+ return Adam (self .parameters (), lr = 0.001 )
68+
69+ """# Data Module"""
70+
71+ import pytorch_lightning as pl
72+ from torch .utils .data import DataLoader
73+ from torchvision import transforms , datasets
74+
75+ class CIFAR100DataModule (pl .LightningDataModule ):
76+ def __init__ (self , data_dir = "./" , batch_size = 64 ):
77+ super ().__init__ ()
78+ self .data_dir = data_dir
79+ self .batch_size = batch_size
80+
81+ # 데이터 전처리 변환 정의
82+ self .transform_train = transforms .Compose ([
83+ transforms .ToTensor (),
84+ ])
85+
86+ self .transform_test = transforms .Compose ([
87+ transforms .ToTensor (),
88+ ])
89+
90+ def prepare_data (self ):
91+ # 데이터셋 다운로드
92+ datasets .MNIST (self .data_dir , train = True , download = True )
93+ datasets .MNIST (self .data_dir , train = False , download = True )
94+
95+ def setup (self , stage = None ):
96+ # 데이터셋 정의
97+ if stage == "fit" or stage is None :
98+ self .cifar_train = datasets .CIFAR100 (self .data_dir , train = True , transform = self .transform_train )
99+ self .cifar_val = datasets .CIFAR100 (self .data_dir , train = True , transform = self .transform_test )
100+
101+ # 훈련/검증 나누기
102+ val_size = 5000
103+ train_size = len (self .cifar_train ) - val_size
104+ self .cifar_train , self .cifar_val = torch .utils .data .random_split (self .cifar_train , [train_size , val_size ])
105+
106+ if stage == "test" or stage is None :
107+ self .cifar_test = datasets .CIFAR100 (self .data_dir , train = False , transform = self .transform_test )
108+
109+ def train_dataloader (self ):
110+ return DataLoader (self .cifar_train , batch_size = self .batch_size , shuffle = True , num_workers = 4 )
111+
112+ def val_dataloader (self ):
113+ return DataLoader (self .cifar_val , batch_size = self .batch_size , shuffle = False , num_workers = 4 )
114+
115+ def test_dataloader (self ):
116+ return DataLoader (self .cifar_test , batch_size = self .batch_size , shuffle = False , num_workers = 4 )
117+
118+ """# Train Test"""
119+
120+ # 학습 및 테스트 실행
121+ if __name__ == "__main__" :
122+ model = LitModel (linear = btnnLinear , conv2d = btnnConv2d )
123+ data_module = CIFAR100DataModule (batch_size = 64 )
124+
125+ device = "gpu" if torch .cuda .is_available () else "cpu"
126+ print (device )
127+ # 학습
128+ logger = pl .loggers .CSVLogger ("logs" , name = "CIFAR100_BT_CASViT" )
129+ trainer = pl .Trainer (max_epochs = 5 , devices = 1 , accelerator = device , logger = logger )
130+ trainer .fit (model , data_module )
131+
132+ # 테스트
133+ test_results = trainer .test (datamodule = data_module )
134+ print ("Test Results:" , test_results )
135+
136+ # 학습 및 테스트 실행
137+ if __name__ == "__main__" :
138+ model = LitModel ()
139+ data_module = CIFAR100DataModule (batch_size = 64 )
140+
141+ device = "gpu" if torch .cuda .is_available () else "cpu"
142+ print (device )
143+ # 학습
144+ logger = pl .loggers .CSVLogger ("logs" , name = "CIFAR100_CASViT" )
145+ trainer = pl .Trainer (max_epochs = 5 , devices = 1 , accelerator = device , logger = logger )
146+ trainer .fit (model , data_module )
147+
148+ # 테스트
149+ test_results = trainer .test (datamodule = data_module )
150+ print ("Test Results:" , test_results )
151+
152+ import pandas as pd
153+ import matplotlib .pyplot as plt
154+
155+ # 두 모델의 metrics.csv 파일 경로
156+ paths = {
157+ "BT_CASViT" : "/content/logs/CIFAR100_BT_CASViT/version_0/metrics.csv" ,
158+ "CASViT" : "/content/logs/CIFAR100_CASViT/version_0/metrics.csv"
159+ }
160+
161+ # 데이터를 로드하고 처리
162+ data = {}
163+ for model_name , path in paths .items ():
164+ df = pd .read_csv (path )
165+
166+ # train 데이터와 validation 데이터 필터링
167+ df_train = df [df ["train_acc" ].notna ()] # train 데이터 (step 단위)
168+ df_val = df [df ["val_acc" ].notna ()] # validation 데이터 (epoch 단위)
169+
170+ data [model_name ] = {
171+ "train_step" : df_train ["step" ].values ,
172+ "train_acc" : df_train ["train_acc" ].values ,
173+ "train_loss" : df_train ["train_loss" ].values ,
174+ "val_epoch" : range (1 , len (df_val ) + 1 ), # epoch 번호
175+ "val_acc" : df_val ["val_acc" ].values
176+ }
177+
178+ save_dir = "/content/logs/"
179+
180+ # Training Accuracy 그래프
181+ plt .figure (figsize = (12 , 6 ))
182+ for model_name in data :
183+ plt .plot (data [model_name ]["train_step" ], data [model_name ]["train_acc" ], label = f"{ model_name } Train Accuracy" )
184+ plt .title ("Training Accuracy" )
185+ plt .xlabel ("Step" )
186+ plt .ylabel ("Accuracy" )
187+ plt .legend ()
188+ plt .grid (True )
189+ plt .savefig (f"{ save_dir } training_accuracy.png" )
190+ plt .close ()
191+
192+ # Training Loss 그래프
193+ plt .figure (figsize = (12 , 6 ))
194+ for model_name in data :
195+ plt .plot (data [model_name ]["train_step" ], data [model_name ]["train_loss" ], label = f"{ model_name } Train Loss" )
196+ plt .title ("Training Loss" )
197+ plt .xlabel ("Step" )
198+ plt .ylabel ("Loss" )
199+ plt .legend ()
200+ plt .grid (True )
201+ plt .savefig (f"{ save_dir } training_loss.png" )
202+ plt .close ()
203+
204+ # Validation Accuracy 그래프
205+ plt .figure (figsize = (12 , 6 ))
206+ for model_name in data :
207+ plt .plot (data [model_name ]["val_epoch" ], data [model_name ]["val_acc" ], marker = 'o' , linestyle = '--' , label = f"{ model_name } Validation Accuracy" )
208+ plt .title ("Validation Accuracy" )
209+ plt .xlabel ("Epoch" )
210+ plt .ylabel ("Accuracy" )
211+ plt .legend ()
212+ plt .grid (True )
213+ plt .savefig (f"{ save_dir } validation_accuracy.png" )
214+ plt .close ()
0 commit comments