File: test_spectral_ops.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (1582 lines) | stat: -rw-r--r-- 64,256 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
# Owner(s): ["module: fft"]

import torch
import unittest
import math
from contextlib import contextmanager
from itertools import product
import itertools
import doctest
import inspect

from torch.testing._internal.common_utils import \
    (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
     make_tensor)
from torch.testing._internal.common_device_type import \
    (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
     skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
from torch.testing._internal.common_methods_invocations import (
    spectral_funcs, SpectralFuncType)
from torch.testing._internal.common_cuda import SM53OrLater

from setuptools import distutils
from typing import Optional, List


if TEST_NUMPY:
    import numpy as np


if TEST_LIBROSA:
    import librosa

has_scipy_fft = False
try:
    import scipy.fft
    has_scipy_fft = True
except ModuleNotFoundError:
    pass

LooseVersion = distutils.version.LooseVersion
REFERENCE_NORM_MODES = (
    (None, "forward", "backward", "ortho")
    if LooseVersion(np.__version__) >= '1.20.0' and (
        not has_scipy_fft or LooseVersion(scipy.__version__) >= '1.6.0')
    else (None, "ortho"))


def _complex_stft(x, *args, **kwargs):
    # Transform real and imaginary components separably
    stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False)
    stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False)
    return stft_real + 1j * stft_imag


def _hermitian_conj(x, dim):
    """Returns the hermitian conjugate along a single dimension

    H(x)[i] = conj(x[-i])
    """
    out = torch.empty_like(x)
    mid = (x.size(dim) - 1) // 2
    idx = [slice(None)] * out.dim()
    idx_center = list(idx)
    idx_center[dim] = 0
    out[idx] = x[idx]

    idx_neg = list(idx)
    idx_neg[dim] = slice(-mid, None)
    idx_pos = idx
    idx_pos[dim] = slice(1, mid + 1)

    out[idx_pos] = x[idx_neg].flip(dim)
    out[idx_neg] = x[idx_pos].flip(dim)
    if (2 * mid + 1 < x.size(dim)):
        idx[dim] = mid + 1
        out[idx] = x[idx]
    return out.conj()


def _complex_istft(x, *args, **kwargs):
    # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary)
    n_fft = x.size(-2)
    slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None))

    hconj = _hermitian_conj(x, dim=-2)
    x_hermitian = (x + hconj) / 2
    x_antihermitian = (x - hconj) / 2
    istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True)
    istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True)
    return torch.complex(istft_real, istft_imag)


def _stft_reference(x, hop_length, window):
    r"""Reference stft implementation

    This doesn't implement all of torch.stft, only the STFT definition:

    .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega}

    """
    n_fft = window.numel()
    X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length),
                    device=x.device, dtype=torch.cdouble)
    for m in range(X.size(1)):
        start = m * hop_length
        if start + n_fft > x.numel():
            slc = torch.empty(n_fft, device=x.device, dtype=x.dtype)
            tmp = x[start:]
            slc[:tmp.numel()] = tmp
        else:
            slc = x[start: start + n_fft]
        X[:, m] = torch.fft.fft(slc * window)
    return X


def skip_helper_for_fft(device, dtype):
    device_type = torch.device(device).type
    if dtype not in (torch.half, torch.complex32):
        return

    if device_type == 'cpu':
        raise unittest.SkipTest("half and complex32 are not supported on CPU")
    if TEST_WITH_ROCM:
        raise unittest.SkipTest("half and complex32 are not supported on ROCM")
    if not SM53OrLater:
        raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")


