# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import os
import azure.cli.command_modules
import re
import shutil
import py_compile
import logging

_LOGGER = logging.getLogger(__name__)


class CompactorCtx:

    def __init__(self):
        self._code_pieces = {}
        self._namespace = None
        self._write_mode = False
        self._written_namespaces = set()
        self._parent_namespaces = set()

    def add_code_piece(self, key, code):
        assert self._namespace is not None and self._write_mode, "Namespace is not set or readonly"
        if key not in self._code_pieces:
            self._code_pieces[key] = {
                self._namespace: {
                    "max_count_code": code,
                    "codes": {
                        code: 1
                    }
                }
            }
            return

        if self._namespace in self._code_pieces[key] and code in self._code_pieces[key][self._namespace]["codes"]:
            # if the code defined in the current namespace increase the count and update the max_count_code
            codes = self._code_pieces[key][self._namespace]["codes"]
            codes[code] += 1
            max_count_code = self._code_pieces[key][self._namespace]["max_count_code"]
            if codes[code] > codes[max_count_code]:
                self._code_pieces[key][self._namespace]["max_count_code"] = code
            return

        # the code is defined in other _helpers
        if self.find_code_piece_in_parent(key, code):
            return

        # add code piece
        if self._namespace not in self._code_pieces[key]:
            self._code_pieces[key] = {
                self._namespace: {
                    "max_count_code": code,
                    "codes": {
                        code: 1
                    }
                }
            }
            return

        self._code_pieces[key][self._namespace]["codes"][code] = 1

    def set_current_namespace(self, namespace, write_mode):
        assert not write_mode or namespace not in self._written_namespaces
        self._namespace = namespace
        self._write_mode = write_mode
        if not namespace:
            # clean namespace
            return

        self._parent_namespaces.clear()
        pieces = namespace.split(os.sep)
        for i in range(1, len(pieces)-1):
            parent_namespace = os.sep.join(pieces[:-i])
            self._parent_namespaces.add(parent_namespace)

        if write_mode:
            self._written_namespaces.add(namespace)

    def find_code_piece_in_parent(self, key, code):
        """find code piece in frozen namespaces"""
        if key not in self._code_pieces:
            return
        for namespace, value in self._code_pieces[key].items():
            if namespace not in self._parent_namespaces:
                continue
            if value["max_count_code"] == code and value["codes"][code] > 1:
                return namespace

    def fetch_code_piece(self, key, code):
        assert not self._write_mode, "Fetch code piece is not supported in write mode"
        parent_namespace = self.find_code_piece_in_parent(key, code)
        if parent_namespace:
            return parent_namespace
        if key in self._code_pieces and self._namespace in self._code_pieces[key] and \
                code == self._code_pieces[key][self._namespace]["max_count_code"] and \
                self._code_pieces[key][self._namespace]["codes"][code] > 1:
            return self._namespace

    def fetch_helper_code_piece(self, key):
        assert not self._write_mode, "Fetch code piece is not supported in write mode"
        if key not in self._code_pieces or self._namespace not in self._code_pieces[key]:
            return
        value = self._code_pieces[key][self._namespace]
        code = value["max_count_code"]
        if value["codes"][code] > 1:
            return code


_PY_HEADER = """# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
#
# Code generated by aaz-dev-tools
# --------------------------------------------------------------------------------------------
# pylint: skip-file
# flake8: noqa

"""


