Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 183 additions & 14 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,16 +1,185 @@
name: fpt
channels:
- anaconda
- pytorch
- conda-forge
- defaults
dependencies:
- python=3.7
- pip:
- boto3==1.17.102
- einops==0.3.0
- matplotlib==3.2.1
- numpy==1.18.3
- tape-proteins==0.4
- tensorflow==2.3.0
- tensorflow-datasets==4.0.1
- torch==1.7.1
- torchvision==0.8.2
- transformers==4.1.1
- tqdm==4.46.0
- wandb==0.9.1
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- _pytorch_select=0.1=cpu_0
- _tflow_select=2.3.0=mkl
- absl-py=0.13.0=py37h06a4308_0
- aiohttp=3.8.1=py37h7f8727e_0
- aiosignal=1.2.0=pyhd3eb1b0_0
- astor=0.8.1=py37h06a4308_0
- astunparse=1.6.3=py_0
- async-timeout=4.0.1=pyhd3eb1b0_0
- asynctest=0.13.0=py_0
- attrs=21.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- blinker=1.4=py37h06a4308_0
- boto3=1.20.16=pyhd8ed1ab_0
- botocore=1.23.16=pyhd8ed1ab_0
- bottleneck=1.3.2=py37heb32a55_1
- brotli=1.0.9=he6710b0_2
- brotlipy=0.7.0=py37h27cfd23_1003
- c-ares=1.17.1=h27cfd23_0
- ca-certificates=2021.10.26=h06a4308_2
- cachetools=4.2.2=pyhd3eb1b0_0
- certifi=2021.10.8=py37h06a4308_0
- cffi=1.14.6=py37h400218f_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- click=8.0.3=pyhd3eb1b0_0
- cryptography=3.4.8=py37hd23ed53_0
- cudatoolkit=10.2.89=hfd86e86_1
- cycler=0.11.0=pyhd3eb1b0_0
- dataclasses=0.8=pyh6d0b6a4_7
- dbus=1.13.18=hb2f20db_0
- dill=0.3.4=pyhd3eb1b0_0
- einops=0.3.2=pyhd8ed1ab_0
- expat=2.4.1=h2531618_2
- fontconfig=2.13.1=h6c09931_0
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.11.0=h70c0345_0
- frozenlist=1.2.0=py37h7f8727e_0
- future=0.18.2=py37_1
- gast=0.4.0=pyhd3eb1b0_0
- giflib=5.2.1=h7b6447c_0
- glib=2.69.1=h5202010_0
- google-auth=1.33.0=pyhd3eb1b0_0
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
- google-pasta=0.2.0=pyhd3eb1b0_0
- googleapis-common-protos=1.53.0=py37h06a4308_0
- grpcio=1.42.0=py37hce63b2e_0
- gst-plugins-base=1.14.0=h8213a91_2
- gstreamer=1.14.0=h28cd5cc_2
- h5py=2.10.0=py37hd6299e0_1
- hdf5=1.10.6=hb1b8bf9_0
- icu=58.2=he6710b0_3
- idna=3.3=pyhd3eb1b0_0
- importlib-metadata=4.8.1=py37h06a4308_0
- intel-openmp=2021.4.0=h06a4308_3561
- jmespath=0.10.0=pyhd3eb1b0_0
- joblib=1.1.0=pyhd3eb1b0_0
- jpeg=9d=h7f8727e_0
- keras-preprocessing=1.1.2=pyhd3eb1b0_0
- kiwisolver=1.3.1=py37h2531618_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgfortran-ng=7.5.0=ha8ba4b0_17
- libgfortran4=7.5.0=ha8ba4b0_17
- libgomp=9.3.0=h5101ec6_17
- libpng=1.6.37=hbc83047_0
- libprotobuf=3.17.2=h4ff587b_1
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtiff=4.2.0=h85742a9_0
- libuuid=1.0.3=h7f8727e_2
- libuv=1.40.0=h7b6447c_0
- libwebp=1.2.0=h89dd481_0
- libwebp-base=1.2.0=h27cfd23_0
- libxcb=1.14=h7b6447c_0
- libxml2=2.9.12=h03d6c58_0
- lz4-c=1.9.3=h295c915_1
- markdown=3.3.4=py37h06a4308_0
- matplotlib=3.4.3=py37h06a4308_0
- matplotlib-base=3.4.3=py37hbbc1b5f_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py37h7f8727e_0
- mkl_fft=1.3.1=py37hd3c417c_0
- mkl_random=1.2.2=py37h51133e4_0
- multidict=5.1.0=py37h27cfd23_2
- munkres=1.1.4=py_0
- ncurses=6.3=h7f8727e_2
- ninja=1.10.2=py37hd09550d_3
- numexpr=2.7.3=py37h22e1b3c_1
- numpy=1.21.2=py37h20f2e39_0
- numpy-base=1.21.2=py37h79a1101_0
- oauthlib=3.1.1=pyhd3eb1b0_0
- olefile=0.46=py37_0
- openssl=1.1.1l=h7f8727e_0
- opt_einsum=3.3.0=pyhd3eb1b0_1
- pandas=1.3.4=py37h8c16a72_0
- pcre=8.45=h295c915_0
- pillow=8.4.0=py37h5aabda8_0
- pip=21.2.2=py37h06a4308_0
- promise=2.3=py37h06a4308_0
- psutil=5.8.0=py37h27cfd23_1
- pyasn1=0.4.8=pyhd3eb1b0_0
- pyasn1-modules=0.2.8=py_0
- pycparser=2.21=pyhd3eb1b0_0
- pyjwt=2.1.0=py37h06a4308_0
- pyopenssl=21.0.0=pyhd3eb1b0_1
- pyparsing=3.0.4=pyhd3eb1b0_0
- pyqt=5.9.2=py37h05f1152_2
- pysocks=1.7.1=py37_1
- python=3.7.11=h12debd9_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-flatbuffers=2.0=pyhd3eb1b0_0
- pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0
- pytz=2021.3=pyhd3eb1b0_0
- qt=5.9.7=h5867ecd_1
- readline=8.1=h27cfd23_0
- requests=2.26.0=pyhd3eb1b0_0
- requests-oauthlib=1.3.0=py_0
- rsa=4.7.2=pyhd3eb1b0_1
- s3transfer=0.5.0=pyhd3eb1b0_0
- scikit-learn=0.23.2=py37h0573a6f_0
- setuptools=58.0.4=py37h06a4308_0
- sip=4.19.8=py37hf484d3e_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.36.0=hc218d9a_0
- tensorboard=2.4.0=pyhc547734_0
- tensorboard-plugin-wit=1.6.0=py_0
- tensorflow=2.4.1=mkl_py37h2d14ff2_0
- tensorflow-base=2.4.1=mkl_py37h43e0292_0
- tensorflow-datasets=1.2.0=py37_0
- tensorflow-estimator=2.6.0=pyh7b7c402_0
- tensorflow-metadata=0.14.0=pyhe6710b0_1
- termcolor=1.1.0=py37h06a4308_1
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.11=h1ccaba5_0
- torchvision=0.8.2=cpu_py37ha229d99_0
- tornado=6.1=py37h27cfd23_0
- tqdm=4.62.3=pyhd3eb1b0_1
- typing-extensions=3.10.0.2=hd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- urllib3=1.26.7=pyhd3eb1b0_0
- werkzeug=2.0.2=pyhd3eb1b0_0
- wheel=0.37.0=pyhd3eb1b0_1
- wrapt=1.13.3=py37h7f8727e_2
- xz=5.2.5=h7b6447c_0
- yarl=1.6.3=py37h27cfd23_0
- zipp=3.6.0=pyhd3eb1b0_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- biopython==1.79
- configparser==5.1.0
- docker-pycreds==0.4.0
- filelock==3.4.0
- gitdb==4.0.9
- gitpython==3.1.24
- lmdb==1.2.1
- packaging==21.3
- pathlib==1.0.1
- pathtools==0.1.2
- protobuf==3.19.1
- pyyaml==6.0
- regex==2021.11.10
- sacremoses==0.0.46
- scipy==1.7.3
- sentry-sdk==1.5.0
- shortuuid==1.0.8
- smmap==5.0.0
- subprocess32==3.5.4
- tape-proteins==0.5
- tensorboardx==2.4.1
- timm==0.4.12
- tokenizers==0.9.4
- transformers==4.1.1
- wandb==0.12.7
- yaspin==2.1.0
prefix: /u/home/a/anivrit/.conda/envs/fpt
6 changes: 3 additions & 3 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

