1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import logging
import os
import pathlib
import shutil
import tempfile
from util import logger, run
_log = logger.get_logger("fix_long_lines", logging.INFO)
# look for long lines in the file, and if found run clang-format on those lines
def _process_files(filenames, clang_exe, tmpdir):
for path in filenames:
_log.debug(f"Checking {path}")
bad_lines = []
with open(path, encoding="UTF8") as f:
for i, line in enumerate(f):
line_num = i + 1 # clang-format line numbers start at 1
if len(line) > 120:
bad_lines.append(line_num)
if bad_lines:
_log.info(f"Updating {path}")
filename = os.path.basename(path)
target = os.path.join(tmpdir, filename)
shutil.copy(path, target)
# run clang-format to update just the long lines in the file
cmd = [
clang_exe,
"-i",
]
for line in bad_lines:
cmd.append(f"--lines={line}:{line}")
cmd.append(target)
run(*cmd, cwd=tmpdir, check=True, shell=True)
# copy updated file back to original location
shutil.copy(target, path)
# file extensions we process
_EXTENSIONS = [".cc", ".h"]
def _get_branch_diffs(ort_root, branch):
command = ["git", "diff", branch, "--name-only"]
result = run(*command, capture_stdout=True, check=True)
# stdout is bytes. one filename per line. decode, split, and filter to the extensions we are looking at
for f in result.stdout.decode("utf-8").splitlines():
if os.path.splitext(f.lower())[-1] in _EXTENSIONS:
yield os.path.join(ort_root, f)
def _get_file_list(path):
for root, _, files in os.walk(path):
for file in files:
if os.path.splitext(file.lower())[-1] in _EXTENSIONS:
yield os.path.join(root, file)
def main():
argparser = argparse.ArgumentParser(
"Script to fix long lines in the source using clang-format. "
"Only lines that exceed the 120 character maximum are altered in order to minimize the impact. "
"Checks .cc and .h files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
argparser.add_argument(
"--branch",
type=str,
default="origin/main",
help="Limit changes to files that differ from this branch. Use origin/main when preparing a PR.",
)
argparser.add_argument(
"--all_files",
action="store_true",
help="Process all files under /include/onnxruntime and /onnxruntime/core. Ignores --branch value.",
)
argparser.add_argument(
"--clang-format",
type=pathlib.Path,
required=False,
default="clang-format",
help="Path to clang-format executable",
)
argparser.add_argument("--debug", action="store_true", help="Set log level to DEBUG.")
args = argparser.parse_args()
if args.debug:
_log.setLevel(logging.DEBUG)
script_dir = os.path.dirname(os.path.realpath(__file__))
ort_root = os.path.abspath(os.path.join(script_dir, "..", ".."))
with tempfile.TemporaryDirectory() as tmpdir:
# create config in tmpdir
with open(os.path.join(tmpdir, ".clang-format"), "w") as f:
f.write(
"""
BasedOnStyle: Google
ColumnLimit: 120
DerivePointerAlignment: false
"""
)
clang_format = str(args.clang_format)
if args.all_files:
include_path = os.path.join(ort_root, "include", "onnxruntime")
src_path = os.path.join(ort_root, "onnxruntime", "core")
_process_files(_get_file_list(include_path), clang_format, tmpdir)
_process_files(_get_file_list(src_path), clang_format, tmpdir)
else:
_process_files(_get_branch_diffs(ort_root, args.branch), clang_format, tmpdir)
if __name__ == "__main__":
main()
|