From 2bfba8708ff16c6b6d0c63eab9fe99be3269b07e Mon Sep 17 00:00:00 2001 From: Jaroslaw Bojar Date: Wed, 11 Aug 2021 13:15:57 +0200 Subject: [PATCH 1/2] Switch to torch.fft module --- torch_dct/_dct.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_dct/_dct.py b/torch_dct/_dct.py index a09453f..73dd4fc 100644 --- a/torch_dct/_dct.py +++ b/torch_dct/_dct.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch.nn as nn +import torch.fft def dct1(x): @@ -13,7 +14,7 @@ def dct1(x): x_shape = x.shape x = x.view(-1, x_shape[-1]) - return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape) + return torch.view_as_real(torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), dim=1))[:, :, 0].view(*x_shape) def idct1(X): @@ -46,7 +47,7 @@ def dct(x, norm=None): v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) - Vc = torch.rfft(v, 1, onesided=False) + Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) @@ -98,7 +99,7 @@ def idct(X, norm=None): V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) - v = torch.irfft(V, 1, onesided=False) + v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) x = v.new_zeros(v.shape) x[:, ::2] += v[:, :N - (N // 2)] x[:, 1::2] += v.flip([1])[:, :N // 2] From 0002278069db971054a5eb618588267defa83f72 Mon Sep 17 00:00:00 2001 From: Jaroslaw Bojar Date: Wed, 11 Aug 2021 13:51:26 +0200 Subject: [PATCH 2/2] Compatibility with older pytorch versions --- torch_dct/_dct.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/torch_dct/_dct.py b/torch_dct/_dct.py index 73dd4fc..25fc6b9 100644 --- a/torch_dct/_dct.py +++ b/torch_dct/_dct.py @@ -1,7 +1,30 @@ import numpy as np import torch import torch.nn as nn -import torch.fft + +try: + # PyTorch 1.7.0 and newer versions + import torch.fft + + def dct1_rfft_impl(x): + return torch.view_as_real(torch.fft.rfft(x, dim=1)) + + def dct_fft_impl(v): + return torch.view_as_real(torch.fft.fft(v, dim=1)) + + def idct_irfft_impl(V): + return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) +except ImportError: + # PyTorch 1.6.0 and older versions + def dct1_rfft_impl(x): + return torch.rfft(x, 1) + + def dct_fft_impl(v): + return torch.rfft(v, 1, onesided=False) + + def idct_irfft_impl(V): + return torch.irfft(V, 1, onesided=False) + def dct1(x): @@ -13,8 +36,9 @@ def dct1(x): """ x_shape = x.shape x = x.view(-1, x_shape[-1]) + x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1) - return torch.view_as_real(torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), dim=1))[:, :, 0].view(*x_shape) + return dct1_rfft_impl(x)[:, :, 0].view(*x_shape) def idct1(X): @@ -47,7 +71,7 @@ def dct(x, norm=None): v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) - Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) + Vc = dct_fft_impl(v) k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) @@ -99,7 +123,7 @@ def idct(X, norm=None): V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) - v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) + v = idct_irfft_impl(V) x = v.new_zeros(v.shape) x[:, ::2] += v[:, :N - (N // 2)] x[:, 1::2] += v.flip([1])[:, :N // 2]