@@ -769,7 +769,7 @@ def sample(self, sample_shape=torch.Size()):
769769 dm_sample:
770770 Sample(s) from the Dirichlet Multinomial distribution.
771771 """
772- shape = self ._extended_shape (sample_shape )
772+ # shape = self._extended_shape(sample_shape)
773773 p = td .Dirichlet (self .concentration ).sample (sample_shape )
774774
775775 batch_dims = p .shape [:- 1 ]
@@ -1207,7 +1207,7 @@ def kl_div_loss(self, x):
12071207 KL-divergence loss
12081208 """
12091209 q = self .encoder (x )
1210- z = q .rsample ()
1210+ # z = q.rsample()
12111211 kl_loss = torch .mean (
12121212 self .beta * td .kl_divergence (q , self .prior ()),
12131213 dim = 0 ,
@@ -1359,7 +1359,7 @@ def elbo(self, x):
13591359
13601360 def kl_div_loss (self , x ):
13611361 q = self .encoder (x )
1362- z = q .rsample ()
1362+ # z = q.rsample()
13631363 kl_loss = torch .mean (
13641364 self .beta * td .kl_divergence (q , self .prior ()),
13651365 dim = 0 ,
@@ -2035,7 +2035,7 @@ def train_abaco(
20352035 for loader_data in data_iter :
20362036 x = loader_data [0 ].to (device )
20372037 y = loader_data [1 ].to (device ).float () # Batch label
2038- z = loader_data [2 ].to (device ).float () # Bio type label
2038+ # z = loader_data[2].to(device).float() # Bio type label
20392039
20402040 # VAE ELBO computation with masked batch label
20412041 vae_optim_post .zero_grad ()
@@ -2050,8 +2050,8 @@ def train_abaco(
20502050 p_xz = vae .decoder (torch .cat ([latent_points , alpha * y ], dim = 1 ))
20512051
20522052 # Log probabilities of prior and posterior
2053- log_q_zx = q_zx .log_prob (latent_points )
2054- log_p_z = vae .log_prob (latent_points )
2053+ # log_q_zx = q_zx.log_prob(latent_points)
2054+ # log_p_z = vae.log_prob(latent_points)
20552055
20562056 # Compute ELBO
20572057 recon_term = p_xz .log_prob (x ).mean ()
@@ -2829,7 +2829,7 @@ def train_abaco_ensemble(
28292829 for loader_data in data_iter :
28302830 x = loader_data [0 ].to (device )
28312831 y = loader_data [1 ].to (device ).float () # Batch label
2832- z = loader_data [2 ].to (device ).float () # Bio type label
2832+ # z = loader_data[2].to(device).float() # Bio type label
28332833
28342834 # VAE ELBO computation with masked batch label
28352835 vae_optim_post .zero_grad ()
@@ -2849,8 +2849,8 @@ def train_abaco_ensemble(
28492849 p_xzs .append (p_xz )
28502850
28512851 # Log probabilities of prior and posterior
2852- log_q_zx = q_zx .log_prob (latent_points )
2853- log_p_z = vae .log_prob (latent_points )
2852+ # log_q_zx = q_zx.log_prob(latent_points)
2853+ # log_p_z = vae.log_prob(latent_points)
28542854
28552855 # Compute ELBO
28562856
@@ -4529,7 +4529,7 @@ def correct(
45294529 for loader_data in iter (self .dataloader ):
45304530 x = loader_data [0 ].to (self .device )
45314531 ohe_batch = loader_data [1 ].to (self .device ).float () # Batch label
4532- ohe_bio = loader_data [2 ].to (self .device ).float () # Bio type label
4532+ # ohe_bio = loader_data[2].to(self.device).float() # Bio type label
45334533
45344534 # Encode and decode the input data along with the one-hot encoded batch label
45354535 q_zx = self .vae .encoder (torch .cat ([x , ohe_batch ], dim = 1 )) # td.Distribution
0 commit comments