Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
de68fbd
Python 3.6+ and Pytorch 1.0+
davidstap Mar 7, 2019
676657a
Update README.md
davidstap Mar 7, 2019
78f9e1d
Update README.md
davidstap Mar 19, 2019
3ea6c5b
Update config.py
ImmortalBoi Dec 3, 2023
2b83a7c
Update datasets.py
ImmortalBoi Dec 3, 2023
92d6d00
Update datasets.py
ImmortalBoi Dec 4, 2023
9c99553
Update datasets.py
ImmortalBoi Dec 4, 2023
3707608
Update datasets.py
ImmortalBoi Dec 4, 2023
b92c509
Update datasets.py
ImmortalBoi Dec 4, 2023
3494e50
Update datasets.py
ImmortalBoi Dec 4, 2023
7f400e0
Update datasets.py
ImmortalBoi Dec 4, 2023
38f0647
Update datasets.py
ImmortalBoi Dec 4, 2023
90e8c04
Update datasets.py
ImmortalBoi Dec 4, 2023
244926e
Update datasets.py
ImmortalBoi Dec 4, 2023
6113566
Update datasets.py
ImmortalBoi Dec 4, 2023
abde8f2
Update datasets.py
ImmortalBoi Dec 4, 2023
07a1c46
Update pretrain_DAMSM.py
ImmortalBoi Dec 4, 2023
fde1d31
Update pretrain_DAMSM.py
ImmortalBoi Dec 4, 2023
4977433
Update utils.py
ImmortalBoi Dec 4, 2023
5b00abf
Update utils.py
ImmortalBoi Dec 4, 2023
838e213
Update utils.py
ImmortalBoi Dec 4, 2023
13ccd3c
Update utils.py
ImmortalBoi Dec 4, 2023
1aee555
Update utils.py
ImmortalBoi Dec 4, 2023
8b5df51
Update utils.py
ImmortalBoi Dec 4, 2023
601afe4
Update pretrain_DAMSM.py
ImmortalBoi Dec 4, 2023
e5faf89
Update utils.py
ImmortalBoi Dec 4, 2023
dce1293
Update utils.py
ImmortalBoi Dec 4, 2023
0926b2c
Update utils.py
ImmortalBoi Dec 4, 2023
d9c3881
Update utils.py
ImmortalBoi Dec 4, 2023
83b7c0d
Update utils.py
ImmortalBoi Dec 4, 2023
6fc3a42
Update utils.py
ImmortalBoi Dec 4, 2023
bf399e2
Update utils.py
ImmortalBoi Dec 4, 2023
831e78f
Update utils.py
ImmortalBoi Dec 4, 2023
d44dcdc
Update utils.py
ImmortalBoi Dec 4, 2023
d208a96
Update utils.py
ImmortalBoi Dec 4, 2023
60c725a
Update utils.py
ImmortalBoi Dec 4, 2023
45536a1
Update utils.py
ImmortalBoi Dec 4, 2023
0a82fa5
Update utils.py
ImmortalBoi Dec 4, 2023
a561846
Update utils.py
ImmortalBoi Dec 4, 2023
4279798
update pretrain_DAMSM.py
ImmortalBoi Dec 4, 2023
9a9f9d3
Update pretrain_DAMSM.py
ImmortalBoi Dec 4, 2023
32a2cf9
Update pretrain_DAMSM.py
ImmortalBoi Dec 5, 2023
b05f202
Update pretrain_DAMSM.py
ImmortalBoi Dec 5, 2023
07d9502
update trainer.py
ImmortalBoi Dec 5, 2023
79f8f08
Update trainer.py
ImmortalBoi Dec 5, 2023
0e6eead
Update main.py
ImmortalBoi Dec 5, 2023
1a8067d
Update datasets.py
ImmortalBoi Dec 6, 2023
0625db0
Update datasets.py
ImmortalBoi Dec 6, 2023
655427a
Update datasets.py
ImmortalBoi Dec 6, 2023
7952fc0
Update datasets.py
ImmortalBoi Dec 6, 2023
02c6547
mask
ImmortalBoi Dec 7, 2023
7c10f9f
Update README.md
ImmortalBoi Mar 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# AttnGAN
# AttnGAN (Python 3, Pytorch 1.0)

Pytorch implementation for reproducing AttnGAN results in the paper [AttnGAN: Fine-Grained Text to Image Generation
with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf) by Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He. (This work was performed when Tao was an intern with Microsoft Research).
Expand All @@ -7,9 +7,9 @@ with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/c


