# The purpose of this test is to check that we have implementation parity between
# a Python `torch.nn` module and its corresponding C++ `torch::nn` module. Concretely,
# this test does the following:
#
# 1. Get a test params dict from common_nn.py, run forward and backward on the
# Python module created using the test params.
#
# 2. Serialize the Python module's parameters / buffers and its forward input
# arguments, deserialize them in C++ and load them into the C++ module.
#
# 3. Run the same forward and backward passes on the C++ module, and serialize
# the C++ module's forward output and backward gradients.
#
# 4. Compare Python/C++ module's forward output and backward gradients. If they
# are the same, then we have implementation parity between Python/C++ module.

import tempfile
from string import Template
import types
import pprint
import os

import torch
from cpp_api_parity.utils import TorchNNModuleTestParams, TORCH_NN_COMMON_TEST_HARNESS, \
    compile_cpp_code_inline, set_python_tensors_requires_grad, move_python_tensors_to_device, \
    add_test, compute_cpp_args_construction_stmts_and_forward_arg_symbols, serialize_arg_dict_as_script_module, \
    compute_arg_dict, decorate_test_fn, compute_temp_file_path, generate_error_msg, is_torch_nn_functional_test, \
    try_remove_folder
from cpp_api_parity.sample_module import SAMPLE_MODULE_CPP_SOURCE

# Expected substitutions:
#
# ${module_variant_name}  (e.g. `Linear_no_bias_cpu`)
# ${module_qualified_name}  (e.g. `torch::nn::Linear`)
# ${cpp_args_construction_stmts}
# ${cpp_constructor_args}
# ${device}
# ${cpp_forward_args_symbols}
TORCH_NN_MODULE_TEST_FORWARD_BACKWARD = Template("""
void ${module_variant_name}_test_forward_backward(
    const std::string& arg_dict_file_path,
    const std::string& module_file_path,
    const std::string& forward_output_file_path,
    const std::string& backward_grad_dict_file_path) {
  pybind11::gil_scoped_release no_gil;

  // Declare arguments
  auto arg_dict = load_dict_from_file(arg_dict_file_path);
  ${cpp_args_construction_stmts};

  // Construct module and load params/buffers from Python module
  ${module_qualified_name} module${cpp_constructor_args};
  module->to(std::string("${device}"));
  torch::load(module, module_file_path);

  // Some modules (such as `RReLU`) create random tensors in their forward pass.
  // To make sure the random tensors created are the same in Python/C++, we need
  // to set the RNG seed manually.
  torch::manual_seed(0);

  // Forward pass
  auto cpp_output = module(${cpp_forward_args_symbols});

  // Save the output into a file to be compared in Python later
  write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path);

  // Backward pass
  cpp_output.sum().backward();

  // Put all gradients into a c10::Dict, save it into a file to be compared in Python later
  c10::Dict<std::string, torch::Tensor> grad_dict;
  for (const auto& param : module->named_parameters()) {
    torch::Tensor grad = param.value().grad();
    if (grad.is_sparse()) {
      grad_dict.insert(param.key() + "_grad_indices", grad.coalesce().indices());
      grad_dict.insert(param.key() + "_grad_values", grad.coalesce().values());
    } else {
      grad_dict.insert(param.key() + "_grad", grad);
    }
  }

  write_ivalue_to_file(torch::IValue(grad_dict), backward_grad_dict_file_path);
}
""")

