Skip to content

Commit 34e4de7

Browse files
committed
nand layer updated
1 parent 3a0a655 commit 34e4de7

4 files changed

Lines changed: 280 additions & 1 deletion

File tree

Experiments/mnist/exp.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import torch
2+
from torch import nn
3+
from torch.optim import Adam
4+
from torch.utils.data import DataLoader
5+
from torchvision import transforms
6+
import pytorch_lightning as pl
7+
8+
from plinear import layers
9+
10+
# LightningModule 정의
11+
class LitModel(pl.LightningModule):
12+
def __init__(self, linear = nn.Linear, conv2d = nn.Conv2d):
13+
super().__init__()
14+
15+
self.model = rcvit_xs(num_classes=100, linear = linear, conv2d = conv2d)
16+
self.criterion = nn.CrossEntropyLoss()
17+
18+
def forward(self, x):
19+
return self.model(x)
20+
21+
def training_step(self, batch, batch_idx):
22+
x, y = batch
23+
logits = self(x)
24+
loss = self.criterion(logits, y)
25+
26+
preds = torch.argmax(logits, dim=1)
27+
acc = (preds == y).float().mean()
28+
29+
# Batch 단위 결과 로깅
30+
self.log("train_loss", loss, prog_bar=True)
31+
self.log("train_acc", acc, prog_bar=True)
32+
return {"loss": loss, "accuracy": acc}
33+
34+
def train_epoch_end(self, outputs):
35+
# 에포크 단위 결과 요약
36+
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
37+
avg_acc = torch.stack([x["accuracy"] for x in outputs]).mean()
38+
print(f"\nEpoch End - Train Loss: {avg_loss:.4f}, Train Accuracy: {avg_acc:.4f}")
39+
40+
def validation_step(self, batch, batch_idx):
41+
x, y = batch
42+
logits = self(x)
43+
loss = self.criterion(logits, y)
44+
45+
preds = torch.argmax(logits, dim=1)
46+
acc = (preds == y).float().mean()
47+
48+
# Batch 단위 결과 로깅
49+
self.log("val_acc", acc, prog_bar=True)
50+
return {"val_accuracy": acc}
51+
52+
def validation_step_epoch_end(self, outputs):
53+
# 에포크 단위 결과 요약
54+
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
55+
print(f"Validation Accuracy: {avg_acc:.4f}")
56+
57+
def test_step(self, batch, batch_idx):
58+
x, y = batch
59+
logits = self(x)
60+
loss = self.criterion(logits, y)
61+
preds = torch.argmax(logits, dim=1)
62+
acc = (preds == y).float().mean()
63+
self.log("test_accuracy", acc)
64+
return {"test_accuracy": acc}
65+
66+
def configure_optimizers(self):
67+
return Adam(self.parameters(), lr=0.001)
68+
69+
"""# Data Module"""
70+
71+
import pytorch_lightning as pl
72+
from torch.utils.data import DataLoader
73+
from torchvision import transforms, datasets
74+
75+
class CIFAR100DataModule(pl.LightningDataModule):
76+
def __init__(self, data_dir="./", batch_size=64):
77+
super().__init__()
78+
self.data_dir = data_dir
79+
self.batch_size = batch_size
80+
81+
# 데이터 전처리 변환 정의
82+
self.transform_train = transforms.Compose([
83+
transforms.ToTensor(),
84+
])
85+
86+
self.transform_test = transforms.Compose([
87+
transforms.ToTensor(),
88+
])
89+
90+
def prepare_data(self):
91+
# 데이터셋 다운로드
92+
datasets.MNIST(self.data_dir, train=True, download=True)
93+
datasets.MNIST(self.data_dir, train=False, download=True)
94+
95+
def setup(self, stage=None):
96+
# 데이터셋 정의
97+
if stage == "fit" or stage is None:
98+
self.cifar_train = datasets.CIFAR100(self.data_dir, train=True, transform=self.transform_train)
99+
self.cifar_val = datasets.CIFAR100(self.data_dir, train=True, transform=self.transform_test)
100+
101+
# 훈련/검증 나누기
102+
val_size = 5000
103+
train_size = len(self.cifar_train) - val_size
104+
self.cifar_train, self.cifar_val = torch.utils.data.random_split(self.cifar_train, [train_size, val_size])
105+
106+
if stage == "test" or stage is None:
107+
self.cifar_test = datasets.CIFAR100(self.data_dir, train=False, transform=self.transform_test)
108+
109+
def train_dataloader(self):
110+
return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True, num_workers=4)
111+
112+
def val_dataloader(self):
113+
return DataLoader(self.cifar_val, batch_size=self.batch_size, shuffle=False, num_workers=4)
114+
115+
def test_dataloader(self):
116+
return DataLoader(self.cifar_test, batch_size=self.batch_size, shuffle=False, num_workers=4)
117+
118+
"""# Train Test"""
119+
120+
# 학습 및 테스트 실행
121+
if __name__ == "__main__":
122+
model = LitModel(linear = btnnLinear, conv2d = btnnConv2d)
123+
data_module = CIFAR100DataModule(batch_size=64)
124+
125+
device = "gpu" if torch.cuda.is_available() else "cpu"
126+
print(device)
127+
# 학습
128+
logger = pl.loggers.CSVLogger("logs", name = "CIFAR100_BT_CASViT")
129+
trainer = pl.Trainer(max_epochs=5, devices = 1, accelerator=device, logger = logger)
130+
trainer.fit(model, data_module)
131+
132+
# 테스트
133+
test_results = trainer.test(datamodule=data_module)
134+
print("Test Results:", test_results)
135+
136+
# 학습 및 테스트 실행
137+
if __name__ == "__main__":
138+
model = LitModel()
139+
data_module = CIFAR100DataModule(batch_size=64)
140+
141+
device = "gpu" if torch.cuda.is_available() else "cpu"
142+
print(device)
143+
# 학습
144+
logger = pl.loggers.CSVLogger("logs", name = "CIFAR100_CASViT")
145+
trainer = pl.Trainer(max_epochs=5, devices = 1, accelerator=device, logger = logger)
146+
trainer.fit(model, data_module)
147+
148+
# 테스트
149+
test_results = trainer.test(datamodule=data_module)
150+
print("Test Results:", test_results)
151+
152+
import pandas as pd
153+
import matplotlib.pyplot as plt
154+
155+
# 두 모델의 metrics.csv 파일 경로
156+
paths = {
157+
"BT_CASViT": "/content/logs/CIFAR100_BT_CASViT/version_0/metrics.csv",
158+
"CASViT": "/content/logs/CIFAR100_CASViT/version_0/metrics.csv"
159+
}
160+
161+
# 데이터를 로드하고 처리
162+
data = {}
163+
for model_name, path in paths.items():
164+
df = pd.read_csv(path)
165+
166+
# train 데이터와 validation 데이터 필터링
167+
df_train = df[df["train_acc"].notna()] # train 데이터 (step 단위)
168+
df_val = df[df["val_acc"].notna()] # validation 데이터 (epoch 단위)
169+
170+
data[model_name] = {
171+
"train_step": df_train["step"].values,
172+
"train_acc": df_train["train_acc"].values,
173+
"train_loss": df_train["train_loss"].values,
174+
"val_epoch": range(1, len(df_val) + 1), # epoch 번호
175+
"val_acc": df_val["val_acc"].values
176+
}
177+
178+
save_dir = "/content/logs/"
179+
180+
# Training Accuracy 그래프
181+
plt.figure(figsize=(12, 6))
182+
for model_name in data:
183+
plt.plot(data[model_name]["train_step"], data[model_name]["train_acc"], label=f"{model_name} Train Accuracy")
184+
plt.title("Training Accuracy")
185+
plt.xlabel("Step")
186+
plt.ylabel("Accuracy")
187+
plt.legend()
188+
plt.grid(True)
189+
plt.savefig(f"{save_dir}training_accuracy.png")
190+
plt.close()
191+
192+
# Training Loss 그래프
193+
plt.figure(figsize=(12, 6))
194+
for model_name in data:
195+
plt.plot(data[model_name]["train_step"], data[model_name]["train_loss"], label=f"{model_name} Train Loss")
196+
plt.title("Training Loss")
197+
plt.xlabel("Step")
198+
plt.ylabel("Loss")
199+
plt.legend()
200+
plt.grid(True)
201+
plt.savefig(f"{save_dir}training_loss.png")
202+
plt.close()
203+
204+
# Validation Accuracy 그래프
205+
plt.figure(figsize=(12, 6))
206+
for model_name in data:
207+
plt.plot(data[model_name]["val_epoch"], data[model_name]["val_acc"], marker='o', linestyle='--', label=f"{model_name} Validation Accuracy")
208+
plt.title("Validation Accuracy")
209+
plt.xlabel("Epoch")
210+
plt.ylabel("Accuracy")
211+
plt.legend()
212+
plt.grid(True)
213+
plt.savefig(f"{save_dir}validation_accuracy.png")
214+
plt.close()

