from six.moves import range

from .gpuarray import _split, _concatenate, dtype_to_typecode
from .dtypes import upcast
from . import asarray


def atleast_1d(*arys):
    res = []
    for ary in arys:
        ary = asarray(ary)
        if len(ary.shape) == 0:
            result = ary.reshape((1,))
        else:
            result = ary
        res.append(result)
    if len(res) == 1:
        return res[0]
    else:
        return res


def atleast_2d(*arys):
    res = []
    for ary in arys:
        ary = asarray(ary)
        if len(ary.shape) == 0:
            result = ary.reshape((1, 1))
        elif len(ary.shape) == 1:
            result = ary.reshape((1, ary.shape[0]))
        else:
            result = ary
        res.append(result)
    if len(res) == 1:
        return res[0]
    else:
        return res


def atleast_3d(*arys):
    res = []
    for ary in arys:
        ary = asarray(ary)
        if len(ary.shape) == 0:
            result = ary.reshape((1, 1, 1))
        elif len(ary.shape) == 1:
            result = ary.reshape((1, ary.shape[0], 1))
        elif len(ary.shape) == 2:
            result = ary.reshape(ary.shape + (1,))
        else:
            result = ary
        res.append(result)
    if len(res) == 1:
        return res[0]
    else:
        return res


def split(ary, indices_or_sections, axis=0):
    try:
        len(indices_or_sections)
    except TypeError:
        if ary.shape[axis] % indices_or_sections != 0:
            raise ValueError("array split does not result in an "
                             "equal division")
    return array_split(ary, indices_or_sections, axis)


def array_split(ary, indices_or_sections, axis=0):
    try:
        indices = list(indices_or_sections)
        res = _split(ary, indices, axis)
    except TypeError:
        if axis < 0:
            axis += ary.ndim
        if axis < 0:
            raise ValueError('axis out of bounds')
        nsec = int(indices_or_sections)
        if nsec <= 0:
            raise ValueError('number of sections must be larger than 0.')
        neach, extra = divmod(ary.shape[axis], nsec)
        # this madness is to support the numpy interface
        # it is supported by tests, but little else
        divs = (list(range(neach + 1, (neach + 1) * extra + 1, neach + 1)) +
                list(range((neach + 1) * extra + neach,
                           ary.shape[axis], neach)))
        res = _split(ary, divs, axis)
    return res


def hsplit(ary, indices_or_sections):
    if len(ary.shape) == 0:
        raise ValueError('hsplit only works on arrays of 1 or more dimensions')
    if len(ary.shape) > 1:
        axis = 1
    else:
        axis = 0
    return split(ary, indices_or_sections, axis=axis)


def vsplit(ary, indices_or_sections):
    if len(ary.shape) < 2:
        raise ValueError('vsplit only works on arrays of 2 or more dimensions')
    return split(ary, indices_or_sections, axis=0)


def dsplit(ary, indices_or_sections):
    if len(ary.shape) < 3:
        raise ValueError('vsplit only works on arrays of 3 or more dimensions')
    return split(ary, indices_or_sections, axis=2)


def concatenate(arys, axis=0, context=None):
    if len(arys) == 0:
        raise ValueError("concatenation of zero-length sequences is "
                         "impossible")
    if axis < 0:
        axis += arys[0].ndim
    if axis < 0:
        raise ValueError('axis out of bounds')
    al = [asarray(a, context=context) for a in arys]
    if context is None:
        context = al[0].context
    outtype = upcast(*[a.dtype for a in arys])
    return _concatenate(al, axis, dtype_to_typecode(outtype), type(al[0]),
                        context)


def vstack(tup, context=None):
    return concatenate([atleast_2d(a) for a in tup], 0, context)


def hstack(tup, context=None):
    tup = [atleast_1d(a) for a in tup]
    if tup[0].ndim == 1:
        return concatenate(tup, 0, context)
    else:
        return concatenate(tup, 1, context)


def dstack(tup, context=None):
    return concatenate([atleast_3d(a) for a in tup], 2, context)
