File: tutorial_neuralnetworks.rst

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (324 lines) | stat: -rw-r--r-- 15,103 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
Tutorial: Neural networks
#########################

einx provides several neural network layer types for deep learning frameworks (`PyTorch <https://pytorch.org/>`_, `Flax <https://github.com/google/flax>`_,
`Haiku <https://github.com/google-deepmind/dm-haiku>`_, `Equinox <https://github.com/patrick-kidger/equinox>`_, `Keras <https://keras.io/>`_) in the ``einx.nn.*`` namespace 
based on the functions in ``einx.*``. These layers provide abstractions that can implement a wide variety of deep learning operations using einx notation.
The ``einx.nn.*`` namespace is entirely optional, and is imported as follows:

..  code::

    import einx.nn.{torch|flax|haiku|equinox|keras} as einn

Motivation
----------

The main idea for implementing layers in einx is to exploit :ref:`tensor factories <lazytensorconstruction>` to initialize the weights of a layer.
For example, consider the following linear layer:

..  code::

    x = einx.dot("... [c1->c2]", x, w) # x * w
    x = einx.add("... [c2]", x, b)     # x + b

The arguments ``w`` and ``b`` represent the layer weights. Instead of determining the shapes of ``w`` and ``b`` in advance to create the weights manually,
we define ``w`` and ``b`` as tensor factories that
are called inside the einx functions once the shapes are determined. For example, in the Haiku framework ``hk.get_parameter`` is used to create new weights
in the current module and can be defined as a tensor factory as follows:

..  code::

    import haiku as hk

    class Linear(hk.Module):
        def __call__(self, x):
            w = lambda shape: hk.get_parameter(name="weight", shape=shape, dtype="float32", init=hk.initializers.VarianceScaling(1.0, "fan_in", "truncated_normal"))
            b = lambda shape: hk.get_parameter(name="bias", shape=shape, dtype="float32", init=hk.initializers.Constant(0.0))

            x = einx.dot("b... [c1->c2]", x, w, c2=64)
            x = einx.add("b... [c2]", x, b)
            return x

Unlike a tensor, the tensor factory does not provide shape constraints to the expression solver and requires that we define the missing axes (``c2``) manually. Here,
this corresponds to specifying the number of output channels of the linear layer. All other axis values are determined implicitly from the input shapes.

The weights are created once a layer is run on the first input batch. This is common practice in jax-based frameworks like Flax and Haiku where a model
is typically first invoked with a dummy batch to instantiate all weights.

In PyTorch, we rely on `lazy modules <https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin>`_
by creating weights as ``torch.nn.parameter.UninitializedParameter`` in the constructor and calling their ``materialize`` method on the first input batch. This is
handled automatically by einx (see below).

Parameter definition with ``einn.param``
----------------------------------------

einx provides the function ``einn.param`` to create *parameter factories* for the respective deep learning framework. ``einn.param`` is simply a convenience wrapper for
the ``lambda shape: ...`` syntax that is used in the example above:

..  code:: python

    # w1 and w2 give the same result when used as tensor factories in einx functions:

    w1 = lambda shape: hk.get_parameter(name="weight", shape=shape, dtype="float32", init=...)

    w2 = einn.param(name="weight", dtype="float32", init=...)

The utility of ``einn.param`` comes from providing several useful default arguments that simplify the definition of parameters:

*   **Default argument for** ``init``

    The type of (random) initialization that is used for a parameter in neural networks typically depends on the operation that the parameter is used in. For example:

    * A bias parameter is used in an ``add`` operation and often initialized with zeros.
    * A weight parameter in linear layers is used in a ``dot`` operation and initialized e.g. using
      `Lecun normal initialization <https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.lecun_normal.html>`_
      based on the fan-in or fan-out of the layer.
    * A scale parameter is used in a ``multiply`` operation and e.g. initialized with ones in normalization layers.

    To allow ``einn.param`` to use a default initialization method based on the operation that it is used in, einx functions like :func:`einx.dot` and :func:`einx.add`
    forward their name as optional arguments to tensor factories. ``einn.param`` then defines a corresponding initializer in the respective framework and
    uses it as a default argument for ``init``. E.g. in Flax:

    ..  code:: python

        from flax import linen as nn

        if init == "get_at" or init == "rearrange":
            init = nn.initializers.normal(stddev=0.02)
        elif init == "add":
            init = nn.initializers.zeros_init()
        elif init == "multiply":
            init = nn.initializers.ones_init()
        elif init == "dot":
            init = nn.initializers.lecun_normal(kwargs["in_axis"], kwargs["out_axis"], kwargs["batch_axis"])

    :func:`einx.dot` additionally determines ``in_axis``, ``out_axis`` and ``batch_axis`` from the einx expression and forwards them as optional arguments
    to tensor factories. In this case, they allow ``nn.initializers.lecun_normal`` to determine the fan-in of the layer and choose the initialization accordingly.

