1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
|
import inspect
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch._decomp
from torch import Tensor
decomposition_table = torch._decomp.decomposition_table
decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
register_decomposition = torch._decomp.register_decomposition
aten = torch.ops.aten
# NOTE: [forward-mode AD decompositions mechanism]
#
# The mechanism is in VariableType,
# IF any inputs have forward grad
# AND there is no forward AD formula implemented
# AND the functions is actually differentiable
# run the decomposition
# See run_jit_decomposition_with_args_for_jvp
# We currently use python decompositions that we torchscript.
#
# Note that we would be building the backward graph at the decomposed level
# too, but that is OK, because we would've errored out otherwise anyway.
#
# TODO: The mechanism we are using to register decompositions doesn't
# seem to be exclusively used for jvp. So open question here is whether
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
# If that is the case, we may go down the decomposition path unexpectedly
# (and possibly produce an unintelligible error) vs erroring out earlier and
# printing that the forward AD formula is not implemented.
#
# The solution to this may be to have a explicitly white list control when
# to enable the decomposition.
def maybe_register_decomposition(op):
def decorator(f):
try:
return register_decomposition(op)(f)
except Exception:
return f
return decorator
# Functions where we need a special decomposition for jvp but there's another version that
# should be used more generally (ex. for jvp we need to recompute the mean and variance for
# the backwards of a normalization function. Without jvp, it should used the saved value)
decomposition_table_for_jvp = {}
def register_decomposition_for_jvp(fn):
return register_decomposition(fn, registry=decomposition_table_for_jvp)
def _register_jit_decomposition_for_jvp(decomp, use_python=False):
if decomp in decomposition_table_for_jvp:
decomposition_table_used = decomposition_table_for_jvp
elif decomp in decomposition_table:
decomposition_table_used = decomposition_table
else:
raise RuntimeError(f"could not find decomposition for {decomp}")
decomp_fn = decomposition_table_used[decomp]
if use_python:
decomp_fn = torch.jit.ignore(decomp_fn)
sig = inspect.signature(decomp_fn)
# Create a string wrapping the function from the signature
# example output:
# def wrapped_decomp(x: torch.Tensor, y: int, z: int):
# return decomp_fn(x, y, z)
# Thanks copilot!
def get_function_def(sig):
param_def = [f"{param_str}" for param_str in sig.parameters.values()]
param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
f_str = get_function_def(sig)
graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
else:
graph = torch.jit.script(decomp_fn).graph
torch.jit._register_decomposition(decomp, graph)
# The only decompositions here are temporary or hacks for the purposes of jvp
# TODO: do these also belong here?
@maybe_register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:
return torch.sum(torch.diag(self))
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda:
buffer = self.new_zeros((0,))
else:
buffer = z
return min - torch.log1p(z), buffer
def recompute_mean_var(
input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
):
# for most norm decompositions, it will be the same as the core version except for here.
# We recompute the mean and variance so that they track gradients through input
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
eps = eps.detach()
rstd = 1 / torch.sqrt(var + eps)
return mean, rstd
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
def native_layer_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: List[int],
mean: Tensor,
rstd: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices = list(range(axis, input_ndim))
outer_dim_indices = list(range(0, axis))
N = 1
for i in inner_dims:
N *= i
M = 1
for i in outer_dims:
M *= i
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape),
input.new_zeros(input_shape[axis:]),
input.new_zeros(input_shape[axis:]),
)
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
x_hat = (input - mean_) * rstd_
if weight is not None:
grad_x_hat = grad_out * weight
else:
grad_x_hat = grad_out
a = grad_x_hat * N
b = torch.sum(grad_x_hat, inner_dim_indices, True)
c1 = torch.mul(grad_x_hat, x_hat)
c2 = torch.sum(c1, inner_dim_indices, True)
c3 = torch.mul(x_hat, c2)
inner = a - b - c3
if output_mask[0]:
d_input: Optional[Tensor] = (rstd_ / N) * inner
else:
d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
if output_mask[1] and weight is not None:
if len(outer_dim_indices) > 0:
d_weight: Optional[Tensor] = torch.sum(
grad_out * x_hat, outer_dim_indices, False
)
else:
d_weight = grad_out * x_hat
elif weight is not None:
d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
else:
d_weight = torch.zeros(()) # should be None but doesn't work with vjp
if output_mask[2] and bias is not None:
if len(outer_dim_indices) > 0:
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
else:
d_bias = grad_out.clone()
elif bias is not None:
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
else:
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
return (d_input, d_weight, d_bias)
def prod(x: List[int]):
r = 1
for i in x:
r *= i
return r
@register_decomposition_for_jvp(aten.native_batch_norm_backward)
def native_batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"
axis = 1
num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
mean = save_mean
invstd = save_invstd
if train:
assert (
save_mean is not None and save_invstd is not None
), "when train=True, save_mean and save_invstd are required"
reduciton_dims = [0] + list(range(2, input.dim()))
assert invstd is not None # for typing
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
else:
assert running_mean is not None and running_var is not None
mean = running_mean
invstd = torch.rsqrt(running_var + eps)
assert invstd is not None and mean is not None
broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]
reduction_axes: List[int] = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)
mean = torch.reshape(mean, broadcast_mask)
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out, reduction_axes)
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
if weight is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
else:
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
if train:
proj = (input - mean) * proj_scale
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out * grad_scale
if output_mask[1]:
grad_weight = dot_p * invstd
elif weight is not None:
grad_weight = torch.zeros_like(
weight
) # should be None but doesn't work with vjp
else:
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
if output_mask[2]:
grad_bias = grad_output_sum
else:
grad_bias = torch.zeros_like(
grad_output_sum
) # should be None but doesn't work with vjp
return (grad_input, grad_weight, grad_bias)
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
|