File: commonnnops.rst

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (138 lines) | stat: -rw-r--r-- 4,669 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
Example: Common neural network operations
#########################################

einx allows formulating many common operations of deep learning models as concise expressions. This page provides a few examples.

..  code-block:: python

    import einx
    import einx.nn.{torch|flax|haiku|equinox|keras} as einn

LayerScale
----------

Multiply the input tensor ``x`` with a learnable parameter per channel that is initialized with a small value:

..  code-block:: python

    x = einx.multiply("... [c]", x, einn.param(init=1e-5))

Reference: `LayerScale explained <https://paperswithcode.com/method/layerscale>`_

Prepend class-token
-------------------

Flatten the spatial axes of an n-dimensional input tensor ``x`` and prepend a learnable class token:

..  code-block:: python

    x = einx.rearrange("b s... c, c -> b (1 + (s...)) c", x, einn.param(name="class_token"))

Reference: `Classification token in Vision Transformer <https://paperswithcode.com/method/vision-transformer>`_

Positional embedding
--------------------

Add a learnable positional embedding onto all tokens of the input ``x``. Works with n-dimensional inputs (text, image, video, ...):

..  code-block:: python

    x = einx.add("b [s... c]", x, einn.param(name="pos_embed", init=nn.initializers.normal(stddev=0.02)))

Reference: `Position embeddings in Vision Transformer <https://paperswithcode.com/method/vision-transformer>`_

Word embedding
--------------

Retrieve a learnable embedding vector for each token in the input sequence ``x``:

..  code-block:: python

    x = einx.get_at("[v] c, b t -> b t c", einn.param(name="vocab_embed"), x, v=50257, c=1024)

Reference: `Torch tutorial on word embeddings <https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html>`_

Layer normalization
-------------------

Compute the mean and variance along the channel axis, and normalize the tensor by subtracting the mean and dividing by the standard deviation.
Apply learnable scale and bias:

..  code-block:: python

    mean = einx.mean("... [c]", x, keepdims=True)
    var = einx.var("... [c]", x, keepdims=True)
    x = (x - mean) * torch.rsqrt(var + epsilon)

    x = einx.add("... [c]", x, einn.param(name="bias"))
    x = einx.multiply("... [c]", x, einn.param(name="scale"))

This can similarly be achieved using the ``einn.Norm`` layer:

..  code-block:: python

    import einx.nn.{torch|flax|haiku|...} as einn
    x = einn.Norm("... [c]")(x)

Reference: `Layer normalization explained <https://paperswithcode.com/method/layer-normalization>`_

Multihead attention
-------------------

Compute multihead attention for the queries ``q``, keys ``k`` and values ``v`` with ``h = 8`` heads:

..  code-block:: python

    a = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=8)
    a = einx.softmax("b q [k] h", a)
    x = einx.dot("b q k h, b k (h c) -> b q (h c)", a, v)

Reference: `Multi-Head Attention <https://paperswithcode.com/method/multi-head-attention>`_

Shifted window attention
------------------------

Shift and partition the input tensor ``x`` into windows with sidelength ``w``, compute self-attention in each window, and unshift and merge windows again. Works with
n-dimensional inputs (text, image, video, ...):

..  code-block:: python

    # Compute axis values so we don't have to specify s and w manually later
    consts = einx.solve("b (s w)... c", x, w=16) 

    # Shift and partition windows
    x = einx.roll("b [...] c", x, shift=-shift)
    x = einx.rearrange("b (s w)... c -> (b s...) (w...) c", x, **consts)

    # Compute attention
    ...

    # Unshift and merge windows
    x = einx.rearrange("(b s...) (w...) c -> b (s w)... c", x, **consts)
    x = einx.roll("b [...] c", x, shift=shift)

Reference: `Swin Transformer <https://paperswithcode.com/method/swin-transformer>`_

Multilayer Perceptron along spatial axes (MLP-Mixer)
----------------------------------------------------

Apply a weight matrix multiplication along the spatial axes of the input tensor:

..  code-block:: python

    x = einx.dot("b [s...->s2] c", x, einn.param(name="weight1"))
    ...
    x = einx.dot("b [s2->s...] c", x, einn.param(name="weight2"), s=(256, 256))

Or with the ``einn.Linear`` layer that includes a bias term:

..  code-block:: python

    x = einn.Linear("b [s...->s2] c")(x)
    ...
    x = einn.Linear("b [s2->s...] c", s=(256, 256))(x)

Reference: `MLP-Mixer <https://paperswithcode.com/method/mlp-mixer>`_

The following page provides an example implementation of GPT-2 with ``einx`` and ``einn`` using many of these operations and validates
their correctness by loading pretrained weights and generating some example text.