File: test_fftw.py

package info (click to toggle)
mpi4py-fft 2.0.6-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 720 kB
  • sloc: python: 3,053; ansic: 87; makefile: 42; sh: 33
file content (164 lines) | stat: -rw-r--r-- 8,359 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from __future__ import print_function
from time import time
import numpy as np
from scipy.fftpack import dctn as scipy_dctn
from scipy.fftpack import dstn as scipy_dstn
import scipy.fftpack # pylint: disable=unused-import
from mpi4py_fft import fftw

has_pyfftw = True
try:
    import pyfftw
except ImportError:
    has_pyfftw = False

abstol = dict(f=5e-4, d=1e-12, g=1e-14)

kinds = {'dst4': fftw.FFTW_RODFT11, # no scipy to compare with
         'dct4': fftw.FFTW_REDFT11, # no scipy to compare with
         'dst3': fftw.FFTW_RODFT01,
         'dct3': fftw.FFTW_REDFT01,
         'dct2': fftw.FFTW_REDFT10,
         'dst2': fftw.FFTW_RODFT10,
         'dct1': fftw.FFTW_REDFT00,
         'dst1': fftw.FFTW_RODFT00}

rkinds = {val: key for key, val in kinds.items()}

def allclose(a, b):
    atol = abstol[a.dtype.char.lower()]
    return np.allclose(a, b, rtol=0, atol=atol)