### Dependencies
python 2.7
python 3.6+

Pytorch
Pytorch 1.0+

In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
- `python-dateutil`
Expand All @@ -27,7 +27,10 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f
2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) image data. Extract them to `data/birds/`
3. Download [coco](http://cocodataset.org/#download) dataset and extract the images to `data/coco/`


**Expected Dataset Folder Structure in YML**
<div>- dataset</div>
<div>|- images</div>
<p>|- text</p>

**Training**
- Pre-train DAMSM models:
Expand All @@ -40,8 +43,6 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f

- `*.yml` files are example configuration files for training/evaluation our models.



**Pretrained Model**
- [DAMSM for bird](https://drive.google.com/open?id=1GNUKjVeyWYBJ8hEU-yrfYQpDOkxEyP3V). Download and save it to `DAMSMencoders/`
- [DAMSM for coco](https://drive.google.com/open?id=1zIrXCE9F6yfbEJIbNP5-YrEe2pZcPSGJ). Download and save it to `DAMSMencoders/`
Expand Down
8 changes: 4 additions & 4 deletions code/GlobalAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def func_attention(query, context, gamma1):
attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper
# --> batch*sourceL x queryL
attn = attn.view(batch_size*sourceL, queryL)
attn = nn.Softmax()(attn) # Eq. (8)
attn = nn.Softmax(dim=1)(attn) # Eq. (8)

# --> batch x sourceL x queryL
attn = attn.view(batch_size, sourceL, queryL)
Expand All @@ -57,7 +57,7 @@ def func_attention(query, context, gamma1):
attn = attn.view(batch_size*queryL, sourceL)
# Eq. (9)
attn = attn * gamma1
attn = nn.Softmax()(attn)
attn = nn.Softmax(dim=1)(attn)
attn = attn.view(batch_size, queryL, sourceL)
# --> batch x sourceL x queryL
attnT = torch.transpose(attn, 1, 2).contiguous()
Expand All @@ -73,7 +73,7 @@ class GlobalAttentionGeneral(nn.Module):
def __init__(self, idf, cdf):
super(GlobalAttentionGeneral, self).__init__()
self.conv_context = conv1x1(cdf, idf)
self.sm = nn.Softmax()
self.sm = nn.Softmax(dim=1)
self.mask = None

def applyMask(self, mask):
Expand Down Expand Up @@ -104,7 +104,7 @@ def forward(self, input, context):
attn = attn.view(batch_size*queryL, sourceL)
if self.mask is not None:
# batch_size x sourceL --> batch_size*queryL x sourceL
mask = self.mask.repeat(queryL, 1)
mask = self.mask.repeat(queryL, 1).to(torch.bool)
attn.data.masked_fill_(mask.data, -float('inf'))
attn = self.sm(attn) # Eq. (2)
# --> batch x queryL x sourceL
Expand Down
98 changes: 84 additions & 14 deletions code/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function
from __future__ import unicode_literals


from nltk.tokenize import RegexpTokenizer
from collections import defaultdict
from miscc.config import cfg
Expand All @@ -13,6 +12,10 @@
from torch.autograd import Variable
import torchvision.transforms as transforms

import os
import shutil
from sklearn.model_selection import train_test_split

import os
import sys
import numpy as np
Expand Down Expand Up @@ -80,7 +83,7 @@ def get_imgs(img_path, imsize, bbox=None,
for i in range(cfg.TREE.BRANCH_NUM):
# print(imsize[i])
if i < (cfg.TREE.BRANCH_NUM - 1):
re_img = transforms.Scale(imsize[i])(img)
re_img = transforms.Resize(imsize[i])(img)
else:
re_img = img
ret.append(normalize(re_img))
Expand Down Expand Up @@ -133,7 +136,7 @@ def load_bbox(self):
#
filename_bbox = {img_file[:-4]: [] for img_file in filenames}
numImgs = len(filenames)
for i in xrange(0, numImgs):
for i in range(0, numImgs):
# bbox = [x-left, y-top, width, height]
bbox = df_bounding_boxes.iloc[i][1:].tolist()

Expand All @@ -142,12 +145,12 @@ def load_bbox(self):
#
return filename_bbox

def load_captions(self, data_dir, filenames):
def load_captions(self, data_dir, filenames:list[str], split):
all_captions = []
for i in range(len(filenames)):
cap_path = '%s/text/%s.txt' % (data_dir, filenames[i])
cap_path = '%s/%s/text/%s.txt' % (data_dir, split, filenames[i])
with open(cap_path, "r") as f:
captions = f.read().decode('utf8').split('\n')
captions = f.read().split('\n')
cnt = 0
for cap in captions:
if len(cap) == 0:
Expand Down Expand Up @@ -221,8 +224,8 @@ def load_text_data(self, data_dir, split):
train_names = self.load_filenames(data_dir, 'train')
test_names = self.load_filenames(data_dir, 'test')
if not os.path.isfile(filepath):
train_captions = self.load_captions(data_dir, train_names)
test_captions = self.load_captions(data_dir, test_names)
train_captions = self.load_captions(data_dir, train_names, "train")
test_captions = self.load_captions(data_dir, test_names, "test")

train_captions, test_captions, ixtoword, wordtoix, n_words = \
self.build_dictionary(train_captions, test_captions)
Expand Down Expand Up @@ -251,7 +254,7 @@ def load_text_data(self, data_dir, split):
def load_class_id(self, data_dir, total_num):
if os.path.isfile(data_dir + '/class_info.pickle'):
with open(data_dir + '/class_info.pickle', 'rb') as f:
class_id = pickle.load(f)
class_id = pickle.load(f, encoding="bytes")
else:
class_id = np.arange(total_num)
return class_id
Expand All @@ -262,13 +265,69 @@ def load_filenames(self, data_dir, split):
with open(filepath, 'rb') as f:
filenames = pickle.load(f)
print('Load filenames from: %s (%d)' % (filepath, len(filenames)))

return filenames

else:
filenames = []
return filenames
image_dir = os.path.join(data_dir, 'images')
text_dir = os.path.join(data_dir, 'text')

image_files = sorted(os.listdir(image_dir))

filenames = [os.path.splitext(f)[0] for f in image_files]

train_filenames, test_filenames = train_test_split(filenames, test_size=0.3)

image_train = [f + '.jpg' for f in train_filenames]
text_train = [f + '.txt' for f in train_filenames]

image_test = [f + '.jpg' for f in test_filenames]
text_test = [f + '.txt' for f in test_filenames]

train_image_dir = os.path.join(data_dir, 'train/images')
test_image_dir = os.path.join(data_dir, 'test/images')
train_text_dir = os.path.join(data_dir, 'train/text')
test_text_dir = os.path.join(data_dir, 'test/text')

os.makedirs(train_image_dir, exist_ok=True)
os.makedirs(test_image_dir, exist_ok=True)
os.makedirs(train_text_dir, exist_ok=True)
os.makedirs(test_text_dir, exist_ok=True)

for file in image_train:
shutil.move(os.path.join(image_dir, file), train_image_dir)

for file in image_test:
shutil.move(os.path.join(image_dir, file), test_image_dir)

for file in text_train:
shutil.move(os.path.join(text_dir, file), train_text_dir)

for file in text_test:
shutil.move(os.path.join(text_dir, file), test_text_dir)

os.rmdir(image_dir)
os.rmdir(text_dir)

with open('%s/%s/filenames.pickle' % (data_dir, "train"), 'wb') as f:
pickle.dump(train_filenames, f, protocol=pickle.HIGHEST_PROTOCOL)

with open('%s/%s/filenames.pickle' % (data_dir, "test"), 'wb') as f:
pickle.dump(test_filenames, f, protocol=pickle.HIGHEST_PROTOCOL)

print('Create pickle and Load filenames from: %s (%d)' % (filepath, len(filenames)))

if split == 'train':
return train_filenames
else:
return test_filenames

def get_caption(self, sent_ix):
# a list of indices for a sentence
sent_caption = np.asarray(self.captions[sent_ix]).astype('int64')
try:
sent_caption = np.asarray(self.captions[sent_ix]).astype('int64')
except:
print(f'Getting caption for sent_ix: {sent_ix}')
if (sent_caption == 0).sum() > 0:
print('ERROR: do not need END (0) token', sent_caption)
num_words = len(sent_caption)
Expand Down Expand Up @@ -298,13 +357,24 @@ def __getitem__(self, index):
bbox = None
data_dir = self.data_dir
#
img_name = '%s/images/%s.jpg' % (data_dir, key)
img_name = ""
if os.path.isfile('%s/%s/images/%s.jpg' % (data_dir, "train", key)):
img_name = '%s/%s/images/%s.jpg' % (data_dir, "train", key)

if os.path.isfile('%s/%s/images/%s.jpg' % (data_dir, "test", key)):
img_name = '%s/%s/images/%s.jpg' % (data_dir, "test", key)

imgs = get_imgs(img_name, self.imsize,
bbox, self.transform, normalize=self.norm)
# random select a sentence
sent_ix = random.randint(0, self.embeddings_num)
new_sent_ix = index * self.embeddings_num + sent_ix
caps, cap_len = self.get_caption(new_sent_ix)
try:
caps, cap_len = self.get_caption(new_sent_ix)
except Exception as error:
print(error)
print(f'index: {index}, new_sent_ix: {new_sent_ix}, sent_ix: {sent_ix}, len(self.captions): {len(self.captions)}')
caps, cap_len = self.get_caption(new_sent_ix-1)
return imgs, caps, cap_len, cls_id, key


Expand Down
12 changes: 7 additions & 5 deletions code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datasets import TextDataset
from trainer import condGANTrainer as trainer

from pathlib import Path

import os
import sys
import time
Expand Down Expand Up @@ -39,14 +41,14 @@ def gen_example(wordtoix, algo):
filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR)
data_dic = {}
with open(filepath, "r") as f:
filenames = f.read().decode('utf8').split('\n')
filenames = f.read().split('\n')
for name in filenames:
if len(name) == 0:
continue
filepath = '%s/%s.txt' % (cfg.DATA_DIR, name)
with open(filepath, "r") as f:
print('Load from:', name)
sentences = f.read().decode('utf8').split('\n')
sentences = f.read().split('\n')
# a list of indices for a sentence
captions = []
cap_lens = []
Expand Down Expand Up @@ -110,8 +112,8 @@ def gen_example(wordtoix, algo):

now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/%s_%s_%s' % \
(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
output_dir = '%s/../output/%s_%s_%s' % \
(str(Path(cfg.DATA_DIR).parent.parent) ,cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

split_dir, bshuffle = 'train', True
if not cfg.TRAIN.FLAG:
Expand All @@ -121,7 +123,7 @@ def gen_example(wordtoix, algo):
# Get data loader
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1))
image_transform = transforms.Compose([
transforms.Scale(int(imsize * 76 / 64)),
transforms.Resize(int(imsize * 76 / 64)),
transforms.RandomCrop(imsize),
transforms.RandomHorizontalFlip()])
dataset = TextDataset(cfg.DATA_DIR, split_dir,
Expand Down
6 changes: 3 additions & 3 deletions code/miscc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def _merge_a_into_b(a, b):
if type(a) is not edict:
return

for k, v in a.iteritems():
for k, v in a.items():
# a must specify keys that are in b
if not b.has_key(k):
if k not in b:
raise KeyError('{} is not a valid config key'.format(k))

# the types must match, too
Expand Down Expand Up @@ -100,6 +100,6 @@ def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml
with open(filename, 'r') as f:
yaml_cfg = edict(yaml.load(f))
yaml_cfg = edict(yaml.safe_load(f))

_merge_a_into_b(yaml_cfg, __C)
10 changes: 6 additions & 4 deletions code/miscc/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def sent_loss(cnn_code, rnn_code, labels, class_ids,
# --> batch_size x batch_size
scores0 = scores0.squeeze()
if class_ids is not None:
scores0.data.masked_fill_(masks, -float('inf'))
masks_bool = masks.to(torch.bool)
scores0.data.masked_fill_(masks_bool, -float('inf'))
scores1 = scores0.transpose(0, 1)
if labels is not None:
loss0 = nn.CrossEntropyLoss()(scores0, labels)
Expand Down Expand Up @@ -122,7 +123,8 @@ def words_loss(img_features, words_emb, labels,

similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3
if class_ids is not None:
similarities.data.masked_fill_(masks, -float('inf'))
masks_bool = masks.to(torch.bool)
similarities.data.masked_fill_(masks_bool, -float('inf'))
similarities1 = similarities.transpose(0, 1)
if labels is not None:
loss0 = nn.CrossEntropyLoss()(similarities, labels)
Expand Down Expand Up @@ -181,7 +183,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
g_loss = cond_errG
errG_total += g_loss
# err_img = errG_total.data[0]
logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0])
logs += 'g_loss%d: %.2f ' % (i, g_loss.item())

# Ranking loss
if i == (numDs - 1):
Expand All @@ -202,7 +204,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
# err_sent = err_sent + s_loss.data[0]

errG_total += w_loss + s_loss
logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.data[0], s_loss.data[0])
logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item())
return errG_total, logs


Expand Down
Loading