*   **Default argument for** ``name``

    A default name is determined implicitly from the operation that the parameter is used in, for example:

    .. list-table:: 
       :widths: 30 30
       :header-rows: 0

       * - Operation
         - Name
       * - :func:`einx.add`
         - ``bias``
       * - :func:`einx.multiply`
         - ``scale``
       * - :func:`einx.dot`
         - ``weight``
       * - :func:`einx.get_at`
         - ``embedding``
       * - :func:`einx.rearrange`
         - ``embedding``

*   **Default argument for** ``dtype``

    The default data type of the parameter is determined from the ``dtype`` member variable of the respective module if it exists, and chosen as ``float32`` otherwise.

Any default argument in ``einn.param`` can be overridden by simply passing the respective argument explicitly:

..  code::

    # Initialize bias with non-zero values
    einx.add("b... [c]", x, einn.param(init=nn.initializers.normal(stddev=0.02)))

    # Initialize layerscale with small value
    einx.multiply("b... [c]", x, einn.param(init=1e-5, name="layerscale"))

If no default argument can be determined (e.g. because there is no default initialization for an operation, or the module does not have a ``dtype`` member) and the
argument is not specified explicitly in ``einn.param``, an exception is raised.

Example layer using ``einn.param``
----------------------------------

Our definition of a linear layer above that used the ``lambda shape: ...`` syntax can be simplified using ``einn.param`` as shown below.

**Haiku**

..  code:: python

    import haiku as hk

    class Linear(hk.Module):
        dtype: str = "float32"
        def __call__(self, x):
            x = einx.dot("... [c1->c2]", x, einn.param(), c2=64)
            x = einx.add("... [c2]", x, einn.param())
            return x

In Haiku, ``hk.get_parameter`` and ``hk.get_state`` can be passed as the first parameter of ``einn.param`` to determine whether to create a parameter or state variable:

..  code:: python

    einx.add("... [c]", x, einn.param(hk.get_parameter))  # calls einn.param(hk.get_parameter)
    einx.add("... [c]", x, einn.param())                  # calls einn.param(hk.get_parameter)
    einx.add("... [c]", x, hk.get_parameter)              # calls einn.param(hk.get_parameter)
    einx.add("... [c]", x, einn.param(hk.get_state))      # calls einn.param(hk.get_state)
    einx.add("... [c]", x, hk.get_state)                  # calls einn.param(hk.get_state)

**Flax**

..  code:: python

    from flax import linen as nn

    class Linear(nn.Module):
        dtype: str = "float32"
        def __call__(self, x):
            x = einx.dot("... [c1->c2]", x, einn.param(self), c2=64)
            x = einx.add("... [c2]", x, einn.param(self))
            return x

In Flax, parameters are created by calling the ``self.param`` or ``self.variable`` method of the current module. For
convenience, einx provides several options to determine which one is used:

..  code:: python

    einx.add("... [c]", x, einn.param(self.param))                  # calls einn.param(self.param)
    einx.add("... [c]", x, einn.param(self))                        # calls einn.param(self.param)
    einx.add("... [c]", x, self.param)                              # calls einn.param(self.param)
    einx.add("... [c]", x, self)                                    # calls einn.param(self.param)
    einx.add("... [c]", x, einn.param(self.variable, col="stats"))  # calls einn.param(self.variable, col="stats")

**PyTorch**

..  code::

    import torch.nn as nn

    class Linear(nn.Module):
        def __init__(self):
            super().__init__()
            self.w = nn.parameter.UninitializedParameter(dtype=torch.float32)
            self.b = nn.parameter.UninitializedParameter(dtype=torch.float32)

        def forward(self, x):
            x = einx.dot("b... [c1->c2]", x, self.w, c2=64)
            x = einx.add("b... [c2]", x, self.b)
            return x

In PyTorch, parameters have to be created in the constructor of the module as ``nn.parameter.UninitializedParameter`` and ``nn.parameter.UninitializedBuffer``
(see `lazy modules <https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin>`_). They can
be passed to einx functions directly, or by using ``einn.param`` (e.g. to specify additional arguments):

..  code:: python

    einx.add("... [c]", x, einn.param(self.w))        # calls einn.param(self.w)
    einx.add("... [c]", x, self.w)                    # calls einn.param(self.w)

For PyTorch, ``einn.param`` does not support a ``dtype`` and ``name`` argument since these are specified in the constructor.

**Equinox**

