From 8e30319d5ce4fe3b0aa693bdbdcf6041931b6a37 Mon Sep 17 00:00:00 2001 From: zhf459 <459314612@qq.com> Date: Thu, 19 Apr 2018 09:50:53 +0800 Subject: [PATCH] Update train_student.py --- train_student.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/train_student.py b/train_student.py index 9968f4c..4f4b59a 100644 --- a/train_student.py +++ b/train_student.py @@ -357,22 +357,22 @@ def __train_step(phase, epoch, global_step, global_test_step, y_hat = teacher(z, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic h_pt_ps = 0 # TODO add some constrain on scale ,we want it to be small? - for i in range(sample_T): - u = Variable(torch.from_numpy(np.random.uniform(1e-5, 1 - 1e-5, x.size())).float().cuda(), requires_grad=False) - z = torch.log(u) - torch.log(1 - u) - student_predict = m + s * z # predicted wave - student_predict.clamp(-0.99, 0.99) - student_predict = student_predict.permute(0, 2, 1) - _, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) - h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum() - student_predict = student_predict.permute(0, 2, 1) - power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=512) - power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=256) - power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=2048) - power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=1024) - power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=128) - power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=64) - power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=4096) + # for i in range(sample_T): + u = Variable(torch.from_numpy(np.random.uniform(1e-5, 1 - 1e-5, x.size())).float().cuda(), requires_grad=False) + z = torch.log(u) - torch.log(1 - u) + student_predict = m + s * z # predicted wave + student_predict.clamp(-0.99, 0.99) + student_predict = student_predict.permute(0, 2, 1) + _, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) + h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum() + student_predict = student_predict.permute(0, 2, 1) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=512) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=256) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=2048) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=1024) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=128) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=64) + power_loss_sum += get_power_loss_torch(student_predict, x,n_fft=4096) a = s.permute(0, 2, 1) h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / mask.sum() cross_entropy = h_pt_ps / (sample_T)