1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
|
import torch
from torch.autograd import functional
import time
from argparse import ArgumentParser
from collections import defaultdict
from typing import NamedTuple, Callable, List, Any
try:
import functorch as ft
has_functorch = True
print(f"Found functorch: {ft.__version__}")
except ImportError:
has_functorch = False
import ppl_models
import vision_models
import audio_text_models
from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType
def get_task_func(task: str) -> Callable:
def hessian_fwdrev(model, inp, strict=None):
return functional.hessian(model, inp, strict=False, vectorize=True, outer_jacobian_strategy="forward-mode")
def hessian_revrev(model, inp, strict=None):
return functional.hessian(model, inp, strict=False, vectorize=True)
def jacfwd(model, inp, strict=None):
return functional.jacobian(model, inp, strict=False, vectorize=True, strategy="forward-mode")
def jacrev(model, inp, strict=None):
return functional.jacobian(model, inp, strict=False, vectorize=True)
if task == "hessian_fwdrev":
return hessian_fwdrev
elif task == "hessian_revrev":
return hessian_revrev
elif task == "jacfwd":
return jacfwd
elif task == "jacrev":
return jacrev
else:
return getattr(functional, task)
def get_task_functorch(task: str) -> Callable:
@torch.no_grad()
def vjp(model, inp, v=None, strict=None):
assert v is not None
out, vjpfunc = ft.vjp(model, *inp)
return out, vjpfunc(v)
@torch.no_grad()
def jvp(model, inp, v=None, strict=None):
assert v is not None
return ft.jvp(model, inp, v)
@torch.no_grad()
def vhp(model, inp, v=None, strict=None):
assert v is not None
argnums = tuple(range(len(inp)))
_, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums), *inp, has_aux=True)
return aux, vjpfunc(v)
@torch.no_grad()
def hvp(model, inp, v=None, strict=None):
assert v is not None
argnums = tuple(range(len(inp)))
_, hvp_out, aux = ft.jvp(ft.grad_and_value(model, argnums), inp, v, has_aux=True)
return aux, hvp_out
@torch.no_grad()
def jacfwd(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacfwd(model, argnums)(*inp)
@torch.no_grad()
def jacrev(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacrev(model, argnums)(*inp)
@torch.no_grad()
def hessian(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.hessian(model, argnums=argnums)(*inp)
@torch.no_grad()
def hessian_fwdrev(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacfwd(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)
@torch.no_grad()
def hessian_revrev(model, inp, v=None, strict=None):
argnums = tuple(range(len(inp)))
return ft.jacrev(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)
if task in locals():
return locals()[task]
elif task == "jacobian":
raise RuntimeError("functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet")
else:
raise RuntimeError(f"Unsupported task: {task}")
# Listing of the different tasks
FAST_TASKS_NO_DOUBLE_BACK = [
"vjp",
]
FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
"vhp",
"jvp",
]
ALL_TASKS_NON_VECTORIZED = FAST_TASKS + [
"hvp",
"jacobian",
"hessian"
]
DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"]
ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
# Model definition which contains:
# - name: a string with the model name.
# - getter: a function to get the model. It takes as input the device on which the model
# will run. It should return the forward function and the parameters (Tensors) used as
# input for the forward function. Note that the forward must *not* have any side effect.
# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model.
# - unsupported: the list of tasks that this model cannot run.
class ModelDef(NamedTuple):
name: str
getter: GetterType
tasks: List[str]
unsupported: List[str]
MODELS = [
ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []),
ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []),
ModelDef("detr", vision_models.get_detr, FAST_TASKS, []),
ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []),
ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []),
ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []),
ModelDef("deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS),
ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []),
ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []),
]
def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
v: VType
if task in ["vjp"]:
out = model(*inp)
v = torch.rand_like(out)
elif task in ["jvp", "hvp", "vhp"]:
if isinstance(inp, tuple):
v = tuple(torch.rand_like(i) for i in inp)
else:
v = torch.rand_like(inp)
else:
v = None
return v
def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) -> None:
func = get_task_func(task)
if v is not None:
res = func(model, inp, v=v, strict=True)
else:
res = func(model, inp, strict=True)
def run_once_functorch(model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False) -> None:
func = get_task_functorch(task)
if v is not None:
res = func(model, inp, v=v, strict=True)
else:
res = func(model, inp, strict=True)
if maybe_check_consistency:
af_func = get_task_func(task)
if v is not None:
expected = af_func(model, inp, v=v, strict=True)
else:
expected = af_func(model, inp, strict=True)
atol = 1e-2 if task == "vhp" else 5e-3
torch.testing.assert_close(res, expected, rtol=1e-5, atol=atol, msg=f"Consistency fail for task '{task}'")
def run_model(model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once) -> List[float]:
if args.gpu == -1:
device = torch.device("cpu")
def noop():
pass
do_sync = noop
else:
device = torch.device("cuda:{}".format(args.gpu))
do_sync = torch.cuda.synchronize
model, inp = model_getter(device)
v = get_v_for(model, inp, task)
# Warmup
# maybe_check_consistency=True checks for consistency between
# functorch vs autograd.functional and is done in run_once_functorch only
run_once_fn(model, inp, task, v, maybe_check_consistency=True)
elapsed = []
for it in range(args.num_iters):
do_sync()
start = time.time()
run_once_fn(model, inp, task, v)
do_sync()
elapsed.append(time.time() - start)
return elapsed
def main():
parser = ArgumentParser("Main script to benchmark functional API of the autograd.")
parser.add_argument("--output", type=str, default="", help="Text file where to write the output")
parser.add_argument("--num-iters", type=int, default=10)
parser.add_argument("--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect")
parser.add_argument("--run-slow-tasks", action="store_true", help="Run even the slow tasks")
parser.add_argument("--model-filter", type=str, default="", help="Only run the models in this filter")
parser.add_argument("--task-filter", type=str, default="", help="Only run the tasks in this filter")
parser.add_argument("--num-threads", type=int, default=10,
help="Number of concurrent threads to use when running on cpu")
parser.add_argument("--seed", type=int, default=0, help="The random seed to use.")
args = parser.parse_args()
results: TimingResultType = defaultdict(defaultdict)
torch.set_num_threads(args.num_threads)
torch.set_num_interop_threads(args.num_threads)
# This automatically seed cuda if it is available
torch.manual_seed(args.seed)
if args.gpu == -2:
args.gpu = 0 if torch.cuda.is_available() else -1
for name, model_getter, recommended_tasks, unsupported_tasks in MODELS:
if args.model_filter and name not in args.model_filter:
continue
tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks
for task in tasks:
if task in unsupported_tasks:
continue
if args.task_filter and task not in args.task_filter:
continue
runtimes = run_model(model_getter, args, task)
runtimes = torch.tensor(runtimes)
mean, var = runtimes.mean(), runtimes.var()
results[name][task] = (mean.item(), var.item())
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))
if has_functorch:
try:
runtimes = run_model(model_getter, args, task, run_once_fn=run_once_functorch)
except RuntimeError as e:
print(f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t", e)
continue
runtimes = torch.tensor(runtimes)
mean, var = runtimes.mean(), runtimes.var()
results[name][f"functorch {task}"] = (mean.item(), var.item())
print("Results for model {} on task {} using Functorch: {}s (var: {})".format(name, task, mean, var))
if args.output:
with open(args.output, "w") as f:
f.write(to_markdown_table(results))
if __name__ == "__main__":
main()
|