1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from typing import Callable, Dict
import torch
from torch._decomp import decomposition_table
from torch.utils._pytree import tree_map_only
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
aten = torch._ops.ops.aten
# __torch_function__ runs before the pydispatcher so we need to manually use the same
# decompositions indexed by their torch equivalent
expanded_weights_rnn_decomps = {
# func: (input_decomp, data_decomp)
torch.rnn_relu: (
decomposition_table[aten.rnn_relu.input],
decomposition_table[aten.rnn_relu.data],
),
torch.rnn_tanh: (
decomposition_table[aten.rnn_tanh.input],
decomposition_table[aten.rnn_tanh.data],
),
torch.lstm: (
decomposition_table[aten.lstm.input],
decomposition_table[aten.lstm.data],
),
torch.gru: (
decomposition_table[aten.gru.input],
decomposition_table[aten.gru.data],
),
}
# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set
@contextmanager
def batch_second(args, kwargs):
def set_batch_second(ew):
ew.set_batch_first(False)
def reset_batch_first(ew):
ew.set_batch_first(True)
tree_map_only(ExpandedWeight, set_batch_second, args)
tree_map_only(ExpandedWeight, set_batch_second, kwargs)
try:
yield
finally:
tree_map_only(ExpandedWeight, reset_batch_first, args)
tree_map_only(ExpandedWeight, reset_batch_first, kwargs)
# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch
@contextmanager
def allow_smaller_batches(args, kwargs):
def allow(ew):
ew.set_allow_smaller_batches(True)
def reset(ew):
ew.set_allow_smaller_batches(False)
tree_map_only(ExpandedWeight, allow, args)
tree_map_only(ExpandedWeight, allow, kwargs)
try:
yield
finally:
tree_map_only(ExpandedWeight, reset, args)
tree_map_only(ExpandedWeight, reset, kwargs)
@contextmanager
def setup_rnn(use_input_variant, args, kwargs):
with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(
args, kwargs
):
yield
def implements_per_sample_grads(torch_function):
@functools.wraps(torch_function)
def decorator(autograd_func):
HANDLED_FUNCTIONS[torch_function] = autograd_func
return autograd_func
return decorator
# ExpandedWeight represents a weight (parameter) Tensor that has an expanded
# batch dimension. Operations on the ExpandedWeight Tensor act exactly like
# those without an expanded batch dimension but a call to .backward() populates
# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field
#
# ExpandedWeight has a fallback that always fails since we cannot know what the batch
# dimension of the input tensor is and therefore cannot know if this is a valid call
#
# This is a __torch_function__ object but it could have also been a Tensor Extension
# with a dispatch key.
#
# Needs to be a tensor subclass to allow reparamaterization
class ExpandedWeight(torch.Tensor):
def __init__(self, orig_weight, batch_size, loss_reduction):
self.batch_size = batch_size
self.batch_first = True
self.allow_smaller_batches = False
self.orig_weight = orig_weight
self.loss_reduction = loss_reduction
handled_functions = HANDLED_FUNCTIONS
def __new__(cls, orig_weight, batch_size, loss_reduction):
if not isinstance(orig_weight, torch.Tensor):
raise RuntimeError(
f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}"
)
if not orig_weight.requires_grad:
raise RuntimeError(
"Can only build ExpandedWeights objects of tensors that require_grad"
)
ret = torch.Tensor._make_subclass(cls, orig_weight, True)
return ret
@classmethod
def __torch_function__(cls, func, _, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in expanded_weights_rnn_decomps:
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
decomp_opts = expanded_weights_rnn_decomps[func]
use_input_variant = isinstance(
args[2], list
) # data variant uses a list here
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
if decomp is not None:
with setup_rnn(use_input_variant, args, kwargs):
return decomp(*args, **kwargs)
if func == torch._cudnn_rnn_flatten_weight:
# since we aren't using the fused cuda kernels for RNNs, don't do this
return
if func in cls.handled_functions:
return cls.handled_functions[func].apply(
tuple(kwargs.keys()), func, *(args + tuple(kwargs.values()))
)
# We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
# i.e. torch.add(torch.Tensor, ExpandedWeight)
raise RuntimeError(
f"Expanded Weights encountered but cannot handle function {func.__name__}"
)
@property
def dtype(self):
return self.orig_weight.dtype
@property
def data(self):
return self.orig_weight.data
@property
def shape(self):
return self.orig_weight.shape
@property
def device(self):
return self.orig_weight.device
@property
def is_cuda(self):
return self.orig_weight.is_cuda
def data_ptr(self):
return self.orig_weight.data_ptr()
def get_device(self):
return self.orig_weight.get_device()
def set_allow_smaller_batches(self, is_allow_smaller_batches):
self.allow_smaller_batches = is_allow_smaller_batches
def set_batch_first(self, is_batch_first=True):
self.batch_first = is_batch_first
|