File: test_experimental.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (63 lines) | stat: -rw-r--r-- 2,132 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import importlib
import einx
import numpy as np

if importlib.util.find_spec("jax"):
    import jax
    import jax.numpy as jnp
    from jax.sharding import Mesh

    def assert_sharding(x, mesh=None, partition=None):
        assert {**x.sharding.mesh.shape} == mesh
        assert tuple(x.sharding.spec) == partition

    def test_sharding():
        mesh24 = Mesh(np.asarray(jax.devices("cpu")).reshape(2, 4), axis_names=("d1", "d2"))
        mesh42 = Mesh(np.asarray(jax.devices("cpu")).reshape(4, 2), axis_names=("d1", "d2"))
        mesh4 = Mesh(np.asarray(jax.devices("cpu"))[:4], axis_names=("d1",))

        # Pass mesh=jax.devices("cpu") instead of mesh=None since we cannot set
        # global device to cpu here
        x = jnp.ones((128, 64))
        assert_sharding(
            einx.experimental.shard("([d1] a) b", x, mesh=jax.devices("cpu")), {"d1": 8}, ("d1",)
        )
        assert_sharding(
            einx.experimental.shard("([d1] a) ([d2] b)", x, d2=2, mesh=jax.devices("cpu")),
            {"d1": 4, "d2": 2},
            ("d1", "d2"),
        )
        assert_sharding(
            einx.experimental.shard("([batch] _) ...", x, d2=2, mesh=jax.devices("cpu")),
            {"batch": 8},
            ("batch",),
        )
        assert_sharding(
            einx.experimental.shard("([d1] a) ([d2] b)", x, mesh=mesh24),
            {"d1": 2, "d2": 4},
            ("d1", "d2"),
        )
        assert_sharding(einx.experimental.shard("([d1] a) b", x, mesh=mesh4), {"d1": 4}, ("d1",))
        assert_sharding(
            einx.experimental.shard("b ([d1] a)", x, mesh=mesh4),
            {"d1": 4},
            (
                None,
                "d1",
            ),
        )
        assert_sharding(
            einx.experimental.shard("a ([d1] b)", x, mesh=mesh42),
            {"d1": 4, "d2": 2},
            (
                None,
                "d1",
            ),
        )

        x = jnp.ones((4, 1024, 1024))
        assert_sharding(
            einx.experimental.shard("a ([d2] b) ([d1] c)", x, mesh=mesh42),
            {"d1": 4, "d2": 2},
            (None, "d2", "d1"),
        )