File: fft.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (130 lines) | stat: -rw-r--r-- 4,380 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <gtest/gtest.h>

#include <c10/util/irange.h>
#include <test/cpp/api/support.h>
#include <torch/torch.h>

// Naive DFT of a 1 dimensional tensor
torch::Tensor naive_dft(torch::Tensor x, bool forward = true) {
  TORCH_INTERNAL_ASSERT(x.dim() == 1);
  x = x.contiguous();
  auto out_tensor = torch::zeros_like(x);
  const int64_t len = x.size(0);

  // Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse
  // transform
  std::vector<c10::complex<double>> roots(len);
  const auto angle_base = (forward ? -2.0 : 2.0) * M_PI / len;
  for (const auto i : c10::irange(len)) {
    auto angle = i * angle_base;
    roots[i] = c10::complex<double>(std::cos(angle), std::sin(angle));
  }

  const auto in = x.data_ptr<c10::complex<double>>();
  const auto out = out_tensor.data_ptr<c10::complex<double>>();
  for (const auto i : c10::irange(len)) {
    for (const auto j : c10::irange(len)) {
      out[i] += roots[(j * i) % len] * in[j];
    }
  }
  return out_tensor;
}

// NOTE: Visual Studio and ROCm builds don't understand complex literals
//   as of August 2020

TEST(FFTTest, fft) {
  auto t = torch::randn(128, torch::kComplexDouble);
  auto actual = torch::fft::fft(t);
  auto expect = naive_dft(t);
  ASSERT_TRUE(torch::allclose(actual, expect));
}

TEST(FFTTest, fft_real) {
  auto t = torch::randn(128, torch::kDouble);
  auto actual = torch::fft::fft(t);
  auto expect = torch::fft::fft(t.to(torch::kComplexDouble));
  ASSERT_TRUE(torch::allclose(actual, expect));
}

TEST(FFTTest, fft_pad) {
  auto t = torch::randn(128, torch::kComplexDouble);
  auto actual = torch::fft::fft(t, 200);
  auto expect = torch::fft::fft(torch::constant_pad_nd(t, {0, 72}));
  ASSERT_TRUE(torch::allclose(actual, expect));

  actual = torch::fft::fft(t, 64);
  expect = torch::fft::fft(torch::constant_pad_nd(t, {0, -64}));
  ASSERT_TRUE(torch::allclose(actual, expect));
}

TEST(FFTTest, fft_norm) {
  auto t = torch::randn(128, torch::kComplexDouble);
  // NOLINTNEXTLINE(bugprone-argument-comment)
  auto unnorm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/{});
  // NOLINTNEXTLINE(bugprone-argument-comment)
  auto norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"forward");
  ASSERT_TRUE(torch::allclose(unnorm / 128, norm));

  // NOLINTNEXTLINE(bugprone-argument-comment)
  auto ortho_norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"ortho");
  ASSERT_TRUE(torch::allclose(unnorm / std::sqrt(128), ortho_norm));
}

TEST(FFTTest, ifft) {
  auto T = torch::randn(128, torch::kComplexDouble);
  auto actual = torch::fft::ifft(T);
  auto expect = naive_dft(T, /*forward=*/false) / 128;
  ASSERT_TRUE(torch::allclose(actual, expect));
}

TEST(FFTTest, fft_ifft) {
  auto t = torch::randn(77, torch::kComplexDouble);
  auto T = torch::fft::fft(t);
  ASSERT_EQ(T.size(0), 77);
  ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);

  auto t_round_trip = torch::fft::ifft(T);
  ASSERT_EQ(t_round_trip.size(0), 77);
  ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
  ASSERT_TRUE(torch::allclose(t, t_round_trip));
}

TEST(FFTTest, rfft) {
  auto t = torch::randn(129, torch::kDouble);
  auto actual = torch::fft::rfft(t);
  auto expect = torch::fft::fft(t.to(torch::kComplexDouble)).slice(0, 0, 65);
  ASSERT_TRUE(torch::allclose(actual, expect));
}

TEST(FFTTest, rfft_irfft) {
  auto t = torch::randn(128, torch::kDouble);
  auto T = torch::fft::rfft(t);
  ASSERT_EQ(T.size(0), 65);
  ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);

  auto t_round_trip = torch::fft::irfft(T);
  ASSERT_EQ(t_round_trip.size(0), 128);
  ASSERT_EQ(t_round_trip.scalar_type(), torch::kDouble);
  ASSERT_TRUE(torch::allclose(t, t_round_trip));
}

TEST(FFTTest, ihfft) {
  auto T = torch::randn(129, torch::kDouble);
  auto actual = torch::fft::ihfft(T);
  auto expect = torch::fft::ifft(T.to(torch::kComplexDouble)).slice(0, 0, 65);
  ASSERT_TRUE(torch::allclose(actual, expect));
}

TEST(FFTTest, hfft_ihfft) {
  auto t = torch::randn(64, torch::kComplexDouble);
  t[0] = .5; // Must be purely real to satisfy hermitian symmetry
  auto T = torch::fft::hfft(t, 127);
  ASSERT_EQ(T.size(0), 127);
  ASSERT_EQ(T.scalar_type(), torch::kDouble);

  auto t_round_trip = torch::fft::ihfft(T);
  ASSERT_EQ(t_round_trip.size(0), 64);
  ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
  ASSERT_TRUE(torch::allclose(t, t_round_trip));
}