File: test_darray.py

package info (click to toggle)
mpi4py-fft 2.0.6-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 720 kB
  • sloc: python: 3,053; ansic: 87; makefile: 42; sh: 33
file content (137 lines) | stat: -rw-r--r-- 5,000 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
from mpi4py import MPI
from mpi4py_fft import DistArray, newDistArray, PFFT
from mpi4py_fft.pencil import Subcomm

comm = MPI.COMM_WORLD

def test_1Darray():
    N = (8,)
    z = DistArray(N, val=2)
    assert z[0] == 2
    assert z.shape == N

def test_2Darray():
    N = (8, 8)
    for subcomm in ((0, 1), (1, 0), None, Subcomm(comm, (0, 1))):
        for rank in (0, 1, 2):
            M = (2,)*rank + N
            alignment = None
            if subcomm is None and rank == 1:
                alignment = 1
            a = DistArray(M, subcomm=subcomm, val=1, rank=rank, alignment=alignment)
            assert a.rank == rank
            assert a.global_shape == M
            _ = a.substart
            c = a.subcomm
            z = a.commsizes
            _ = a.pencil
            assert np.prod(np.array(z)) == comm.Get_size()
            if rank > 0:
                a0 = a[0]
                assert isinstance(a0, DistArray)
                assert a0.rank == rank-1
            aa = a.v
            assert isinstance(aa, np.ndarray)
            try:
                k = a.get((0,)*rank+(0, slice(None)))
                if comm.Get_rank() == 0:
                    assert len(k) == N[1]
                    assert np.sum(k) == N[1]
                k = a.get((0,)*rank+(slice(None), 0))
                if comm.Get_rank() == 0:
                    assert len(k) == N[0]
                    assert np.sum(k) == N[0]
            except ModuleNotFoundError:
                pass
            _ = a.local_slice()
            newaxis = (a.alignment+1)%2
            _ = a.get_pencil_and_transfer(newaxis)
            a[:] = MPI.COMM_WORLD.Get_rank()
            b = a.redistribute(newaxis)
            a = b.redistribute(out=a)
            a = b.redistribute(a.alignment, out=a)
            s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(a)**2)
            s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
            if MPI.COMM_WORLD.Get_rank() == 0:
                assert abs(s0-s1) < 1e-1
            c = a.redistribute(a.alignment)
            assert c is a

def test_3Darray():
    N = (8, 8, 8)
    for subcomm in ((0, 0, 1), (0, 1, 0), (1, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 0), None, Subcomm(comm, (0, 0, 1))):
        for rank in (0, 1, 2):
            M = (3,)*rank + N
            alignment = None
            if subcomm is None and rank == 1:
                alignment = 2
            a = DistArray(M, subcomm=subcomm, val=1, rank=rank, alignment=alignment)
            assert a.rank == rank
            assert a.global_shape == M
            _ = a.substart
            _ = a.subcomm
            z = a.commsizes
            _ = a.pencil
            assert np.prod(np.array(z)) == comm.Get_size()
            if rank > 0:
                a0 = a[0]
                assert isinstance(a0, DistArray)
                assert a0.rank == rank-1
            if rank == 2:
                a0 = a[0, 1]
                assert isinstance(a0, DistArray)
                assert a0.rank == 0
            aa = a.v
            assert isinstance(aa, np.ndarray)
            try:
                k = a.get((0,)*rank+(0, 0, slice(None)))
                if comm.Get_rank() == 0:
                    assert len(k) == N[2]
                    assert np.sum(k) == N[2]
                k = a.get((0,)*rank+(slice(None), 0, 0))
                if comm.Get_rank() == 0:
                    assert len(k) == N[0]
                    assert np.sum(k) == N[0]
            except ModuleNotFoundError:
                pass
            _ = a.local_slice()
            newaxis = (a.alignment+1)%3
            _ = a.get_pencil_and_transfer(newaxis)
            a[:] = MPI.COMM_WORLD.Get_rank()
            b = a.redistribute(newaxis)
            a = b.redistribute(out=a)
            s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(a)**2)
            s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
            if MPI.COMM_WORLD.Get_rank() == 0:
                assert abs(s0-s1) < 1e-1

def test_newDistArray():
    N = (8, 8, 8)
    pfft = PFFT(MPI.COMM_WORLD, N)
    for forward_output in (True, False):
        for view in (True, False):
            for rank in (0, 1, 2):
                a = newDistArray(pfft, forward_output=forward_output,
                                 rank=rank, view=view)
                if view is False:
                    assert isinstance(a, DistArray)
                    assert a.rank == rank
                    if rank == 0:
                        qfft = PFFT(MPI.COMM_WORLD, darray=a)
                    elif rank == 1:
                        qfft = PFFT(MPI.COMM_WORLD, darray=a[0])
                    else:
                        qfft = PFFT(MPI.COMM_WORLD, darray=a[0, 0])
                    qfft.destroy()

                else:
                    assert isinstance(a, np.ndarray)
                    assert a.base.rank == rank
    pfft.destroy()

if __name__ == '__main__':
    test_1Darray()
    test_2Darray()
    test_3Darray()
    test_newDistArray()