1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
"""Wrap our operator (and gradient) in autograd."""
# We need to import torch before loading the custom modules
import torch as th
import halide_ops as ops
# TODO(mgharbi): maybe find a way to wrap function and module directly in C++
# instead of generating the C++ wrapper on the fly?
def _dispatch(opname, optype=th.float32, cuda=False):
"""
Helper function that matches an opname and type to the Halide backend.
This is based on the naming convention we use in this example. Functions are
named: <opname>[_cuda]_<optype>.
Args:
opname(str): name of the base Halide function.
optype(torch.dtype): pytorch's tensor datatype.
cuda(bool): whether the operator should use cuda.
Returns:
op: a python function wrapping the requested Halide operator.
"""
assert type(opname) is str, "opname should be a string"
assert type(optype) == th.dtype, "optype should be a tensor datatype (torch.dtype)"
if cuda:
opname += "_cuda"
if optype == th.float32:
opname += "_float32"
elif optype == th.float64:
opname += "_float64"
else:
raise ValueError("Optype {} not recognized {}".format(*optype))
op = getattr(ops, opname)
if not hasattr(ops, opname):
raise ValueError(f"Module has no operator {opname}")
return op
def _forward_common(ctx, input_a, input_b):
tp = input_a.dtype
cuda = input_a.is_cuda
assert tp == input_b.dtype, "inputs should have the same type"
assert cuda == input_b.is_cuda, "inputs should be on the same device (cpu/gpu)"
ctx.save_for_backward(input_a, input_b)
fn_ = _dispatch("add", optype=tp, cuda=cuda)
# Create an output tensor with the proper dimensions
out = input_a.new()
out.resize_(input_a.shape)
fn_(input_a, input_b, out)
return out
def _backward_common(ctx, d_out, backward_op):
tp = d_out.dtype
cuda = d_out.is_cuda
input_a = ctx.saved_tensors[0]
input_b = ctx.saved_tensors[1]
# Fetch the correct Halide operator for the type/device used
fn_ = _dispatch(backward_op, optype=tp, cuda=cuda)
d_input_a = d_out.new()
d_input_b = d_out.new()
d_input_a.resize_(d_out.shape)
d_input_b.resize_(d_out.shape)
fn_(input_a, input_b, d_out.contiguous(), d_input_a, d_input_b)
return d_input_a, d_input_b
# TODO(srj): surely there's a better way to do this,
# but PyTorch seems to make it tricky to pass in
# extra info to the backward() method.
class AddFunction_Grad(th.autograd.Function):
"""Version using the manually-written backprop"""
def __init__(self):
super().__init__()
@staticmethod
def forward(ctx, input_a, input_b):
return _forward_common(ctx, input_a, input_b)
@staticmethod
def backward(ctx, d_out):
return _backward_common(ctx, d_out, "add_grad")
class AddFunction_HalideGrad(th.autograd.Function):
"""Version using the Halide-generated backprop"""
def __init__(self):
super().__init__()
@staticmethod
def forward(ctx, input_a, input_b):
return _forward_common(ctx, input_a, input_b)
@staticmethod
def backward(ctx, d_out):
return _backward_common(ctx, d_out, "add_halidegrad")
class Add(th.nn.Module):
"""Defines a module that uses our autograd function.
This is so we can use it as an operator.
"""
def __init__(self, backward_op):
super().__init__()
if backward_op == "add_grad":
self._adder = AddFunction_Grad
elif backward_op == "add_halidegrad":
self._adder = AddFunction_HalideGrad
else:
assert False
def forward(self, a, b):
return self._adder.apply(a, b)
|