import argparse
import os
import pathlib
import subprocess
import sys
from collections import Counter
from multiprocessing import Pool, cpu_count
from typing import Dict, List, Tuple

from test_utils import PY_PACKAGE, ROOT, cd, print_time, record_time


class LintersPaths:
    """The paths each linter run on."""

    BLACK = (
        # core
        "python-package/",
        # tests
        "tests/python/test_config.py",
        "tests/python/test_callback.py",
        "tests/python/test_collective.py",
        "tests/python/test_data_iterator.py",
        "tests/python/test_dmatrix.py",
        "tests/python/test_demos.py",
        "tests/python/test_eval_metrics.py",
        "tests/python/test_early_stopping.py",
        "tests/python/test_multi_target.py",
        "tests/python/test_objectives.py",
        "tests/python/test_predict.py",
        "tests/python/test_quantile_dmatrix.py",
        "tests/python/test_tracker.py",
        "tests/python/test_tree_regularization.py",
        "tests/python/test_training_continuation.py",
        "tests/python/test_shap.py",
        "tests/python/test_updaters.py",
        "tests/python/test_model_io.py",
        "tests/python/test_with_pandas.py",
        "tests/python-gpu/",
        "tests/python-sycl/",
        "tests/test_distributed/test_federated/",
        "tests/test_distributed/test_gpu_federated/",
        "tests/test_distributed/test_with_dask/",
        "tests/test_distributed/test_gpu_with_dask/",
        "tests/test_distributed/test_with_spark/",
        "tests/test_distributed/test_gpu_with_spark/",
        # demo
        "demo/dask/",
        "demo/rmm_plugin",
        "demo/guide-python/continuation.py",
        "demo/guide-python/cat_in_the_dat.py",
        "demo/guide-python/callbacks.py",
        "demo/guide-python/categorical.py",
        "demo/guide-python/cat_pipeline.py",
        "demo/guide-python/cross_validation.py",
        "demo/guide-python/feature_weights.py",
        "demo/guide-python/model_parser.py",
        "demo/guide-python/sklearn_parallel.py",
        "demo/guide-python/sklearn_examples.py",
        "demo/guide-python/sklearn_evals_result.py",
        "demo/guide-python/spark_estimator_examples.py",
        "demo/guide-python/external_memory.py",
        "demo/guide-python/distributed_extmem_basic.py",
        "demo/guide-python/individual_trees.py",
        "demo/guide-python/quantile_regression.py",
        "demo/guide-python/multioutput_regression.py",
        "demo/guide-python/learning_to_rank.py",
        "demo/guide-python/quantile_data_iterator.py",
        "demo/guide-python/update_process.py",
        "demo/aft_survival/aft_survival_viz_demo.py",
        # CI
        "ops/",
    )

    ISORT = (
        # core
        "python-package/",
        # tests
        "tests/test_distributed/",
        "tests/python/",
        "tests/python-gpu/",
        # demo
        "demo/",
        # misc
        "dev/",
        "doc/",
        # CI
        "ops/",
    )

    MYPY = (
        # core
        "python-package/",
        # tests
        "tests/python/test_collective.py",
        "tests/python/test_demos.py",
        "tests/python/test_data_iterator.py",
        "tests/python/test_multi_target.py",
        "tests/python/test_objectives.py",
        "tests/python-gpu/test_gpu_data_iterator.py",
        "tests/python-gpu/load_pickle.py",
        "tests/python-gpu/test_gpu_training_continuation.py",
        "tests/python/test_model_io.py",
        "tests/test_distributed/test_federated/",
        "tests/test_distributed/test_gpu_federated/",
        "tests/test_distributed/test_with_dask/test_ranking.py",
        "tests/test_distributed/test_with_dask/test_external_memory.py",
        "tests/test_distributed/test_with_spark/test_data.py",
        "tests/test_distributed/test_gpu_with_spark/test_data.py",
        "tests/test_distributed/test_gpu_with_dask/",
        # demo
        "demo/dask/",
        "demo/guide-python/external_memory.py",
        "demo/guide-python/distributed_extmem_basic.py",
        "demo/guide-python/sklearn_examples.py",
        "demo/guide-python/continuation.py",
        "demo/guide-python/callbacks.py",
        "demo/guide-python/cat_in_the_dat.py",
        "demo/guide-python/categorical.py",
        "demo/guide-python/cat_pipeline.py",
        "demo/guide-python/feature_weights.py",
        "demo/guide-python/model_parser.py",
        "demo/guide-python/individual_trees.py",
        "demo/guide-python/quantile_regression.py",
        "demo/guide-python/quantile_data_iterator.py",
        "demo/guide-python/multioutput_regression.py",
        "demo/guide-python/learning_to_rank.py",
        "demo/aft_survival/aft_survival_viz_demo.py",
        # CI
        "ops/",
    )


