1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
from mpi4py import MPI
import numpy as np
__all__ = ('FileBase',)
comm = MPI.COMM_WORLD
class FileBase(object):
"""Base class for reading/writing distributed arrays
Parameters
----------
filename : str, optional
Name of backend file used to store data
domain : sequence, optional
An optional spatial mesh or domain to go with the data.
Sequence of either
- 2-tuples, where each 2-tuple contains the (origin, length)
of each dimension, e.g., (0, 2*pi).
- Arrays of coordinates, e.g., np.linspace(0, 2*pi, N). One
array per dimension.
"""
def __init__(self, filename=None, domain=None):
self.f = None
self.filename = filename
self.domain = domain
def _check_domain(self, group, field):
"""Check dimensions and store (if missing) self.domain"""
raise NotImplementedError
def write(self, step, fields, **kw):
"""Write snapshot ``step`` of ``fields`` to file
Parameters
----------
step : int
Index of snapshot.
fields : dict
The fields to be dumped to file. (key, value) pairs are group name
and either arrays or 2-tuples, respectively. The arrays are complete
arrays to be stored, whereas 2-tuples are arrays with associated
*global* slices.
as_scalar : boolean, optional
Whether to store rank > 0 arrays as scalars. Default is False.
"""
as_scalar = kw.get("as_scalar", False)
def _write(group, u, sl, step, kw, k=None):
if sl is None:
self._write_group(group, u, step, **kw)
else:
self._write_slice_step(group, step, sl, u, **kw)
for group, list_of_fields in fields.items():
assert isinstance(list_of_fields, (tuple, list))
assert isinstance(group, str)
for field in list_of_fields:
u = field[0] if isinstance(field, (tuple, list)) else field
sl = field[1] if isinstance(field, (tuple, list)) else None
if as_scalar is False or u.rank == 0:
self._check_domain(group, u)
_write(group, u, sl, step, kw)
else: # as_scalar is True and u.rank > 0
if u.rank == 1:
for k in range(u.shape[0]):
g = group + str(k)
self._check_domain(g, u[k])
_write(g, u[k], sl, step, kw)
elif u.rank == 2:
for k in range(u.shape[0]):
for l in range(u.shape[1]):
g = group + str(k) + str(l)
self._check_domain(g, u[k, l])
_write(g, u[k, l], sl, step, kw)
def read(self, u, name, **kw):
"""Read field ``name`` into distributed array ``u``
Parameters
----------
u : array
The :class:`.DistArray` to read into.
name : str
Name of field to be read.
step : int, optional
Index of field to be read. Default is 0.
"""
raise NotImplementedError
def close(self):
"""Close the self.filename file"""
self.f.close()
def open(self, mode='r+'):
"""Open the self.filename file for reading or writing
Parameters
----------
mode : str
Open file in this mode. Default is 'r+'.
"""
raise NotImplementedError
@staticmethod
def backend():
"""Return which backend is used to store data"""
raise NotImplementedError
def _write_slice_step(self, name, step, slices, field, **kwargs):
raise NotImplementedError
def _write_group(self, name, u, step, **kwargs):
raise NotImplementedError
@staticmethod
def _get_slice_name(slices):
sl = list(slices)
slname = ''
for ss in sl:
if isinstance(ss, slice):
slname += 'slice_'
else:
slname += str(ss)+'_'
return slname[:-1]
@staticmethod
def _get_local_slices(slices, s):
# Check if data is on this processor and make slices local
inside = 1
si = np.nonzero([isinstance(x, int) and not z == slice(None) for x, z in zip(slices, s)])[0]
for i in si:
if slices[i] >= s[i].start and slices[i] < s[i].stop:
slices[i] -= s[i].start
else:
inside = 0
return slices, inside
|