diff --git a/train_student.py b/train_student.py index e9067d9..ce217f7 100644 --- a/train_student.py +++ b/train_student.py @@ -343,21 +343,22 @@ def __train_step(phase, epoch, global_step, global_test_step, y_hat = teacher(predict, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic h_pt_ps = 0 # TODO add some constrain on scale ,we want it to be small? - for i in range(sample_T): - # https://en.wikipedia.org/wiki/Logistic_distribution - u = Variable(torch.zeros(*x.size()).uniform_(1e-5,1-1e-5),requires_grad=False).cuda() - z = torch.log(u) - torch.log(1 - u) - student_predict = m + s * z # predicted wave - # student_predict.clamp(-0.99, 0.99) - student_predict = student_predict.permute(0, 2, 1) - _, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) - h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum() - student_predict = student_predict.permute(0, 2, 1) - power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=512, hop_length=128) - power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=256, hop_length=64) - power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=2048, hop_length=512) - power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=1024, hop_length=256) - power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=128, hop_length=32) + # for i in range(sample_T): + u = Variable(torch.from_numpy(np.random.uniform(1e-5, 1 - 1e-5, x.size())).float().cuda(), requires_grad=False) + z = torch.log(u) - torch.log(1 - u) + student_predict = m + s * z # predicted wave + student_predict.clamp(-0.99, 0.99) + student_predict = student_predict.permute(0, 2, 1) + _, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) + h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum() + student_predict = student_predict.permute(0, 2, 1) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=512) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=256) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=2048) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=1024) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=128) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=64) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=4096) a = s.permute(0, 2, 1) h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / ( mask.sum()) cross_entropy = h_pt_ps /(sample_T)