def check_cmd_print_failure_assistance(cmd: List[str]) -> bool:
    if subprocess.run(cmd).returncode == 0:
        return True

    subprocess.run([cmd[0], "--version"])
    msg = """
Please run the following command on your machine to address the error:

    """
    msg += " ".join(cmd)
    print(msg, file=sys.stderr)
    return False


@record_time
@cd(PY_PACKAGE)
def run_black(rel_path: str, fix: bool) -> bool:
    cmd = ["black", "-q", os.path.join(ROOT, rel_path)]
    if not fix:
        cmd += ["--check"]

    return check_cmd_print_failure_assistance(cmd)


@record_time
@cd(PY_PACKAGE)
def run_isort(rel_path: str, fix: bool) -> bool:
    # Isort gets confused when trying to find the config file, so specified explicitly.
    cmd = [
        "isort",
        "--settings-path",
        PY_PACKAGE,
        f"--src={PY_PACKAGE}",
        os.path.join(ROOT, rel_path),
    ]
    if not fix:
        cmd += ["--check"]

    return check_cmd_print_failure_assistance(cmd)


@record_time
@cd(PY_PACKAGE)
def run_mypy(rel_path: str) -> bool:
    cmd = ["mypy", os.path.join(ROOT, rel_path)]

    return check_cmd_print_failure_assistance(cmd)


class PyLint:
    """A helper for running pylint, mostly copied from dmlc-core/scripts."""

    MESSAGE_CATEGORIES = {
        "Fatal",
        "Error",
        "Warning",
        "Convention",
        "Refactor",
        "Information",
    }
    MESSAGE_PREFIX_TO_CATEGORY = {
        category[0]: category for category in MESSAGE_CATEGORIES
    }

    @classmethod
    @cd(PY_PACKAGE)
    def get_summary(cls, path: str) -> Tuple[str, Dict[str, int], str, str, bool]:
        """Get the summary of pylint's errors, warnings, etc."""
        ret = subprocess.run(["pylint", path], capture_output=True)
        stdout = ret.stdout.decode("utf-8")

        emap: Dict[str, int] = Counter()
        for line in stdout.splitlines():
            if ":" in line and (
                category := cls.MESSAGE_PREFIX_TO_CATEGORY.get(
                    line.split(":")[-2].strip()[0]
                )
            ):
                emap[category] += 1

        return path, emap, stdout, ret.stderr.decode("utf-8"), ret.returncode == 0

    @staticmethod
    def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int:
        """Print summary of certain result map."""
        if len(result_map) == 0:
            return 0

        ftype = "Python"
        nfail = sum(map(bool, result_map.values()))
        print(
            f"====={len(result_map) - nfail}/{len(result_map)} {ftype} files passed check====="
        )
        for fname, emap in result_map.items():
            if emap:
                print(
                    f"{fname}: {sum(emap.values())} Errors of {len(emap)} Categories map={emap}"
                )
        return nfail

    @classmethod
    def run(cls) -> bool:
        """Run pylint with parallelization on a batch of paths."""
        all_errors: Dict[str, Dict[str, int]] = {}

        with Pool(cpu_count()) as pool:
            error_maps = pool.map(
                cls.get_summary,
                (os.fspath(file) for file in pathlib.Path(PY_PACKAGE).glob("**/*.py")),
            )
            for path, emap, out, err, succeeded in error_maps:
                all_errors[path] = emap
                if succeeded:
                    continue

                print(out)
                if len(err) != 0:
                    print(err)

        nerr = cls.print_summary_map(all_errors)
        return nerr == 0


@record_time
def run_pylint() -> bool:
    return PyLint.run()


@record_time
def main(args: argparse.Namespace) -> None:
    if args.format == 1:
        black_results = [run_black(path, args.fix) for path in LintersPaths.BLACK]
        if not all(black_results):
            sys.exit(-1)

        isort_results = [run_isort(path, args.fix) for path in LintersPaths.ISORT]
        if not all(isort_results):
            sys.exit(-1)

    if args.type_check == 1:
        mypy_results = [run_mypy(path) for path in LintersPaths.MYPY]
        if not all(mypy_results):
            sys.exit(-1)

    if args.pylint == 1:
        if not run_pylint():
            sys.exit(-1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description=(
            "Run static checkers for XGBoost, see `python_lint.yml' "
            "conda env file for a list of dependencies."
        )
    )
    parser.add_argument("--format", type=int, choices=[0, 1], default=1)
    parser.add_argument("--type-check", type=int, choices=[0, 1], default=1)
    parser.add_argument("--pylint", type=int, choices=[0, 1], default=1)
    parser.add_argument(
        "--fix",
        action="store_true",
        help="Fix the formatting issues instead of emitting an error.",
    )
    args = parser.parse_args()
    try:
        main(args)
    finally:
        print_time()
