1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
from . import benchmark
import torch
class RNNEltwise(benchmark.Benchmark):
def __init__(self, mode, device, dtype, b, hs):
super().__init__(mode, device, dtype)
self.b = b
self.hs = hs
self.input = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.hx = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.cx = self.rand(
[b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_ih = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_hh = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
def forward(self, input, hx, cx, b_ih, b_hh):
gates = input + hx + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
def config(self):
return [self.b, self.hs]
@staticmethod
def module():
return "rnn_eltwise"
def memory_workload(self):
def memsize(t):
return t.numel() * t.element_size()
input_size = sum([memsize(t) for t in self.inputs])
output_size = 2 * memsize(self.cx)
io_size = input_size + output_size
return {"sol": io_size, "algorithmic": io_size}
@staticmethod
def default_configs():
return [[64, 512]]
benchmark.register_benchmark_class(RNNEltwise)
class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
def __init__(self, mode, device, dtype, b, hs):
benchmark.DynamicShape.__init__(self)
RNNEltwise.__init__(self, mode, device, dtype, b, hs)
def instantiate_input(self):
b, hs = self.rand_shape([self.b, self.hs])
self.input = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.hx = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.cx = self.rand(
[b, hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.b_ih = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.b_hh = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
@staticmethod
def module():
return "dynamic_lstm"
benchmark.register_benchmark_class(DynamicLSTM)
|