Why don't you use reference in the test mode?
def evaluate(self, batch, n_samples):
(
observed_data,
observed_mask,
observed_tp,
gt_mask,
_, #you don't use reference
_,
) = self.process_data(batch)
with torch.no_grad():
cond_mask = gt_mask
target_mask = observed_mask * (1-gt_mask)
side_info = self.get_side_info(observed_tp, cond_mask)
samples = self.impute(observed_data, cond_mask, side_info, n_samples)
return samples, observed_data, target_mask, observed_mask, observed_tp
Why don't you use reference in the test mode?
def evaluate(self, batch, n_samples):
(
observed_data,
observed_mask,
observed_tp,
gt_mask,
_, #you don't use reference
_,
) = self.process_data(batch)