1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
from six.moves import range
from .gpuarray import _split, _concatenate, dtype_to_typecode
from .dtypes import upcast
from . import asarray
def atleast_1d(*arys):
res = []
for ary in arys:
ary = asarray(ary)
if len(ary.shape) == 0:
result = ary.reshape((1,))
else:
result = ary
res.append(result)
if len(res) == 1:
return res[0]
else:
return res
def atleast_2d(*arys):
res = []
for ary in arys:
ary = asarray(ary)
if len(ary.shape) == 0:
result = ary.reshape((1, 1))
elif len(ary.shape) == 1:
result = ary.reshape((1, ary.shape[0]))
else:
result = ary
res.append(result)
if len(res) == 1:
return res[0]
else:
return res
def atleast_3d(*arys):
res = []
for ary in arys:
ary = asarray(ary)
if len(ary.shape) == 0:
result = ary.reshape((1, 1, 1))
elif len(ary.shape) == 1:
result = ary.reshape((1, ary.shape[0], 1))
elif len(ary.shape) == 2:
result = ary.reshape(ary.shape + (1,))
else:
result = ary
res.append(result)
if len(res) == 1:
return res[0]
else:
return res
def split(ary, indices_or_sections, axis=0):
try:
len(indices_or_sections)
except TypeError:
if ary.shape[axis] % indices_or_sections != 0:
raise ValueError("array split does not result in an "
"equal division")
return array_split(ary, indices_or_sections, axis)
def array_split(ary, indices_or_sections, axis=0):
try:
indices = list(indices_or_sections)
res = _split(ary, indices, axis)
except TypeError:
if axis < 0:
axis += ary.ndim
if axis < 0:
raise ValueError('axis out of bounds')
nsec = int(indices_or_sections)
if nsec <= 0:
raise ValueError('number of sections must be larger than 0.')
neach, extra = divmod(ary.shape[axis], nsec)
# this madness is to support the numpy interface
# it is supported by tests, but little else
divs = (list(range(neach + 1, (neach + 1) * extra + 1, neach + 1)) +
list(range((neach + 1) * extra + neach,
ary.shape[axis], neach)))
res = _split(ary, divs, axis)
return res
def hsplit(ary, indices_or_sections):
if len(ary.shape) == 0:
raise ValueError('hsplit only works on arrays of 1 or more dimensions')
if len(ary.shape) > 1:
axis = 1
else:
axis = 0
return split(ary, indices_or_sections, axis=axis)
def vsplit(ary, indices_or_sections):
if len(ary.shape) < 2:
raise ValueError('vsplit only works on arrays of 2 or more dimensions')
return split(ary, indices_or_sections, axis=0)
def dsplit(ary, indices_or_sections):
if len(ary.shape) < 3:
raise ValueError('vsplit only works on arrays of 3 or more dimensions')
return split(ary, indices_or_sections, axis=2)
def concatenate(arys, axis=0, context=None):
if len(arys) == 0:
raise ValueError("concatenation of zero-length sequences is "
"impossible")
if axis < 0:
axis += arys[0].ndim
if axis < 0:
raise ValueError('axis out of bounds')
al = [asarray(a, context=context) for a in arys]
if context is None:
context = al[0].context
outtype = upcast(*[a.dtype for a in arys])
return _concatenate(al, axis, dtype_to_typecode(outtype), type(al[0]),
context)
def vstack(tup, context=None):
return concatenate([atleast_2d(a) for a in tup], 0, context)
def hstack(tup, context=None):
tup = [atleast_1d(a) for a in tup]
if tup[0].ndim == 1:
return concatenate(tup, 0, context)
else:
return concatenate(tup, 1, context)
def dstack(tup, context=None):
return concatenate([atleast_3d(a) for a in tup], 2, context)
|