1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
|
from pypy.interpreter.error import OperationError, oefmt
from pypy.interpreter.gateway import interp2app
from pypy.interpreter.typedef import TypeDef, GetSetProperty
from pypy.module.micronumpy import loop
from pypy.module.micronumpy.base import convert_to_array
from pypy.module.micronumpy.concrete import BaseConcreteArray
from .ndarray import W_NDimArray
class FakeArrayImplementation(BaseConcreteArray):
""" The sole purpose of this class is to W_FlatIterator can behave
like a real array for descr_eq and friends
"""
def __init__(self, base):
self._base = base
self.dtype = base.get_dtype()
self.shape = [base.get_size()]
self.storage = self._base.implementation.storage
self.order = base.get_order()
def base(self):
return self._base
def get_shape(self):
return self.shape
def get_size(self):
return self.base().get_size()
def create_iter(self, shape=None, backward_broadcast=False):
assert isinstance(self.base(), W_NDimArray)
return self.base().create_iter()
class W_FlatIterator(W_NDimArray):
def __init__(self, arr):
self.base = arr
self.iter, self.state = arr.create_iter()
# this is needed to support W_NDimArray interface
self.implementation = FakeArrayImplementation(self.base)
def descr_base(self, space):
return self.base
def descr_index(self, space):
return space.newint(self.state.index)
def descr_coords(self, space):
coords = self.iter.indices(self.state)
return space.newtuple([space.newint(c) for c in coords])
def descr_iter(self):
return self
def descr_len(self, space):
return space.newint(self.iter.size)
def descr_next(self, space):
if self.iter.done(self.state):
raise OperationError(space.w_StopIteration, space.w_None)
w_res = self.iter.getitem(self.state)
self.iter.next(self.state, mutate=True)
return w_res
def descr_getitem(self, space, w_idx):
if not (space.isinstance_w(w_idx, space.w_int) or
space.isinstance_w(w_idx, space.w_slice)):
raise oefmt(space.w_IndexError, 'unsupported iterator index')
try:
start, stop, step, length = space.decode_index4_unsafe(w_idx, self.iter.size)
state = self.iter.goto(start)
if length == 1:
return self.iter.getitem(state)
base = self.base
res = W_NDimArray.from_shape(space, [length], base.get_dtype(),
base.get_order(), w_instance=base)
return loop.flatiter_getitem(res, self.iter, state, step)
finally:
self.iter.reset(self.state, mutate=True)
def descr_setitem(self, space, w_idx, w_value):
if not (space.isinstance_w(w_idx, space.w_int) or
space.isinstance_w(w_idx, space.w_slice)):
raise oefmt(space.w_IndexError, 'unsupported iterator index')
start, stop, step, length = space.decode_index4_unsafe(w_idx, self.iter.size)
try:
state = self.iter.goto(start)
dtype = self.base.get_dtype()
if length == 1:
try:
val = dtype.coerce(space, w_value)
except OperationError:
raise oefmt(space.w_ValueError, "Error setting single item of array.")
self.iter.setitem(state, val)
return
arr = convert_to_array(space, w_value)
loop.flatiter_setitem(space, dtype, arr, self.iter, state, step, length)
finally:
self.iter.reset(self.state, mutate=True)
def descr___array_wrap__(self, space, obj, w_context=None):
return obj
W_FlatIterator.typedef = TypeDef("numpy.flatiter",
base = GetSetProperty(W_FlatIterator.descr_base),
index = GetSetProperty(W_FlatIterator.descr_index),
coords = GetSetProperty(W_FlatIterator.descr_coords),
__iter__ = interp2app(W_FlatIterator.descr_iter),
__len__ = interp2app(W_FlatIterator.descr_len),
next = interp2app(W_FlatIterator.descr_next),
__getitem__ = interp2app(W_FlatIterator.descr_getitem),
__setitem__ = interp2app(W_FlatIterator.descr_setitem),
__eq__ = interp2app(W_FlatIterator.descr_eq),
__ne__ = interp2app(W_FlatIterator.descr_ne),
__lt__ = interp2app(W_FlatIterator.descr_lt),
__le__ = interp2app(W_FlatIterator.descr_le),
__gt__ = interp2app(W_FlatIterator.descr_gt),
__ge__ = interp2app(W_FlatIterator.descr_ge),
__array_wrap__ = interp2app(W_NDimArray.descr___array_wrap__),
)
|