# Tests of functions related to Fourier analysis in the torch.fft namespace
class TestFFT(TestCase):
    exact_dtype = True

    @onlyNativeDeviceTypes
    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD],
         allowed_dtypes=(torch.float, torch.cfloat))
    def test_reference_1d(self, device, dtype, op):
        if op.ref is None:
            raise unittest.SkipTest("No reference implementation")

        norm_modes = REFERENCE_NORM_MODES
        test_args = [
            *product(
                # input
                (torch.randn(67, device=device, dtype=dtype),
                 torch.randn(80, device=device, dtype=dtype),
                 torch.randn(12, 14, device=device, dtype=dtype),
                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
                # n
                (None, 50, 6),
                # dim
                (-1, 0),
                # norm
                norm_modes
            ),
            # Test transforming middle dimensions of multi-dim tensor
            *product(
                (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),),
                (None,),
                (1, 2, -2,),
                norm_modes
            )
        ]

        for iargs in test_args:
            args = list(iargs)
            input = args[0]
            args = args[1:]

            expected = op.ref(input.cpu().numpy(), *args)
            exact_dtype = dtype in (torch.double, torch.complex128)
            actual = op(input, *args)
            self.assertEqual(actual, expected, exact_dtype=exact_dtype)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @toleranceOverride({
        torch.half : tol(1e-2, 1e-2),
        torch.chalf : tol(1e-2, 1e-2),
    })
    @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
    def test_fft_round_trip(self, device, dtype):
        skip_helper_for_fft(device, dtype)
        # Test that round trip through ifft(fft(x)) is the identity
        if dtype not in (torch.half, torch.complex32):
            test_args = list(product(
                # input
                (torch.randn(67, device=device, dtype=dtype),
                 torch.randn(80, device=device, dtype=dtype),
                 torch.randn(12, 14, device=device, dtype=dtype),
                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
                # dim
                (-1, 0),
                # norm
                (None, "forward", "backward", "ortho")
            ))
        else:
            # cuFFT supports powers of 2 for half and complex half precision
            test_args = list(product(
                # input
                (torch.randn(64, device=device, dtype=dtype),
                 torch.randn(128, device=device, dtype=dtype),
                 torch.randn(4, 16, device=device, dtype=dtype),
                 torch.randn(8, 6, 2, device=device, dtype=dtype)),
                # dim
                (-1, 0),
                # norm
                (None, "forward", "backward", "ortho")
            ))

        fft_functions = [(torch.fft.fft, torch.fft.ifft)]
        # Real-only functions
        if not dtype.is_complex:
            # NOTE: Using ihfft as "forward" transform to avoid needing to
            # generate true half-complex input
            fft_functions += [(torch.fft.rfft, torch.fft.irfft),
                              (torch.fft.ihfft, torch.fft.hfft)]

        for forward, backward in fft_functions:
            for x, dim, norm in test_args:
                kwargs = {
                    'n': x.size(dim),
                    'dim': dim,
                    'norm': norm,
                }

                y = backward(forward(x, **kwargs), **kwargs)
                if x.dtype is torch.half and y.dtype is torch.complex32:
                    # Since type promotion currently doesn't work with complex32
                    # manually promote `x` to complex32
                    x = x.to(torch.complex32)
                # For real input, ifft(fft(x)) will convert to complex
                self.assertEqual(x, y, exact_dtype=(
                    forward != torch.fft.fft or x.is_complex()))

    # Note: NumPy will throw a ValueError for an empty input
    @onlyNativeDeviceTypes
    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
    def test_empty_fft(self, device, dtype, op):
        t = torch.empty(1, 0, device=device, dtype=dtype)
        match = r"Invalid number of data points \([-\d]*\) specified"

        with self.assertRaisesRegex(RuntimeError, match):
            op(t)

    @onlyNativeDeviceTypes
    def test_empty_ifft(self, device):
        t = torch.empty(2, 1, device=device, dtype=torch.complex64)
        match = r"Invalid number of data points \([-\d]*\) specified"

        for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn,
                  torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]:
            with self.assertRaisesRegex(RuntimeError, match):
                f(t)

    @onlyNativeDeviceTypes
    def test_fft_invalid_dtypes(self, device):
        t = torch.randn(64, device=device, dtype=torch.complex128)

        with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"):
            torch.fft.rfft(t)

        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"):
            torch.fft.rfftn(t)

        with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"):
            torch.fft.ihfft(t)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @dtypes(torch.int8, torch.half, torch.float, torch.double,
            torch.complex32, torch.complex64, torch.complex128)
    def test_fft_type_promotion(self, device, dtype):
        skip_helper_for_fft(device, dtype)

        if dtype.is_complex or dtype.is_floating_point:
            t = torch.randn(64, device=device, dtype=dtype)
        else:
            t = torch.randint(-2, 2, (64,), device=device, dtype=dtype)

        PROMOTION_MAP = {
            torch.int8: torch.complex64,
            torch.half: torch.complex32,
            torch.float: torch.complex64,
            torch.double: torch.complex128,
            torch.complex32: torch.complex32,
            torch.complex64: torch.complex64,
            torch.complex128: torch.complex128,
        }
        T = torch.fft.fft(t)
        self.assertEqual(T.dtype, PROMOTION_MAP[dtype])

        PROMOTION_MAP_C2R = {
            torch.int8: torch.float,
            torch.half: torch.half,
            torch.float: torch.float,
            torch.double: torch.double,
            torch.complex32: torch.half,
            torch.complex64: torch.float,
            torch.complex128: torch.double,
        }
        if dtype in (torch.half, torch.complex32):
            # cuFFT supports powers of 2 for half and complex half precision
            # NOTE: With hfft and default args where output_size n=2*(input_size - 1),
            # we make sure that logical fft size is a power of two.
            x = torch.randn(65, device=device, dtype=dtype)
            R = torch.fft.hfft(x)
        else:
            R = torch.fft.hfft(t)
        self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])

        if not dtype.is_complex:
            PROMOTION_MAP_R2C = {
                torch.int8: torch.complex64,
                torch.half: torch.complex32,
                torch.float: torch.complex64,
                torch.double: torch.complex128,
            }
            C = torch.fft.rfft(t)
            self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype])

    @onlyNativeDeviceTypes
    @ops(spectral_funcs, dtypes=OpDTypes.unsupported,
         allowed_dtypes=[torch.half, torch.bfloat16])
    def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
        # TODO: Remove torch.half error when complex32 is fully implemented
        sample = first_sample(self, op.sample_inputs(device, dtype))
        device_type = torch.device(device).type
        if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
            err_msg = "Unsupported dtype "
        elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
            err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
        else:
            err_msg = "Unsupported dtype "
        with self.assertRaisesRegex(RuntimeError, err_msg):
            op(sample.input, *sample.args, **sample.kwargs)

    @onlyNativeDeviceTypes
    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
    def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
        t = make_tensor(13, 13, device=device, dtype=dtype)
        err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
        with self.assertRaisesRegex(RuntimeError, err_msg):
            op(t)

        if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
            kwargs = {'s': (12, 12)}
        else:
            kwargs = {'n': 12}

        with self.assertRaisesRegex(RuntimeError, err_msg):
            op(t, **kwargs)

    # nd-fft tests
    @onlyNativeDeviceTypes
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
         allowed_dtypes=(torch.cfloat, torch.cdouble))
    def test_reference_nd(self, device, dtype, op):
        if op.ref is None:
            raise unittest.SkipTest("No reference implementation")

        norm_modes = REFERENCE_NORM_MODES

        # input_ndim, s, dim
        transform_desc = [
            *product(range(2, 5), (None,), (None, (0,), (0, -1))),
            *product(range(2, 5), (None, (4, 10)), (None,)),
            (6, None, None),
            (5, None, (1, 3, 4)),
            (3, None, (1,)),
            (1, None, (0,)),
            (4, (10, 10), None),
            (4, (10, 10), (0, 1))
        ]

        for input_ndim, s, dim in transform_desc:
            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
            input = torch.randn(*shape, device=device, dtype=dtype)

            for norm in norm_modes:
                expected = op.ref(input.cpu().numpy(), s, dim, norm)
                exact_dtype = dtype in (torch.double, torch.complex128)
                actual = op(input, s, dim, norm)
                self.assertEqual(actual, expected, exact_dtype=exact_dtype)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @toleranceOverride({
        torch.half : tol(1e-2, 1e-2),
        torch.chalf : tol(1e-2, 1e-2),
    })
    @dtypes(torch.half, torch.float, torch.double,
            torch.complex32, torch.complex64, torch.complex128)
    def test_fftn_round_trip(self, device, dtype):
        skip_helper_for_fft(device, dtype)

        norm_modes = (None, "forward", "backward", "ortho")

        # input_ndim, dim
        transform_desc = [
            *product(range(2, 5), (None, (0,), (0, -1))),
            (7, None),
            (5, (1, 3, 4)),
            (3, (1,)),
            (1, 0),
        ]

        fft_functions = [(torch.fft.fftn, torch.fft.ifftn)]

        # Real-only functions
        if not dtype.is_complex:
            # NOTE: Using ihfftn as "forward" transform to avoid needing to
            # generate true half-complex input
            fft_functions += [(torch.fft.rfftn, torch.fft.irfftn),
                              (torch.fft.ihfftn, torch.fft.hfftn)]

        for input_ndim, dim in transform_desc:
            if dtype in (torch.half, torch.complex32):
                # cuFFT supports powers of 2 for half and complex half precision
                shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
            else:
                shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
            x = torch.randn(*shape, device=device, dtype=dtype)

            for (forward, backward), norm in product(fft_functions, norm_modes):
                if isinstance(dim, tuple):
                    s = [x.size(d) for d in dim]
                else:
                    s = x.size() if dim is None else x.size(dim)

                kwargs = {'s': s, 'dim': dim, 'norm': norm}
                y = backward(forward(x, **kwargs), **kwargs)
                # For real input, ifftn(fftn(x)) will convert to complex
                if x.dtype is torch.half and y.dtype is torch.chalf:
                    # Since type promotion currently doesn't work with complex32
                    # manually promote `x` to complex32
                    self.assertEqual(x.to(torch.chalf), y)
                else:
                    self.assertEqual(x, y, exact_dtype=(
                        forward != torch.fft.fftn or x.is_complex()))

    @onlyNativeDeviceTypes
    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
         allowed_dtypes=[torch.float, torch.cfloat])
    def test_fftn_invalid(self, device, dtype, op):
        a = torch.rand(10, 10, 10, device=device, dtype=dtype)

        with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
            op(a, dim=(0, 1, 0))

        with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
            op(a, dim=(2, -1))

        with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
            op(a, s=(1,), dim=(0, 1))

        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
            op(a, dim=(3,))

        with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
            op(a, s=(10, 10, 10, 10))

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @toleranceOverride({
        torch.half : tol(1e-2, 1e-2),
    })
    @dtypes(torch.half, torch.float, torch.double)
    def test_hfftn(self, device, dtype):
        skip_helper_for_fft(device, dtype)

        # input_ndim, dim
        transform_desc = [
            *product(range(2, 5), (None, (0,), (0, -1))),
            (6, None),
            (5, (1, 3, 4)),
            (3, (1,)),
            (1, (0,)),
            (4, (0, 1))
        ]

        for input_ndim, dim in transform_desc:
            actual_dims = list(range(input_ndim)) if dim is None else dim
            if dtype is torch.half:
                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
            else:
                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
            expect = torch.randn(*shape, device=device, dtype=dtype)
            input = torch.fft.ifftn(expect, dim=dim, norm="ortho")

            lastdim = actual_dims[-1]
            lastdim_size = input.size(lastdim) // 2 + 1
            idx = [slice(None)] * input_ndim
            idx[lastdim] = slice(0, lastdim_size)
            input = input[idx]

            s = [shape[dim] for dim in actual_dims]
            actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho")

            self.assertEqual(expect, actual)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @toleranceOverride({
        torch.half : tol(1e-2, 1e-2),
    })
    @dtypes(torch.half, torch.float, torch.double)
    def test_ihfftn(self, device, dtype):
        skip_helper_for_fft(device, dtype)

        # input_ndim, dim
        transform_desc = [
            *product(range(2, 5), (None, (0,), (0, -1))),
            (6, None),
            (5, (1, 3, 4)),
            (3, (1,)),
            (1, (0,)),
            (4, (0, 1))
        ]

        for input_ndim, dim in transform_desc:
            if dtype is torch.half:
                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
            else:
                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))

            input = torch.randn(*shape, device=device, dtype=dtype)
            expect = torch.fft.ifftn(input, dim=dim, norm="ortho")

            # Slice off the half-symmetric component
            lastdim = -1 if dim is None else dim[-1]
            lastdim_size = expect.size(lastdim) // 2 + 1
            idx = [slice(None)] * input_ndim
            idx[lastdim] = slice(0, lastdim_size)
            expect = expect[idx]

            actual = torch.fft.ihfftn(input, dim=dim, norm="ortho")
            self.assertEqual(expect, actual)


    # 2d-fft tests

    # NOTE: 2d transforms are only thin wrappers over n-dim transforms,
    # so don't require exhaustive testing.


    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @dtypes(torch.double, torch.complex128)
    def test_fft2_numpy(self, device, dtype):
        norm_modes = REFERENCE_NORM_MODES

        # input_ndim, s
        transform_desc = [
            *product(range(2, 5), (None, (4, 10))),
        ]

        fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2']
        if dtype.is_floating_point:
            fft_functions += ['rfft2', 'ihfft2']

        for input_ndim, s in transform_desc:
            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
            input = torch.randn(*shape, device=device, dtype=dtype)
            for fname, norm in product(fft_functions, norm_modes):
                torch_fn = getattr(torch.fft, fname)
                if "hfft" in fname:
                    if not has_scipy_fft:
                        continue  # Requires scipy to compare against
                    numpy_fn = getattr(scipy.fft, fname)
                else:
                    numpy_fn = getattr(np.fft, fname)

                def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
                    return torch_fn(t, s, dim, norm)

                torch_fns = (torch_fn, torch.jit.script(fn))

                # Once with dim defaulted
                input_np = input.cpu().numpy()
                expected = numpy_fn(input_np, s, norm=norm)
                for fn in torch_fns:
                    actual = fn(input, s, norm=norm)
                    self.assertEqual(actual, expected)

                # Once with explicit dims
                dim = (1, 0)
                expected = numpy_fn(input_np, s, dim, norm)
                for fn in torch_fns:
                    actual = fn(input, s, dim, norm)
                    self.assertEqual(actual, expected)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @dtypes(torch.float, torch.complex64)
    def test_fft2_fftn_equivalence(self, device, dtype):
        norm_modes = (None, "forward", "backward", "ortho")

        # input_ndim, s, dim
        transform_desc = [
            *product(range(2, 5), (None, (4, 10)), (None, (1, 0))),
            (3, None, (0, 2)),
        ]

        fft_functions = ['fft', 'ifft', 'irfft', 'hfft']
        # Real-only functions
        if dtype.is_floating_point:
            fft_functions += ['rfft', 'ihfft']

        for input_ndim, s, dim in transform_desc:
            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
            x = torch.randn(*shape, device=device, dtype=dtype)

            for func, norm in product(fft_functions, norm_modes):
                f2d = getattr(torch.fft, func + '2')
                fnd = getattr(torch.fft, func + 'n')

                kwargs = {'s': s, 'norm': norm}

                if dim is not None:
                    kwargs['dim'] = dim
                    expect = fnd(x, **kwargs)
                else:
                    expect = fnd(x, dim=(-2, -1), **kwargs)

                actual = f2d(x, **kwargs)

                self.assertEqual(actual, expect)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    def test_fft2_invalid(self, device):
        a = torch.rand(10, 10, 10, device=device)
        fft_funcs = (torch.fft.fft2, torch.fft.ifft2,
                     torch.fft.rfft2, torch.fft.irfft2)

        for func in fft_funcs:
            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
                func(a, dim=(0, 0))

            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
                func(a, dim=(2, -1))

            with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
                func(a, s=(1,))

            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
                func(a, dim=(2, 3))

        c = torch.complex(a, a)
        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"):
            torch.fft.rfft2(c)

    # Helper functions

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
    @dtypes(torch.float, torch.double)
    def test_fftfreq_numpy(self, device, dtype):
        test_args = [
            *product(
                # n
                range(1, 20),
                # d
                (None, 10.0),
            )
        ]

        functions = ['fftfreq', 'rfftfreq']

        for fname in functions:
            torch_fn = getattr(torch.fft, fname)
            numpy_fn = getattr(np.fft, fname)

            for n, d in test_args:
                args = (n,) if d is None else (n, d)
                expected = numpy_fn(*args)
                actual = torch_fn(*args, device=device, dtype=dtype)
                self.assertEqual(actual, expected, exact_dtype=False)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @dtypes(torch.float, torch.double)
    def test_fftfreq_out(self, device, dtype):
        for func in (torch.fft.fftfreq, torch.fft.rfftfreq):
            expect = func(n=100, d=.5, device=device, dtype=dtype)
            actual = torch.empty((), device=device, dtype=dtype)
            with self.assertWarnsRegex(UserWarning, "out tensor will be resized"):
                func(n=100, d=.5, out=actual)
            self.assertEqual(actual, expect)


    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
    def test_fftshift_numpy(self, device, dtype):
        test_args = [
            # shape, dim
            *product(((11,), (12,)), (None, 0, -1)),
            *product(((4, 5), (6, 6)), (None, 0, (-1,))),
            *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))),
        ]

        functions = ['fftshift', 'ifftshift']

        for shape, dim in test_args:
            input = torch.rand(*shape, device=device, dtype=dtype)
            input_np = input.cpu().numpy()

            for fname in functions:
                torch_fn = getattr(torch.fft, fname)
                numpy_fn = getattr(np.fft, fname)

                expected = numpy_fn(input_np, axes=dim)
                actual = torch_fn(input, dim=dim)
                self.assertEqual(actual, expected)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
    @dtypes(torch.float, torch.double)
    def test_fftshift_frequencies(self, device, dtype):
        for n in range(10, 15):
            sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2),
                                            device=device, dtype=dtype)
            x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype)

            # Test fftshift sorts the fftfreq output
            shifted = torch.fft.fftshift(x)
            self.assertEqual(shifted, shifted.sort().values)
            self.assertEqual(sorted_fft_freqs, shifted)

            # And ifftshift is the inverse
            self.assertEqual(x, torch.fft.ifftshift(shifted))

    # Legacy fft tests
    def _test_fft_ifft_rfft_irfft(self, device, dtype):
        complex_dtype = {
            torch.float16: torch.complex32,
            torch.float32: torch.complex64,
            torch.float64: torch.complex128
        }[dtype]

        def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
            x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device))
            dim = tuple(range(-signal_ndim, 0))
            for norm in ('ortho', None):
                res = torch.fft.fftn(x, dim=dim, norm=norm)
                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft')
                res = torch.fft.ifftn(x, dim=dim, norm=norm)
                rec = torch.fft.fftn(res, dim=dim, norm=norm)
                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft')

        def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
            x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device))
            signal_numel = 1
            signal_sizes = x.size()[-signal_ndim:]
            dim = tuple(range(-signal_ndim, 0))
            for norm in (None, 'ortho'):
                res = torch.fft.rfftn(x, dim=dim, norm=norm)
                rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm)
                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft')
                res = torch.fft.fftn(x, dim=dim, norm=norm)
                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
                x_complex = torch.complex(x, torch.zeros_like(x))
                self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)')

        # contiguous case
        _test_real((100,), 1)
        _test_real((10, 1, 10, 100), 1)
        _test_real((100, 100), 2)
        _test_real((2, 2, 5, 80, 60), 2)
        _test_real((50, 40, 70), 3)
        _test_real((30, 1, 50, 25, 20), 3)

        _test_complex((100,), 1)
        _test_complex((100, 100), 1)
        _test_complex((100, 100), 2)
        _test_complex((1, 20, 80, 60), 2)
        _test_complex((50, 40, 70), 3)
        _test_complex((6, 5, 50, 25, 20), 3)

        # non-contiguous case
        _test_real((165,), 1, lambda x: x.narrow(0, 25, 100))  # input is not aligned to complex type
        _test_real((100, 100, 3), 1, lambda x: x[:, :, 0])
        _test_real((100, 100), 2, lambda x: x.t())
        _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60])
        _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80])
        _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3))

        _test_complex((100,), 1, lambda x: x.expand(100, 100))
        _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100))
        _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
        _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @dtypes(torch.double)
    def test_fft_ifft_rfft_irfft(self, device, dtype):
        self._test_fft_ifft_rfft_irfft(device, dtype)

    @deviceCountAtLeast(1)
    @onlyCUDA
    @dtypes(torch.double)
    def test_cufft_plan_cache(self, devices, dtype):
        @contextmanager
        def plan_cache_max_size(device, n):
            if device is None:
                plan_cache = torch.backends.cuda.cufft_plan_cache
            else:
                plan_cache = torch.backends.cuda.cufft_plan_cache[device]
            original = plan_cache.max_size
            plan_cache.max_size = n
            yield
            plan_cache.max_size = original

        with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
            self._test_fft_ifft_rfft_irfft(devices[0], dtype)

        with plan_cache_max_size(devices[0], 0):
            self._test_fft_ifft_rfft_irfft(devices[0], dtype)

        torch.backends.cuda.cufft_plan_cache.clear()

        # check that stll works after clearing cache
        with plan_cache_max_size(devices[0], 10):
            self._test_fft_ifft_rfft_irfft(devices[0], dtype)

        with self.assertRaisesRegex(RuntimeError, r"must be non-negative"):
            torch.backends.cuda.cufft_plan_cache.max_size = -1

        with self.assertRaisesRegex(RuntimeError, r"read-only property"):
            torch.backends.cuda.cufft_plan_cache.size = -1

        with self.assertRaisesRegex(RuntimeError, r"but got device with index"):
            torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10]

        # Multigpu tests
        if len(devices) > 1:
            # Test that different GPU has different cache
            x0 = torch.randn(2, 3, 3, device=devices[0])
            x1 = x0.to(devices[1])
            self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1)))
            # If a plan is used across different devices, the following line (or
            # the assert above) would trigger illegal memory access. Other ways
            # to trigger the error include
            #   (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and
            #   (2) printing a device 1 tensor.
            x0.copy_(x1)

            # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device
            with plan_cache_max_size(devices[0], 10):
                with plan_cache_max_size(devices[1], 11):
                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)

                    self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
                    with torch.cuda.device(devices[1]):
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
                        with torch.cuda.device(devices[0]):
                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0

                self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
                with torch.cuda.device(devices[1]):
                    with plan_cache_max_size(None, 11):  # default is cuda:1
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)

                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
                        with torch.cuda.device(devices[0]):
                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1

    # passes on ROCm w/ python 2.7, fails w/ python 3.6
    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @dtypes(torch.double)
    def test_stft(self, device, dtype):
        if not TEST_LIBROSA:
            raise unittest.SkipTest('librosa not found')

        def librosa_stft(x, n_fft, hop_length, win_length, window, center):
            if window is None:
                window = np.ones(n_fft if win_length is None else win_length)
            else:
                window = window.cpu().numpy()
            input_1d = x.dim() == 1
            if input_1d:
                x = x.view(1, -1)

            # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding)
            # however, we use the pre-0.9 default ('reflect')
            pad_mode = 'reflect'

            result = []
            for xi in x:
                ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
                                  win_length=win_length, window=window, center=center,
                                  pad_mode=pad_mode)
                result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
            result = torch.stack(result, 0)
            if input_1d:
                result = result[0]
            return result

        def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
                  center=True, expected_error=None):
            x = torch.randn(*sizes, dtype=dtype, device=device)
            if win_sizes is not None:
                window = torch.randn(*win_sizes, dtype=dtype, device=device)
            else:
                window = None
            if expected_error is None:
                result = x.stft(n_fft, hop_length, win_length, window,
                                center=center, return_complex=False)
                # NB: librosa defaults to np.complex64 output, no matter what
                # the input dtype
                ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
                self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False)
                # With return_complex=True, the result is the same but viewed as complex instead of real
                result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True)
                self.assertEqual(result_complex, torch.view_as_complex(result))
            else:
                self.assertRaises(expected_error,
                                  lambda: x.stft(n_fft, hop_length, win_length, window, center=center))

        for center in [True, False]:
            _test((10,), 7, center=center)
            _test((10, 4000), 1024, center=center)

            _test((10,), 7, 2, center=center)
            _test((10, 4000), 1024, 512, center=center)

            _test((10,), 7, 2, win_sizes=(7,), center=center)
            _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)

            # spectral oversample
            _test((10,), 7, 2, win_length=5, center=center)
            _test((10, 4000), 1024, 512, win_length=100, center=center)

        _test((10, 4, 2), 1, 1, expected_error=RuntimeError)
        _test((10,), 11, 1, center=False, expected_error=RuntimeError)
        _test((10,), -1, 1, expected_error=RuntimeError)
        _test((10,), 3, win_length=5, expected_error=RuntimeError)
        _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
        _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)

    @skipCPUIfNoFFT
    @onlyNativeDeviceTypes
    @dtypes(torch.double)
    def test_istft_against_librosa(self, device, dtype):
        if not TEST_LIBROSA:
            raise unittest.SkipTest('librosa not found')

        def librosa_istft(x, n_fft, hop_length, win_length, window, length, center):
            if window is None:
                window = np.ones(n_fft if win_length is None else win_length)
            else:
                window = window.cpu().numpy()

            return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
                                 win_length=win_length, length=length, window=window, center=center)

        def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None,
                  length=None, center=True):
            x = torch.randn(size, dtype=dtype, device=device)
            if win_sizes is not None:
                window = torch.randn(*win_sizes, dtype=dtype, device=device)
            else:
                window = None

            x_stft = x.stft(n_fft, hop_length, win_length, window, center=center,
                            onesided=True, return_complex=True)

            ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length,
                                       window, length, center)
            result = x_stft.istft(n_fft, hop_length, win_length, window,
                                  length=length, center=center)
            self.assertEqual(result, ref_result)

        for center in [True, False]:
            _test(10, 7, center=center)
            _test(4000, 1024, center=center)
            _test(4000, 1024, center=center, length=4000)

            _test(10, 7, 2, center=center)
            _test(4000, 1024, 512, center=center)
            _test(4000, 1024, 512, center=center, length=4000)

            _test(10, 7, 2, win_sizes=(7,), center=center)
            _test(4000, 1024, 512, win_sizes=(1024,), center=center)
            _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.double, torch.cdouble)
    def test_complex_stft_roundtrip(self, device, dtype):
        test_args = list(product(
            # input
            (torch.randn(600, device=device, dtype=dtype),
             torch.randn(807, device=device, dtype=dtype),
             torch.randn(12, 60, device=device, dtype=dtype)),
            # n_fft
            (50, 27),
            # hop_length
            (None, 10),
            # center
            (True,),
            # pad_mode
            ("constant", "reflect", "circular"),
            # normalized
            (True, False),
            # onesided
            (True, False) if not dtype.is_complex else (False,),
        ))

        for args in test_args:
            x, n_fft, hop_length, center, pad_mode, normalized, onesided = args
            common_kwargs = {
                'n_fft': n_fft, 'hop_length': hop_length, 'center': center,
                'normalized': normalized, 'onesided': onesided,
            }

            # Functional interface
            x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs)
            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
                                      length=x.size(-1), **common_kwargs)
            self.assertEqual(x_roundtrip, x)

            # Tensor method interface
            x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs)
            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
                                      length=x.size(-1), **common_kwargs)
            self.assertEqual(x_roundtrip, x)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.double, torch.cdouble)
    def test_stft_roundtrip_complex_window(self, device, dtype):
        test_args = list(product(
            # input
            (torch.randn(600, device=device, dtype=dtype),
             torch.randn(807, device=device, dtype=dtype),
             torch.randn(12, 60, device=device, dtype=dtype)),
            # n_fft
            (50, 27),
            # hop_length
            (None, 10),
            # pad_mode
            ("constant", "reflect", "replicate", "circular"),
            # normalized
            (True, False),
        ))
        for args in test_args:
            x, n_fft, hop_length, pad_mode, normalized = args
            window = torch.rand(n_fft, device=device, dtype=torch.cdouble)
            x_stft = torch.stft(
                x, n_fft=n_fft, hop_length=hop_length, window=window,
                center=True, pad_mode=pad_mode, normalized=normalized)
            self.assertEqual(x_stft.dtype, torch.cdouble)
            self.assertEqual(x_stft.size(-2), n_fft)  # Not onesided

            x_roundtrip = torch.istft(
                x_stft, n_fft=n_fft, hop_length=hop_length, window=window,
                center=True, normalized=normalized, length=x.size(-1),
                return_complex=True)
            self.assertEqual(x_stft.dtype, torch.cdouble)

            if not dtype.is_complex:
                self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag),
                                 atol=1e-6, rtol=0)
                self.assertEqual(x_roundtrip.real, x)
            else:
                self.assertEqual(x_roundtrip, x)


    @skipCPUIfNoFFT
    @dtypes(torch.cdouble)
    def test_complex_stft_definition(self, device, dtype):
        test_args = list(product(
            # input
            (torch.randn(600, device=device, dtype=dtype),
             torch.randn(807, device=device, dtype=dtype)),
            # n_fft
            (50, 27),
            # hop_length
            (10, 15)
        ))

        for args in test_args:
            window = torch.randn(args[1], device=device, dtype=dtype)
            expected = _stft_reference(args[0], args[2], window)
            actual = torch.stft(*args, window=window, center=False)
            self.assertEqual(actual, expected)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.cdouble)
    def test_complex_stft_real_equiv(self, device, dtype):
        test_args = list(product(
            # input
            (torch.rand(600, device=device, dtype=dtype),
             torch.rand(807, device=device, dtype=dtype),
             torch.rand(14, 50, device=device, dtype=dtype),
             torch.rand(6, 51, device=device, dtype=dtype)),
            # n_fft
            (50, 27),
            # hop_length
            (None, 10),
            # win_length
            (None, 20),
            # center
            (False, True),
            # pad_mode
            ("constant", "reflect", "circular"),
            # normalized
            (True, False),
        ))

        for args in test_args:
            x, n_fft, hop_length, win_length, center, pad_mode, normalized = args
            expected = _complex_stft(x, n_fft, hop_length=hop_length,
                                     win_length=win_length, pad_mode=pad_mode,
                                     center=center, normalized=normalized)
            actual = torch.stft(x, n_fft, hop_length=hop_length,
                                win_length=win_length, pad_mode=pad_mode,
                                center=center, normalized=normalized)
            self.assertEqual(expected, actual)

    @skipCPUIfNoFFT
    @dtypes(torch.cdouble)
    def test_complex_istft_real_equiv(self, device, dtype):
        test_args = list(product(
            # input
            (torch.rand(40, 20, device=device, dtype=dtype),
             torch.rand(25, 1, device=device, dtype=dtype),
             torch.rand(4, 20, 10, device=device, dtype=dtype)),
            # hop_length
            (None, 10),
            # center
            (False, True),
            # normalized
            (True, False),
        ))

        for args in test_args:
            x, hop_length, center, normalized = args
            n_fft = x.size(-2)
            expected = _complex_istft(x, n_fft, hop_length=hop_length,
                                      center=center, normalized=normalized)
            actual = torch.istft(x, n_fft, hop_length=hop_length,
                                 center=center, normalized=normalized,
                                 return_complex=True)
            self.assertEqual(expected, actual)

    @skipCPUIfNoFFT
    def test_complex_stft_onesided(self, device):
        # stft of complex input cannot be onesided
        for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2):
            x = torch.rand(100, device=device, dtype=x_dtype)
            window = torch.rand(10, device=device, dtype=window_dtype)

            if x_dtype.is_complex or window_dtype.is_complex:
                with self.assertRaisesRegex(RuntimeError, 'complex'):
                    x.stft(10, window=window, pad_mode='constant', onesided=True)
            else:
                y = x.stft(10, window=window, pad_mode='constant', onesided=True,
                           return_complex=True)
                self.assertEqual(y.dtype, torch.cdouble)
                self.assertEqual(y.size(), (6, 51))

        x = torch.rand(100, device=device, dtype=torch.cdouble)
        with self.assertRaisesRegex(RuntimeError, 'complex'):
            x.stft(10, pad_mode='constant', onesided=True)

    # stft is currently warning that it requires return-complex while an upgrader is written
    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    def test_stft_requires_complex(self, device):
        x = torch.rand(100)
        y = x.stft(10, pad_mode='constant')
        # with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
        #     y = x.stft(10, pad_mode='constant')

    @skipCPUIfNoFFT
    def test_fft_input_modification(self, device):
        # FFT functions should not modify their input (gh-34551)

        signal = torch.ones((2, 2, 2), device=device)
        signal_copy = signal.clone()
        spectrum = torch.fft.fftn(signal, dim=(-2, -1))
        self.assertEqual(signal, signal_copy)

        spectrum_copy = spectrum.clone()
        _ = torch.fft.ifftn(spectrum, dim=(-2, -1))
        self.assertEqual(spectrum, spectrum_copy)

        half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1))
        self.assertEqual(signal, signal_copy)

        half_spectrum_copy = half_spectrum.clone()
        _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1))
        self.assertEqual(half_spectrum, half_spectrum_copy)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    def test_fft_plan_repeatable(self, device):
        # Regression test for gh-58724 and gh-63152
        for n in [2048, 3199, 5999]:
            a = torch.randn(n, device=device, dtype=torch.complex64)
            res1 = torch.fft.fftn(a)
            res2 = torch.fft.fftn(a.clone())
            self.assertEqual(res1, res2)

            a = torch.randn(n, device=device, dtype=torch.float64)
            res1 = torch.fft.rfft(a)
            res2 = torch.fft.rfft(a.clone())
            self.assertEqual(res1, res2)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.double)
    def test_istft_round_trip_simple_cases(self, device, dtype):
        """stft -> istft should recover the original signale"""
        def _test(input, n_fft, length):
            stft = torch.stft(input, n_fft=n_fft, return_complex=True)
            inverse = torch.istft(stft, n_fft=n_fft, length=length)
            self.assertEqual(input, inverse, exact_dtype=True)

        _test(torch.ones(4, dtype=dtype, device=device), 4, 4)
        _test(torch.zeros(4, dtype=dtype, device=device), 4, 4)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.double)
    def test_istft_round_trip_various_params(self, device, dtype):
        """stft -> istft should recover the original signale"""
        def _test_istft_is_inverse_of_stft(stft_kwargs):
            # generates a random sound signal for each tril and then does the stft/istft
            # operation to check whether we can reconstruct signal
            data_sizes = [(2, 20), (3, 15), (4, 10)]
            num_trials = 100
            istft_kwargs = stft_kwargs.copy()
            del istft_kwargs['pad_mode']
            for sizes in data_sizes:
                for i in range(num_trials):
                    original = torch.randn(*sizes, dtype=dtype, device=device)
                    stft = torch.stft(original, return_complex=True, **stft_kwargs)
                    inversed = torch.istft(stft, length=original.size(1), **istft_kwargs)
                    self.assertEqual(
                        inversed, original, msg='istft comparison against original',
                        atol=7e-6, rtol=0, exact_dtype=True)

        patterns = [
            # hann_window, centered, normalized, onesided
            {
                'n_fft': 12,
                'hop_length': 4,
                'win_length': 12,
                'window': torch.hann_window(12, dtype=dtype, device=device),
                'center': True,
                'pad_mode': 'reflect',
                'normalized': True,
                'onesided': True,
            },
            # hann_window, centered, not normalized, not onesided
            {
                'n_fft': 12,
                'hop_length': 2,
                'win_length': 8,
                'window': torch.hann_window(8, dtype=dtype, device=device),
                'center': True,
                'pad_mode': 'reflect',
                'normalized': False,
                'onesided': False,
            },
            # hamming_window, centered, normalized, not onesided
            {
                'n_fft': 15,
                'hop_length': 3,
                'win_length': 11,
                'window': torch.hamming_window(11, dtype=dtype, device=device),
                'center': True,
                'pad_mode': 'constant',
                'normalized': True,
                'onesided': False,
            },
            # hamming_window, centered, not normalized, onesided
            # window same size as n_fft
            {
                'n_fft': 5,
                'hop_length': 2,
                'win_length': 5,
                'window': torch.hamming_window(5, dtype=dtype, device=device),
                'center': True,
                'pad_mode': 'constant',
                'normalized': False,
                'onesided': True,
            },
        ]
        for i, pattern in enumerate(patterns):
            _test_istft_is_inverse_of_stft(pattern)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.double)
    def test_istft_round_trip_with_padding(self, device, dtype):
        """long hop_length or not centered may cause length mismatch in the inversed signal"""
        def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs):
            # generates a random sound signal for each tril and then does the stft/istft
            # operation to check whether we can reconstruct signal
            num_trials = 100
            sizes = stft_kwargs['size']
            del stft_kwargs['size']
            istft_kwargs = stft_kwargs.copy()
            del istft_kwargs['pad_mode']
            for i in range(num_trials):
                original = torch.randn(*sizes, dtype=dtype, device=device)
                stft = torch.stft(original, return_complex=True, **stft_kwargs)
                with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."):
                    inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs)
                n_frames = stft.size(-1)
                if stft_kwargs["center"] is True:
                    len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1)
                else:
                    len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1)
                # trim the original for case when constructed signal is shorter than original
                padding = inversed[..., len_expected:]
                inversed = inversed[..., :len_expected]
                original = original[..., :len_expected]
                # test the padding points of the inversed signal are all zeros
                zeros = torch.zeros_like(padding, device=padding.device)
                self.assertEqual(
                    padding, zeros, msg='istft padding values against zeros',
                    atol=7e-6, rtol=0, exact_dtype=True)
                self.assertEqual(
                    inversed, original, msg='istft comparison against original',
                    atol=7e-6, rtol=0, exact_dtype=True)

        patterns = [
            # hamming_window, not centered, not normalized, not onesided
            # window same size as n_fft
            {
                'size': [2, 20],
                'n_fft': 3,
                'hop_length': 2,
                'win_length': 3,
                'window': torch.hamming_window(3, dtype=dtype, device=device),
                'center': False,
                'pad_mode': 'reflect',
                'normalized': False,
                'onesided': False,
            },
            # hamming_window, centered, not normalized, onesided, long hop_length
            # window same size as n_fft
            {
                'size': [2, 500],
                'n_fft': 256,
                'hop_length': 254,
                'win_length': 256,
                'window': torch.hamming_window(256, dtype=dtype, device=device),
                'center': True,
                'pad_mode': 'constant',
                'normalized': False,
                'onesided': True,
            },
        ]
        for i, pattern in enumerate(patterns):
            _test_istft_is_inverse_of_stft_with_padding(pattern)

    @onlyNativeDeviceTypes
    def test_istft_throws(self, device):
        """istft should throw exception for invalid parameters"""
        stft = torch.zeros((3, 5, 2), device=device)
        # the window is size 1 but it hops 20 so there is a gap which throw an error
        self.assertRaises(
            RuntimeError, torch.istft, stft, n_fft=4,
            hop_length=20, win_length=1, window=torch.ones(1))
        # A window of zeros does not meet NOLA
        invalid_window = torch.zeros(4, device=device)
        self.assertRaises(
            RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window)
        # Input cannot be empty
        self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2)
        self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.double)
    def test_istft_of_sine(self, device, dtype):
        def _test(amplitude, L, n):
            # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
            x = torch.arange(2 * L + 1, device=device, dtype=dtype)
            original = amplitude * torch.sin(2 * math.pi / L * x * n)
            # stft = torch.stft(original, L, hop_length=L, win_length=L,
            #                   window=torch.ones(L), center=False, normalized=False)
            stft = torch.zeros((L // 2 + 1, 2, 2), device=device, dtype=dtype)
            stft_largest_val = (amplitude * L) / 2.0
            if n < stft.size(0):
                stft[n, :, 1] = -stft_largest_val

            if 0 <= L - n < stft.size(0):
                # symmetric about L // 2
                stft[L - n, :, 1] = stft_largest_val

            inverse = torch.istft(
                stft, L, hop_length=L, win_length=L,
                window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False)
            # There is a larger error due to the scaling of amplitude
            original = original[..., :inverse.size(-1)]
            self.assertEqual(inverse, original, atol=1e-3, rtol=0)

        _test(amplitude=123, L=5, n=1)
        _test(amplitude=150, L=5, n=2)
        _test(amplitude=111, L=5, n=3)
        _test(amplitude=160, L=7, n=4)
        _test(amplitude=145, L=8, n=5)
        _test(amplitude=80, L=9, n=6)
        _test(amplitude=99, L=10, n=7)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    @dtypes(torch.double)
    def test_istft_linearity(self, device, dtype):
        num_trials = 100

        def _test(data_size, kwargs):
            for i in range(num_trials):
                tensor1 = torch.randn(data_size, device=device, dtype=dtype)
                tensor2 = torch.randn(data_size, device=device, dtype=dtype)
                a, b = torch.rand(2, dtype=dtype, device=device)
                # Also compare method vs. functional call signature
                istft1 = tensor1.istft(**kwargs)
                istft2 = tensor2.istft(**kwargs)
                istft = a * istft1 + b * istft2
                estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs)
                self.assertEqual(istft, estimate, atol=1e-5, rtol=0)
        patterns = [
            # hann_window, centered, normalized, onesided
            (
                (2, 7, 7, 2),
                {
                    'n_fft': 12,
                    'window': torch.hann_window(12, device=device, dtype=dtype),
                    'center': True,
                    'normalized': True,
                    'onesided': True,
                },
            ),
            # hann_window, centered, not normalized, not onesided
            (
                (2, 12, 7, 2),
                {
                    'n_fft': 12,
                    'window': torch.hann_window(12, device=device, dtype=dtype),
                    'center': True,
                    'normalized': False,
                    'onesided': False,
                },
            ),
            # hamming_window, centered, normalized, not onesided
            (
                (2, 12, 7, 2),
                {
                    'n_fft': 12,
                    'window': torch.hamming_window(12, device=device, dtype=dtype),
                    'center': True,
                    'normalized': True,
                    'onesided': False,
                },
            ),
            # hamming_window, not centered, not normalized, onesided
            (
                (2, 7, 3, 2),
                {
                    'n_fft': 12,
                    'window': torch.hamming_window(12, device=device, dtype=dtype),
                    'center': False,
                    'normalized': False,
                    'onesided': True,
                },
            )
        ]
        for data_size, kwargs in patterns:
            _test(data_size, kwargs)

    @onlyNativeDeviceTypes
    @skipCPUIfNoFFT
    def test_batch_istft(self, device):
        original = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
        ], device=device)

        single = original.repeat(1, 1, 1, 1)
        multi = original.repeat(4, 1, 1, 1)

        i_original = torch.istft(original, n_fft=4, length=4)
        i_single = torch.istft(single, n_fft=4, length=4)
        i_multi = torch.istft(multi, n_fft=4, length=4)

        self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True)
        self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True)

    @onlyCUDA
    @skipIf(not TEST_MKL, "Test requires MKL")
    def test_stft_window_device(self, device):
        # Test the (i)stft window must be on the same device as the input
        x = torch.randn(1000, dtype=torch.complex64)
        window = torch.randn(100, dtype=torch.complex64)

        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
            torch.stft(x, n_fft=100, window=window.to(device))

        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
            torch.stft(x.to(device), n_fft=100, window=window)

        X = torch.stft(x, n_fft=100, window=window)

        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
            torch.istft(X, n_fft=100, window=window.to(device))

        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
            torch.istft(x.to(device), n_fft=100, window=window)


