1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
from libmpi import ffi, lib
def ring(comm, count=1, loop=1, skip=0):
size_p = ffi.new("int*")
rank_p = ffi.new("int*")
lib.MPI_Comm_size(comm, size_p)
lib.MPI_Comm_rank(comm, rank_p)
size = size_p[0]
rank = rank_p[0]
source = (rank - 1) % size
dest = (rank + 1) % size
sbuf = ffi.new("unsigned char[]", [42] * count)
rbuf = ffi.new("unsigned char[]", [0] * count)
iterations = list(range(loop + skip))
if size == 1:
for i in iterations:
if i == skip:
tic = lib.MPI_Wtime()
lib.MPI_Sendrecv(
sbuf,
count,
lib.MPI_BYTE,
dest,
0,
rbuf,
count,
lib.MPI_BYTE,
source,
0,
comm,
lib.MPI_STATUS_IGNORE,
)
else:
if rank == 0:
for i in iterations:
if i == skip:
tic = lib.MPI_Wtime()
lib.MPI_Send(sbuf, count, lib.MPI_BYTE, dest, 0, comm)
lib.MPI_Recv(
rbuf,
count,
lib.MPI_BYTE,
source,
0,
comm,
lib.MPI_STATUS_IGNORE,
)
else:
sbuf = rbuf
for i in iterations:
if i == skip:
tic = lib.MPI_Wtime()
lib.MPI_Recv(
rbuf,
count,
lib.MPI_BYTE,
source,
0,
comm,
lib.MPI_STATUS_IGNORE,
)
lib.MPI_Send(sbuf, count, lib.MPI_BYTE, dest, 0, comm)
toc = lib.MPI_Wtime()
if rank == 0 and ffi.string(sbuf) != ffi.string(rbuf):
import traceback
import warnings
try:
warnings.warn("received message does not match!", stacklevel=2)
except UserWarning:
traceback.print_exc()
lib.MPI_Abort(comm, 2)
return toc - tic
def ringtest(comm):
size = 1
loop = 1
skip = 0
lib.MPI_Barrier(comm)
elapsed = ring(comm, size, loop, skip)
size_p = ffi.new("int*")
rank_p = ffi.new("int*")
lib.MPI_Comm_size(comm, size_p)
lib.MPI_Comm_rank(comm, rank_p)
comm_size = size_p[0]
comm_rank = rank_p[0]
if comm_rank == 0:
print(
f"time for {loop} loops = {elapsed:g} seconds "
f"({comm_size:d} processes, {size:d} bytes)"
)
def main():
lib.MPI_Init(ffi.NULL, ffi.NULL)
ringtest(lib.MPI_COMM_WORLD)
lib.MPI_Finalize()
if __name__ == "__main__":
main()
|