1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
#include <gtest/gtest.h>
#include <c10/util/irange.h>
#include <test/cpp/api/support.h>
#include <torch/torch.h>
// Naive DFT of a 1 dimensional tensor
torch::Tensor naive_dft(torch::Tensor x, bool forward = true) {
TORCH_INTERNAL_ASSERT(x.dim() == 1);
x = x.contiguous();
auto out_tensor = torch::zeros_like(x);
const int64_t len = x.size(0);
// Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse
// transform
std::vector<c10::complex<double>> roots(len);
const auto angle_base = (forward ? -2.0 : 2.0) * M_PI / len;
for (const auto i : c10::irange(len)) {
auto angle = i * angle_base;
roots[i] = c10::complex<double>(std::cos(angle), std::sin(angle));
}
const auto in = x.data_ptr<c10::complex<double>>();
const auto out = out_tensor.data_ptr<c10::complex<double>>();
for (const auto i : c10::irange(len)) {
for (const auto j : c10::irange(len)) {
out[i] += roots[(j * i) % len] * in[j];
}
}
return out_tensor;
}
// NOTE: Visual Studio and ROCm builds don't understand complex literals
// as of August 2020
TEST(FFTTest, fft) {
auto t = torch::randn(128, torch::kComplexDouble);
auto actual = torch::fft::fft(t);
auto expect = naive_dft(t);
ASSERT_TRUE(torch::allclose(actual, expect));
}
TEST(FFTTest, fft_real) {
auto t = torch::randn(128, torch::kDouble);
auto actual = torch::fft::fft(t);
auto expect = torch::fft::fft(t.to(torch::kComplexDouble));
ASSERT_TRUE(torch::allclose(actual, expect));
}
TEST(FFTTest, fft_pad) {
auto t = torch::randn(128, torch::kComplexDouble);
auto actual = torch::fft::fft(t, 200);
auto expect = torch::fft::fft(torch::constant_pad_nd(t, {0, 72}));
ASSERT_TRUE(torch::allclose(actual, expect));
actual = torch::fft::fft(t, 64);
expect = torch::fft::fft(torch::constant_pad_nd(t, {0, -64}));
ASSERT_TRUE(torch::allclose(actual, expect));
}
TEST(FFTTest, fft_norm) {
auto t = torch::randn(128, torch::kComplexDouble);
// NOLINTNEXTLINE(bugprone-argument-comment)
auto unnorm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/{});
// NOLINTNEXTLINE(bugprone-argument-comment)
auto norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"forward");
ASSERT_TRUE(torch::allclose(unnorm / 128, norm));
// NOLINTNEXTLINE(bugprone-argument-comment)
auto ortho_norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"ortho");
ASSERT_TRUE(torch::allclose(unnorm / std::sqrt(128), ortho_norm));
}
TEST(FFTTest, ifft) {
auto T = torch::randn(128, torch::kComplexDouble);
auto actual = torch::fft::ifft(T);
auto expect = naive_dft(T, /*forward=*/false) / 128;
ASSERT_TRUE(torch::allclose(actual, expect));
}
TEST(FFTTest, fft_ifft) {
auto t = torch::randn(77, torch::kComplexDouble);
auto T = torch::fft::fft(t);
ASSERT_EQ(T.size(0), 77);
ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
auto t_round_trip = torch::fft::ifft(T);
ASSERT_EQ(t_round_trip.size(0), 77);
ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
ASSERT_TRUE(torch::allclose(t, t_round_trip));
}
TEST(FFTTest, rfft) {
auto t = torch::randn(129, torch::kDouble);
auto actual = torch::fft::rfft(t);
auto expect = torch::fft::fft(t.to(torch::kComplexDouble)).slice(0, 0, 65);
ASSERT_TRUE(torch::allclose(actual, expect));
}
TEST(FFTTest, rfft_irfft) {
auto t = torch::randn(128, torch::kDouble);
auto T = torch::fft::rfft(t);
ASSERT_EQ(T.size(0), 65);
ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
auto t_round_trip = torch::fft::irfft(T);
ASSERT_EQ(t_round_trip.size(0), 128);
ASSERT_EQ(t_round_trip.scalar_type(), torch::kDouble);
ASSERT_TRUE(torch::allclose(t, t_round_trip));
}
TEST(FFTTest, ihfft) {
auto T = torch::randn(129, torch::kDouble);
auto actual = torch::fft::ihfft(T);
auto expect = torch::fft::ifft(T.to(torch::kComplexDouble)).slice(0, 0, 65);
ASSERT_TRUE(torch::allclose(actual, expect));
}
TEST(FFTTest, hfft_ihfft) {
auto t = torch::randn(64, torch::kComplexDouble);
t[0] = .5; // Must be purely real to satisfy hermitian symmetry
auto T = torch::fft::hfft(t, 127);
ASSERT_EQ(T.size(0), 127);
ASSERT_EQ(T.scalar_type(), torch::kDouble);
auto t_round_trip = torch::fft::ihfft(T);
ASSERT_EQ(t_round_trip.size(0), 64);
ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
ASSERT_TRUE(torch::allclose(t, t_round_trip));
}
|