def run_python_forward_backward(unit_test_class, test_params):
    device = test_params.device
    module = test_params.test_instance.constructor(*test_params.test_instance.constructor_args).to(device)

    inputs = set_python_tensors_requires_grad(move_python_tensors_to_device(
        [arg_value for _, arg_value in test_params.arg_dict['input']], device))
    inputs += move_python_tensors_to_device(
        [arg_value for _, arg_value in test_params.arg_dict['target']], device)
    inputs += move_python_tensors_to_device(
        [arg_value for _, arg_value in test_params.arg_dict['extra_args']], device)

    # Some modules (such as `RReLU`) create random tensors in their forward pass.
    # To make sure the random tensors created are the same in Python/C++, we need
    # to set the RNG seed manually.
    torch.manual_seed(0)

    # Forward pass
    python_output = module(*inputs)

    # NOTE: This is a workaround to allow any module to be traced.
    # We can do this because we are only interested in transferring
    # the Python module's parameters and buffers to the C++ module.
    module.forward = types.MethodType(lambda self, input: input, module)
    script_module = torch.jit.trace(module, torch.tensor(0))

    # Backward pass
    python_output.sum().backward()

    # Put all gradients into a dict, to be compared later
    python_grad_dict = {}
    for name, param in module.named_parameters():
        grad = param.grad
        if grad.is_sparse:
            python_grad_dict[name + "_grad_indices"] = grad.coalesce().indices()
            python_grad_dict[name + "_grad_values"] = grad.coalesce().values()
        else:
            python_grad_dict[name + "_grad"] = grad

    return script_module, python_output, python_grad_dict

def test_forward_backward(unit_test_class, test_params):
    module_variant_name = test_params.module_variant_name
    cpp_tmp_folder = test_params.cpp_tmp_folder
    # Remove the temporary folder if it exists already
    try_remove_folder(cpp_tmp_folder)
    os.mkdir(cpp_tmp_folder)

    # Run forward and backward on Python module
    script_module, python_output, python_grad_dict = run_python_forward_backward(unit_test_class, test_params)

    # Save Python module and arguments to be used from C++ function
    module_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'module')
    arg_dict_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'arg_dict')
    script_module.save(module_file_path)
    serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path)

    cpp_test_name = '{}_test_forward_backward'.format(test_params.module_variant_name)
    cpp_test_fn = getattr(unit_test_class.module_impl_check_cpp_module, cpp_test_name)

    def run_cpp_test_fn_and_check_output():
        forward_output_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'forward_output')
        backward_grad_dict_file_path = compute_temp_file_path(cpp_tmp_folder, module_variant_name, 'backward_grad_dict')

        cpp_test_fn(arg_dict_file_path, module_file_path, forward_output_file_path, backward_grad_dict_file_path)
        cpp_output = torch.load(forward_output_file_path)
        cpp_grad_dict = torch.load(backward_grad_dict_file_path)

        # Check that forward outputs are equal
        unit_test_class.assertEqual(python_output, cpp_output,
                                    msg=generate_error_msg("forward output", cpp_output, python_output))

        # Check that module parameter gradients are equal after backward pass
        unit_test_class.assertEqual(
            len(python_grad_dict), len(cpp_grad_dict),
            msg=generate_error_msg("# of parameters", len(cpp_grad_dict), len(python_grad_dict)))
        for key in python_grad_dict:
            param_name = None
            for suffix in ['_grad', '_grad_indices', '_grad_values']:
                if key.endswith(suffix):
                    param_name = key[:-len(suffix)]
                    break
            assert param_name is not None
            sparsity_str = 'sparse' if key.endswith('_grad_indices') or key.endswith('_grad_values') else 'dense'

            unit_test_class.assertTrue(
                key in cpp_grad_dict,
                msg=generate_error_msg(
                    "\"Does module have a parameter named `{}` with {} gradient?\"".format(param_name, sparsity_str),
                    False, True))
            unit_test_class.assertEqual(
                python_grad_dict[key], cpp_grad_dict[key],
                msg=generate_error_msg(
                    "`{}`'s {} gradient (`{}`)".format(param_name, sparsity_str, key),
                    cpp_grad_dict[key], python_grad_dict[key]))

    run_cpp_test_fn_and_check_output()

    # Remove temporary folder that stores C++ outputs
    try_remove_folder(cpp_tmp_folder)

def compute_module_name(test_params_dict):
    fullname = test_params_dict.get('fullname', None)
    if fullname:
        module_name = fullname.split('_')[0]
    else:
        module_name = test_params_dict.get('module_name')
    return module_name