def test_fftw():
    from itertools import product

    dims = (1, 2, 3)
    sizes = (7, 8, 10)
    types = ''
    for t in 'fdg':
        if fftw.get_fftw_lib(t):
            types += t
    fflags = (fftw.FFTW_ESTIMATE, fftw.FFTW_DESTROY_INPUT)
    iflags = (fftw.FFTW_ESTIMATE, fftw.FFTW_DESTROY_INPUT)

    for threads in (1, 2):
        for typecode in types:
            for dim in dims:
                for shape in product(*([sizes]*dim)):
                    allaxes = tuple(reversed(range(dim)))
                    for i in range(dim):
                        for j in range(i+1, dim):
                            axes = allaxes[i:j]
                            #print(shape, axes, typecode, threads)
                            # r2c - c2r
                            input_array = fftw.aligned(shape, dtype=typecode)
                            outshape = list(shape)
                            outshape[axes[-1]] = shape[axes[-1]]//2+1
                            output_array = fftw.aligned(outshape, dtype=typecode.upper())
                            oa = output_array if typecode == 'd' else None # Test for both types of signature
                            rfftn = fftw.rfftn(input_array, None, axes, threads, fflags, output_array=oa)
                            A = np.random.random(shape).astype(typecode)
                            input_array[:] = A
                            B = rfftn()
                            assert id(B) == id(rfftn.output_array)
                            if has_pyfftw:
                                B2 = pyfftw.interfaces.numpy_fft.rfftn(input_array, axes=axes)
                                assert allclose(B, B2), np.linalg.norm(B-B2)
                            ia = input_array if typecode == 'd' else None
                            sa = np.take(input_array.shape, axes) if shape[axes[-1]] % 2 == 1 else None
                            irfftn = fftw.irfftn(output_array, sa, axes, threads, iflags, output_array=ia)
                            irfftn.input_array[...] = B
                            A2 = irfftn(normalize=True)
                            assert allclose(A, A2), np.linalg.norm(A-A2)
                            hfftn = fftw.hfftn(output_array, sa, axes, threads, fflags, output_array=ia)
                            hfftn.input_array[...] = B
                            AC = hfftn().copy()
                            ihfftn = fftw.ihfftn(input_array, None, axes, threads, iflags, output_array=oa)
                            A2 = ihfftn(AC, implicit=False, normalize=True)
                            assert allclose(A2, B), print(np.linalg.norm(A2-B))

                            # c2c
                            input_array = fftw.aligned(shape, dtype=typecode.upper())
                            output_array = fftw.aligned_like(input_array)
                            oa = output_array if typecode=='d' else None
                            fftn = fftw.fftn(input_array, None, axes, threads, fflags, output_array=oa)
                            C = np.random.random(shape).astype(typecode.upper())
                            fftn.input_array[...] = C
                            D = fftn().copy()
                            ifftn = fftw.ifftn(input_array, None, axes, threads, iflags, output_array=oa)
                            ifftn.input_array[...] = D
                            C2 = ifftn(normalize=True)
                            assert allclose(C, C2), np.linalg.norm(C-C2)
                            if has_pyfftw:
                                D2 = pyfftw.interfaces.numpy_fft.fftn(C, axes=axes)
                                assert allclose(D, D2), np.linalg.norm(D-D2)

                            # r2r
                            input_array = fftw.aligned(shape, dtype=typecode)
                            output_array = fftw.aligned_like(input_array)
                            oa = output_array if typecode =='d' else None
                            for type in (1, 2, 3, 4):
                                dct = fftw.dctn(input_array, None, axes, type, threads, fflags, output_array=oa)
                                B = dct(A).copy()
                                idct = fftw.idctn(input_array, None, axes, type, threads, iflags, output_array=oa)
                                A2 = idct(B, implicit=True, normalize=True)
                                assert allclose(A, A2), np.linalg.norm(A-A2)
                                if typecode != 'g' and type != 4:
                                    B2 = scipy_dctn(A, axes=axes, type=type)
                                    assert allclose(B, B2), np.linalg.norm(B-B2)

                                dst = fftw.dstn(input_array, None, axes, type, threads, fflags, output_array=oa)
                                B = dst(A).copy()
                                idst = fftw.idstn(input_array, None, axes, type, threads, iflags, output_array=oa)
                                A2 = idst(B, implicit=True, normalize=True)
                                assert allclose(A, A2), np.linalg.norm(A-A2)
                                if typecode != 'g' and type != 4:
                                    B2 = scipy_dstn(A, axes=axes, type=type)
                                    assert allclose(B, B2), np.linalg.norm(B-B2)

                            # Different r2r transforms along all axes. Just pick
                            # any naxes transforms and compare with scipy
                            naxes = len(axes)
                            kds = np.random.randint(3, 11, size=naxes) # get naxes random transforms
                            tsf = [rkinds[k] for k in kds]
                            T = fftw.get_planned_FFT(input_array, input_array.copy(), axes=axes,
                                                     kind=kds, threads=threads, flags=fflags)
                            C = T(A)
                            TI = fftw.get_planned_FFT(input_array.copy(), input_array.copy(), axes=axes,
                                                      kind=list([fftw.inverse[kd] for kd in kds]),
                                                      threads=threads, flags=iflags)

                            C2 = TI(C)
                            M = fftw.get_normalization(kds, input_array.shape, axes)
                            assert allclose(C2*M, A)
                            # Test vs scipy for transforms available in scipy
                            if typecode != 'g' and not any(f in kds for f in (fftw.FFTW_RODFT11, fftw.FFTW_REDFT11)):
                                for m, ts in enumerate(tsf):
                                    A = eval('scipy.fftpack.'+ts[:-1])(A, axis=axes[m], type=int(ts[-1]))
                                assert allclose(C, A), np.linalg.norm(C-A)

def test_wisdom():
    # Test a simple export/import call
    fftw.export_wisdom('newwisdom.dat')
    fftw.import_wisdom('newwisdom.dat')
    fftw.forget_wisdom()

def test_timelimit():
    limit = 0.01
    input_array = fftw.aligned((128, 128), dtype='d')
    t0 = time()
    fftw.rfftn(input_array, flags=fftw.FFTW_PATIENT)
    t1 = time()-t0
    fftw.forget_wisdom()
    fftw.set_timelimit(limit)
    t0 = time()
    fftw.rfftn(input_array, flags=fftw.FFTW_PATIENT)
    t2 = time()-t0
    assert t1 > t2
    assert abs(t2-limit) < limit, print(abs(t2-limit), limit)
    fftw.cleanup()

if __name__ == '__main__':
    test_fftw()
    test_wisdom()
    test_timelimit()