1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
import io
import six
try:
import numpy
NUMPY_AVAILABLE = True
except ImportError:
numpy = None
NUMPY_AVAILABLE = False
from .base import BaseItertool
def iter_(obj):
"""A custom replacement for iter(), dispatching a few custom picklable
iterators for known types.
"""
if six.PY2:
file_types = file, # noqa
if six.PY3:
file_types = io.IOBase,
dict_items = {}.items().__class__
dict_values = {}.values().__class__
dict_keys = {}.keys().__class__
dict_view = (dict_items, dict_values, dict_keys)
if isinstance(obj, dict):
return ordered_sequence_iterator(list(obj.keys()))
if isinstance(obj, file_types):
return file_iterator(obj)
if six.PY2:
if isinstance(obj, (list, tuple)):
return ordered_sequence_iterator(obj)
if isinstance(obj, xrange): # noqa
return range_iterator(obj)
if NUMPY_AVAILABLE and isinstance(obj, numpy.ndarray):
return ordered_sequence_iterator(obj)
if six.PY3 and isinstance(obj, dict_view):
return ordered_sequence_iterator(list(obj))
return iter(obj)
class range_iterator(BaseItertool):
"""A picklable range iterator for Python 2."""
def __init__(self, xrange_):
self._start, self._stop, self._step = xrange_.__reduce__()[1]
self._n = self._start
def __next__(self):
if (self._step > 0 and self._n < self._stop or
self._step < 0 and self._n > self._stop):
value = self._n
self._n += self._step
return value
else:
raise StopIteration
class file_iterator(BaseItertool):
"""A picklable file iterator."""
def __init__(self, f):
self._f = f
def __next__(self):
line = self._f.readline()
if not line:
raise StopIteration
return line
def __getstate__(self):
name, pos, mode = self._f.name, self._f.tell(), self._f.mode
return name, pos, mode
def __setstate__(self, state):
name, pos, mode = state
self._f = open(name, mode=mode)
self._f.seek(pos)
class ordered_sequence_iterator(BaseItertool):
"""A picklable replacement for list and tuple iterators."""
def __init__(self, sequence):
self._sequence = sequence
self._position = 0
def __next__(self):
if self._position < len(self._sequence):
value = self._sequence[self._position]
self._position += 1
return value
else:
raise StopIteration
|