-
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexample.py
More file actions
22 lines (19 loc) · 630 Bytes
/
example.py
File metadata and controls
22 lines (19 loc) · 630 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from simba_torch.main import Simba
# Forward pass with images
img = torch.randn(1, 3, 224, 224)
# Create model
model = Simba(
dim=4, # Dimension of the transformer
dropout=0.1, # Dropout rate for regularization
d_state=64, # Dimension of the transformer state
d_conv=64, # Dimension of the convolutional layers
num_classes=64, # Number of output classes
depth=8, # Number of transformer layers
patch_size=16, # Size of the image patches
image_size=224, # Size of the input image
channels=3, # Number of input channels
)
# Forward pass
out = model(img)
print(out.shape)