1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
from mpi4py import MPI
from array import array as _array
import struct as _struct
# --------------------------------------------------------------------
class Counter:
def __init__(self, comm):
rank = comm.Get_rank()
itemsize = MPI.INT.Get_size()
if rank == 0:
n = 1
else:
n = 0
self.win = MPI.Win.Allocate(n*itemsize, itemsize,
MPI.INFO_NULL, comm)
if rank == 0:
mem = self.win.tomemory()
mem[:] = _struct.pack('i', 0)
def free(self):
self.win.Free()
def next(self, increment=1):
incr = _array('i', [increment])
nval = _array('i', [0])
self.win.Lock(0)
self.win.Get_accumulate([incr, 1, MPI.INT],
[nval, 1, MPI.INT],
0, op=MPI.SUM)
self.win.Unlock(0)
return nval[0]
# -----------------------------------------------------------------------------
class Mutex:
def __init__(self, comm):
self.counter = Counter(comm)
def __enter__(self):
self.lock()
return self
def __exit__(self, *exc):
self.unlock()
return None
def free(self):
self.counter.free()
def lock(self):
value = self.counter.next(+1)
while value != 0:
value = self.counter.next(-1)
value = self.counter.next(+1)
def unlock(self):
self.counter.next(-1)
# -----------------------------------------------------------------------------
def test_counter():
vals = []
counter = Counter(MPI.COMM_WORLD)
for i in range(5):
c = counter.next()
vals.append(c)
counter.free()
vals = MPI.COMM_WORLD.allreduce(vals)
assert sorted(vals) == list(range(len(vals)))
def test_mutex():
mutex = Mutex(MPI.COMM_WORLD)
mutex.lock()
mutex.unlock()
mutex.free()
if __name__ == '__main__':
test_counter()
test_mutex()
# -----------------------------------------------------------------------------
|