if __name__ == '__main__':

experiment_name = 'fpt'
experiment_name = 'eurosat'

experiment_params = dict(
task='bit-memory',
task='eurosat',
n=1000, # ignored if not a bit task
num_patterns=5, # ignored if not a bit task
patch_size=50,
patch_size=16,

model_name='gpt2',
pretrained=True, # if vit this is forced to true, if lstm this is forced to false
Expand Down
103 changes: 103 additions & 0 deletions universal_computation/datasets/eurosat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import pathlib
from pathlib import Path
import pandas as pd
from einops import rearrange
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import torch
from torch.utils.data import DataLoader
from PIL import Image
import torchvision.transforms as transforms

from universal_computation.datasets.dataset import Dataset

class EuroSatDatasetHelper(torch.utils.data.Dataset):
def __init__(self, img_dir, ann_file, transform=None, target_transform=None):
df = pd.read_csv(ann_file)
self.img_labels = df[['img_name', 'int_label']].reset_index(drop=True)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform

def __len__(self):
return len(self.img_labels)

def __getitem__(self, idx):
label = self.img_labels.iloc[idx, 1]
img_name = self.img_labels.iloc[idx, 0]
dir_name = img_name.split('_')[0]
img_path = os.path.join(self.img_dir, dir_name, img_name)
img = Image.open(img_path)
if self.transform:
img = self.transform(img)
if self.target_transform:
label = self.target_transform(label)
return img, label


class EuroSatDataset(Dataset):
def __init__(self, batch_size, patch_size=None, data_aug=True, *args, **kwargs):
super(EuroSatDataset, self).__init__(*args, **kwargs)

self.batch_size = batch_size
self.patch_size = patch_size

if data_aug:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224), interpolation=3),
transforms.RandomApply([transforms.GaussianBlur(3)]),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
else:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224), interpolation=3),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224), interpolation=3),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
train_test_dir = 'data/2750'
self.d_train = DataLoader(
EuroSatDatasetHelper(train_test_dir, os.path.join(train_test_dir, 'train.csv'), transform=transform),
batch_size=batch_size, drop_last=True, shuffle=True,
)
self.d_test = DataLoader(
EuroSatDatasetHelper(train_test_dir, os.path.join(train_test_dir, 'test.csv'), transform=val_transform),
batch_size=batch_size, drop_last=True, shuffle=True,
)

