diff --git a/README.md b/README.md index aa75fe10..6ac9a6a3 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# AttnGAN +# AttnGAN (Python 3, Pytorch 1.0) Pytorch implementation for reproducing AttnGAN results in the paper [AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf) by Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He. (This work was performed when Tao was an intern with Microsoft Research). @@ -7,9 +7,9 @@ with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/c ### Dependencies -python 2.7 +python 3.6+ -Pytorch +Pytorch 1.0+ In addition, please add the project folder to PYTHONPATH and `pip install` the following packages: - `python-dateutil` @@ -27,7 +27,10 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f 2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) image data. Extract them to `data/birds/` 3. Download [coco](http://cocodataset.org/#download) dataset and extract the images to `data/coco/` - +**Expected Dataset Folder Structure in YML** +
|- text
**Training** - Pre-train DAMSM models: @@ -40,8 +43,6 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f - `*.yml` files are example configuration files for training/evaluation our models. - - **Pretrained Model** - [DAMSM for bird](https://drive.google.com/open?id=1GNUKjVeyWYBJ8hEU-yrfYQpDOkxEyP3V). Download and save it to `DAMSMencoders/` - [DAMSM for coco](https://drive.google.com/open?id=1zIrXCE9F6yfbEJIbNP5-YrEe2pZcPSGJ). Download and save it to `DAMSMencoders/` diff --git a/code/GlobalAttention.py b/code/GlobalAttention.py index 501fb720..9d629e25 100644 --- a/code/GlobalAttention.py +++ b/code/GlobalAttention.py @@ -48,7 +48,7 @@ def func_attention(query, context, gamma1): attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper # --> batch*sourceL x queryL attn = attn.view(batch_size*sourceL, queryL) - attn = nn.Softmax()(attn) # Eq. (8) + attn = nn.Softmax(dim=1)(attn) # Eq. (8) # --> batch x sourceL x queryL attn = attn.view(batch_size, sourceL, queryL) @@ -57,7 +57,7 @@ def func_attention(query, context, gamma1): attn = attn.view(batch_size*queryL, sourceL) # Eq. (9) attn = attn * gamma1 - attn = nn.Softmax()(attn) + attn = nn.Softmax(dim=1)(attn) attn = attn.view(batch_size, queryL, sourceL) # --> batch x sourceL x queryL attnT = torch.transpose(attn, 1, 2).contiguous() @@ -73,7 +73,7 @@ class GlobalAttentionGeneral(nn.Module): def __init__(self, idf, cdf): super(GlobalAttentionGeneral, self).__init__() self.conv_context = conv1x1(cdf, idf) - self.sm = nn.Softmax() + self.sm = nn.Softmax(dim=1) self.mask = None def applyMask(self, mask): @@ -104,7 +104,7 @@ def forward(self, input, context): attn = attn.view(batch_size*queryL, sourceL) if self.mask is not None: # batch_size x sourceL --> batch_size*queryL x sourceL - mask = self.mask.repeat(queryL, 1) + mask = self.mask.repeat(queryL, 1).to(torch.bool) attn.data.masked_fill_(mask.data, -float('inf')) attn = self.sm(attn) # Eq. (2) # --> batch x queryL x sourceL diff --git a/code/datasets.py b/code/datasets.py index 24ffdc4a..fee54e6e 100644 --- a/code/datasets.py +++ b/code/datasets.py @@ -3,7 +3,6 @@ from __future__ import print_function from __future__ import unicode_literals - from nltk.tokenize import RegexpTokenizer from collections import defaultdict from miscc.config import cfg @@ -13,6 +12,10 @@ from torch.autograd import Variable import torchvision.transforms as transforms +import os +import shutil +from sklearn.model_selection import train_test_split + import os import sys import numpy as np @@ -80,7 +83,7 @@ def get_imgs(img_path, imsize, bbox=None, for i in range(cfg.TREE.BRANCH_NUM): # print(imsize[i]) if i < (cfg.TREE.BRANCH_NUM - 1): - re_img = transforms.Scale(imsize[i])(img) + re_img = transforms.Resize(imsize[i])(img) else: re_img = img ret.append(normalize(re_img)) @@ -133,7 +136,7 @@ def load_bbox(self): # filename_bbox = {img_file[:-4]: [] for img_file in filenames} numImgs = len(filenames) - for i in xrange(0, numImgs): + for i in range(0, numImgs): # bbox = [x-left, y-top, width, height] bbox = df_bounding_boxes.iloc[i][1:].tolist() @@ -142,12 +145,12 @@ def load_bbox(self): # return filename_bbox - def load_captions(self, data_dir, filenames): + def load_captions(self, data_dir, filenames:list[str], split): all_captions = [] for i in range(len(filenames)): - cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) + cap_path = '%s/%s/text/%s.txt' % (data_dir, split, filenames[i]) with open(cap_path, "r") as f: - captions = f.read().decode('utf8').split('\n') + captions = f.read().split('\n') cnt = 0 for cap in captions: if len(cap) == 0: @@ -221,8 +224,8 @@ def load_text_data(self, data_dir, split): train_names = self.load_filenames(data_dir, 'train') test_names = self.load_filenames(data_dir, 'test') if not os.path.isfile(filepath): - train_captions = self.load_captions(data_dir, train_names) - test_captions = self.load_captions(data_dir, test_names) + train_captions = self.load_captions(data_dir, train_names, "train") + test_captions = self.load_captions(data_dir, test_names, "test") train_captions, test_captions, ixtoword, wordtoix, n_words = \ self.build_dictionary(train_captions, test_captions) @@ -251,7 +254,7 @@ def load_text_data(self, data_dir, split): def load_class_id(self, data_dir, total_num): if os.path.isfile(data_dir + '/class_info.pickle'): with open(data_dir + '/class_info.pickle', 'rb') as f: - class_id = pickle.load(f) + class_id = pickle.load(f, encoding="bytes") else: class_id = np.arange(total_num) return class_id @@ -262,13 +265,69 @@ def load_filenames(self, data_dir, split): with open(filepath, 'rb') as f: filenames = pickle.load(f) print('Load filenames from: %s (%d)' % (filepath, len(filenames))) + + return filenames + else: - filenames = [] - return filenames + image_dir = os.path.join(data_dir, 'images') + text_dir = os.path.join(data_dir, 'text') + + image_files = sorted(os.listdir(image_dir)) + + filenames = [os.path.splitext(f)[0] for f in image_files] + + train_filenames, test_filenames = train_test_split(filenames, test_size=0.3) + + image_train = [f + '.jpg' for f in train_filenames] + text_train = [f + '.txt' for f in train_filenames] + + image_test = [f + '.jpg' for f in test_filenames] + text_test = [f + '.txt' for f in test_filenames] + + train_image_dir = os.path.join(data_dir, 'train/images') + test_image_dir = os.path.join(data_dir, 'test/images') + train_text_dir = os.path.join(data_dir, 'train/text') + test_text_dir = os.path.join(data_dir, 'test/text') + + os.makedirs(train_image_dir, exist_ok=True) + os.makedirs(test_image_dir, exist_ok=True) + os.makedirs(train_text_dir, exist_ok=True) + os.makedirs(test_text_dir, exist_ok=True) + + for file in image_train: + shutil.move(os.path.join(image_dir, file), train_image_dir) + + for file in image_test: + shutil.move(os.path.join(image_dir, file), test_image_dir) + + for file in text_train: + shutil.move(os.path.join(text_dir, file), train_text_dir) + + for file in text_test: + shutil.move(os.path.join(text_dir, file), test_text_dir) + + os.rmdir(image_dir) + os.rmdir(text_dir) + + with open('%s/%s/filenames.pickle' % (data_dir, "train"), 'wb') as f: + pickle.dump(train_filenames, f, protocol=pickle.HIGHEST_PROTOCOL) + + with open('%s/%s/filenames.pickle' % (data_dir, "test"), 'wb') as f: + pickle.dump(test_filenames, f, protocol=pickle.HIGHEST_PROTOCOL) + + print('Create pickle and Load filenames from: %s (%d)' % (filepath, len(filenames))) + + if split == 'train': + return train_filenames + else: + return test_filenames def get_caption(self, sent_ix): # a list of indices for a sentence - sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') + try: + sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') + except: + print(f'Getting caption for sent_ix: {sent_ix}') if (sent_caption == 0).sum() > 0: print('ERROR: do not need END (0) token', sent_caption) num_words = len(sent_caption) @@ -298,13 +357,24 @@ def __getitem__(self, index): bbox = None data_dir = self.data_dir # - img_name = '%s/images/%s.jpg' % (data_dir, key) + img_name = "" + if os.path.isfile('%s/%s/images/%s.jpg' % (data_dir, "train", key)): + img_name = '%s/%s/images/%s.jpg' % (data_dir, "train", key) + + if os.path.isfile('%s/%s/images/%s.jpg' % (data_dir, "test", key)): + img_name = '%s/%s/images/%s.jpg' % (data_dir, "test", key) + imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm) # random select a sentence sent_ix = random.randint(0, self.embeddings_num) new_sent_ix = index * self.embeddings_num + sent_ix - caps, cap_len = self.get_caption(new_sent_ix) + try: + caps, cap_len = self.get_caption(new_sent_ix) + except Exception as error: + print(error) + print(f'index: {index}, new_sent_ix: {new_sent_ix}, sent_ix: {sent_ix}, len(self.captions): {len(self.captions)}') + caps, cap_len = self.get_caption(new_sent_ix-1) return imgs, caps, cap_len, cls_id, key diff --git a/code/main.py b/code/main.py index 934e7764..b3ced207 100644 --- a/code/main.py +++ b/code/main.py @@ -4,6 +4,8 @@ from datasets import TextDataset from trainer import condGANTrainer as trainer +from pathlib import Path + import os import sys import time @@ -39,14 +41,14 @@ def gen_example(wordtoix, algo): filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR) data_dic = {} with open(filepath, "r") as f: - filenames = f.read().decode('utf8').split('\n') + filenames = f.read().split('\n') for name in filenames: if len(name) == 0: continue filepath = '%s/%s.txt' % (cfg.DATA_DIR, name) with open(filepath, "r") as f: print('Load from:', name) - sentences = f.read().decode('utf8').split('\n') + sentences = f.read().split('\n') # a list of indices for a sentence captions = [] cap_lens = [] @@ -110,8 +112,8 @@ def gen_example(wordtoix, algo): now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') - output_dir = '../output/%s_%s_%s' % \ - (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) + output_dir = '%s/../output/%s_%s_%s' % \ + (str(Path(cfg.DATA_DIR).parent.parent) ,cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) split_dir, bshuffle = 'train', True if not cfg.TRAIN.FLAG: @@ -121,7 +123,7 @@ def gen_example(wordtoix, algo): # Get data loader imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) image_transform = transforms.Compose([ - transforms.Scale(int(imsize * 76 / 64)), + transforms.Resize(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextDataset(cfg.DATA_DIR, split_dir, diff --git a/code/miscc/config.py b/code/miscc/config.py index 797319ba..413e9d77 100644 --- a/code/miscc/config.py +++ b/code/miscc/config.py @@ -70,9 +70,9 @@ def _merge_a_into_b(a, b): if type(a) is not edict: return - for k, v in a.iteritems(): + for k, v in a.items(): # a must specify keys that are in b - if not b.has_key(k): + if k not in b: raise KeyError('{} is not a valid config key'.format(k)) # the types must match, too @@ -100,6 +100,6 @@ def cfg_from_file(filename): """Load a config file and merge it into the default options.""" import yaml with open(filename, 'r') as f: - yaml_cfg = edict(yaml.load(f)) + yaml_cfg = edict(yaml.safe_load(f)) _merge_a_into_b(yaml_cfg, __C) diff --git a/code/miscc/losses.py b/code/miscc/losses.py index b15612bf..d3f30b37 100644 --- a/code/miscc/losses.py +++ b/code/miscc/losses.py @@ -49,7 +49,8 @@ def sent_loss(cnn_code, rnn_code, labels, class_ids, # --> batch_size x batch_size scores0 = scores0.squeeze() if class_ids is not None: - scores0.data.masked_fill_(masks, -float('inf')) + masks_bool = masks.to(torch.bool) + scores0.data.masked_fill_(masks_bool, -float('inf')) scores1 = scores0.transpose(0, 1) if labels is not None: loss0 = nn.CrossEntropyLoss()(scores0, labels) @@ -122,7 +123,8 @@ def words_loss(img_features, words_emb, labels, similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 if class_ids is not None: - similarities.data.masked_fill_(masks, -float('inf')) + masks_bool = masks.to(torch.bool) + similarities.data.masked_fill_(masks_bool, -float('inf')) similarities1 = similarities.transpose(0, 1) if labels is not None: loss0 = nn.CrossEntropyLoss()(similarities, labels) @@ -181,7 +183,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels, g_loss = cond_errG errG_total += g_loss # err_img = errG_total.data[0] - logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0]) + logs += 'g_loss%d: %.2f ' % (i, g_loss.item()) # Ranking loss if i == (numDs - 1): @@ -202,7 +204,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels, # err_sent = err_sent + s_loss.data[0] errG_total += w_loss + s_loss - logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.data[0], s_loss.data[0]) + logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item()) return errG_total, logs diff --git a/code/miscc/utils.py b/code/miscc/utils.py index f131a365..bad1050e 100644 --- a/code/miscc/utils.py +++ b/code/miscc/utils.py @@ -32,7 +32,6 @@ def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): img_txt = Image.fromarray(convas) # get a font # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) - fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) # get a drawing context d = ImageDraw.Draw(img_txt) sentence_list = [] @@ -44,7 +43,7 @@ def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): break word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), - font=fnt, fill=(255, 255, 255, 255)) + fill=(255, 255, 255, 255)) sentence.append(word) sentence_list.append(sentence) return img_txt, sentence_list @@ -127,9 +126,10 @@ def build_super_images(real_imgs, captions, ixtoword, for j in range(num_attn): one_map = attn[j] if (vis_size // att_sze) > 1: - one_map = \ - skimage.transform.pyramid_expand(one_map, sigma=20, - upscale=vis_size // att_sze) + # one_map = \ + # skimage.transform.pyramid_expand(one_map, sigma=20, + # upscale=vis_size // att_sze) + one_map = skimage.transform.resize(one_map, (vis_size, vis_size), mode='reflect') row_beforeNorm.append(one_map) minV = one_map.min() maxV = one_map.max() @@ -175,7 +175,6 @@ def build_super_images(real_imgs, captions, ixtoword, else: return None - def build_super_images2(real_imgs, captions, cap_lens, ixtoword, attn_maps, att_sze, vis_size=256, topK=5): batch_size = real_imgs.size(0) @@ -185,7 +184,8 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, dtype=np.uint8) real_imgs = \ - nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) + nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), + mode='bilinear', align_corners=False) # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() @@ -228,7 +228,9 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, if (vis_size // att_sze) > 1: one_map = \ skimage.transform.pyramid_expand(one_map, sigma=20, - upscale=vis_size // att_sze) + upscale=vis_size // att_sze, +) + minV = one_map.min() maxV = one_map.max() one_map = (one_map - minV) / (maxV - minV) @@ -286,12 +288,12 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) if m.bias is not None: m.bias.data.fill_(0.0) diff --git a/code/model.py b/code/model.py index 09bde34e..9bc514b9 100644 --- a/code/model.py +++ b/code/model.py @@ -20,7 +20,19 @@ def forward(self, x): nc = x.size(1) assert nc % 2 == 0, 'channels dont divide 2!' nc = int(nc/2) - return x[:, :nc] * F.sigmoid(x[:, nc:]) + return x[:, :nc] * torch.sigmoid(x[:, nc:]) + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, size=None): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.size = size + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, size=self.size) + return x def conv1x1(in_planes, out_planes, bias=False): @@ -38,7 +50,7 @@ def conv3x3(in_planes, out_planes): # Upsale the spatial size by a factor of 2 def upBlock(in_planes, out_planes): block = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), + Interpolate(scale_factor=2, mode='nearest'), conv3x3(in_planes, out_planes * 2), nn.BatchNorm2d(out_planes * 2), GLU()) @@ -207,7 +219,7 @@ def init_trainable_weights(self): def forward(self, x): features = None # --> fixed-size input: batch x 3 x 299 x 299 - x = nn.Upsample(size=(299, 299), mode='bilinear')(x) + x = nn.functional.interpolate(x,size=(299, 299), mode='bilinear', align_corners=False) # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32 diff --git a/code/pretrain_DAMSM.py b/code/pretrain_DAMSM.py index 5f8b0ff9..f635e062 100644 --- a/code/pretrain_DAMSM.py +++ b/code/pretrain_DAMSM.py @@ -10,6 +10,8 @@ from model import RNN_ENCODER, CNN_ENCODER +from pathlib import Path + import os import sys import time @@ -93,18 +95,19 @@ def train(dataloader, cnn_model, rnn_model, batch_size, # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. - torch.nn.utils.clip_grad_norm(rnn_model.parameters(), + torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step - s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL - s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL + + s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL + s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL - w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL - w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL + w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL + w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' @@ -157,8 +160,8 @@ def evaluate(dataloader, cnn_model, rnn_model, batch_size): if step == 50: break - s_cur_loss = s_total_loss[0] / step - w_cur_loss = w_total_loss[0] / step + s_cur_loss = s_total_loss.item() / step + w_cur_loss = w_total_loss.item() / step return s_cur_loss, w_cur_loss @@ -220,8 +223,8 @@ def build_models(): ########################################################################## now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') - output_dir = '../output/%s_%s_%s' % \ - (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) + output_dir = '%s/../output/%s_%s_%s' % \ + (str(Path(cfg.DATA_DIR).parent.parent) ,cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) model_dir = os.path.join(output_dir, 'Model') image_dir = os.path.join(output_dir, 'Image') @@ -235,7 +238,7 @@ def build_models(): imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) batch_size = cfg.TRAIN.BATCH_SIZE image_transform = transforms.Compose([ - transforms.Scale(int(imsize * 76 / 64)), + transforms.Resize(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextDataset(cfg.DATA_DIR, 'train', @@ -282,14 +285,13 @@ def build_models(): print('-' * 89) if lr > cfg.TRAIN.ENCODER_LR/10.: lr *= 0.98 - if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or epoch == cfg.TRAIN.MAX_EPOCH): torch.save(image_encoder.state_dict(), '%s/image_encoder%d.pth' % (model_dir, epoch)) torch.save(text_encoder.state_dict(), '%s/text_encoder%d.pth' % (model_dir, epoch)) - print('Save G/Ds models.') + print(f'Save G/Ds models in {model_dir}.') except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') diff --git a/code/trainer.py b/code/trainer.py index a6d4180f..bd44057c 100644 --- a/code/trainer.py +++ b/code/trainer.py @@ -242,7 +242,7 @@ def train(self): ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### - data = data_iter.next() + data = next(data_iter) imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) @@ -274,7 +274,7 @@ def train(self): errD.backward() optimizersD[i].step() errD_total += errD - D_logs += 'errD%d: %.2f ' % (i, errD.data[0]) + D_logs += 'errD%d: %.2f ' % (i, errD.item()) ####################################################### # (4) Update G network: maximize log(D(G(z))) @@ -291,12 +291,12 @@ def train(self): words_embs, sent_emb, match_labels, cap_lens, class_ids) kl_loss = KL_loss(mu, logvar) errG_total += kl_loss - G_logs += 'kl_loss: %.2f ' % kl_loss.data[0] + G_logs += 'kl_loss: %.2f ' % kl_loss.item() # backward and update parameters errG_total.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): - avg_p.mul_(0.999).add_(0.001, p.data) + avg_p.mul_(0.999).add_(p.data, alpha=0.001) if gen_iterations % 100 == 0: print(D_logs + '\n' + G_logs) @@ -318,7 +318,7 @@ def train(self): print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, - errD_total.data[0], errG_total.data[0], + errD_total.item(), errG_total.item(), end_t - start_t)) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: @@ -370,8 +370,10 @@ def sampling(self, split_dir): batch_size = self.batch_size nz = cfg.GAN.Z_DIM - noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) - noise = noise.cuda() + + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, nz)) + noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ @@ -463,14 +465,18 @@ def gen_example(self, data_dic): batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM - captions = Variable(torch.from_numpy(captions), volatile=True) - cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) - captions = captions.cuda() - cap_lens = cap_lens.cuda() + with torch.no_grad(): + captions = Variable(torch.from_numpy(captions)) + cap_lens = Variable(torch.from_numpy(cap_lens)) + + captions = captions.cuda() + cap_lens = cap_lens.cuda() + for i in range(1): # 16 - noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) - noise = noise.cuda() + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, nz)) + noise = noise.cuda() ####################################################### # (1) Extract text embeddings ###################################################### diff --git a/eval/GlobalAttention.py b/eval/GlobalAttention.py index 501fb720..e59cbbc8 100644 --- a/eval/GlobalAttention.py +++ b/eval/GlobalAttention.py @@ -104,7 +104,7 @@ def forward(self, input, context): attn = attn.view(batch_size*queryL, sourceL) if self.mask is not None: # batch_size x sourceL --> batch_size*queryL x sourceL - mask = self.mask.repeat(queryL, 1) + mask = self.mask.repeat(queryL, 1).to(torch.bool) attn.data.masked_fill_(mask.data, -float('inf')) attn = self.sm(attn) # Eq. (2) # --> batch x queryL x sourceL diff --git a/eval/eval.py b/eval/eval.py index 48005f73..dc42f39c 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -54,16 +54,17 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM - captions = Variable(torch.from_numpy(captions), volatile=True) - cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) - noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) + with torch.no_grad(): + captions = Variable(torch.from_numpy(captions)) + cap_lens = Variable(torch.from_numpy(cap_lens)) + noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.CUDA: captions = captions.cuda() cap_lens = cap_lens.cuda() noise = noise.cuda() - + ####################################################### # (1) Extract text embeddings @@ -71,7 +72,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) - + ####################################################### # (2) Generate fake images @@ -131,7 +132,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() - + attn_maps = attention_maps[k] att_sze = attn_maps.size(2) @@ -152,7 +153,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi urls.append(full_path % blob_name) if copies == 2: break - + #print(len(urls), urls) return urls @@ -223,7 +224,7 @@ def eval(caption): if __name__ == "__main__": caption = "the bird has a yellow crown and a black eyering that is round" - + # load configuration #cfg_from_file('eval_bird.yml') # load word dictionaries @@ -232,9 +233,9 @@ def eval(caption): text_encoder, netG = models(len(wordtoix)) # load blob service blob_service = BlockBlobService(account_name='attgan', account_key='[REDACTED]') - + t0 = time.time() urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service) t1 = time.time() print(t1-t0) - print(urls) \ No newline at end of file + print(urls) diff --git a/eval/miscc/utils.py b/eval/miscc/utils.py index 13fc4739..1993b4c0 100644 --- a/eval/miscc/utils.py +++ b/eval/miscc/utils.py @@ -58,7 +58,7 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, dtype=np.uint8) real_imgs = \ - nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) + nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), mode='bilinear') # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() @@ -159,12 +159,12 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) if m.bias is not None: m.bias.data.fill_(0.0) diff --git a/eval/model.py b/eval/model.py index 6d37ab3d..183769c1 100644 --- a/eval/model.py +++ b/eval/model.py @@ -26,7 +26,7 @@ def __init__(self, ntoken, ninput=300, drop_prob=0.5, self.drop_prob = drop_prob # probability of an element to be zeroed self.nlayers = nlayers # Number of recurrent layers self.bidirectional = bidirectional - + if bidirectional: self.num_directions = 2 else: @@ -113,7 +113,7 @@ def __init__(self): nef = cfg.TEXT.EMBEDDING_DIM ncf = cfg.GAN.CONDITION_DIM - + self.ca_net = CA_NET() if cfg.TREE.BRANCH_NUM > 0: @@ -170,7 +170,7 @@ def __init__(self): self.t_dim = cfg.TEXT.EMBEDDING_DIM self.c_dim = cfg.GAN.CONDITION_DIM - + self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True) self.relu = GLU() @@ -204,7 +204,7 @@ def forward(self, x): nc = x.size(1) assert nc % 2 == 0, 'channels dont divide 2!' nc = int(nc/2) - return x[:, :nc] * F.sigmoid(x[:, nc:]) + return x[:, :nc] * torch.sigmoid(x[:, nc:]) def conv1x1(in_planes, out_planes, bias=False): @@ -222,7 +222,7 @@ def conv3x3(in_planes, out_planes): # Upsale the spatial size by a factor of 2 def upBlock(in_planes, out_planes): block = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), + nn.functional.interpolate(scale_factor=2, mode='nearest'), conv3x3(in_planes, out_planes * 2), nn.BatchNorm2d(out_planes * 2), GLU()) @@ -304,7 +304,7 @@ def init_trainable_weights(self): def forward(self, x): features = None # --> fixed-size input: batch x 3 x 299 x 299 - x = nn.Upsample(size=(299, 299), mode='bilinear')(x) + x = nn.functional.interpolate(x,size=(299, 299), mode='bilinear') # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32