1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
|
from torch._C import _TensorBase
import torch
import functools
from typing import Callable, Dict, cast
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
def implements_per_sample_grads(torch_function):
@functools.wraps(torch_function)
def decorator(autograd_func):
HANDLED_FUNCTIONS[torch_function] = autograd_func
return autograd_func
return decorator
# ExpandedWeight represents a weight (parameter) Tensor that has an expanded
# batch dimension. Operations on the ExpandedWeight Tensor act exactly like
# those without an expanded batch dimension but a call to .backward() populates
# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field
#
# ExpandedWeight has a fallback that always fails since we cannot know what the batch
# dimension of the input tensor is and therefore cannot know if this is a valid call
#
# This is a __torch_function__ object but it could have also been a Tensor Extension
# with a dispatch key.
#
# Needs to be a tensor subclass to allow reparamaterization
class ExpandedWeight(torch.Tensor):
def __init__(self, orig_weight, batch_size, loss_reduction):
self.batch_size = batch_size
self.orig_weight = orig_weight
self.loss_reduction = loss_reduction
handled_functions = HANDLED_FUNCTIONS
def __new__(cls, orig_weight, batch_size, loss_reduction):
if not isinstance(orig_weight, torch.Tensor):
raise RuntimeError(f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}")
if not orig_weight.requires_grad:
raise RuntimeError("Can only build ExpandedWeights objects of tensors that require_grad")
ret = torch.Tensor._make_subclass(cast(_TensorBase, cls), orig_weight, True)
return ret
@classmethod
def __torch_function__(cls, func, _, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in cls.handled_functions:
return cls.handled_functions[func].apply(tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())))
# We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
# i.e. torch.add(torch.Tensor, ExpandedWeight)
raise RuntimeError(f"Expanded Weights encountered but cannot handle function {func.__name__}")
@property
def dtype(self):
return self.orig_weight.dtype
@property
def shape(self):
return self.orig_weight.shape
|