# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import argparse
import logging
import os
import pathlib
import shutil
import tempfile

from util import logger, run

_log = logger.get_logger("fix_long_lines", logging.INFO)


# look for long lines in the file, and if found run clang-format on those lines
def _process_files(filenames, clang_exe, tmpdir):
    for path in filenames:
        _log.debug(f"Checking {path}")
        bad_lines = []

        with open(path, encoding="UTF8") as f:
            for i, line in enumerate(f):
                line_num = i + 1  # clang-format line numbers start at 1
                if len(line) > 120:
                    bad_lines.append(line_num)

        if bad_lines:
            _log.info(f"Updating {path}")
            filename = os.path.basename(path)
            target = os.path.join(tmpdir, filename)
            shutil.copy(path, target)

            # run clang-format to update just the long lines in the file
            cmd = [
                clang_exe,
                "-i",
            ]
            for line in bad_lines:
                cmd.append(f"--lines={line}:{line}")

            cmd.append(target)

            run(*cmd, cwd=tmpdir, check=True, shell=True)

            # copy updated file back to original location
            shutil.copy(target, path)


# file extensions we process
_EXTENSIONS = [".cc", ".h"]


def _get_branch_diffs(ort_root, branch):
    command = ["git", "diff", branch, "--name-only"]
    result = run(*command, capture_stdout=True, check=True)

    # stdout is bytes. one filename per line. decode, split, and filter to the extensions we are looking at
    for f in result.stdout.decode("utf-8").splitlines():
        if os.path.splitext(f.lower())[-1] in _EXTENSIONS:
            yield os.path.join(ort_root, f)


def _get_file_list(path):
    for root, _, files in os.walk(path):
        for file in files:
            if os.path.splitext(file.lower())[-1] in _EXTENSIONS:
                yield os.path.join(root, file)


def main():
    argparser = argparse.ArgumentParser(
        "Script to fix long lines in the source using clang-format. "
        "Only lines that exceed the 120 character maximum are altered in order to minimize the impact. "
        "Checks .cc and .h files",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    argparser.add_argument(
        "--branch",
        type=str,
        default="origin/main",
        help="Limit changes to files that differ from this branch. Use origin/main when preparing a PR.",
    )

    argparser.add_argument(
        "--all_files",
        action="store_true",
        help="Process all files under /include/onnxruntime and /onnxruntime/core. Ignores --branch value.",
    )

    argparser.add_argument(
        "--clang-format",
        type=pathlib.Path,
        required=False,
        default="clang-format",
        help="Path to clang-format executable",
    )

    argparser.add_argument("--debug", action="store_true", help="Set log level to DEBUG.")

    args = argparser.parse_args()

    if args.debug:
        _log.setLevel(logging.DEBUG)

    script_dir = os.path.dirname(os.path.realpath(__file__))
    ort_root = os.path.abspath(os.path.join(script_dir, "..", ".."))

    with tempfile.TemporaryDirectory() as tmpdir:
        # create config in tmpdir
        with open(os.path.join(tmpdir, ".clang-format"), "w") as f:
            f.write(
                """
            BasedOnStyle: Google
            ColumnLimit: 120
            DerivePointerAlignment: false
            """
            )

        clang_format = str(args.clang_format)

        if args.all_files:
            include_path = os.path.join(ort_root, "include", "onnxruntime")
            src_path = os.path.join(ort_root, "onnxruntime", "core")
            _process_files(_get_file_list(include_path), clang_format, tmpdir)
            _process_files(_get_file_list(src_path), clang_format, tmpdir)
        else:
            _process_files(_get_branch_diffs(ort_root, args.branch), clang_format, tmpdir)


if __name__ == "__main__":
    main()
