from __future__ import annotations

from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable
from warnings import warn

from tools.testing.target_determination.heuristics.interface import (
    HeuristicInterface,
    TestPrioritizations,
)
from tools.testing.target_determination.heuristics.utils import (
    normalize_ratings,
    query_changed_files,
)
from tools.testing.test_run import TestRun


REPO_ROOT = Path(__file__).parent.parent.parent.parent

keyword_synonyms: dict[str, list[str]] = {
    "amp": ["mixed_precision"],
    "quant": ["quantized", "quantization", "quantize"],
    "decomp": ["decomposition", "decompositions"],
    "numpy": ["torch_np", "numpy_tests"],
    "ops": ["opinfo"],
    "hop": ["higher_order_op"],
    "aot": ["flex_attention", "autograd"],
    "inductor": ["dynamo", "export"],  # not actually synonyms but they interact a lot
}


custom_matchers: dict[str, Callable[[str], bool]] = {
    "nn": lambda x: "nn" in x.replace("onnx", "_"),
    "c10": lambda x: "c10" in x.replace("c10d", "_"),
}


def is_valid_keyword(keyword: str) -> bool:
    not_keyword = [
        "torch",
        "test",
        "tests",
        "util",
        "utils",
        "func",
        "src",
        "c",
        "ns",
        "tools",
        "internal",
    ]
    return keyword == "nn" or (keyword not in not_keyword and len(keyword) > 2)


@lru_cache(maxsize=1)
def get_keywords(file: str) -> list[str]:
    keywords = []
    for folder in Path(file).parts[:-1]:
        folder = sanitize_name(folder)
        keywords.append(folder)

    file_name = Path(file).stem.split("_")
    keywords.extend([sanitize_name(x) for x in file_name])
    return [kw for kw in keywords if is_valid_keyword(kw)]


def sanitize_name(folder_name: str) -> str:
    if folder_name.startswith("_"):
        folder_name = folder_name[1:]

    for syn_rep, syns in keyword_synonyms.items():
        if folder_name in syns or folder_name == syn_rep:
            return syn_rep

    return folder_name


def file_matches_keyword(file: str, keyword: str) -> bool:
    keywords = get_keywords(file)
    return (
        keyword in keywords
        or any(
            syn in keywords or syn in file for syn in keyword_synonyms.get(keyword, [])
        )
        or custom_matchers.get(keyword, lambda x: keyword in x)(file)  # type: ignore[no-untyped-call]
    )


def get_freq_dict(tests: list[str], changed_files: list[str]) -> dict[str, int]:
    keyword_frequency: dict[str, int] = defaultdict(int)
    for cf in changed_files:
        keywords = get_keywords(cf)
        for keyword in keywords:
            keyword_frequency[keyword] += 1

    test_ratings: dict[str, int] = defaultdict(int)

    for test in tests:
        for keyword, frequency in keyword_frequency.items():
            if file_matches_keyword(test, keyword):
                test_ratings[test] += frequency
    return test_ratings


class Filepath(HeuristicInterface):
    # Heuristic based on folders in the file path.  Takes each folder of each
    # changed file and attempts to find matches based on those folders
    def __init__(self, **kwargs: dict[str, Any]) -> None:
        super().__init__(**kwargs)

    def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
        try:
            changed_files = query_changed_files()
        except Exception as e:
            warn(f"Can't query changed test files due to {e}")
            changed_files = []

        test_ratings = get_freq_dict(tests, changed_files)
        test_ratings = {
            TestRun(k): float(v) for (k, v) in test_ratings.items() if k in tests
        }
        return TestPrioritizations(
            tests, normalize_ratings(test_ratings, 0.25, min_value=0.125)
        )


if __name__ == "__main__":
    # Quick thing so you can call the heuristic from the command line with a sha
    import os
    import sys

    from tools.testing.discover_tests import TESTS

    git_diff = f"git diff --name-only {sys.argv[1]} {sys.argv[1]}^"
    changed_files = os.popen(git_diff).read().split("\n")
    freq_dict = get_freq_dict(
        TESTS, [x for x in changed_files if x != "" and not x.startswith("test")]
    )
    for k, v in sorted(freq_dict.items(), key=lambda x: x[1], reverse=False):
        print(k, v)
    print(changed_files)
