# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from enum import Enum
import importlib
import inspect
import logging
from pathlib import Path
import pkgutil
import shutil
import sys
import tempfile

from msrest.serialization import Model
from msrest.paging import Paged

_LOGGER = logging.getLogger(__name__)

copyright_header = b"""# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is
# regenerated.
# --------------------------------------------------------------------------

"""

header = copyright_header + b"""from msrest.serialization import Model
from msrest.exceptions import HttpOperationError
"""

track2_header = copyright_header + b"""import datetime
import msrest
import msrest.serialization
from typing import Dict, List, Optional, Union
from msrest.exceptions import HttpOperationError
"""

paging_header = copyright_header + b"""from msrest.paging import Paged
"""

init_file = """
try:
    from .{} import *
except (SyntaxError, ImportError):
    from .{} import *
from .{} import *
"""


def parse_input(input_parameter):
    """From a syntax like package_name#submodule, build a package name
    and complete module name.
    """
    split_package_name = input_parameter.split('#')
    package_name = split_package_name[0]
    module_name = package_name.replace("-", ".")
    if len(split_package_name) >= 2:
        module_name = ".".join([module_name, split_package_name[1]])
    return package_name, module_name


def solve_mro(models, track2=False):
    for models_module in models:
        models_path = models_module.__path__[0]
        _LOGGER.info("Working on %s", models_path)

        if not track2 and Path(models_path, "models_py3.py").exists():
            _LOGGER.info("Skipping since already patched")
            return

        # Build the new files in a temp folder
        with tempfile.TemporaryDirectory() as temp_folder:
            final_models_path = Path(temp_folder, "models")
            final_models_path.mkdir()
            solve_one_model(models_module, final_models_path, track2=track2)

            # Switch the files
            shutil.rmtree(models_path)
            shutil.move(final_models_path, models_path)


def solve_one_model(models_module, output_folder, track2=False):
    """Will build the compacted models in the output_folder"""

    models_classes = [
        (len(model_class.__mro__), inspect.getfile(model_class), model_class) for model_name, model_class in
        vars(models_module).items()
        if model_name[0].isupper() and Model in model_class.__mro__
    ]
    # Only sort based on the first element in the tuple
    models_classes.sort(key=lambda x: x[0])

    paged_models_classes = [
        (inspect.getfile(model_class), model_class) for model_name, model_class in vars(models_module).items()
        if model_name[0].isupper() and Paged in model_class.__mro__
    ]

    enum_models_classes = [
        (inspect.getfile(model_class), model_class) for model_name, model_class in vars(models_module).items()
        if model_name[0].isupper() and Enum in model_class.__mro__
    ]
    if enum_models_classes:
        enum_file = Path(enum_models_classes[0][0])
        shutil.copyfile(enum_file, Path(output_folder, enum_file.name))
        enum_file_module_name = enum_file.with_suffix('').name
    else:
        enum_file_module_name = None

    write_model_file(Path(output_folder, "models_py3.py"), models_classes, track2=track2)
    write_paging_file(Path(output_folder, "paged_models.py"), paged_models_classes)
    write_init(
        Path(output_folder, "__init__.py"),
        "models_py3",
        "models",
        "paged_models",
        enum_file_module_name
    )


def write_model_file(output_file_path, classes_to_write, track2=False):
    with open(output_file_path, "bw") as write_fd:
        if track2:
            write_fd.write(track2_header)
        else:
            write_fd.write(header)

        for model in classes_to_write:
            _, model_file_path, _ = model

            with open(model_file_path, "rb") as read_fd:
                lines = read_fd.readlines()
                # Skip until it's "class XXXX"
                while lines:
                    if lines[0].startswith(b"class "):
                        break
                    lines.pop(0)
                else:
                    raise ValueError("Never found any class definition!")
                # Now I keep everything
                write_fd.write(b'\n')
                write_fd.write(b'\n')
                write_fd.writelines(lines)


def write_paging_file(output_file_path, classes_to_write):
    with open(output_file_path, "bw") as write_fd:
        write_fd.write(paging_header)

        for model in classes_to_write:
            model_file_path, _ = model

            with open(model_file_path, "rb") as read_fd:
                # Skip the first 15 lines (based on Autorest deterministic behavior)
                # If we want this less random, look for the first line starts with "class"
                lines = read_fd.readlines()[14:]
                write_fd.write(b'\n')
                write_fd.write(b'\n')
                write_fd.writelines(lines)


def write_init(output_file_path, model_file_name, model_file_name_py2, paging_file_name, enum_file_name):
    with open(output_file_path, "bw") as write_fd:
        write_fd.write(copyright_header)

        write_fd.write(init_file.format(
            model_file_name,
            model_file_name_py2,
            paging_file_name,
        ).encode('utf8'))
        if enum_file_name:
            write_fd.write(
                "from .{} import *".format(enum_file_name).encode('utf8')
            )


def find_models_to_change(module_name):
    """Will figure out if the package is a multi-api one,
    and understand what to generate.
    """
    main_module = importlib.import_module(module_name)
    try:
        models_module = main_module.models
        models_module.__path__
        # It didn't fail, that's a single API package
        return [models_module]
    except AttributeError:
        # This means I loaded the fake module "models"
        # and it's multi-api, load all models
        return [
            importlib.import_module('.' + label + '.models', main_module.__name__)
            for (_, label, ispkg) in pkgutil.iter_modules(main_module.__path__)
            if ispkg and label != 'aio'
        ]


def find_autorest_generated_folder(module_prefix="azure.mgmt"):
    """Find all Autorest generated code in that module prefix.

    This actually looks for a "models" package only. We could be smarter if necessary.
    """
    _LOGGER.info("Looking for Autorest generated package in %s", module_prefix)
    result = []
    prefix_module = importlib.import_module(module_prefix)
    for _, sub_package, ispkg in pkgutil.iter_modules(prefix_module.__path__, module_prefix + "."):
        try:
            _LOGGER.debug("Try %s", sub_package)
            importlib.import_module(".models", sub_package)
            # If not exception, we found it
            _LOGGER.info("Found %s", sub_package)
            result.append(sub_package)
        except ModuleNotFoundError:
            # No model, might dig deeper
            if ispkg:
                result += find_autorest_generated_folder(sub_package)
    return result


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    track2_packages = [
        'azure.mgmt.keyvault',
        'azure.mgmt.storage',
        'azure.mgmt.compute',
        'azure.mgmt.monitor',
        'azure.mgmt.rdbms'
        'azure.mgmt.loganalytics',
        'azure.mgmt.web',
        'azure.mgmt.cosmosdb',
        'azure.mgmt.privatedns',
        'azure.mgmt.dms',
        'azure.mgmt.sqlvirtualmachine'
    ]
    prefix = sys.argv[1] if len(sys.argv) >= 2 else "azure.mgmt"
    for autorest_package in find_autorest_generated_folder(prefix):
        models = find_models_to_change(autorest_package)
        track2 = False
        for track2_pkg in track2_packages:
            if autorest_package.startswith(track2_pkg):
                track2 = True
                break
        solve_mro(models, track2=track2)