..  code::

    import equinox as eqx

    class Linear(eqx.Module):
        w: jax.Array
        b: jax.Array
        dtype: str = "float32"

        def __init__(self):
            self.w = None
            self.b = None

        def forward(self, x, rng=None):
            x = einx.dot("b... [c1->c2]", x, einn.param(self, name="weight", rng=rng), c2=64)
            x = einx.add("b... [c2]", x, einn.param(self, name="bias", rng=rng))
            return x

In Equinox, parameters have to be specified as dataclass member variables of the module. In einx, these variables are set to ``None`` in the constructor and initialized in the
``__call__`` method instead by passing the module and member variable name to ``einn.param``. This initializes the parameter and stores it in the respective
member variable, such that the module can be used as a regular Equinox module. When a parameter is initialized randomly, it also requires passing a random key ``rng`` to
``einn.param`` on the first call:

..  code:: python

    einx.add("... [c]", x, einn.param(self, rng=rng))

Stateful layers are currently not supported for Equinox, since they require the shape of the state variable to be known in the constructor.

**Keras**

..  code::

    class Linear(einn.Layer):
        def call(self, x):
            x = einx.dot("b... [c1->c2]", x, einn.param(self, name="weight"), c2=64)
            x = einx.add("b... [c2]", x, einn.param(self, name="bias"))
            return x

In Keras, parameters can be created in a layer's ``build`` method instead of the ``__init__`` method, which gives access to the shapes of the layer's input arguments. The regular
forward-pass is defined in the ``call`` method. einx provides the base class ``einn.Layer`` which simply implements the ``build`` method to call the layer's ``call`` method
with dummy arguments and thereby initialize the layer parameters.

..  code:: python

    einx.add("... [c]", x, einn.param(self))

Layers
------

einx provides the layer types ``einn.{Linear|Norm|Dropout}`` that are implemented as outlined above.

**einn.Norm** implements a normalization layer with optional exponential moving average (EMA) over the computed statistics. The first parameter is an einx expression for
the axes along which the statistics for normalization are computed. The second parameter is an einx expression for the axes corresponding to the bias and scale terms, and
defaults to ``b... [c]``. The different sub-steps can be toggled by passing ``True`` or ``False`` for the ``mean``, ``var``, ``scale`` and ``bias`` parameters. The EMA is used only if 
``decay_rate`` is passed.

A variety of normalization layers can be implemented using this abstraction:

..  code::

    layernorm       = einn.Norm("b... [c]")
    instancenorm    = einn.Norm("b [s...] c")
    groupnorm       = einn.Norm("b [s...] (g [c])", g=8)
    batchnorm       = einn.Norm("[b...] c", decay_rate=0.9)
    rmsnorm         = einn.Norm("b... [c]", mean=False, bias=False)

**einn.Linear** implements a linear layer with optional bias term. The first parameter is an operation string that is forwarded to :func:`einx.dot` to multiply the weight matrix.
A bias is added corresponding to the marked output expressions, and is disabled by passing ``bias=False``.

..  code::

    channel_mix     = einn.Linear("b... [c1->c2]", c2=64)
    spatial_mix1    = einn.Linear("b [s...->s2] c", s2=64)
    spatial_mix2    = einn.Linear("b [s2->s...] c", s=(64, 64))
    patch_embed     = einn.Linear("b (s [s2->])... [c1->c2]", s2=4, c2=64)

**einn.Dropout** implements a stochastic dropout. The first parameter specifies the shape of the mask in einx notation that is applied to the input tensor.

..  code::

    dropout         = einn.Dropout("[...]",       drop_rate=0.2)
    spatial_dropout = einn.Dropout("[b] ... [c]", drop_rate=0.2)
    droppath        = einn.Dropout("[b] ...",     drop_rate=0.2)

The following is an example of a simple fully-connected network for image classification using ``einn`` in Flax:

..  code::

    from flax import linen as nn
    import einx.nn.flax as einn

    class Net(nn.Module):
        @nn.compact
        def __call__(self, x, training):
            for c in [1024, 512, 256]:
                x = einn.Linear("b [...->c]", c=c)(x)
                x = einn.Norm("[b] c", decay_rate=0.99)(x, training=training)
                x = nn.gelu(x)
                x = einn.Dropout("[...]", drop_rate=0.2)(x, training=training)
            x = einn.Linear("b [...->c]", c=10)(x) # 10 classes
            return x

Example trainings on CIFAR10 are provided in ``examples/train_{torch|flax|haiku|equinox|keras}.py`` for models implemented using ``einn``. ``einn`` layers can be combined
with other layers or used as submodules in the respective framework seamlessly.

The following page provides examples of common operations in neural networks using ``einx`` and ``einn`` notation.