1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
|
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import torch
from functorch import vmap
import torch.utils._pytree as pytree
from functorch_additional_op_db import additional_op_db
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db
import os
import unittest
from torch.testing._internal.common_device_type import toleranceOverride
from collections import namedtuple
IS_FBCODE = os.getenv('FUNCTORCH_TEST_FBCODE') == '1'
def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
outs = []
for idx in range(batch_size):
flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims, dims_spec = pytree.tree_flatten(in_dims)
assert(args_spec == dims_spec)
new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)]
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
outs.append(out)
loop_out = []
if isinstance(outs[0], torch.Tensor):
loop_out = torch.stack(outs)
else:
for idx in range(len(outs[0])):
loop_out.append(torch.stack([i[idx] for i in outs], out_dim))
return loop_out
# Like loop helper function but for 2 levels of vmap. If we need more levels than this, probably possible
# to generalize the loops function but it seemed too complicated for this
def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2, *batched_args, **kwarg_values):
outs = []
flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
assert(args_spec == dims_spec1)
assert(args_spec == dims_spec2)
assert(len(flat_dims1) == len(flat_dims2))
for idx1 in range(batch_size1):
out_split = []
arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)]
for idx2 in range(batch_size2):
new_args = [a.select(in_dim, idx2) if in_dim is not None else a for a, in_dim in zip(arg_split, flat_dims2)]
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
out_split.append(out)
outs.append(out_split)
loop_out = []
for out_split in outs:
if isinstance(out_split[0], torch.Tensor):
loop_out.append(torch.stack(out_split, out_dim1))
else:
new_out = []
for idx in range(len(out_split[0])):
new_out.append(torch.stack([i[idx] for i in out_split], out_dim1))
loop_out.append(new_out)
new_out = []
if isinstance(loop_out, torch.Tensor):
new_out = torch.stack(loop_out, out_dim2)
else:
for idx in range(len(loop_out[0])):
new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2))
return new_out
def is_valid_inplace_sample_input(sample_input, op, inplace_variant):
if inplace_variant is None:
return False
if sample_input.broadcasts_input:
return False
# Check if input's dtype matches the output's dtype
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
output_dtype = op(*args, **kwargs).dtype
return sample_input.input.dtype == output_dtype
# This is kind of dangerous, please think carefully before using it.
# Known risks:
# - the return better not be mutated so it's best to return immutable types
# (e.g. prefer tuples to list)
# - Don't hash tensors in a global context, that'll keep them around forever
def memoize(fn):
memo = {}
def wrapped(*args):
if args not in memo:
memo[args] = fn(*args)
return memo[args]
return wrapped
# NB: This is O(2 ** num_tensors).
# num_tensors ranges from 1 to 10, with 2-4 being most common.
# Try not to extravagate it if you're modifying it.
@memoize
def get_bdim_choices(num_tensors):
choices = []
# full of zeros
choices.append((0,) * num_tensors)
# All permutations of (-1, None)
options = (-1, None)
for choice in itertools.product(options, repeat=num_tensors):
choices.append(choice)
assert choices[-1] == (None,) * num_tensors
return tuple(choices[:-1])
# NB: This is O(2 ** num_tensors).
# num_tensors ranges from 1 to 10, with 2-4 being most common.
# Try not to extravagate it if you're modifying it.
def get_bdim_choices_batch_norm(num_tensors, _, running_mean=None, running_var=None, *args):
choices = []
options = (-1, None)
# instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
if running_mean is None or running_var is None:
choices.append((None,) + (0,) * (num_tensors - 1))
for choice in itertools.product(options, repeat=num_tensors - 1):
choices.append((None,) + choice)
else:
# running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
# running_mean/var are unbatched, so this tests all other cases
choices.append((0,) * num_tensors)
for choice in itertools.product(options, repeat=num_tensors):
input_bdim = choice[0]
running_mean_bdim = choice[1]
running_var_bdim = choice[2]
if input_bdim and (not running_mean_bdim or not running_var_bdim):
continue
choices.append(choice)
assert choices[-1] == (None,) * num_tensors
return tuple(choices[:-1])
def add_batch_dim(arg, bdim, batch_size=3):
assert bdim == 0 or bdim == -1
assert isinstance(arg, torch.Tensor)
if bdim == 0:
shape = [1] * len(arg.shape)
shape.insert(bdim, batch_size)
return (arg.repeat(shape), bdim)
if bdim == -1:
arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
return (arg, bdim)
def construct_in_dims(bdim_choice_for_tensors, is_tensors):
result = []
bdim = iter(bdim_choice_for_tensors)
for is_tensor in is_tensors:
if not is_tensor:
result.append(None)
continue
result.append(next(bdim))
return tuple(result)
def is_batch_norm_training(op_name, kwarg_values):
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
if op_name not in batch_norm_fns:
return False
# batch norm and instance norm require the value to be a plain bool
default_training = op_name == "nn.functional.instance_norm" # instance norm defaults to training, batch norm doesn't
is_training = tuple(arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool))
if len(is_training) == 0:
return default_training
else:
assert len(is_training) == 1
return is_training[0]
def generate_vmap_inputs(arg_values, kwarg_values, is_batch_norm_and_training=False, batch_size=2):
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
num_tensors = sum(is_tensors)
# For Batch Norm, if there's only an input, we can't
# batch it since running_mean/var will be seen as unbatched tensors
if num_tensors == 1 and is_batch_norm_and_training:
return
bdim_choices = get_bdim_choices_batch_norm(
num_tensors, *arg_values) if is_batch_norm_and_training else get_bdim_choices(num_tensors)
@memoize
def get_batched_arg(arg, bdim):
assert isinstance(arg, torch.Tensor)
assert bdim is not None
result, _ = add_batch_dim(arg, bdim, batch_size)
return result
for bdim_choice in bdim_choices:
flat_in_dims = construct_in_dims(bdim_choice, is_tensors)
flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim)
for arg, in_dim in zip(flat_args, flat_in_dims))
batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
yield batched_args, in_dims, kwarg_values
def clone_if_tensor(x):
if isinstance(x, torch.Tensor):
return x.clone()
return x
def compute_quantities_for_vmap_test(
op, orig_batched_args, orig_kwarg_values, in_dims,
out_dim=0, batch_size=2, compute_loop_out=True,
clone_inputs=False):
def maybe_clone_inputs():
if clone_inputs:
batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args)
kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values)
return batched_args, kwarg_values
return orig_batched_args, orig_kwarg_values
batched_args, kwarg_values = maybe_clone_inputs()
if compute_loop_out:
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)
else:
loop_out = None
# Used for debugging the resulting operations
# from functorch import make_fx
# def f(a):
# return op(a)
# t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
# print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
batched_args, kwarg_values = maybe_clone_inputs()
batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
yield (loop_out, batched_out)
# Tests case where we dispatch to a batching rule with no bdims
# This should be handled by autogenerated plumbing. For vmap support
# added via a manual plumbing you may need to handle this specially.
def add_bdim_if_tensor(x):
if isinstance(x, torch.Tensor):
return x.unsqueeze(1)
return x
def f(dummy, *args, **kwargs):
return op(*args, **kwargs)
dummy = torch.ones(batch_size, 1)
expected = pytree.tree_map(add_bdim_if_tensor, batched_out)
inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims)
outer_in_dims = (0,) + in_dims
batched_args, kwarg_values = maybe_clone_inputs()
output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values)
yield (expected, output)
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True):
out_dim = 0
batch_size = 2
generator = generate_vmap_inputs(arg_values, kwarg_values, is_batch_norm_and_training)
for batched_args, in_dims, kwarg_values in generator:
for quantities in compute_quantities_for_vmap_test(
op, batched_args, kwarg_values, in_dims, out_dim, batch_size, compute_loop_out):
yield quantities
def opinfo_in_dict(opinfo, d):
return (opinfo.name in d) or (f'{opinfo.name}.{opinfo.variant_test_name}' in d)
DecorateMeta = namedtuple("DecorateMeta", [
"op_name",
"variant_name",
"decorator",
"device_type",
"dtypes",
])
def decorate(op_name, variant_name='', *, decorator=None, device_type=None, dtypes=None):
assert decorator is not None
return DecorateMeta(op_name=op_name,
variant_name=variant_name,
decorator=decorator,
device_type=device_type,
dtypes=dtypes)
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
return decorate(op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
device_type=device_type,
dtypes=dtypes)
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
return decorate(op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip("Skipped!"),
device_type=device_type,
dtypes=dtypes)
def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db + additional_op_db
for decorate_meta in to_skip:
matching_opinfos = [o for o in all_opinfos
if o.name == decorate_meta.op_name and
o.variant_test_name == decorate_meta.variant_name]
assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}"
assert len(matching_opinfos) == 1, (
"OpInfos should be uniquely determined by their (name, variant_name). "
f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})"
)
opinfo = matching_opinfos[0]
decorators = list(opinfo.decorators)
new_decorator = DecorateInfo(decorate_meta.decorator,
test_case_name, base_test_name,
device_type=decorate_meta.device_type,
dtypes=decorate_meta.dtypes)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def expectedFailureIf(condition):
def decorator(fn):
if condition:
return unittest.expectedFailure(fn)
return fn
return decorator
def tol2(op_name, variant_name, override_dct, *, device_type=None):
return (op_name, variant_name, override_dct, device_type)
def tol1(op_name, override_dct, *, device_type=None):
return tol2(op_name, '', override_dct, device_type=device_type)
def opsToleranceOverride(test_case_name, base_test_name, overrides):
all_opinfos = op_db + additional_op_db
for override in overrides:
op_name, variant_name, override, device_type = override
matching_opinfos = [o for o in all_opinfos
if o.name == op_name and o.variant_test_name == variant_name]
assert len(matching_opinfos) == 1, f"Couldn't find OpInfo for {override}"
opinfo = matching_opinfos[0]
decorators = list(opinfo.decorators)
decorators.append(DecorateInfo(
toleranceOverride(override),
test_case_name, base_test_name, device_type=device_type))
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
class DisableVmapFallback:
def __enter__(self):
self.prev_state = torch._C._functorch._is_vmap_fallback_enabled()
torch._C._functorch._set_vmap_fallback_enabled(False)
def __exit__(self, *ignored):
torch._C._functorch._set_vmap_fallback_enabled(self.prev_state)
def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False):
try:
with DisableVmapFallback():
thunk()
except Exception:
if not dry_run:
raise
if opinfo.variant_test_name:
print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
else:
print(f"xfail('{opinfo.name}'),")
|