self.train_enum = enumerate(self.d_train)
self.test_enum = enumerate(self.d_test)

self.train_size = len(self.d_train)
self.test_size = len(self.d_test)

def reset_test(self):
self.test_enum = enumerate(self.d_test)

def get_batch(self, batch_size=None, train=True):
if train:
_, (x, y) = next(self.train_enum, (None, (None, None)))
if x is None:
self.train_enum = enumerate(self.d_train)
_, (x, y) = next(self.train_enum)
else:
_, (x, y) = next(self.test_enum, (None, (None, None)))
if x is None:
self.test_enum = enumerate(self.d_test)
_, (x, y) = next(self.train_enum)

if self.patch_size is not None:
x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)

x = x.to(device=self.device)
y = y.to(device=self.device)

self._ind += 1

return x, y
23 changes: 23 additions & 0 deletions universal_computation/datasets/helpers/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pandas as pd
import os
import pathlib
from pathlib import Path

data = 'data/2750'

df = pd.DataFrame(columns=['label', 'int_label', 'img_name'])
labels_dict = {}
counter = 0
for subdir in os.listdir(data):
labels_dict[subdir] = counter
counter += 1
filepath = os.path.join(data,subdir)
if os.path.isdir(filepath):
for file in os.listdir(filepath):
dict = {'label': subdir, 'int_label': labels_dict[subdir], 'img_name': file}
df = df.append(dict, ignore_index = True)
train = df.sample(frac=0.75,random_state=200) # random state is a seed value
test = df.drop(train.index)

train.to_csv(data + '/train.csv')
test.to_csv(data + '/test.csv')
6 changes: 6 additions & 0 deletions universal_computation/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def experiment(
input_dim, output_dim = 30, 1200
use_embeddings = True
experiment_type = 'classification'
elif task == 'eurosat':
from universal_computation.datasets.eurosat import EuroSatDataset
dataset = EuroSatDataset(batch_size=batch_size, patch_size=patch_size, device=device)
input_dim, output_dim = 3 * patch_size**2, 3
use_embeddings = False
experiment_type = 'classification'

else:
raise NotImplementedError('dataset not implemented')
Expand Down
Empty file.
Loading