1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
import argparse
import json
from os import path
import torch
# Import all utils so that getattr below can find them
all_submod_list = [
"",
"nn",
"nn.functional",
"nn.init",
"optim",
"autograd",
"cuda",
"sparse",
"distributions",
"fft",
"linalg",
"jit",
"distributed",
"futures",
"onnx",
"random",
"utils.bottleneck",
"utils.checkpoint",
"utils.data",
"utils.model_zoo",
]
def get_content(submod):
mod = torch
if submod:
submod = submod.split(".")
for name in submod:
mod = getattr(mod, name)
content = dir(mod)
return content
def namespace_filter(data):
out = {d for d in data if d[0] != "_"}
return out
def run(args, submod):
print(f"## Processing torch.{submod}")
prev_filename = f"prev_data_{submod}.json"
new_filename = f"new_data_{submod}.json"
if args.prev_version:
content = get_content(submod)
with open(prev_filename, "w") as f:
json.dump(content, f)
print("Data saved for previous version.")
elif args.new_version:
content = get_content(submod)
with open(new_filename, "w") as f:
json.dump(content, f)
print("Data saved for new version.")
else:
assert args.compare
if not path.exists(prev_filename):
raise RuntimeError("Previous version data not collected")
if not path.exists(new_filename):
raise RuntimeError("New version data not collected")
with open(prev_filename) as f:
prev_content = set(json.load(f))
with open(new_filename) as f:
new_content = set(json.load(f))
if not args.show_all:
prev_content = namespace_filter(prev_content)
new_content = namespace_filter(new_content)
if new_content == prev_content:
print("Nothing changed.")
print("")
else:
print("Things that were added:")
print(new_content - prev_content)
print("")
print("Things that were removed:")
print(prev_content - new_content)
print("")
def main():
parser = argparse.ArgumentParser(
description="Tool to check namespace content changes"
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--prev-version", action="store_true")
group.add_argument("--new-version", action="store_true")
group.add_argument("--compare", action="store_true")
group = parser.add_mutually_exclusive_group()
group.add_argument("--submod", default="", help="part of the submodule to check")
group.add_argument(
"--all-submod",
action="store_true",
help="collects data for all main submodules",
)
parser.add_argument(
"--show-all",
action="store_true",
help="show all the diff, not just public APIs",
)
args = parser.parse_args()
if args.all_submod:
submods = all_submod_list
else:
submods = [args.submod]
for mod in submods:
run(args, mod)
if __name__ == "__main__":
main()
|