From 174647ecd902354fc21c2eaccfa0522b3fb71747 Mon Sep 17 00:00:00 2001 From: Sebastian Rassmann <39744645+sRassmann@users.noreply.github.com> Date: Mon, 4 Nov 2024 17:21:37 +0100 Subject: [PATCH] Fix bug in stage1 (en-/decoder pre-) training Fix the bug of choosing the trained target modality randomly. Instead source and target images are always identical --- VQ-GAN/taming/models/vqgan.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/VQ-GAN/taming/models/vqgan.py b/VQ-GAN/taming/models/vqgan.py index 48282da..deec908 100644 --- a/VQ-GAN/taming/models/vqgan.py +++ b/VQ-GAN/taming/models/vqgan.py @@ -1,3 +1,4 @@ +# fmt: off import random import torch import torch.nn as nn @@ -76,7 +77,7 @@ def decode_code(self, code_b): dec = self.decode(quant_b) return dec - def forward(self, input, target): + def forward(self, input, target=None): quant, diff, _ = self.encode(input) if target is not None: quant = self.spade(quant, target) dec = self.decode(quant) @@ -84,17 +85,20 @@ def forward(self, input, target): def get_input(self, batch, k): x = batch[k] - + return x.float() def training_step(self, batch, batch_idx, optimizer_idx): source = random.choice(self.modalities) - target = random.choice(self.modalities) + if self.stage == 1: + target = source + else: + target = random.choice(self.modalities) x_src = self.get_input(batch, source) x_tar = self.get_input(batch, target) skip_pass = 0 - if self.stage == 1: + if self.stage == 1: xrec, qloss = self(x_src) else: z_src, qloss, _ = self.encode(x_src) @@ -122,11 +126,14 @@ def training_step(self, batch, batch_idx, optimizer_idx): def validation_step(self, batch, batch_idx): source = random.choice(self.modalities) - target = random.choice(self.modalities) + if self.stage == 1: + target = source + else: + target = random.choice(self.modalities) x_src = self.get_input(batch, source) x_tar = self.get_input(batch, target) - - if self.stage == 1: + + if self.stage == 1: xrec, qloss = self(x_src) else: z_src, qloss, _ = self.encode(x_src) @@ -141,10 +148,10 @@ def validation_step(self, batch, batch_idx): discloss, log_dict_disc = self.loss(qloss, x_tar, xrec, 1, self.global_step, last_layer=self.get_last_layer(), split="val") rec_loss = log_dict_ae["val/rec_loss"] - self.log("val/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) - self.log("val/aeloss", aeloss, - prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + # self.log("val/rec_loss", rec_loss, + # prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + # self.log("val/aeloss", aeloss, + # prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict @@ -188,7 +195,7 @@ def log_images(self, batch, **kwargs): xrec = self.to_rgb(xrec) log["source"] = x_src log["target"] = x_tar - if self.stage == 1: + if self.stage == 1: log["recon"] = xrec else: log[f"recon_{source}_to_{target}"] = xrec @@ -448,4 +455,4 @@ def configure_optimizers(self): lr=lr, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] \ No newline at end of file + return [opt_ae, opt_disc], []