File: modules.py

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (131 lines) | stat: -rw-r--r-- 3,739 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Wrap our operator (and gradient) in autograd."""

# We need to import torch before loading the custom modules
import torch as th
import halide_ops as ops

# TODO(mgharbi): maybe find a way to wrap function and module directly in C++
# instead of generating the C++ wrapper on the fly?


def _dispatch(opname, optype=th.float32, cuda=False):
    """
    Helper function that matches an opname and type to the Halide backend.

    This is based on the naming convention we use in this example. Functions are
    named: <opname>[_cuda]_<optype>.

    Args:
      opname(str): name of the base Halide function.
      optype(torch.dtype): pytorch's tensor datatype.
      cuda(bool): whether the operator should use cuda.

    Returns:
      op: a python function wrapping the requested Halide operator.
    """

    assert type(opname) is str, "opname should be a string"
    assert type(optype) == th.dtype, "optype should be a tensor datatype (torch.dtype)"

    if cuda:
        opname += "_cuda"

    if optype == th.float32:
        opname += "_float32"
    elif optype == th.float64:
        opname += "_float64"
    else:
        raise ValueError("Optype {} not recognized {}".format(*optype))
    op = getattr(ops, opname)
    if not hasattr(ops, opname):
        raise ValueError(f"Module has no operator {opname}")
    return op


def _forward_common(ctx, input_a, input_b):
    tp = input_a.dtype
    cuda = input_a.is_cuda
    assert tp == input_b.dtype, "inputs should have the same type"
    assert cuda == input_b.is_cuda, "inputs should be on the same device (cpu/gpu)"

    ctx.save_for_backward(input_a, input_b)

    fn_ = _dispatch("add", optype=tp, cuda=cuda)

    # Create an output tensor with the proper dimensions
    out = input_a.new()
    out.resize_(input_a.shape)

    fn_(input_a, input_b, out)
    return out


def _backward_common(ctx, d_out, backward_op):
    tp = d_out.dtype
    cuda = d_out.is_cuda

    input_a = ctx.saved_tensors[0]
    input_b = ctx.saved_tensors[1]

    # Fetch the correct Halide operator for the type/device used
    fn_ = _dispatch(backward_op, optype=tp, cuda=cuda)

    d_input_a = d_out.new()
    d_input_b = d_out.new()
    d_input_a.resize_(d_out.shape)
    d_input_b.resize_(d_out.shape)

    fn_(input_a, input_b, d_out.contiguous(), d_input_a, d_input_b)
    return d_input_a, d_input_b


# TODO(srj): surely there's a better way to do this,
# but PyTorch seems to make it tricky to pass in
# extra info to the backward() method.
class AddFunction_Grad(th.autograd.Function):
    """Version using the manually-written backprop"""

    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(ctx, input_a, input_b):
        return _forward_common(ctx, input_a, input_b)

    @staticmethod
    def backward(ctx, d_out):
        return _backward_common(ctx, d_out, "add_grad")


class AddFunction_HalideGrad(th.autograd.Function):
    """Version using the Halide-generated backprop"""

    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(ctx, input_a, input_b):
        return _forward_common(ctx, input_a, input_b)

    @staticmethod
    def backward(ctx, d_out):
        return _backward_common(ctx, d_out, "add_halidegrad")


class Add(th.nn.Module):
    """Defines a module that uses our autograd function.

    This is so we can use it as an operator.
    """

    def __init__(self, backward_op):
        super().__init__()
        if backward_op == "add_grad":
            self._adder = AddFunction_Grad
        elif backward_op == "add_halidegrad":
            self._adder = AddFunction_HalideGrad
        else:
            assert False

    def forward(self, a, b):
        return self._adder.apply(a, b)