File: reductions.py

package info (click to toggle)
mpi4py 4.0.3-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 4,196 kB
  • sloc: python: 32,170; ansic: 13,449; makefile: 602; sh: 314; f90: 178; cpp: 148
file content (103 lines) | stat: -rw-r--r-- 3,095 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from mpi4py import MPI

class Intracomm(MPI.Intracomm):
    """
    Intracommunicator class with scalable, point-to-point based
    implementations of global reduction operations.
    """

    def __new__(cls, comm=None):
        return super().__new__(cls, comm)

    def reduce(self, sendobj=None, recvobj=None, op=MPI.SUM, root=0):
        size = self.size
        rank = self.rank
        assert 0 <= root < size
        tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1

        recvobj = sendobj
        mask = 1

        while mask < size:
            if (mask & rank) != 0:
                target = (rank & ~mask) % size
                self.send(recvobj, dest=target, tag=tag)
            else:
                target = (rank | mask)
                if target < size:
                    tmp = self.recv(None, source=target, tag=tag)
                    recvobj = op(recvobj, tmp)
            mask <<= 1

        if root != 0:
            if rank == 0:
                self.send(recvobj, dest=root, tag=tag)
            elif rank == root:
                recvobj = self.recv(None, source=0, tag=tag)

        if rank != root:
            recvobj = None

        return recvobj

    def allreduce(self, sendobj=None, recvobj=None, op=MPI.SUM):
        recvobj = self.reduce(sendobj, recvobj, op, 0)
        recvobj = self.bcast(recvobj, 0)
        return recvobj

    def scan(self, sendobj=None, recvobj=None, op=MPI.SUM):
        size = self.size
        rank = self.rank
        tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1

        recvobj = sendobj
        partial = sendobj
        mask = 1

        while mask < size:
            target = rank ^ mask
            if target < size:
                tmp = self.sendrecv(partial, dest=target, source=target,
                                    sendtag=tag, recvtag=tag)
                if rank > target:
                    partial = op(tmp, partial)
                    recvobj = op(tmp, recvobj)
                else:
                    tmp = op(partial, tmp)
                    partial = tmp
            mask <<= 1

        return recvobj

    def exscan(self, sendobj=None, recvobj=None, op=MPI.SUM):
        size = self.size
        rank = self.rank
        tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1

        recvobj = sendobj
        partial = sendobj
        mask = 1
        flag = False

        while mask < size:
            target = rank ^ mask
            if target < size:
                tmp = self.sendrecv(partial, dest=target, source=target,
                                    sendtag=tag, recvtag=tag)
                if rank > target:
                    partial = op(tmp, partial)
                    if rank != 0:
                        if not flag:
                            recvobj = tmp
                            flag = True
                        else:
                            recvobj = op(tmp, recvobj)
                else:
                    tmp = op(partial, tmp)
                    partial = tmp
            mask <<= 1

        if rank == 0:
            recvobj = None

        return recvobj