import argparse
import torch
import torch.nn as nn

from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
from .runner import get_nn_runners


def barf():
    import pdb
    pdb.set_trace()


def assertEqual(tensor, expected, threshold=0.001):
    if isinstance(tensor, list) or isinstance(tensor, tuple):
        for t, e in zip(tensor, expected):
            assertEqual(t, e)
    else:
        if (tensor - expected).abs().max() > threshold:
            barf()


def filter_requires_grad(tensors):
    return [t for t in tensors if t.requires_grad]


def test_rnns(experim_creator, control_creator, check_grad=True, verbose=False,
              seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
              miniBatch=64, device='cuda', seed=17):
    creator_args = dict(seqLength=seqLength, numLayers=numLayers,
                        inputSize=inputSize, hiddenSize=hiddenSize,
                        miniBatch=miniBatch, device=device, seed=seed)

    print("Setting up...")
    control = control_creator(**creator_args)
    experim = experim_creator(**creator_args)

    # Precondition
    assertEqual(experim.inputs, control.inputs)
    assertEqual(experim.params, control.params)

    print("Checking outputs...")
    control_outputs = control.forward(*control.inputs)
    experim_outputs = experim.forward(*experim.inputs)
    assertEqual(experim_outputs, control_outputs)

    print("Checking grads...")
    assert control.backward_setup is not None
    assert experim.backward_setup is not None
    assert control.backward is not None
    assert experim.backward is not None
    control_backward_inputs = control.backward_setup(control_outputs, seed)
    experim_backward_inputs = experim.backward_setup(experim_outputs, seed)

    control.backward(*control_backward_inputs)
    experim.backward(*experim_backward_inputs)

    control_grads = [p.grad for p in control.params]
    experim_grads = [p.grad for p in experim.params]
    assertEqual(experim_grads, control_grads)

    if verbose:
        print(experim.forward.graph_for(*experim.inputs))
    print('')


def test_vl_py(**test_args):
    # XXX: This compares vl_py with vl_lstm.
    # It's done this way because those two don't give the same outputs so
    # the result isn't an apples-to-apples comparison right now.
    control_creator = varlen_pytorch_lstm_creator
    name, experim_creator, context = get_nn_runners('vl_py')[0]
    with context():
        print('testing {}...'.format(name))
        creator_keys = [
            'seqLength', 'numLayers', 'inputSize',
            'hiddenSize', 'miniBatch', 'device', 'seed'
        ]
        creator_args = {key: test_args[key] for key in creator_keys}

        print("Setting up...")
        control = control_creator(**creator_args)
        experim = experim_creator(**creator_args)

        # Precondition
        assertEqual(experim.inputs, control.inputs[:2])
        assertEqual(experim.params, control.params)

        print("Checking outputs...")
        control_out, control_hiddens = control.forward(*control.inputs)
        control_hx, control_cx = control_hiddens
        experim_out, experim_hiddens = experim.forward(*experim.inputs)
        experim_hx, experim_cx = experim_hiddens

        experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2)
        assertEqual(experim_padded, control_out)
        assertEqual(torch.cat(experim_hx, dim=1), control_hx)
        assertEqual(torch.cat(experim_cx, dim=1), control_cx)

        print("Checking grads...")
        assert control.backward_setup is not None
        assert experim.backward_setup is not None
        assert control.backward is not None
        assert experim.backward is not None
        control_backward_inputs = control.backward_setup(
            (control_out, control_hiddens), test_args['seed'])
        experim_backward_inputs = experim.backward_setup(
            (experim_out, experim_hiddens), test_args['seed'])

        control.backward(*control_backward_inputs)
        experim.backward(*experim_backward_inputs)

        control_grads = [p.grad for p in control.params]
        experim_grads = [p.grad for p in experim.params]
        assertEqual(experim_grads, control_grads)

        if test_args['verbose']:
            print(experim.forward.graph_for(*experim.inputs))
        print('')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test lstm correctness')

    parser.add_argument('--seqLength', default='100', type=int)
    parser.add_argument('--numLayers', default='1', type=int)
    parser.add_argument('--inputSize', default='512', type=int)
    parser.add_argument('--hiddenSize', default='512', type=int)
    parser.add_argument('--miniBatch', default='64', type=int)
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--check_grad', default='True', type=bool)
    parser.add_argument('--variable_lstms', action='store_true')
    parser.add_argument('--seed', default='17', type=int)
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--rnns', nargs='*',
                        help='What to run. jit_premul, jit, etc')
    args = parser.parse_args()
    if args.rnns is None:
        args.rnns = ['jit_premul', 'jit']
    print(args)

    if 'cuda' in args.device:
        assert torch.cuda.is_available()

    rnn_runners = get_nn_runners(*args.rnns)

    should_test_varlen_lstms = args.variable_lstms
    test_args = vars(args)
    del test_args['rnns']
    del test_args['variable_lstms']

    if should_test_varlen_lstms:
        test_vl_py(**test_args)

    for name, creator, context in rnn_runners:
        with context():
            print('testing {}...'.format(name))
            test_rnns(creator, pytorch_lstm_creator, **test_args)
