1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from functorch.dim import dims, dimlists, softmax, cat
import math
class Linear(nn.Linear):
def forward(self, input):
ci, co = dims()
b = dimlists()
result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co]
return result.order(b, co)
class BertSelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_probs_dropout_prob, position_embedding_type=None,
max_position_embeddings=None, linear=Linear):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
f"heads ({num_attention_heads})"
)
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = linear(hidden_size, self.all_head_size)
self.key = linear(hidden_size, self.all_head_size)
self.value = linear(hidden_size, self.all_head_size)
self.dropout_prob = attention_probs_dropout_prob
self.position_embedding_type = position_embedding_type
if self.position_embedding_type is not None:
assert max_position_embeddings is not None
self.max_position_embeddings = max_position_embeddings
self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size)
def forward(
self,
hidden_states,
past_key_value=None,
):
# first run the encoding linear layers for q, k, v normally
# the meaning of a linear layer is well understood, so no need to use explicit dimensions
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
# introduce values that represent each dimension. dimensions are 'first class'
# becaue they are actual python values introduced here
batch, query_sequence, key_sequence, heads, features = dims()
heads.size = self.num_attention_heads
# bind the positional dimensions in k, q, and v against
# our values. the sizes of each dimension are determined by this binding
# and when a dimension is used twice (e.g. batch), its size against both
# uses is checked for consistency.
# The group (heads, features) splits apart a single positional dimension
# into two dimensions. Since heads.size*features.size == q.size(2)
# and we specified heads.size, features.size is inferred here.
q = q[batch, query_sequence, [heads, features]]
k = k[batch, key_sequence, [heads, features]]
v = v[batch, key_sequence, [heads, features]]
# this option allows the model to attend to not just the elements of the current sequence
# but the previouse elements as well as additional tokens.
if past_key_value is not None:
extended_key_sequence = dims()
key_past = past_key_value[0][batch, heads, key_sequence, features]
value_past = past_key_value[1][batch, heads, key_sequence, features]
# cat introduces a new dimension exteneded_key_sequence, becuase it is twice as long
# as the original key_sequence
k = cat([key_past, k], key_sequence, extended_key_sequence)
v = cat([value_past, v], key_sequence, extended_key_sequence)
# for the rest of the function, we will just use extended_key_sequence in lieu of
# key_sequence
key_sequence = extended_key_sequence
# Take the dot product between "query" and "key" to get the raw attention scores.
# The actual outer-product and summation are explicitly represented here,
# and like einsum, will be pattern matched to an efficient matrix multiply op.
attention_scores = (q * k).sum(features) / math.sqrt(features.size)
# relative positional embeddings gave a unique embedding based on the distance between
# key and value tokens in the sequence, e.g.
# 0 1 2 3
# -1 0 1 2
# -2 -1 0 1
# -3 -2 -1 0
if self.position_embedding_type is not None:
# the value of a dimension object when used as a tensor is the indices along its dimension
# so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence)
# with the distance between them
distance = query_sequence - key_sequence
assert key_sequence.size <= self.max_position_embeddings
# we can then use that as an indirect index into the embedding table values to look up the features for that index
# this is just a `gather` primitive op. The resulting tensor will
# have all the dimensions of embeddeding_idx (query_sequence x key_sequence),
# plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`).
# this form of indirect indexing is more strainghtforward than either advanced indexing or torch.gather which both
# have a lot of dependencies on the positions of indexing tensors.
positional_embedding = self.distance_embedding.weight[self.max_position_embeddings - 1 + distance, features]
if self.position_embedding_type == "relative_key":
# these were einsum ops in the positional code because they are not easy to fit to existing matmul operators
# eventhough they are degenerate matmuls
relative_position_scores = (q * positional_embedding).sum(features)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = (q * positional_embedding).sum(features)
relative_position_scores_key = (k * positional_embedding).sum(features)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_probs = attention_scores
# Normalize the attention scores to probabilities.
attention_probs = softmax(attention_scores, dim=key_sequence)
# # This is actually dropping out entire tokens to attend to, which might
# # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = torch.nn.functional.dropout(attention_probs, p=self.dropout_prob)
# similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear
# we are weighting the values v across all keys with the attention scores.
context_layer = (attention_probs * v).sum(key_sequence)
# finally, we convert back to a standard tensor by describing the layout of dimensions.
# working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one.
return context_layer.order(batch, query_sequence, [heads, features])
|