1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
# mypy: allow-untyped-defs
from typing import List, Optional
import torch
import torch.nn.functional as F
from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import (
forward_helper,
set_grad_sample_if_exists,
standard_kwargs,
)
@implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
)
if len(expanded_args[0].shape) == 1:
raise RuntimeError(
f"Expanded Weights needs an input with a batch size, got a 1D tensor, {expanded_args[0]}"
)
output = forward_helper(F.embedding, expanded_args, expanded_kwargs)
ctx.input, ctx.weight = expanded_args
ctx.padding_idx, ctx.scale_grad_by_freq = (
expanded_kwargs["padding_idx"],
expanded_kwargs["scale_grad_by_freq"],
)
ctx.sparse = expanded_kwargs["sparse"]
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.input, ctx.weight
padding_idx, scale_grad_by_freq, sparse = (
ctx.padding_idx,
ctx.scale_grad_by_freq,
ctx.sparse,
)
def weight_per_sample_grad(weight):
batch_size = input.shape[0]
embedding_dim = weight.shape[1]
index = (
input.unsqueeze(-1)
.expand(*input.shape, embedding_dim)
.reshape(batch_size, -1, embedding_dim)
)
grad_sample = torch.zeros(
batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
)
return grad_sample.scatter_add_(
1, index, grad_output.reshape(batch_size, -1, embedding_dim)
)
results: List[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
if input.requires_grad:
bw_fn = torch.ops.aten.embedding_backward
results.append(
bw_fn(
grad_output,
input,
weight.shape[0],
padding_idx,
scale_grad_by_freq,
sparse,
)
)
else:
results.append(None)
# weight doesn't compute batched gradients; no other arguments are differentiable (2 not saved from forward)
results = results + [None] * 6
# set grad_sample field for weight with per sample gradients
set_grad_sample_if_exists(weight, weight_per_sample_grad)
return tuple(results)
|