1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
|
import modulefinder
import os
import pathlib
import sys
import warnings
from typing import Any, Dict, List, Set
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
# These tests are slow enough that it's worth calculating whether the patch
# touched any related files first. This list was manually generated, but for every
# run with --determine-from, we use another generated list based on this one and the
# previous test stats.
TARGET_DET_LIST = [
# test_autograd.py is not slow, so it does not belong here. But
# note that if you try to add it back it will run into
# https://bugs.python.org/issue40350 because it imports files
# under test/autograd/.
"test_binary_ufuncs",
"test_cpp_extensions_aot_ninja",
"test_cpp_extensions_aot_no_ninja",
"test_cpp_extensions_jit",
"test_cpp_extensions_open_device_registration",
"test_cuda",
"test_cuda_primary_ctx",
"test_dataloader",
"test_determination",
"test_futures",
"test_jit",
"test_jit_legacy",
"test_jit_profiling",
"test_linalg",
"test_multiprocessing",
"test_nn",
"test_numpy_interop",
"test_optim",
"test_overrides",
"test_pruning_op",
"test_quantization",
"test_reductions",
"test_serialization",
"test_shape_ops",
"test_sort_and_select",
"test_tensorboard",
"test_testing",
"test_torch",
"test_utils",
"test_view_ops",
]
_DEP_MODULES_CACHE: Dict[str, Set[str]] = {}
def should_run_test(
target_det_list: List[str], test: str, touched_files: List[str], options: Any
) -> bool:
test = parse_test_module(test)
# Some tests are faster to execute than to determine.
if test not in target_det_list:
if options.verbose:
print_to_stderr(f"Running {test} without determination")
return True
# HACK: "no_ninja" is not a real module
if test.endswith("_no_ninja"):
test = test[: (-1 * len("_no_ninja"))]
if test.endswith("_ninja"):
test = test[: (-1 * len("_ninja"))]
dep_modules = get_dep_modules(test)
for touched_file in touched_files:
file_type = test_impact_of_file(touched_file)
if file_type == "NONE":
continue
elif file_type == "CI":
# Force all tests to run if any change is made to the CI
# configurations.
log_test_reason(file_type, touched_file, test, options)
return True
elif file_type == "UNKNOWN":
# Assume uncategorized source files can affect every test.
log_test_reason(file_type, touched_file, test, options)
return True
elif file_type in ["TORCH", "CAFFE2", "TEST"]:
parts = os.path.splitext(touched_file)[0].split(os.sep)
touched_module = ".".join(parts)
# test/ path does not have a "test." namespace
if touched_module.startswith("test."):
touched_module = touched_module.split("test.")[1]
if touched_module in dep_modules or touched_module == test.replace(
"/", "."
):
log_test_reason(file_type, touched_file, test, options)
return True
# If nothing has determined the test has run, don't run the test.
if options.verbose:
print_to_stderr(f"Determination is skipping {test}")
return False
def test_impact_of_file(filename: str) -> str:
"""Determine what class of impact this file has on test runs.
Possible values:
TORCH - torch python code
CAFFE2 - caffe2 python code
TEST - torch test code
UNKNOWN - may affect all tests
NONE - known to have no effect on test outcome
CI - CI configuration files
"""
parts = filename.split(os.sep)
if parts[0] in [".jenkins", ".circleci"]:
return "CI"
if parts[0] in ["docs", "scripts", "CODEOWNERS", "README.md"]:
return "NONE"
elif parts[0] == "torch":
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
return "TORCH"
elif parts[0] == "caffe2":
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
return "CAFFE2"
elif parts[0] == "test":
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
return "TEST"
return "UNKNOWN"
def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> None:
if options.verbose:
print_to_stderr(
"Determination found {} file {} -- running {}".format(
file_type,
filename,
test,
)
)
def get_dep_modules(test: str) -> Set[str]:
# Cache results in case of repetition
if test in _DEP_MODULES_CACHE:
return _DEP_MODULES_CACHE[test]
test_location = REPO_ROOT / "test" / f"{test}.py"
# HACK: some platforms default to ascii, so we can't just run_script :(
finder = modulefinder.ModuleFinder(
# Ideally exclude all third party modules, to speed up calculation.
excludes=[
"scipy",
"numpy",
"numba",
"multiprocessing",
"sklearn",
"setuptools",
"hypothesis",
"llvmlite",
"joblib",
"email",
"importlib",
"unittest",
"urllib",
"json",
"collections",
# Modules below are excluded because they are hitting https://bugs.python.org/issue40350
# Trigger AttributeError: 'NoneType' object has no attribute 'is_package'
"mpl_toolkits",
"google",
"onnx",
# Triggers RecursionError
"mypy",
],
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
finder.run_script(str(test_location))
dep_modules = set(finder.modules.keys())
_DEP_MODULES_CACHE[test] = dep_modules
return dep_modules
def parse_test_module(test: str) -> str:
return test.split(".")[0]
def print_to_stderr(message: str) -> None:
print(message, file=sys.stderr)
|