File: gpt2.rst

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (187 lines) | stat: -rw-r--r-- 7,559 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
Example: GPT-2
##############

    We succeeded in taking that picture, and, if you look at it, you see a dot. That's here. That's home. That's us. On it, *we wrote, "We are the people."*
    
    -- Carl Sagan & GPT-2

In this example, we will reimplement the GPT-2 architecture using einx and the deep learning framework `Haiku <https://github.com/google-deepmind/dm-haiku>`_, load
pretrained weights from Hugging Face and validate the model by generating some text.

..  code-block:: python

    import haiku as hk
    import jax, einx
    from functools import partial
    import einx.nn.haiku as einn
    import numpy as np

    # Define some layer types we will use.
    # 1. Use channels-last layout
    # 2. Use layer normalization, and an epsilon of 1e-5 as in the original implementation
    Linear = partial(einn.Linear, "... [_->channels]")
    Norm = partial(einn.Norm, "... [c]", epsilon=1e-5)

The main building block of GPT-2 consists of multi-head self-attention and a multi-layer perceptron (MLP). Each sub-block uses a residual connection and
layer normalization at the beginning of the residual block:

..  code-block:: python

    class Block(hk.Module):
        heads: int = 25
        mlp_ratio: int = 4

        def __call__(self, x):
            # ########### Attention block ###########
            x0 = x
            x = Norm()(x)

            # Predict queries, keys and values
            x = Linear(channels=3 * x.shape[-1])(x)
            q, k, v = jnp.split(x, 3, axis=-1)

            # Compute attention matrix over h heads
            q = q * ((q.shape[-1] // self.heads) ** -0.5)
            attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=self.heads)

            # Apply causal mask
            mask = jnp.tril(jnp.ones((q.shape[1], q.shape[1]), dtype=bool))
            attn = einx.where("q k, b q k h,", mask, attn, -jnp.inf)

            # Apply softmax and compute weighted average over the input tokens
            attn = einx.softmax("b q [k] h", attn)
            x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v)

            # Output projection
            x = Linear(channels=x.shape[-1])(x)

            x = x + x0

            # ########### MLP block ###########
            x0 = x
            x = Norm()(x)

            x = Linear(channels=x.shape[-1] * self.mlp_ratio)(x)
            x = jax.nn.gelu(x)
            x = Linear(channels=x0.shape[-1])(x)

            x = x + x0

            return x

The multi-head attention requires no additional statements to split the channel axis into multiple heads or merge the heads back into a single axis.
We instead just specify the channels axis as an :ref:`axis composition <axiscomposition>` of ``h`` heads and ``c`` channels per head:

..  code-block:: python

    attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=self.heads)
    ...
    x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v)

We can verify the correctness of these operations by inspecting the jit-compiled function:

>>> graph = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=self.heads, graph=True)
>>> print(graph)
import jax.numpy as jnp
def op0(i0, i1):
    x0 = jnp.reshape(i0, (1, 1024, 25, 64))
    x1 = jnp.reshape(i1, (1, 1024, 25, 64))
    x2 = jnp.einsum("abcd,aecd->abec", x0, x1)
    return x2

The final GPT-2 model first embeds the input tokens and adds positional embeddings. It then applies a number of main blocks and maps the output onto next token
logits using a linear layer:

..  code-block:: python

    class GPT2(hk.Module):
        channels: int = 1600
        depth: int = 48
        vocab_size: int = 50257
        block_size: int = 1024

        def __call__(self, x):
            # Word embedding: Retrieve embedding for each token from the word_embed table
            x = einx.get_at("[v] c, b t -> b t c", einn.param(name="word_embed"), x, v=self.vocab_size, c=self.channels)

            # Positional embedding
            x = einx.add("b [t c]", x, einn.param(name="pos_embed", init=hk.initializers.RandomNormal(stddev=0.02)))

            # Blocks
            for i in range(self.depth):
                x = Block(name=f"block{i}")(x)
            x = Norm()(x)

            # Classifier
            x = Linear(channels=self.vocab_size, bias=False)(x)

            return x

We use tensor factories with ``einn.param`` to construct the word and positional embeddings (see 
:doc:`Tutorial: Neural networks </gettingstarted/tutorial_neuralnetworks>`).

With this, we're done with the model definition. Next, we'll define some input data that the model will be applied to and encode it to token representation:

..  code-block:: python

    text = ("We succeeded in taking that picture, and, if you look at it, you see a dot."
            "That's here. That's home. That's us. On it,")
    print(f"Input: {text}")

    # Encode text to tokens
    import tiktoken
    encoder = tiktoken.get_encoding("gpt2")
    tokens = np.asarray(encoder.encode_ordinary(text))
    n = len(tokens)

    # Pad tokens to input block size
    tokens = np.pad(tokens, (0, GPT2.block_size - n), constant_values=0)

The model is initialized using a dummy batch (see `Haiku Basics <https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html>`_):

..  code-block:: python

    import time
    rng = jax.random.PRNGKey(int(time.time() * 1000))
    model = hk.transform(lambda x: GPT2()(x))
    params = model.init(rng, tokens[np.newaxis]) # Add batch axis to tokens using np.newaxis

At this point, ``params`` contains only randomly initialized weights. We download the original model weights for the XL variant of GPT-2 from
`Hugging Face <https://huggingface.co/gpt2-xl>`_ and load them into our model using the
`weightbridge 🌉 <https://github.com/fferflo/weightbridge>`_ library:

..  code-block:: python

    # Download original weights
    import transformers # only used to download weights
    pretrained_params = {k: np.asarray(v) for k, v in transformers.GPT2LMHeadModel.from_pretrained(f"gpt2-xl").state_dict().items()}
    pretrained_params["lm_head.weight"] = np.transpose(pretrained_params["lm_head.weight"], (1, 0))
    pretrained_params = {k: v for k, v in pretrained_params.items() if not k.endswith(".attn.bias") and not k.endswith(".attn.masked_bias")}

    # Map weights to our model implementation
    import weightbridge
    params = weightbridge.adapt(pretrained_params, params, hints=[("norm_1", "ln_2")])

Finally, we can run several forward passes to predict next tokens:

..  code-block:: python

    apply = jax.jit(model.apply) # Just-in-time compile the forward pass
    temperature = 0.3
    for _ in range(10): # Predict 10 next tokens
        logits = apply(params, rng, tokens[np.newaxis])[0]
        logits = logits[n - 1] # Get logits for next token
        tokens[n] = jax.random.categorical(rng, logits / temperature) # Sample next token
        n += 1
    print(f"Prediction: {encoder.decode(tokens[:n])}")

Input:

    We succeeded in taking that picture, and, if you look at it, you see a dot. That's here. That's home. That's us. On it,
    
Prediction:

    We succeeded in taking that picture, and, if you look at it, you see a dot. That's here. That's home. That's us. On it, we wrote, "We are the people."

The `full example script can be found here <https://github.com/fferflo/weightbridge/blob/master/examples/gpt2haiku.py>`_, and a similar example script for the
`Mamba language model using Flax can be found here <https://github.com/fferflo/weightbridge/blob/master/examples/mamba2flax.py>`_.