diff --git a/torch_dct/_dct.py b/torch_dct/_dct.py index a09453f..25fc6b9 100644 --- a/torch_dct/_dct.py +++ b/torch_dct/_dct.py @@ -2,6 +2,30 @@ import torch import torch.nn as nn +try: + # PyTorch 1.7.0 and newer versions + import torch.fft + + def dct1_rfft_impl(x): + return torch.view_as_real(torch.fft.rfft(x, dim=1)) + + def dct_fft_impl(v): + return torch.view_as_real(torch.fft.fft(v, dim=1)) + + def idct_irfft_impl(V): + return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) +except ImportError: + # PyTorch 1.6.0 and older versions + def dct1_rfft_impl(x): + return torch.rfft(x, 1) + + def dct_fft_impl(v): + return torch.rfft(v, 1, onesided=False) + + def idct_irfft_impl(V): + return torch.irfft(V, 1, onesided=False) + + def dct1(x): """ @@ -12,8 +36,9 @@ def dct1(x): """ x_shape = x.shape x = x.view(-1, x_shape[-1]) + x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1) - return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape) + return dct1_rfft_impl(x)[:, :, 0].view(*x_shape) def idct1(X): @@ -46,7 +71,7 @@ def dct(x, norm=None): v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) - Vc = torch.rfft(v, 1, onesided=False) + Vc = dct_fft_impl(v) k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) @@ -98,7 +123,7 @@ def idct(X, norm=None): V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) - v = torch.irfft(V, 1, onesided=False) + v = idct_irfft_impl(V) x = v.new_zeros(v.shape) x[:, ::2] += v[:, :N - (N // 2)] x[:, 1::2] += v.flip([1])[:, :N // 2]