plinear/layers/__init__.py

Whitespace-only changes.

plinear/layers/sparse_btnn.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class SparseBtnn_Selector(nn.Module):
6+
def __init__(self, *dim):
7+
super(SparseBtnn_Selector, self).__init__()
8+
self.selector = nn.Linear(*dim)
9+
10+
torch.nn.init.uniform_(self.selector.weight, -1, 1)
11+
12+
def forward(self, x):
13+
selector = self.selector.weight
14+
15+
mask = torch.zeros_like(self.selector.weight)
16+
mask.scatter_(1, self.selector.weight.argmax(dim=1, keepdim=True), 1.0)
17+
18+
masked = mask - selector.detach() + selector
19+
20+
return F.linear(masked, x)
21+
22+
class SparseBtnn_And(nn.Module):
23+
def __init__(self, *dim):
24+
super(SparseBtnn_And, self).__init__()
25+
self.a = SparseBtnn_Selector(*dim)
26+
self.b = SparseBtnn_Selector(*dim)
27+
28+
def forward(self, x):
29+
a = self.a(x)
30+
b = self.b(x)
31+
32+
return a * b
33+
34+
class SparseBtnn_Not(nn.Module):
35+
def __init__(self, *dim):
36+
super(SparseBtnn_Not, self).__init__()
37+
self.a = nn.parameter.Parameter(torch.randn(*dim),)
38+
39+
def forward(self, x):
40+
a = self.a.expand(x.shape[-1], -1)
41+
qa = (a > 0).float() - a.detach() + a
42+
x = x.permute(1, 0)
43+
44+
return qa + x - 2 * qa * x
45+
46+
class SparseBtnn_Nand(nn.Module):
47+
def __init__(self, x, y):
48+
super(SparseBtnn_Nand, self).__init__()
49+
self.a = SparseBtnn_And(x, y)
50+
self.n = SparseBtnn_Not(y)
51+
52+
def forward(self, x):
53+
out = self.a(x)
54+
out = self.n(out)
55+
return out
56+
57+
class SparseBtnn_Nand_Multihead(nn.Module):
58+
def __init__(self, x, y, n_heads):
59+
super(SparseBtnn_Nand_Multihead, self).__init__()
60+
self.attns = nn.ModuleList([
61+
SparseBtnn_Nand(x, y) for _ in range(n_heads)
62+
])
63+
64+
def forward(self, x):
65+
return

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "plinear"
3-
version = "0.2.0.3"
3+
version = "0.2.0.4"
44
description = "parallel neural network layer for binarization of ternarization - quantized layers from the beginning"
55
authors = ["Choi Soon Ho <sosaror@gmail.com>"]
66
license = "MIT"

0 commit comments

Comments
 (0)