class FFTDocTestFinder:
    '''The default doctest finder doesn't like that function.__module__ doesn't
    match torch.fft. It assumes the functions are leaked imports.
    '''
    def __init__(self):
        self.parser = doctest.DocTestParser()

    def find(self, obj, name=None, module=None, globs=None, extraglobs=None):
        doctests = []

        modname = name if name is not None else obj.__name__
        globs = {} if globs is None else globs

        for fname in obj.__all__:
            func = getattr(obj, fname)
            if inspect.isroutine(func):
                qualname = modname + '.' + fname
                docstring = inspect.getdoc(func)
                if docstring is None:
                    continue

                examples = self.parser.get_doctest(
                    docstring, globs=globs, name=fname, filename=None, lineno=None)
                doctests.append(examples)

        return doctests


class TestFFTDocExamples(TestCase):
    pass

def generate_doc_test(doc_test):
    def test(self, device):
        self.assertEqual(device, 'cpu')
        runner = doctest.DocTestRunner()
        runner.run(doc_test)

        if runner.failures != 0:
            runner.summarize()
            self.fail('Doctest failed')

    setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test))

for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
    generate_doc_test(doc_test)


instantiate_device_type_tests(TestFFT, globals())
instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu')

if __name__ == '__main__':
    run_tests()