from collections import defaultdict, Sequence, Sized, Iterable, Callable
import inspect
import wrapt
from cytoolz import curry
from numpy import ndarray
from six import integer_types

from .exceptions import UndefinedOperatorError, DifferentLengthError
from .exceptions import ExpectedTypeError, ShapeMismatchError
from .exceptions import OutsideRangeError

def is_docs(arg_id, args, kwargs):
    from spacy.tokens.doc import Doc
    docs = args[arg_id]
    if not isinstance(docs, Sequence):
        raise ExpectedTypeError(type(docs), ['Sequence'])
    if not isinstance(docs[0], Doc):
        raise ExpectedTypeError(type(docs[0]), ['spacy.tokens.doc.Doc'])



def equal_length(*args):
    '''Check that arguments have the same length.
    '''
    for i, arg in enumerate(args):
        if not isinstance(arg, Sized):
            raise ExpectedTypeError(arg, ['Sized'])
        if i >= 1 and len(arg) != len(args[0]):
            raise DifferentLengthError(args, arg)


def equal_axis(*args, **axis):
    '''Check that elements have the same dimension on specified axis.
    '''
    axis = axis.get('axis', -1)
    for i, arg in enumerate(args):
        if not isinstance(arg, ndarray):
            raise ExpectedTypeError(arg, ['ndarray'])
        if axis >= 0 and (axis+1) < args[i].shape[axis]:
            raise ShapeError(
                "Shape: %s. Expected at least %d dimensions",
                shape, axis)
        if i >= 1 and arg.shape[axis] != args[0].shape[axis]:
            lengths = [a.shape[axis] for a in args]
            raise DifferentLengthError(lengths, arg)

@curry
def has_shape(shape, arg_id, args, kwargs):
    '''Check that a particular argument is an array with a given shape. The
    shape may contain string attributes, which will be fetched from arg0 to
    the function (usually self).
    '''
    self = args[0]
    arg = args[arg_id]
    if not hasattr(arg, 'shape'):
        raise ExpectedTypeError(arg, ['array'])
    shape_values = []
    for dim in shape:
        if not isinstance(dim, integer_types):
            dim = getattr(self, dim, None)
        shape_values.append(dim)
    if len(shape) != len(arg.shape):
        raise ShapeMismatchError(arg.shape, tuple(shape_values), shape)
    for i, dim in enumerate(shape_values):
        # Allow underspecified dimensions
        if dim != None and arg.shape[i] != dim:
            raise ShapeMismatchError(arg.shape, shape_values, shape)


def is_shape(arg_id, args, func_kwargs, **kwargs):
    arg = args[arg_id]
    if not isinstance(arg, Iterable):
        raise ExpectedTypeError(arg, ['iterable'])
    for value in arg:
        if not isinstance(value, integer_types) or value < 0:
            raise ExpectedTypeError(arg, ['valid shape (positive ints)'])


def is_sequence(arg_id, args, kwargs):
    arg = args[arg_id]
    if not isinstance(arg, Iterable) and not hasattr(arg, '__getitem__'):
        raise ExpectedTypeError(arg, ['iterable'])


def is_float(arg_id, args, func_kwargs, **kwargs):
    arg = args[arg_id]
    if not isinstance(arg, float):
        raise ExpectedTypeError(arg, ['float'])
    if 'min' in kwargs and arg < kwargs['min']:
        raise OutsideRangeError(arg, kwargs['min'], '>=')
    if 'max' in kwargs and arg > kwargs['max']:
        raise OutsideRangeError(arg, kwargs['max'], '<=')


def is_int(arg_id, args, func_kwargs, **kwargs):
    arg = args[arg_id]
    if not isinstance(arg, integer_types):
        raise ExpectedTypeError(arg, ['int'])
    if 'min' in kwargs and arg < kwargs['min']:
        raise OutsideRangeError(arg, kwargs['min'], '>=')
    if 'max' in kwargs and arg > kwargs['max']:
        raise OutsideRangeError(arg, kwargs['max'], '<=')


def is_array(arg_id, args, func_kwargs, **kwargs):
    arg = args[arg_id]
    if not isinstance(arg, ndarray):
        raise ExpectedTypeError(arg, ['ndarray'])


def is_int_array(arg_id, args, func_kwargs, **kwargs):
    arg = args[arg_id]
    if not isinstance(arg, ndarray) or 'i' not in arg.dtype.kind:
        raise ExpectedTypeError(arg, ['ndarray[int]'])


def operator_is_defined(op):
    @wrapt.decorator
    def checker(wrapped, instance, args, kwargs):
        if instance is None:
            instance = args[0]
        if instance is None:
            raise ExpectedTypeError(instance, ['Model'])
        if op not in instance._operators:
            raise UndefinedOperatorError(op, instance, args[0], instance._operators)
        else:
            return wrapped(*args, **kwargs)
    return checker


def arg(arg_id, *constraints):
    @wrapt.decorator
    def checked_function(wrapped, instance, args, kwargs):
        # for partial functions or other C-compiled functions
        if not hasattr(wrapped, 'checks'): # pragma: no cover
            return wrapped(*args, **kwargs)
        if instance is not None:
            fix_args = [instance] + list(args)
        else:
            fix_args = list(args)
        for arg_id, checks in wrapped.checks.items():
            for check in checks:
                if not isinstance(check, Callable):
                    raise ExpectedTypeError(check, ['Callable'])
                check(arg_id, fix_args, kwargs)
        return wrapped(*args, **kwargs)

    def arg_check_adder(func):
        if hasattr(func, 'checks'):
            func.checks.setdefault(arg_id, []).extend(constraints)
            return func
        else:
            wrapped = checked_function(func)
            wrapped.checks = {arg_id: list(constraints)}
            return wrapped
    return arg_check_adder


def args(*constraints):
    @wrapt.decorator
    def arg_check_adder(wrapped, instance, args, kwargs):
        for check in constraints:
            if not isinstance(check, Callable):
                raise ExpectedTypeError(check, ['Callable'])
            check(*args)
        return wrapped(*args, **kwargs)
    return arg_check_adder
