From 66188c5edb15b92fa791b0b8d8234ba1da779e3a Mon Sep 17 00:00:00 2001 From: Pulkit Khandelwal Date: Mon, 2 Dec 2024 00:00:21 -0500 Subject: [PATCH 1/2] Update normalization.py --- LDM/ldm/models/normalization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/LDM/ldm/models/normalization.py b/LDM/ldm/models/normalization.py index 97b6a19..23d2c2a 100644 --- a/LDM/ldm/models/normalization.py +++ b/LDM/ldm/models/normalization.py @@ -15,7 +15,7 @@ def __init__(self, norm_nc, label_nc, kernel_size=3, norm_type='instance'): % norm_type) # The dimension of the intermediate embedding space. Yes, hardcoded. - nhidden = 128 + nhidden = 64 pw = kernel_size // 2 self.mlp_shared = nn.Sequential( @@ -94,7 +94,7 @@ def actvn(self, x): class SPADEGenerator(nn.Module): def __init__(self,modalities, z_dim=3): super().__init__() - nf = 128 + nf = 64 self.in_spade = SPADEResnetBlock(modalities, z_dim, nf) self.out_spade = SPADEResnetBlock(modalities, nf, z_dim) @@ -103,4 +103,4 @@ def forward(self, x, modality): x = self.in_spade(x, modality) x = self.out_spade(x, modality) - return x \ No newline at end of file + return x From 86980d0c7a81f9accda62db2d3d073c7b83a4ef3 Mon Sep 17 00:00:00 2001 From: Pulkit Khandelwal Date: Mon, 2 Dec 2024 00:01:04 -0500 Subject: [PATCH 2/2] Update ddpm.py --- LDM/ldm/models/diffusion/ddpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LDM/ldm/models/diffusion/ddpm.py b/LDM/ldm/models/diffusion/ddpm.py index f80ca62..ebe42dc 100644 --- a/LDM/ldm/models/diffusion/ddpm.py +++ b/LDM/ldm/models/diffusion/ddpm.py @@ -939,7 +939,7 @@ def p_losses(self, x_src, x_tgt, cond, t, noise=None): loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - logvar_t = self.logvar[t].to(self.device) + logvar_t = self.logvar[t.cpu()].to(self.device) logvar_t = repeat(logvar_t, 'b -> b c', c=loss_simple.shape[1]) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar