1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
|
import os
import pathlib
from typing import Any, Dict, List, Set, Tuple, Union
def materialize_lines(lines: List[str], indentation: int) -> str:
output = ""
new_line_with_indent = "\n" + " " * indentation
for i, line in enumerate(lines):
if i != 0:
output += new_line_with_indent
output += line.replace('\n', new_line_with_indent)
return output
def gen_from_template(dir: str, template_name: str, output_name: str, replacements: List[Tuple[str, Any, int]]):
template_path = os.path.join(dir, template_name)
output_path = os.path.join(dir, output_name)
with open(template_path, "r") as f:
content = f.read()
for placeholder, lines, indentation in replacements:
with open(output_path, "w") as f:
content = content.replace(placeholder, materialize_lines(lines, indentation))
f.write(content)
def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]:
"""
When given a path to a directory, returns the paths to the relevant files within it.
This function does NOT recursive traverse to subdirectories.
"""
paths: Set[str] = set()
for dir_path in dir_paths:
all_files = os.listdir(dir_path)
python_files = {fname for fname in all_files if ".py" == fname[-3:]}
filter_files = {fname for fname in python_files if fname not in files_to_exclude}
paths.update({os.path.join(dir_path, fname) for fname in filter_files})
return paths
def extract_method_name(line: str) -> str:
"""
Extracts method name from decorator in the form of "@functional_datapipe({method_name})"
"""
if "(\"" in line:
start_token, end_token = "(\"", "\")"
elif "(\'" in line:
start_token, end_token = "(\'", "\')"
else:
raise RuntimeError(f"Unable to find appropriate method name within line:\n{line}")
start, end = line.find(start_token) + len(start_token), line.find(end_token)
return line[start:end]
def extract_class_name(line: str) -> str:
"""
Extracts class name from class definition in the form of "class {CLASS_NAME}({Type}):"
"""
start_token = "class "
end_token = "("
start, end = line.find(start_token) + len(start_token), line.find(end_token)
return line[start:end]
def parse_datapipe_file(file_path: str) -> Tuple[Dict[str, str], Dict[str, str], Set[str]]:
"""
Given a path to file, parses the file and returns a dictionary of method names to function signatures.
"""
method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
with open(file_path) as f:
open_paren_count = 0
method_name, class_name, signature = "", "", ""
skip = False
for line in f.readlines():
if line.count("\"\"\"") % 2 == 1:
skip = not skip
if skip or "\"\"\"" in line: # Skipping comment/example blocks
continue
if "@functional_datapipe" in line:
method_name = extract_method_name(line)
continue
if method_name and "class " in line:
class_name = extract_class_name(line)
continue
if method_name and ("def __init__(" in line or "def __new__(" in line):
if "def __new__(" in line:
special_output_type.add(method_name)
open_paren_count += 1
start = line.find("(") + len("(")
line = line[start:]
if open_paren_count > 0:
open_paren_count += line.count('(')
open_paren_count -= line.count(')')
if open_paren_count == 0:
end = line.rfind(')')
signature += line[:end]
method_to_signature[method_name] = process_signature(signature)
method_to_class_name[method_name] = class_name
method_name, class_name, signature = "", "", ""
elif open_paren_count < 0:
raise RuntimeError("open parenthesis count < 0. This shouldn't be possible.")
else:
signature += line.strip('\n').strip(' ')
return method_to_signature, method_to_class_name, special_output_type
def parse_datapipe_files(file_paths: Set[str]) -> Tuple[Dict[str, str], Dict[str, str], Set[str]]:
methods_and_signatures, methods_and_class_names, methods_with_special_output_types = {}, {}, set()
for path in file_paths:
method_to_signature, method_to_class_name, methods_needing_special_output_types = parse_datapipe_file(path)
methods_and_signatures.update(method_to_signature)
methods_and_class_names.update(method_to_class_name)
methods_with_special_output_types.update(methods_needing_special_output_types)
return methods_and_signatures, methods_and_class_names, methods_with_special_output_types
def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]:
"""
Given a line of text, split it on comma unless the comma is within a bracket '[]'.
"""
bracket_count = 0
curr_token = ""
res = []
for char in line:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
elif char == delimiter and bracket_count == 0:
res.append(curr_token)
curr_token = ""
continue
curr_token += char
res.append(curr_token)
return res
def process_signature(line: str) -> str:
"""
Given a raw function signature, clean it up by removing the self-referential datapipe argument,
default arguments of input functions, newlines, and spaces.
"""
tokens: List[str] = split_outside_bracket(line)
for i, token in enumerate(tokens):
tokens[i] = token.strip(' ')
if token == "cls":
tokens[i] = "self"
elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"):
# Remove the datapipe after 'self' or 'cls' unless it has '*'
tokens[i] = ""
elif "Callable =" in token: # Remove default argument if it is a function
head, default_arg = token.rsplit("=", 2)
tokens[i] = head.strip(' ') + "= ..."
tokens = [t for t in tokens if t != ""]
line = ', '.join(tokens)
return line
def get_method_definitions(file_path: Union[str, List[str]],
files_to_exclude: Set[str],
deprecated_files: Set[str],
default_output_type: str,
method_to_special_output_type: Dict[str, str],
root: str = "") -> List[str]:
"""
.pyi generation for functional DataPipes Process
# 1. Find files that we want to process (exclude the ones who don't)
# 2. Parse method name and signature
# 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
"""
if root == "":
root = str(pathlib.Path(__file__).parent.resolve())
file_path = [file_path] if isinstance(file_path, str) else file_path
file_path = [os.path.join(root, path) for path in file_path]
file_paths = find_file_paths(file_path,
files_to_exclude=files_to_exclude.union(deprecated_files))
methods_and_signatures, methods_and_class_names, methods_w_special_output_types = \
parse_datapipe_files(file_paths)
for fn_name in method_to_special_output_type:
if fn_name not in methods_w_special_output_types:
methods_w_special_output_types.add(fn_name)
method_definitions = []
for method_name, arguments in methods_and_signatures.items():
class_name = methods_and_class_names[method_name]
if method_name in methods_w_special_output_types:
output_type = method_to_special_output_type[method_name]
else:
output_type = default_output_type
method_definitions.append(f"# Functional form of '{class_name}'\n"
f"def {method_name}({arguments}) -> {output_type}: ...")
method_definitions.sort(key=lambda s: s.split('\n')[1]) # sorting based on method_name
return method_definitions
# Defined outside of main() so they can be imported by TorchData
iterDP_file_path: str = "iter"
iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
iterDP_deprecated_files: Set[str] = set()
iterDP_method_to_special_output_type: Dict[str, str] = {"demux": "List[IterDataPipe]", "fork": "List[IterDataPipe]"}
mapDP_file_path: str = "map"
mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
mapDP_deprecated_files: Set[str] = set()
mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"}
def main() -> None:
"""
# Inject file into template datapipe.pyi.in
TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
"""
iter_method_definitions = get_method_definitions(iterDP_file_path, iterDP_files_to_exclude, iterDP_deprecated_files,
"IterDataPipe", iterDP_method_to_special_output_type)
map_method_definitions = get_method_definitions(mapDP_file_path, mapDP_files_to_exclude, mapDP_deprecated_files,
"MapDataPipe", mapDP_method_to_special_output_type)
path = pathlib.Path(__file__).parent.resolve()
replacements = [('${IterDataPipeMethods}', iter_method_definitions, 4),
('${MapDataPipeMethods}', map_method_definitions, 4)]
gen_from_template(dir=str(path),
template_name="datapipe.pyi.in",
output_name="datapipe.pyi",
replacements=replacements)
if __name__ == '__main__':
print("Generating Python interface file 'datapipe.pyi'...")
main()
|