1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
|
import functools
from caffe2.python import brew, rnn_cell
class GRUCell(rnn_cell.RNNCell):
def __init__(
self,
input_size,
hidden_size,
forget_bias, # Currently unused! Values here will be ignored.
memory_optimization,
drop_states=False,
linear_before_reset=False,
**kwargs
):
super(GRUCell, self).__init__(**kwargs)
self.input_size = input_size
self.hidden_size = hidden_size
self.forget_bias = float(forget_bias)
self.memory_optimization = memory_optimization
self.drop_states = drop_states
self.linear_before_reset = linear_before_reset
# Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
# (reset gate -> output_gate)
# So, much of the logic to calculate the reset gate output and modified
# output gate input is set here, in the graph definition.
# The remaining logic lives in gru_unit_op.{h,cc}.
def _apply(
self,
model,
input_t,
seq_lengths,
states,
timestep,
extra_inputs=None,
):
hidden_t_prev = states[0]
# Split input tensors to get inputs for each gate.
input_t_reset, input_t_update, input_t_output = model.net.Split(
[
input_t,
],
[
self.scope('input_t_reset'),
self.scope('input_t_update'),
self.scope('input_t_output'),
],
axis=2,
)
# Fully connected layers for reset and update gates.
reset_gate_t = brew.fc(
model,
hidden_t_prev,
self.scope('reset_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
update_gate_t = brew.fc(
model,
hidden_t_prev,
self.scope('update_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
# Calculating the modified hidden state going into output gate.
reset_gate_t = model.net.Sum(
[reset_gate_t, input_t_reset],
self.scope('reset_gate_t')
)
reset_gate_t_sigmoid = model.net.Sigmoid(
reset_gate_t,
self.scope('reset_gate_t_sigmoid')
)
# `self.linear_before_reset = True` matches cudnn semantics
if self.linear_before_reset:
output_gate_fc = brew.fc(
model,
hidden_t_prev,
self.scope('output_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
output_gate_t = model.net.Mul(
[reset_gate_t_sigmoid, output_gate_fc],
self.scope('output_gate_t_mul')
)
else:
modified_hidden_t_prev = model.net.Mul(
[reset_gate_t_sigmoid, hidden_t_prev],
self.scope('modified_hidden_t_prev')
)
output_gate_t = brew.fc(
model,
modified_hidden_t_prev,
self.scope('output_gate_t'),
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
# Add input contributions to update and output gate.
# We already (in-place) added input contributions to the reset gate.
update_gate_t = model.net.Sum(
[update_gate_t, input_t_update],
self.scope('update_gate_t'),
)
output_gate_t = model.net.Sum(
[output_gate_t, input_t_output],
self.scope('output_gate_t_summed'),
)
# Join gate outputs and add input contributions
gates_t, _gates_t_concat_dims = model.net.Concat(
[
reset_gate_t,
update_gate_t,
output_gate_t,
],
[
self.scope('gates_t'),
self.scope('_gates_t_concat_dims'),
],
axis=2,
)
if seq_lengths is not None:
inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
else:
inputs = [hidden_t_prev, gates_t, timestep]
hidden_t = model.net.GRUUnit(
inputs,
list(self.get_state_names()),
forget_bias=self.forget_bias,
drop_states=self.drop_states,
sequence_lengths=(seq_lengths is not None),
)
model.net.AddExternalOutputs(hidden_t)
return (hidden_t,)
def prepare_input(self, model, input_blob):
return brew.fc(
model,
input_blob,
self.scope('i2h'),
dim_in=self.input_size,
dim_out=3 * self.hidden_size,
axis=2,
)
def get_state_names(self):
return (self.scope('hidden_t'),)
def get_output_dim(self):
return self.hidden_size
GRU = functools.partial(rnn_cell._LSTM, GRUCell)
|