1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
|
#!/usr/bin/env python3
# mypy: allow-untyped-defs
"""
model_dump: a one-stop shop for TorchScript model inspection.
The goal of this tool is to provide a simple way to extract lots of
useful information from a TorchScript model and make it easy for humans
to consume. It (mostly) replaces zipinfo, common uses of show_pickle,
and various ad-hoc analysis notebooks.
The tool extracts information from the model and serializes it as JSON.
That JSON can then be rendered by an HTML+JS page, either by
loading the JSON over HTTP or producing a fully self-contained page
with all of the code and data burned-in.
"""
# Maintainer notes follow.
"""
The implementation strategy has tension between 3 goals:
- Small file size.
- Fully self-contained.
- Easy, modern JS environment.
Using Preact and HTM achieves 1 and 2 with a decent result for 3.
However, the models I tested with result in ~1MB JSON output,
so even using something heavier like full React might be tolerable
if the build process can be worked out.
One principle I have followed that I think is very beneficial
is to keep the JSON data as close as possible to the model
and do most of the rendering logic on the client.
This makes for easier development (just refresh, usually),
allows for more laziness and dynamism, and lets us add more
views of the same data without bloating the HTML file.
Currently, this code doesn't actually load the model or even
depend on any part of PyTorch. I don't know if that's an important
feature to maintain, but it's probably worth preserving the ability
to run at least basic analysis on models that cannot be loaded.
I think the easiest way to develop this code is to cd into model_dump and
run "python -m http.server", then load http://localhost:8000/skeleton.html
in the browser. In another terminal, run
"python -m torch.utils.model_dump --style=json FILE > \
torch/utils/model_dump/model_info.json"
every time you update the Python code or model.
When you update JS, just refresh.
Possible improvements:
- Fix various TODO comments in this file and the JS.
- Make the HTML much less janky, especially the auxiliary data panel.
- Make the auxiliary data panel start small, expand when
data is available, and have a button to clear/contract.
- Clean up the JS. There's a lot of copypasta because
I don't really know how to use Preact.
- Make the HTML render and work nicely inside a Jupyter notebook.
- Add the ability for JS to choose the URL to load the JSON based
on the page URL (query or hash). That way we could publish the
inlined skeleton once and have it load various JSON blobs.
- Add a button to expand all expandable sections so ctrl-F works well.
- Add hyperlinking from data to code, and code to code.
- Add hyperlinking from debug info to Diffusion.
- Make small tensor contents available.
- Do something nice for quantized models
(they probably don't work at all right now).
"""
import argparse
import io
import json
import os
import pickle
import pprint
import re
import sys
import urllib.parse
import zipfile
from pathlib import Path
from typing import Dict
import warnings
import torch.utils.show_pickle
DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024
__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton',
'burn_in_info', 'get_info_and_burn_skeleton']
def get_storage_info(storage):
assert isinstance(storage, torch.utils.show_pickle.FakeObject)
assert storage.module == "pers"
assert storage.name == "obj"
assert storage.state is None
assert isinstance(storage.args, tuple)
assert len(storage.args) == 1
sa = storage.args[0]
assert isinstance(sa, tuple)
assert len(sa) == 5
assert sa[0] == "storage"
assert isinstance(sa[1], torch.utils.show_pickle.FakeClass)
assert sa[1].module == "torch"
assert sa[1].name.endswith("Storage")
storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:])
return storage_info
def hierarchical_pickle(data):
if isinstance(data, (bool, int, float, str, type(None))):
return data
if isinstance(data, list):
return [hierarchical_pickle(d) for d in data]
if isinstance(data, tuple):
return {
"__tuple_values__": hierarchical_pickle(list(data)),
}
if isinstance(data, dict):
return {
"__is_dict__": True,
"keys": hierarchical_pickle(list(data.keys())),
"values": hierarchical_pickle(list(data.values())),
}
if isinstance(data, torch.utils.show_pickle.FakeObject):
typename = f"{data.module}.{data.name}"
if (
typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
):
assert data.args == ()
return {
"__module_type__": typename,
"state": hierarchical_pickle(data.state),
}
if typename == "torch._utils._rebuild_tensor_v2":
assert data.state is None
storage, offset, size, stride, requires_grad, *_ = data.args
storage_info = get_storage_info(storage)
return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]}
if typename == "torch._utils._rebuild_qtensor":
assert data.state is None
storage, offset, size, stride, quantizer, requires_grad, *_ = data.args
storage_info = get_storage_info(storage)
assert isinstance(quantizer, tuple)
assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass)
assert quantizer[0].module == "torch"
if quantizer[0].name == "per_tensor_affine":
assert len(quantizer) == 3
assert isinstance(quantizer[1], float)
assert isinstance(quantizer[2], int)
quantizer_extra = list(quantizer[1:3])
else:
quantizer_extra = []
quantizer_json = [quantizer[0].name] + quantizer_extra
return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]}
if typename == "torch.jit._pickle.restore_type_tag":
assert data.state is None
obj, typ = data.args
assert isinstance(typ, str)
return hierarchical_pickle(obj)
if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename):
assert data.state is None
ls, = data.args
assert isinstance(ls, list)
return hierarchical_pickle(ls)
if typename == "torch.device":
assert data.state is None
name, = data.args
assert isinstance(name, str)
# Just forget that it was a device and return the name.
return name
if typename == "builtin.UnicodeDecodeError":
assert data.state is None
msg, = data.args
assert isinstance(msg, str)
# Hack: Pretend this is a module so we don't need custom serialization.
# Hack: Wrap the message in a tuple so it looks like a nice state object.
# TODO: Undo at least that second hack. We should support string states.
return {
"__module_type__": typename,
"state": hierarchical_pickle((msg,)),
}
raise Exception(f"Can't prepare fake object of type for JS: {typename}") # noqa: TRY002
raise Exception(f"Can't prepare data of type for JS: {type(data)}") # noqa: TRY002
def get_model_info(
path_or_file,
title=None,
extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT):
"""Get JSON-friendly information about a model.
The result is suitable for being saved as model_info.json,
or passed to burn_in_info.
"""
if isinstance(path_or_file, os.PathLike):
default_title = os.fspath(path_or_file)
file_size = path_or_file.stat().st_size # type: ignore[attr-defined]
elif isinstance(path_or_file, str):
default_title = path_or_file
file_size = Path(path_or_file).stat().st_size
else:
default_title = "buffer"
path_or_file.seek(0, io.SEEK_END)
file_size = path_or_file.tell()
path_or_file.seek(0)
title = title or default_title
with zipfile.ZipFile(path_or_file) as zf:
path_prefix = None
zip_files = []
for zi in zf.infolist():
prefix = re.sub("/.*", "", zi.filename)
if path_prefix is None:
path_prefix = prefix
elif prefix != path_prefix:
raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002
zip_files.append(dict(
filename=zi.filename,
compression=zi.compress_type,
compressed_size=zi.compress_size,
file_size=zi.file_size,
))
assert path_prefix is not None
version = zf.read(path_prefix + "/version").decode("utf-8").strip()
def get_pickle(name):
assert path_prefix is not None
with zf.open(path_prefix + f"/{name}.pkl") as handle:
raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
return hierarchical_pickle(raw)
model_data = get_pickle("data")
constants = get_pickle("constants")
# Intern strings that are likely to be re-used.
# Pickle automatically detects shared structure,
# so re-used strings are stored efficiently.
# However, JSON has no way of representing this,
# so we have to do it manually.
interned_strings : Dict[str, int] = {}
def ist(s):
if s not in interned_strings:
interned_strings[s] = len(interned_strings)
return interned_strings[s]
code_files = {}
for zi in zf.infolist():
if not zi.filename.endswith(".py"):
continue
with zf.open(zi) as handle:
raw_code = handle.read()
with zf.open(zi.filename + ".debug_pkl") as handle:
raw_debug = handle.read()
# Parse debug info and add begin/end markers if not present
# to ensure that we cover the entire source code.
debug_info_t = pickle.loads(raw_debug)
text_table = None
if (len(debug_info_t) == 3 and
isinstance(debug_info_t[0], str) and
debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
_, text_table, content = debug_info_t
def parse_new_format(line):
# (0, (('', '', 0), 0, 0))
num, ((text_indexes, fname_idx, offset), start, end), tag = line
text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index]
fname = text_table[fname_idx] # type: ignore[index]
return num, ((text, fname, offset), start, end), tag
debug_info_t = map(parse_new_format, content)
debug_info = list(debug_info_t)
if not debug_info:
debug_info.append((0, (('', '', 0), 0, 0)))
if debug_info[-1][0] != len(raw_code):
debug_info.append((len(raw_code), (('', '', 0), 0, 0)))
code_parts = []
for di, di_next in zip(debug_info, debug_info[1:]):
start, source_range, *_ = di
end = di_next[0]
assert end > start
source, s_start, s_end = source_range
s_text, s_file, s_line = source
# TODO: Handle this case better. TorchScript ranges are in bytes,
# but JS doesn't really handle byte strings.
# if bytes and chars are not equivalent for this string,
# zero out the ranges so we don't highlight the wrong thing.
if len(s_text) != len(s_text.encode("utf-8")):
s_start = 0
s_end = 0
text = raw_code[start:end]
code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end])
code_files[zi.filename] = code_parts
extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json")
extra_files_jsons = {}
for zi in zf.infolist():
if not extra_files_json_pattern.fullmatch(zi.filename):
continue
if zi.file_size > extra_file_size_limit:
continue
with zf.open(zi) as handle:
try:
json_content = json.load(handle)
extra_files_jsons[zi.filename] = json_content
except json.JSONDecodeError:
extra_files_jsons[zi.filename] = "INVALID JSON"
always_render_pickles = {
"bytecode.pkl",
}
extra_pickles = {}
for zi in zf.infolist():
if not zi.filename.endswith(".pkl"):
continue
with zf.open(zi) as handle:
# TODO: handle errors here and just ignore the file?
# NOTE: For a lot of these files (like bytecode),
# we could get away with just unpickling, but this should be safer.
obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
buf = io.StringIO()
pprint.pprint(obj, buf)
contents = buf.getvalue()
# Checked the rendered length instead of the file size
# because pickles with shared structure can explode in size during rendering.
if os.path.basename(zi.filename) not in always_render_pickles and \
len(contents) > extra_file_size_limit:
continue
extra_pickles[zi.filename] = contents
return {"model": dict(
title=title,
file_size=file_size,
version=version,
zip_files=zip_files,
interned_strings=list(interned_strings),
code_files=code_files,
model_data=model_data,
constants=constants,
extra_files_jsons=extra_files_jsons,
extra_pickles=extra_pickles,
)}
def get_inline_skeleton():
"""Get a fully-inlined skeleton of the frontend.
The returned HTML page has no external network dependencies for code.
It can load model_info.json over HTTP, or be passed to burn_in_info.
"""
import importlib.resources
skeleton = importlib.resources.read_text(__package__, "skeleton.html")
js_code = importlib.resources.read_text(__package__, "code.js")
for js_module in ["preact", "htm"]:
js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code)
return skeleton
def burn_in_info(skeleton, info):
"""Burn model info into the HTML skeleton.
The result will render the hard-coded model info and
have no external network dependencies for code or data.
"""
# Note that Python's json serializer does not escape slashes in strings.
# Since we're inlining this JSON directly into a script tag, a string
# containing "</script>" would end the script prematurely and
# mess up our page. Unconditionally escape fixes that.
return skeleton.replace(
"BURNED_IN_MODEL_INFO = null",
"BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/"))
def get_info_and_burn_skeleton(path_or_bytesio, **kwargs):
model_info = get_model_info(path_or_bytesio, **kwargs)
skeleton = get_inline_skeleton()
page = burn_in_info(skeleton, model_info)
return page
def main(argv, *, stdout=None):
warnings.warn("torch.utils.model_dump is deprecated and will be removed in a future PyTorch release.")
parser = argparse.ArgumentParser()
parser.add_argument("--style", choices=["json", "html"])
parser.add_argument("--title")
parser.add_argument("model")
args = parser.parse_args(argv[1:])
info = get_model_info(args.model, title=args.title)
output = stdout or sys.stdout
if args.style == "json":
output.write(json.dumps(info, sort_keys=True) + "\n")
elif args.style == "html":
skeleton = get_inline_skeleton()
page = burn_in_info(skeleton, info)
output.write(page)
else:
raise Exception("Invalid style") # noqa: TRY002
|