1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
from torch import nn
from torch.nn.utils import rnn as rnn_utils
class RnnModelWithPackedSequence(nn.Module):
def __init__(self, model, batch_first):
super().__init__()
self.model = model
self.batch_first = batch_first
def forward(self, input, *args):
args, seq_lengths = args[:-1], args[-1]
input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first)
rets = self.model(input, *args)
ret, rets = rets[0], rets[1:]
ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first)
return tuple([ret] + list(rets))
class RnnModelWithPackedSequenceWithoutState(nn.Module):
def __init__(self, model, batch_first):
super().__init__()
self.model = model
self.batch_first = batch_first
def forward(self, input, seq_lengths):
input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first)
rets = self.model(input)
ret, rets = rets[0], rets[1:]
ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first)
return list([ret] + list(rets))
class RnnModelWithPackedSequenceWithState(nn.Module):
def __init__(self, model, batch_first):
super().__init__()
self.model = model
self.batch_first = batch_first
def forward(self, input, hx, seq_lengths):
input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first)
rets = self.model(input, hx)
ret, rets = rets[0], rets[1:]
ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first)
return list([ret] + list(rets))
|