File: io_utils.py

package info (click to toggle)
keras 2.3.1%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 9,288 kB
  • sloc: python: 48,266; javascript: 1,794; xml: 297; makefile: 36; sh: 30
file content (430 lines) | stat: -rw-r--r-- 14,639 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
"""Utilities related to disk I/O."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from collections import defaultdict
import sys
import contextlib


import six
try:
    import h5py
    HDF5_OBJECT_HEADER_LIMIT = 64512
except ImportError:
    h5py = None


if sys.version_info[0] == 3:
    import pickle
else:
    import cPickle as pickle


class HDF5Matrix(object):
    """Representation of HDF5 dataset to be used instead of a Numpy array.

    # Example

    ```python
        x_data = HDF5Matrix('input/file.hdf5', 'data')
        model.predict(x_data)
    ```

    Providing `start` and `end` allows use of a slice of the dataset.

    Optionally, a normalizer function (or lambda) can be given. This will
    be called on every slice of data retrieved.

    # Arguments
        datapath: string, path to a HDF5 file
        dataset: string, name of the HDF5 dataset in the file specified
            in datapath
        start: int, start of desired slice of the specified dataset
        end: int, end of desired slice of the specified dataset
        normalizer: function to be called on data when retrieved

    # Returns
        An array-like HDF5 dataset.
    """
    refs = defaultdict(int)

    def __init__(self, datapath, dataset, start=0, end=None, normalizer=None):
        if h5py is None:
            raise ImportError('The use of HDF5Matrix requires '
                              'HDF5 and h5py installed.')

        if datapath not in list(self.refs.keys()):
            f = h5py.File(datapath)
            self.refs[datapath] = f
        else:
            f = self.refs[datapath]
        self.data = f[dataset]
        self.start = start
        if end is None:
            self.end = self.data.shape[0]
        else:
            self.end = end
        self.normalizer = normalizer
        if self.normalizer is not None:
            first_val = self.normalizer(self.data[0:1])
        else:
            first_val = self.data[0:1]
        self._base_shape = first_val.shape[1:]
        self._base_dtype = first_val.dtype

    def __len__(self):
        return self.end - self.start

    def __getitem__(self, key):
        if isinstance(key, slice):
            start, stop = key.start, key.stop
            if start is None:
                start = 0
            if stop is None:
                stop = self.shape[0]
            if stop + self.start <= self.end:
                idx = slice(start + self.start, stop + self.start)
            else:
                raise IndexError
        elif isinstance(key, (int, np.integer)):
            if key + self.start < self.end:
                idx = key + self.start
            else:
                raise IndexError
        elif isinstance(key, np.ndarray):
            if np.max(key) + self.start < self.end:
                idx = (self.start + key).tolist()
            else:
                raise IndexError
        else:
            # Assume list/iterable
            if max(key) + self.start < self.end:
                idx = [x + self.start for x in key]
            else:
                raise IndexError
        if self.normalizer is not None:
            return self.normalizer(self.data[idx])
        else:
            return self.data[idx]

    @property
    def shape(self):
        """Gets a numpy-style shape tuple giving the dataset dimensions.

        # Returns
            A numpy-style shape tuple.
        """
        return (self.end - self.start,) + self._base_shape

    @property
    def dtype(self):
        """Gets the datatype of the dataset.

        # Returns
            A numpy dtype string.
        """
        return self._base_dtype

    @property
    def ndim(self):
        """Gets the number of dimensions (rank) of the dataset.

        # Returns
            An integer denoting the number of dimensions (rank) of the dataset.
        """
        return self.data.ndim

    @property
    def size(self):
        """Gets the total dataset size (number of elements).

        # Returns
            An integer denoting the number of elements in the dataset.
        """
        return np.prod(self.shape)


def ask_to_proceed_with_overwrite(filepath):
    """Produces a prompt asking about overwriting a file.

    # Arguments
        filepath: the path to the file to be overwritten.

    # Returns
        True if we can proceed with overwrite, False otherwise.
    """
    overwrite = six.moves.input('[WARNING] %s already exists - overwrite? '
                                '[y/n]' % (filepath)).strip().lower()
    while overwrite not in ('y', 'n'):
        overwrite = six.moves.input('Enter "y" (overwrite) or "n" '
                                    '(cancel).').strip().lower()
    if overwrite == 'n':
        return False
    print('[TIP] Next time specify overwrite=True!')
    return True


class H5Dict(object):
    """ A dict-like wrapper around h5py groups (or dicts).

    This allows us to have a single serialization logic
    for both pickling and saving to disk.

    Note: This is not intended to be a generic wrapper.
    There are lot of edge cases which have been hardcoded,
    and makes sense only in the context of model serialization/
    deserialization.

    # Arguments
        path: Either a string (path on disk), a Path, a dict, or a HDF5 Group.
        mode: File open mode (one of `{"a", "r", "w"}`).
    """

    def __init__(self, path, mode='a'):
        if isinstance(path, h5py.Group):
            self.data = path
            self._is_file = False
        elif isinstance(path, six.string_types) or _is_path_instance(path):
            self.data = h5py.File(path, mode=mode)
            self._is_file = True
        elif isinstance(path, dict):
            self.data = path
            self._is_file = False
            if mode == 'w':
                self.data.clear()
            # Flag to check if a dict is user defined data or a sub group:
            self.data['_is_group'] = True
        else:
            raise TypeError('Required Group, str, Path or dict. '
                            'Received: {}.'.format(type(path)))
        self.read_only = mode == 'r'

    @staticmethod
    def is_supported_type(path):
        """Check if `path` is of supported type for instantiating a `H5Dict`"""
        return (
            isinstance(path, h5py.Group) or
            isinstance(path, dict) or
            isinstance(path, six.string_types) or
            _is_path_instance(path)
        )

    def __setitem__(self, attr, val):
        if self.read_only:
            raise ValueError('Cannot set item in read-only mode.')
        is_np = type(val).__module__ == np.__name__
        if isinstance(self.data, dict):
            if isinstance(attr, bytes):
                attr = attr.decode('utf-8')
            if is_np:
                self.data[attr] = pickle.dumps(val)
                # We have to remember to unpickle in __getitem__
                self.data['_{}_pickled'.format(attr)] = True
            else:
                self.data[attr] = val
            return
        if isinstance(self.data, h5py.Group) and attr in self.data:
            raise KeyError('Cannot set attribute. '
                           'Group with name "{}" exists.'.format(attr))
        if is_np:
            dataset = self.data.create_dataset(attr, val.shape, dtype=val.dtype)
            if not val.shape:
                # scalar
                dataset[()] = val
            else:
                dataset[:] = val
        elif isinstance(val, (list, tuple)):
            # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
            # because in that case even chunking the array would not make the saving
            # possible.
            bad_attributes = [x for x in val if len(x) > HDF5_OBJECT_HEADER_LIMIT]

            # Expecting this to never be true.
            if bad_attributes:
                raise RuntimeError('The following attributes cannot be saved to '
                                   'HDF5 file because they are larger than '
                                   '%d bytes: %s' % (HDF5_OBJECT_HEADER_LIMIT,
                                                     ', '.join(bad_attributes)))

            if (val and sys.version_info[0] == 3 and isinstance(
                    val[0], six.string_types)):
                # convert to bytes
                val = [x.encode('utf-8') for x in val]

            data_npy = np.asarray(val)

            num_chunks = 1
            chunked_data = np.array_split(data_npy, num_chunks)

            # This will never loop forever thanks to the test above.
            is_too_big = lambda x: x.nbytes > HDF5_OBJECT_HEADER_LIMIT
            while any(map(is_too_big, chunked_data)):
                num_chunks += 1
                chunked_data = np.array_split(data_npy, num_chunks)

            if num_chunks > 1:
                for chunk_id, chunk_data in enumerate(chunked_data):
                    self.data.attrs['%s%d' % (attr, chunk_id)] = chunk_data
            else:
                self.data.attrs[attr] = val
        else:
            self.data.attrs[attr] = val

    def __getitem__(self, attr):
        if isinstance(self.data, dict):
            if isinstance(attr, bytes):
                attr = attr.decode('utf-8')
            if attr in self.data:
                val = self.data[attr]
                if isinstance(val, dict) and val.get('_is_group'):
                    val = H5Dict(val)
                elif '_{}_pickled'.format(attr) in self.data:
                    val = pickle.loads(val)
                return val
            else:
                if self.read_only:
                    raise ValueError('Cannot create group in read-only mode.')
                val = {'_is_group': True}
                self.data[attr] = val
                return H5Dict(val)
        if attr in self.data.attrs:
            val = self.data.attrs[attr]
            if type(val).__module__ == np.__name__:
                if val.dtype.type == np.string_:
                    val = val.tolist()
        elif attr in self.data:
            val = self.data[attr]
            if isinstance(val, h5py.Dataset):
                val = np.asarray(val)
            else:
                val = H5Dict(val)
        else:
            # could be chunked
            chunk_attr = '%s%d' % (attr, 0)
            is_chunked = chunk_attr in self.data.attrs
            if is_chunked:
                val = []
                chunk_id = 0
                while chunk_attr in self.data.attrs:
                    chunk = self.data.attrs[chunk_attr]
                    val.extend([x.decode('utf8') for x in chunk])
                    chunk_id += 1
                    chunk_attr = '%s%d' % (attr, chunk_id)
            else:
                if self.read_only:
                    raise ValueError('Cannot create group in read-only mode.')
                val = H5Dict(self.data.create_group(attr))
        return val

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter(self.data)

    def iter(self):
        return iter(self.data)

    def __getattr__(self, attr):

        def wrapper(f):
            def h5wrapper(*args, **kwargs):
                out = f(*args, **kwargs)
                if isinstance(self.data, type(out)):
                    return H5Dict(out)
                else:
                    return out
            return h5wrapper

        return wrapper(getattr(self.data, attr))

    def close(self):
        if isinstance(self.data, h5py.Group):
            self.data.file.flush()
            if self._is_file:
                self.data.close()

    def update(self, *args):
        if isinstance(self.data, dict):
            self.data.update(*args)
        raise NotImplementedError

    def __contains__(self, key):
        if isinstance(self.data, dict):
            return key in self.data
        else:
            return (key in self.data) or (key in self.data.attrs)

    def get(self, key, default=None):
        if key in self:
            return self[key]
        return default

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


h5dict = H5Dict


def load_from_binary_h5py(load_function, stream):
    """Calls `load_function` on a `h5py.File` read from the binary `stream`.

    # Arguments
        load_function: A function that takes a `h5py.File`, reads from it, and
            returns any object.
        stream: Any file-like object implementing the method `read` that returns
            `bytes` data (e.g. `io.BytesIO`) that represents a valid h5py file image.

    # Returns
        The object returned by `load_function`.
    """
    # Implementation based on suggestion solution here:
    #   https://github.com/keras-team/keras/issues/9343#issuecomment-440903847
    binary_data = stream.read()
    file_access_property_list = h5py.h5p.create(h5py.h5p.FILE_ACCESS)
    file_access_property_list.set_fapl_core(backing_store=False)
    file_access_property_list.set_file_image(binary_data)
    file_id_args = {'fapl': file_access_property_list,
                    'flags': h5py.h5f.ACC_RDONLY,
                    'name': b'in-memory-h5py'}  # name does not matter
    h5_file_args = {'backing_store': False,
                    'driver': 'core',
                    'mode': 'r'}
    with contextlib.closing(h5py.h5f.open(**file_id_args)) as file_id:
        with h5py.File(file_id, **h5_file_args) as h5_file:
            return load_function(h5_file)


def save_to_binary_h5py(save_function, stream):
    """Calls `save_function` on an in memory `h5py.File`.

    The file is subsequently written to the binary `stream`.

     # Arguments
        save_function: A function that takes a `h5py.File`, writes to it and
            (optionally) returns any object.
        stream: Any file-like object implementing the method `write` that accepts
            `bytes` data (e.g. `io.BytesIO`).
     """
    with h5py.File('in-memory-h5py', driver='core', backing_store=False, mode='w') as h5file:
        # note that filename does not matter here.
        return_value = save_function(h5file)
        h5file.flush()
        binary_data = h5file.id.get_file_image()
    stream.write(binary_data)

    return return_value


def _is_path_instance(path):
    # We can't use isinstance here because it would require
    # us to add pathlib2 to the Python 2 dependencies.
    class_name = type(path).__name__
    return class_name == 'PosixPath' or class_name == 'WindowsPath'