1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
import os
import re
import sys
from typing import List
__all__ = [
"check_code_for_cuda_kernel_launches",
"check_cuda_kernel_launches",
]
# FILES TO EXCLUDE (match is done with suffix using `endswith`)
# You wouldn't drive without a seatbelt, though, so why would you
# launch a kernel without some safety? Use this as a quick workaround
# for a problem with the checker, fix the checker, then de-exclude
# the files in question.
exclude_files: List[str] = []
# Without using a C++ AST we can't 100% detect kernel launches, so we
# model them as having the pattern "<<<parameters>>>(arguments);"
# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
# the next statement.
#
# We model the next statement as ending at the next `}` or `;`.
# If we see `}` then a clause ended (bad) if we see a semi-colon then
# we expect the launch check just before it.
#
# Since the kernel launch can include lambda statements, it's important
# to find the correct end-paren of the kernel launch. Doing this with
# pure regex requires recursive regex, which aren't part of the Python
# standard library. To avoid an additional dependency, we build a prefix
# regex that finds the start of a kernel launch, use a paren-matching
# algorithm to find the end of the launch, and then another regex to
# determine if a launch check is present.
# Finds potential starts of kernel launches
kernel_launch_start = re.compile(
r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
)
# This pattern should start at the character after the final paren of the
# kernel launch. It returns a match if the launch check is not the next statement
has_check = re.compile(
r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
)
def find_matching_paren(s: str, startpos: int) -> int:
"""Given a string "prefix (unknown number of characters) suffix"
and the position of the first `(` returns the index of the character
1 past the `)`, accounting for paren nesting
"""
opening = 0
for i, c in enumerate(s[startpos:]):
if c == '(':
opening += 1
elif c == ')':
opening -= 1
if opening == 0:
return startpos + i + 1
raise IndexError("Closing parens not found!")
def should_exclude_file(filename) -> bool:
for exclude_suffix in exclude_files:
if filename.endswith(exclude_suffix):
return True
return False
def check_code_for_cuda_kernel_launches(code, filename=None):
"""Checks code for CUDA kernel launches without cuda error checks.
Args:
filename - Filename of file containing the code. Used only for display
purposes, so you can put anything here.
code - The code to check
Returns:
The number of unsafe kernel launches in the code
"""
if filename is None:
filename = "##Python Function Call##"
# We break the code apart and put it back together to add
# helpful line numberings for identifying problem areas
code = enumerate(code.split("\n")) # Split by line breaks
code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines
code = '\n'.join(code) # Put it back together
num_launches_without_checks = 0
for m in kernel_launch_start.finditer(code):
end_paren = find_matching_paren(code, m.end() - 1)
if has_check.match(code, end_paren):
num_launches_without_checks += 1
context = code[m.start():end_paren + 1]
print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
return num_launches_without_checks
def check_file(filename):
"""Checks a file for CUDA kernel launches without cuda error checks
Args:
filename - File to check
Returns:
The number of unsafe kernel launches in the file
"""
if not (filename.endswith(".cu") or filename.endswith(".cuh")):
return 0
if should_exclude_file(filename):
return 0
with open(filename, "r") as fo:
contents = fo.read()
unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
return unsafeCount
def check_cuda_kernel_launches():
"""Checks all pytorch code for CUDA kernel launches without cuda error checks
Returns:
The number of unsafe kernel launches in the codebase
"""
torch_dir = os.path.dirname(os.path.realpath(__file__))
torch_dir = os.path.dirname(torch_dir) # Go up to parent torch
torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2
kernels_without_checks = 0
files_without_checks = []
for root, dirnames, filenames in os.walk(torch_dir):
# `$BASE/build` and `$BASE/torch/include` are generated
# so we don't want to flag their contents
if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
# Curtail search by modifying dirnames and filenames in place
# Yes, this is the way to do this, see `help(os.walk)`
dirnames[:] = []
continue
for x in filenames:
filename = os.path.join(root, x)
file_result = check_file(filename)
if file_result > 0:
kernels_without_checks += file_result
files_without_checks.append(filename)
if kernels_without_checks > 0:
count_str = f"Found {kernels_without_checks} instances in " \
f"{len(files_without_checks)} files where kernel " \
"launches didn't have checks."
print(count_str, file=sys.stderr)
print("Files without checks:", file=sys.stderr)
for x in files_without_checks:
print(f"\t{x}", file=sys.stderr)
print(count_str, file=sys.stderr)
return kernels_without_checks
if __name__ == "__main__":
unsafe_launches = check_cuda_kernel_launches()
sys.exit(0 if unsafe_launches == 0 else 1)
|