class MainModuleCompactor:

    _command_group_pattern = re.compile(r'^class\s+(.*)\(.*AAZCommandGroup.*\)\s*:\s*$')
    _command_pattern = re.compile(r'^class\s+(.*)\(.*AAZ(Wait)?Command.*\)\s*:\s*$')
    _command_helper_pattern = re.compile(r'^class\s+(_(.+)Helper)\s*:\s*$')
    _command_helper_func_pattern = re.compile(r'^\s{4}def (.+)\(cls.*\)\s*:\s*$')
    _file_end_pattern = re.compile(r'^__all__ = \[.*]\s*$')
    _class_method_register = "    @classmethod"

    def __init__(self, mod_name, compiled_to_pyc=True):
        self._mod_name = mod_name
        self._modules_dir = azure.cli.command_modules.__path__[0]
        self._folder = self._get_module_folder()
        self._compiled_to_pyc = compiled_to_pyc

    def compact(self):
        _LOGGER.info("Compacting {} folder:".format(self._get_aaz_folder()))
        self._create_compact_aaz_folder()
        from azure.cli.core.profiles import AZURE_API_PROFILES
        from azure.cli.core.aaz.utils import get_aaz_profile_module_name
        for profile in AZURE_API_PROFILES:
            _LOGGER.info("Compacting profile {}".format(profile))

            profile_mod_name = get_aaz_profile_module_name(profile)
            profile_path = self._get_aaz_rg_path(profile_mod_name)
            if not os.path.exists(profile_path):
                continue

            # create profile folders
            compact_folder = self._get_compact_aaz_rg_path(profile_mod_name)
            os.mkdir(compact_folder)
            self._write_py_file(os.path.join(compact_folder, '__init__.py'), content="")

            ctx = CompactorCtx()
            self.compact_sub_resource_groups(ctx, profile_mod_name)

    def replace(self):
        aaz_folder = self._get_aaz_folder()
        compact_aaz_folder = self._get_compact_aaz_folder()
        _LOGGER.info("Removing {} folder".format(aaz_folder))
        shutil.rmtree(aaz_folder)
        _LOGGER.info("Move {} folder to {}".format(compact_aaz_folder, aaz_folder))
        shutil.move(compact_aaz_folder, aaz_folder)

    def _write_py_file(self, path, content):
        _LOGGER.debug("Writing python file {}".format(path))
        init_content = _PY_HEADER + content
        with open(path, 'w') as f:
            f.write(init_content)
        if self._compiled_to_pyc:
            py_compile.compile(path)

    def _create_compact_aaz_folder(self):
        folder = self._get_compact_aaz_folder()
        if os.path.exists(folder):
            shutil.rmtree(folder)
        os.mkdir(folder)
        self._write_py_file(os.path.join(folder, '__init__.py'), content="")

    def compact_sub_resource_groups(self, ctx: CompactorCtx, dirs):
        folder = self._get_aaz_rg_path(dirs)
        assert os.path.isdir(folder), f'Invalid folder path {folder}'
        for name in os.listdir(folder):
            sub_dir = os.path.join(dirs, name)
            self.compact_resource_group(ctx, sub_dir)

    def compact_resource_group(self, ctx: CompactorCtx, dirs):
        folder = self._get_aaz_rg_path(dirs)
        init_file = os.path.join(folder, '__init__.py')
        if not os.path.exists(init_file):
            return
        if not os.path.isfile(init_file):
            raise ValueError("Invalid init file: {}".format(init_file))

        compact_folder = self._get_compact_aaz_rg_path(dirs)
        os.mkdir(compact_folder)
        self.compact_resource_group_commands(ctx, dirs)
        self.compact_sub_resource_groups(ctx, dirs)

    def compact_resource_group_commands(self, ctx: CompactorCtx, dirs):
        folder = self._get_aaz_rg_path(dirs)
        cmds_file = os.path.join(folder, '__cmds.py')
        if os.path.exists(cmds_file):
            raise ValueError("Module is already compacted: {}".format(cmds_file))

        ctx.set_current_namespace(dirs, write_mode=True)

        compact_folder = self._get_compact_aaz_rg_path(dirs)

        _LOGGER.debug("Parsing folder {}".format(folder))
        cmds_content, grp_cls = self._parse_cmd_group_file(folder, None)

        cmds_content, cmd_clses, cmd_cls_helpers = self._parse_cmd_files(folder, cmds_content)

        if cmd_cls_helpers:
            link_helper_codes = {}
            for helper in cmd_cls_helpers.values():
                helper_codes = helper["codes"]
                for k, v in helper_codes.items():
                    ctx.add_code_piece(k, v)

            helper_content = ""

            ctx.set_current_namespace(dirs, write_mode=False)
            for name, helper in cmd_cls_helpers.items():
                helper_codes = helper["codes"]
                helper_properties = helper["properties"]
                helper_cls_links = []
                helper_cls_content = f'class {name}:\n{helper_properties}'
                for k, v in helper_codes.items():
                    link_dirs = ctx.fetch_code_piece(k, v)
                    if not link_dirs:
                        helper_cls_content += '\n'
                        helper_cls_content += v
                        continue
                    assert dirs.startswith(link_dirs)
                    if link_dirs == dirs:
                        helper_cls_links.append(f'    ("{k}", _Helper),')
                        linker_code = ctx.fetch_helper_code_piece(k)
                        assert linker_code
                        link_helper_codes[k] = linker_code
                    else:
                        relative_path = '.' * (len(dirs.split(os.sep))-len(link_dirs.split(os.sep))) + '.__cmds'
                        helper_cls_links.append(f'    ("{k}", "{relative_path}"),')
                if helper_cls_links:
                    helper_cls_content = '\n'.join(['@link_helper(','    __package__,', *helper_cls_links, ')']) + '\n' + helper_cls_content
                helper_content += '\n\n' + helper_cls_content

            if link_helper_codes:
                cmds_content += f'\n\nclass _Helper:\n'
                for v in link_helper_codes.values():
                    cmds_content += "\n"
                    assert v.startswith(self._class_method_register)
                    cmds_content += f"    @staticmethod" + v[len(self._class_method_register):]

            cmds_content += helper_content

        init_content = "from .__cmds import *\n" if cmds_content else ""
        self._write_py_file(os.path.join(compact_folder, '__init__.py'), content=init_content)

        if cmds_content:
            if cmd_clses:
                all_clses = [f'"{cmd_cls}"' for cmd_cls in cmd_clses]
            else:
                all_clses = []

            if grp_cls:
                all_clses.append(f'"{grp_cls}"')

            if all_clses:
                cmds_content += ''.join(['\n', '\n', f'__all__ = [{",".join(all_clses)}]\n'])

            self._write_py_file(os.path.join(compact_folder, '__cmds.py'), content=cmds_content)

        ctx.set_current_namespace(None, write_mode=False)

    def _parse_cmd_group_file(self, folder, cmds_content):
        cmd_group_file = os.path.join(folder, '__cmd_group.py')
        if not os.path.isfile(cmd_group_file):
            return None, None

        _LOGGER.debug("Parsing command group file {}".format(cmd_group_file))

        grp_cls = None

        if not cmds_content:
            cmds_content = "from azure.cli.core.aaz import *\n"

        cg_lines = []
        with open(cmd_group_file, 'r') as f:
            while f.readable():
                line = f.readline()
                if not cg_lines and line.startswith('@register_command_group('):
                    cg_lines.append(line)
                    continue
                if not grp_cls:
                    match = self._command_group_pattern.match(line)
                    if match:
                        grp_cls = match[1]
                        cg_lines.append(line)
                        continue
                if not cg_lines:
                    continue

                if self._file_end_pattern.match(line):
                    break
                cg_lines.append(line)
        while cg_lines and not cg_lines[-1].strip():
            cg_lines.pop()

        cmds_content += ''.join(['\n', '\n', *cg_lines])

        _LOGGER.debug("Parsed CommandGroup Class: {}".format(grp_cls))

        return cmds_content, grp_cls

    def _parse_cmd_files(self, folder,  cmds_content):
        cmd_clses = []
        cmd_cls_helpers = {}
        for name in os.listdir(folder):
            if name.startswith('__') or not name.startswith('_') or not name.endswith('.py'):
                continue
            cmd_file = os.path.join(folder, name)
            if not os.path.isfile(cmd_file):
                continue
            cmd_cls, cmd_lines, cmd_helper_lines = self._parse_cmd_file(cmd_file)
            if not cmd_cls:
                continue
            cmd_clses.append(cmd_cls)
            cmds_content += ''.join(['\n', '\n', *cmd_lines])

            if cmd_helper_lines:
                helper_name, helper_properties, helper_codes = self._parse_cmd_helper_lines(cmd_helper_lines)
                if helper_properties or helper_codes:
                    cmd_cls_helpers[helper_name] = {
                        "properties": helper_properties,
                        "codes": helper_codes
                    }

        _LOGGER.debug("Parsed Command Classes: {}".format(cmd_clses))

        return cmds_content, cmd_clses, cmd_cls_helpers

    def _parse_cmd_file(self, cmd_file):
        _LOGGER.debug("Parsing command file {}".format(cmd_file))

        cmd_cls = None
        cmd_lines = []
        cmd_helper_lines = []
        with open(cmd_file, 'r') as f:
            # read the cmd_cls definition
            while f.readable():
                line = f.readline()
                if not cmd_lines and line.startswith('@register_command('):
                    cmd_lines.append(line)
                    continue
                if not cmd_cls:
                    match = self._command_pattern.match(line)
                    if match:
                        cmd_cls = match[1]
                        cmd_lines.append(line)
                        continue
                if not cmd_lines:
                    continue

                if self._file_end_pattern.match(line):
                    break
                match = self._command_helper_pattern.match(line)
                if match:
                    cmd_helper_lines.append(line)
                    break
                cmd_lines.append(line)

            # read the cmd_helper_cls definition
            while f.readable() and cmd_helper_lines:
                line = f.readline()
                if self._file_end_pattern.match(line):
                    break
                cmd_helper_lines.append(line)

        while cmd_lines and not cmd_lines[-1].strip():
            cmd_lines.pop()

        while cmd_helper_lines and not cmd_helper_lines[-1].strip():
            cmd_helper_lines.pop()

        return cmd_cls, cmd_lines, cmd_helper_lines

    def _parse_cmd_helper_lines(self, cmd_helper_lines):
        helper_name, helper_codes = None, {}
        idx = 0
        code_key = None
        code_lines = []
        properties_lines = []
        while idx < len(cmd_helper_lines):
            line = cmd_helper_lines[idx]
            if not helper_name:
                match = self._command_helper_pattern.match(line)
                if match:
                    helper_name = match[1]
                    continue
            if not helper_name:
                continue
            if line.startswith(self._class_method_register) or line.startswith("    _"):
                if code_key and code_lines:
                    while code_lines and not code_lines[-1].strip():
                        code_lines.pop()
                    helper_codes[code_key] = ''.join(code_lines)
                code_lines = []
                code_key = None
                if line.startswith("    _"):
                    while not line.startswith(self._class_method_register) and idx < len(cmd_helper_lines):
                        properties_lines.append(line)
                        idx += 1
                        line = cmd_helper_lines[idx]

            if line.startswith(self._class_method_register):
                code_lines.append(line)
                idx += 1
                line = cmd_helper_lines[idx]
                match = self._command_helper_func_pattern.match(line)
                if match:
                    code_key = match[1]

            if code_key:
                code_lines.append(line)

            idx += 1

        if code_key and code_lines:
            while code_lines and not code_lines[-1].strip():
                code_lines.pop()
            helper_codes[code_key] = ''.join(code_lines)

        properties_code = ''.join(properties_lines)

        return helper_name, properties_code, helper_codes

    def _get_module_folder(self):
        module_folder = os.path.join(self._modules_dir, self._mod_name.replace('-', '_').lower())
        if not os.path.exists(module_folder):
            raise ValueError("Module folder is not exist: {}".format(module_folder))
        return module_folder

    def _get_aaz_folder(self):
        return os.path.join(self._folder, 'aaz')

    def _get_aaz_rg_path(self, dirs):
        return os.path.join(self._get_aaz_folder(), dirs)

    def _get_compact_aaz_folder(self):
        return os.path.join(self._folder, 'aaz_compact')

    def _get_compact_aaz_rg_path(self, dirs):
        return os.path.join(self._get_compact_aaz_folder(), dirs)


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    for module in ["network", "vm"]:
        compactor = MainModuleCompactor(module)
        compactor.compact()
        compactor.replace()
