diff --git a/models/voicecraft.py b/models/voicecraft.py index 8d83729d..45819b13 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -678,7 +678,7 @@ def inference( ##################### silence repetition handling ##################### # prepare the cache placeholder # n_layers, 2, bsz, num_heads, src_len, head_dim - past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None # handle multi-span kv-cache new_masked_span = False @@ -978,7 +978,7 @@ def inference_tts( # prepare the cache placeholder # n_layers, 2, bsz, num_heads, src_len, head_dim - past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") @@ -1228,7 +1228,7 @@ def inference_tts_batch( # prepare the cache placeholder # n_layers, 2, bsz, num_heads, src_len, head_dim - past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None + past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float16) if kvcache else None # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}") @@ -1403,4 +1403,4 @@ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_t res = res - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special) - return res, flatten_gen[0].unsqueeze(0) \ No newline at end of file + return res, flatten_gen[0].unsqueeze(0)