diff --git a/models/voicecraft.py b/models/voicecraft.py index 508e55f2..405c86a8 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -40,7 +40,7 @@ def top_k_top_p_filtering( max(top_k, min_tokens_to_keep), logits.size(-1) ) # Safety check # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + indices_to_remove = logits < torch.topk(logits, int(top_k))[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: