File: test_nn.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (288 lines) | stat: -rw-r--r-- 10,566 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import einx
import importlib
import pytest
import numpy as np
from functools import partial

norms = [
    ("[b...] c", {}),
    ("b [s...] (g [c])", {"g": 2}),
    ("b [s...] c", {}),
    ("b... [c]", {}),
    ("b [s...] ([g] c)", {"g": 2}),
]

if importlib.util.find_spec("torch"):
    import torch
    import einx.nn.torch

    if "compiler" in dir(torch):
        compiler = torch.compiler
    else:
        import torch._dynamo as compiler

    def test_torch_linear():
        compiler.reset()
        x = torch.zeros((4, 128, 128, 3))

        layer = einx.nn.torch.Linear("b... [c1->c2]", c2=32)
        assert layer.forward(x).shape == (4, 128, 128, 32)
        layer = torch.compile(layer)
        assert layer.forward(x).shape == (4, 128, 128, 32)

    @pytest.mark.parametrize("expr_kwargs", norms)
    @pytest.mark.parametrize("mean", [True, False])
    @pytest.mark.parametrize("scale", [True, False])
    @pytest.mark.parametrize("decay_rate", [None, 0.9])
    def test_torch_norm(expr_kwargs, mean, scale, decay_rate):
        compiler.reset()
        expr, kwargs = expr_kwargs
        x = torch.zeros((4, 128, 128, 32))

        layer = einx.nn.torch.Norm(expr, mean=mean, scale=scale, decay_rate=decay_rate, **kwargs)
        layer.train()
        assert layer.forward(x).shape == (4, 128, 128, 32)
        layer.eval()
        assert layer.forward(x).shape == (4, 128, 128, 32)

        layer = torch.compile(layer, fullgraph=True)
        layer.train()
        assert layer.forward(x).shape == (4, 128, 128, 32)
        layer.eval()
        assert layer.forward(x).shape == (4, 128, 128, 32)

    def test_torch_dropout():
        compiler.reset()
        x = torch.zeros((4, 128, 128, 3))

        layer = einx.nn.torch.Dropout("[b] ... [c]", drop_rate=0.2)
        layer.train()
        assert layer.forward(x).shape == (4, 128, 128, 3)
        layer = torch.compile(layer)
        assert layer.forward(x).shape == (4, 128, 128, 3)

        layer = einx.nn.torch.Dropout("[b] ... [c]", drop_rate=0.2)
        layer.eval()
        assert layer.forward(x).shape == (4, 128, 128, 3)
        layer = torch.compile(layer)
        assert layer.forward(x).shape == (4, 128, 128, 3)


if importlib.util.find_spec("haiku"):
    import haiku as hk
    import jax.numpy as jnp
    import jax
    import einx.nn.haiku

    def test_haiku_linear():
        x = jnp.zeros((4, 128, 128, 3))
        rng = jax.random.PRNGKey(42)

        def model(x):
            return einx.nn.haiku.Linear("b... [c1->c2]", c2=32)(x)

        model = hk.transform_with_state(model)

        params, state = model.init(rng=rng, x=x)

        y, state = jax.jit(model.apply)(params=params, state=state, x=x, rng=rng)
        assert y.shape == (4, 128, 128, 32)

    @pytest.mark.parametrize("expr_kwargs", norms)
    @pytest.mark.parametrize("mean", [True, False])
    @pytest.mark.parametrize("scale", [True, False])
    @pytest.mark.parametrize("decay_rate", [None, 0.9])
    def test_haiku_norm(expr_kwargs, mean, scale, decay_rate):
        expr, kwargs = expr_kwargs
        x = jnp.zeros((4, 128, 128, 32))
        rng = jax.random.PRNGKey(42)

        def model(x, training):
            return einx.nn.haiku.Norm(
                expr, mean=mean, scale=scale, decay_rate=decay_rate, **kwargs
            )(x, training)

        model = hk.transform_with_state(model)

        params, state = model.init(rng=rng, x=x, training=True)

        y, state = jax.jit(partial(model.apply, training=False))(
            params=params, state=state, x=x, rng=rng
        )
        assert y.shape == (4, 128, 128, 32)
        y, state = jax.jit(partial(model.apply, training=True))(
            params=params, state=state, x=x, rng=rng
        )
        assert y.shape == (4, 128, 128, 32)

    def test_haiku_dropout():
        x = jnp.zeros((4, 128, 128, 3))
        rng = jax.random.PRNGKey(42)

        def model(x, training):
            return einx.nn.haiku.Dropout("[b] ... [c]", drop_rate=0.2)(x, training=training)

        model = hk.transform_with_state(model)

        params, state = model.init(rng=rng, x=x, training=True)

        y, state = jax.jit(partial(model.apply, training=True))(
            params=params, state=state, x=x, rng=rng
        )
        assert y.shape == (4, 128, 128, 3)
        y, state = jax.jit(partial(model.apply, training=False))(
            params=params, state=state, x=x, rng=rng
        )
        assert y.shape == (4, 128, 128, 3)


if importlib.util.find_spec("flax"):
    import flax.linen as nn
    import jax.numpy as jnp
    import jax
    import flax
    import einx.nn.flax

    def test_flax_linear():
        x = jnp.zeros((4, 128, 128, 3))
        rng = jax.random.PRNGKey(0)

        model = einx.nn.flax.Linear("b... [c1->c2]", c2=32)

        params = model.init(rng, x)

        y = jax.jit(model.apply)(params, x=x)
        assert y.shape == (4, 128, 128, 32)

    @pytest.mark.parametrize("expr_kwargs", norms)
    @pytest.mark.parametrize("mean", [True, False])
    @pytest.mark.parametrize("scale", [True, False])
    @pytest.mark.parametrize("decay_rate", [None, 0.9])
    def test_flax_norm(expr_kwargs, mean, scale, decay_rate):
        expr, kwargs = expr_kwargs
        x = jnp.zeros((4, 128, 128, 32))
        rng = jax.random.PRNGKey(42)

        model = einx.nn.flax.Norm(expr, mean=mean, scale=scale, decay_rate=decay_rate, **kwargs)

        params = model.init(rng, x, training=True)
        state, params = flax.core.pop(params, "params")

        y, state = jax.jit(partial(model.apply, training=False, mutable=list(state.keys())))(
            {"params": params, **state}, x=x
        )
        assert y.shape == (4, 128, 128, 32)
        y, state = jax.jit(partial(model.apply, training=True, mutable=list(state.keys())))(
            {"params": params, **state}, x=x
        )
        assert y.shape == (4, 128, 128, 32)

    def test_flax_dropout():
        x = jnp.zeros((4, 128, 128, 3))
        rng = jax.random.PRNGKey(0)

        model = einx.nn.flax.Dropout("[b] ... [c]", drop_rate=0.2)

        params = model.init({"params": rng, "dropout": rng}, x, training=True)

        y = jax.jit(partial(model.apply, training=True))(params, x=x, rngs={"dropout": rng})
        assert y.shape == (4, 128, 128, 3)
        y = jax.jit(partial(model.apply, training=False))(params, x=x, rngs={"dropout": rng})
        assert y.shape == (4, 128, 128, 3)


if importlib.util.find_spec("equinox"):
    import equinox as eqx
    import jax.numpy as jnp
    import einx.nn.equinox
    import jax

    def test_equinox_linear():
        x = jnp.zeros((4, 128, 128, 3))
        rng = jax.random.PRNGKey(0)

        layer = einx.nn.equinox.Linear("b... [c1->c2]", c2=32)
        assert layer(x, rng=rng).shape == (4, 128, 128, 32)
        assert layer(x).shape == (4, 128, 128, 32)
        layer = eqx.nn.inference_mode(layer)
        assert layer(x).shape == (4, 128, 128, 32)
        assert layer(x).shape == (4, 128, 128, 32)

    @pytest.mark.parametrize("expr_kwargs", norms)
    @pytest.mark.parametrize("mean", [True, False])
    @pytest.mark.parametrize("scale", [True, False])
    @pytest.mark.parametrize("decay_rate", [None])
    def test_equinox_norm(expr_kwargs, mean, scale, decay_rate):
        expr, kwargs = expr_kwargs
        x = jnp.zeros((4, 128, 128, 32))
        for expr, kwargs in norms:
            for mean in [True, False]:
                for scale in [True, False]:
                    for decay_rate in [
                        None
                    ]:  # Stateful layers are currently not supported for Equinox
                        layer = einx.nn.equinox.Norm(
                            expr, mean=mean, scale=scale, decay_rate=decay_rate, **kwargs
                        )
                        assert layer(x).shape == (4, 128, 128, 32)
                        assert layer(x).shape == (4, 128, 128, 32)
                        layer = eqx.nn.inference_mode(layer)
                        assert layer(x).shape == (4, 128, 128, 32)
                        assert layer(x).shape == (4, 128, 128, 32)

    def test_equinox_dropout():
        x = jnp.zeros((4, 128, 128, 3))
        rng = jax.random.PRNGKey(0)

        layer = einx.nn.equinox.Dropout("[b] ... [c]", drop_rate=0.2)
        assert layer(x, rng=rng).shape == (4, 128, 128, 3)
        assert layer(x, rng=rng).shape == (4, 128, 128, 3)
        layer = eqx.nn.inference_mode(layer)
        assert layer(x, rng=rng).shape == (4, 128, 128, 3)
        assert layer(x, rng=rng).shape == (4, 128, 128, 3)


if importlib.util.find_spec("keras"):
    import keras

    version = tuple(int(i) for i in keras.__version__.split(".")[:2])
    if version >= (3, 0):
        import tensorflow as tf
        import einx.nn.keras

        def test_keras_linear():
            x = tf.zeros((4, 128, 128, 3))

            layer = einx.nn.keras.Linear("b... [c1->c2]", c2=32)
            model = keras.Sequential([layer])
            assert model(x, training=True).shape == (4, 128, 128, 32)
            assert model(x, training=True).shape == (4, 128, 128, 32)
            assert model(x, training=False).shape == (4, 128, 128, 32)
            assert model(x, training=False).shape == (4, 128, 128, 32)

        @pytest.mark.parametrize("expr_kwargs", norms)
        @pytest.mark.parametrize("mean", [True, False])
        @pytest.mark.parametrize("scale", [True, False])
        @pytest.mark.parametrize("decay_rate", [None, 0.9])
        def test_keras_norm(expr_kwargs, mean, scale, decay_rate):
            expr, kwargs = expr_kwargs
            x = tf.zeros((4, 128, 128, 32))

            layer = einx.nn.keras.Norm(
                expr, mean=mean, scale=scale, decay_rate=decay_rate, **kwargs
            )
            model = keras.Sequential([layer])
            assert model(x, training=True).shape == (4, 128, 128, 32)
            assert model(x, training=True).shape == (4, 128, 128, 32)
            assert model(x, training=False).shape == (4, 128, 128, 32)
            assert model(x, training=False).shape == (4, 128, 128, 32)

        def test_keras_dropout():
            x = tf.zeros((4, 128, 128, 3))

            layer = einx.nn.keras.Dropout("[b] ... [c]", drop_rate=0.2)
            model = keras.Sequential([layer])
            assert model(x, training=True).shape == (4, 128, 128, 3)
            assert model(x, training=True).shape == (4, 128, 128, 3)
            assert model(x, training=False).shape == (4, 128, 128, 3)
            assert model(x, training=False).shape == (4, 128, 128, 3)