def process_test_params_for_module(test_params_dict, device, test_instance_class):
    module_name = compute_module_name(test_params_dict)
    test_params_dict['constructor'] = test_params_dict.get('constructor', getattr(torch.nn, module_name))
    test_instance = test_instance_class(**test_params_dict)
    assert test_instance.get_name().startswith('test_')
    # Example output: `BCELoss_weights_cuda`
    module_variant_name = test_instance.get_name()[5:] + (('_' + device) if device != 'cpu' else '')

    if 'constructor_args' in test_params_dict:
        assert 'cpp_constructor_args' in test_params_dict, (
            "If `constructor_args` is present in test params dict, to enable C++ API parity test, "
            "`cpp_constructor_args` must be present in:\n{}"
            "If you are interested in adding the C++ API parity test, please see:\n"
            "NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
            "If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
        ).format(pprint.pformat(test_params_dict))

    return TorchNNModuleTestParams(
        module_name=module_name,
        module_variant_name=module_variant_name,
        test_instance=test_instance,
        cpp_constructor_args=test_params_dict.get('cpp_constructor_args', ''),
        arg_dict=compute_arg_dict(test_params_dict, test_instance),
        has_parity=test_params_dict.get('has_parity', True),
        device=device,
        cpp_tmp_folder=tempfile.mkdtemp(),
    )

def write_test_to_test_class(
        unit_test_class, test_params_dict, test_instance_class, parity_table, devices):
    assert not is_torch_nn_functional_test(test_params_dict)

    module_name = compute_module_name(test_params_dict)

    assert hasattr(torch.nn, module_name), (
        "`torch.nn` doesn't have module `{}`. "
        "If you are adding a new test, please set `fullname` using format `ModuleName_desc` "
        "or set `module_name` using format `ModuleName` in the module test dict:\n{}"
    ).format(module_name, pprint.pformat(test_params_dict))

    module_full_name = 'torch::nn::' + module_name

    assert module_full_name in parity_table['torch::nn'], (
        "Please add `{}` entry to `torch::nn` section of `test/cpp_api_parity/parity-tracker.md`. "
        "(Discovered while processing\n{}.)").format(module_full_name, pprint.pformat(test_params_dict))

    for device in devices:
        test_params = process_test_params_for_module(
            test_params_dict=test_params_dict,
            device=device,
            test_instance_class=test_instance_class,
        )
        try_remove_folder(test_params.cpp_tmp_folder)
        unit_test_name = 'test_torch_nn_{}'.format(test_params.module_variant_name)
        unit_test_class.module_test_params_map[unit_test_name] = test_params

        def test_fn(self):
            test_forward_backward(
                unit_test_class=self, test_params=unit_test_class.module_test_params_map[self._testMethodName])

        test_fn = decorate_test_fn(
            test_fn=test_fn,
            test_cuda=test_params_dict.get('test_cuda', True),
            has_impl_parity=parity_table['torch::nn'][module_full_name][0] and
            test_params_dict.get('has_parity', True),
            device=device)

        add_test(unit_test_class, unit_test_name, test_fn)

def generate_test_cpp_sources(test_params, template):
    device = test_params.device

    cpp_constructor_args = test_params.cpp_constructor_args
    if cpp_constructor_args != '':
        cpp_constructor_args = '({})'.format(cpp_constructor_args)

    cpp_args_construction_stmts, cpp_forward_args_symbols = \
        compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)

    test_cpp_sources = template.substitute(
        module_variant_name=test_params.module_variant_name,
        module_qualified_name='torch::nn::{}'.format(test_params.module_name),
        cpp_args_construction_stmts=";\n  ".join(cpp_args_construction_stmts),
        cpp_constructor_args=cpp_constructor_args,
        cpp_forward_args_symbols=", ".join(cpp_forward_args_symbols),
        device=device,
    )
    return test_cpp_sources

# Build all C++ tests together, instead of once per test.
def build_cpp_tests(unit_test_class, print_cpp_source=False):
    assert len(unit_test_class.module_test_params_map) > 0
    cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_MODULE_CPP_SOURCE
    functions = []
    for test_name, test_params in unit_test_class.module_test_params_map.items():
        cpp_sources += generate_test_cpp_sources(
            test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD_BACKWARD)
        functions.append('{}_test_forward_backward'.format(test_params.module_variant_name))
    if print_cpp_source:
        print(cpp_sources)

    cpp_module = compile_cpp_code_inline(
        name='module_impl_check',
        cpp_sources=cpp_sources,
        functions=functions)
    unit_test_class.module_impl_check_cpp_module = cpp_module
