File: shard.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (237 lines) | stat: -rw-r--r-- 9,201 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import einx
import einx.op.util as util
import numpy as np
from functools import partial
from typing import Callable, Union, Any
import numpy.typing as npt

tP = einx.tracer.import_("PartitionSpec", "P", from_="jax.sharding")
tNamedSharding = einx.tracer.import_("NamedSharding", from_="jax.sharding")
tMesh = einx.tracer.import_("Mesh", from_="jax.sharding")
tjax = einx.tracer.import_("jax")
tnp = einx.tracer.import_("numpy", as_="np")


def _is_composed(expr):
    node = expr
    while node is not None:
        if isinstance(node, einx.expr.stage3.Composition):
            return True
        node = node.parent
    return False


@einx.jit(
    trace=lambda t, c: lambda expr_in, tensor_in, expr_out, backend=None: c(
        expr_in,
        t(tensor_in),
        expr_out,
    )
)
def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None):
    import jax

    for root in [expr_in, expr_out]:
        for expr in root.all():
            if isinstance(expr, einx.expr.stage3.Concatenation):
                raise ValueError("Concatenation not allowed")
            if isinstance(expr, einx.expr.stage3.Marker):
                child = expr
                while child.parent is not None:
                    if (
                        isinstance(child.parent, einx.expr.stage3.List)
                        and _is_composed(child.parent)
                        and child is not child.parent.children[0]
                    ):
                        raise ValueError(
                            "If device axes are used within a composition they "
                            "must appear as the left-most member of the composition"
                        )
                    child = child.parent

    # Call tensor factories
    tensor_in = einx.tracer.call_factory(tensor_in, expr_in.shape, backend=backend)
    (tensor_in,) = backend.all_to_tensor([tensor_in])

    # Flatten expressions
    (expr_in,), (tensor_in,) = util.flatten([expr_in], [tensor_in], backend=backend)
    marked_axes = tuple(
        axis
        for axis in expr_in
        if isinstance(axis, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(axis)
    )

    if mesh is None:
        # Construct new mesh
        devices = tnp.array(tjax.devices()).reshape(tuple(a.value for a in marked_axes))
        mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes))
    elif isinstance(mesh, jax.sharding.Mesh):
        # Got mesh -> check that marked axes match mesh
        marked_names = set(a.name for a in marked_axes)
        mesh_names = set(str(a) for a in mesh.axis_names)
        if not marked_names.issubset(mesh_names):
            raise ValueError(
                f"Marked axes must be subset of mesh axes. Got marked axes {marked_names} and mesh axes {mesh_names}"
            )
    else:
        # Got list of devices -> construct new mesh
        devices = tnp.array(mesh).reshape(tuple(a.value for a in marked_axes))
        mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes))

    # Construct partition spec
    axes = tuple(axis for axis in expr_in if isinstance(axis, einx.expr.stage3.Axis))
    partition_spec = [axis.name if einx.expr.stage3.is_marked(axis) else None for axis in axes]
    partition_spec = tP(*partition_spec)

    # Shard tensor
    sharding = tNamedSharding(mesh, partition_spec)
    tensor_in = tjax.device_put(tensor_in, sharding)

    # Unflatten output expressions
    (tensor_in,) = util.unflatten([expr_in], [tensor_in], [expr_out], backend=backend)

    return tensor_in, expr_in


@einx.lru_cache
def parse(description, tensor_shape, cse=True, mesh=None, jax_devices=None, **parameters):
    import jax

    description, parameters = einx.op.util._clean_description_and_parameters(
        description, parameters
    )

    op = einx.expr.stage1.parse_op(description)

    if len(op) != 1:
        raise ValueError(f"Expected exactly one expression, got {len(op)}")

    def solve(eqs):
        return einx.expr.solve(
            [einx.expr.Equation(op[0][0], tensor_shape)]
            + eqs
            + [
                einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None)
                for k, v in parameters.items()
            ],
            cse=cse,
        )[0]

    if mesh is None:
        # If no mesh is given, create new mesh of all devices
        try:
            expr_in = solve([])
        except einx.expr.SolveException as e:
            # Try with additional constraint of total number of devices
            expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0]))
            mesh_eq = einx.expr.Equation(expr_mesh, [len(jax.devices())])
            try:
                expr_in = solve([mesh_eq])
            except einx.expr.SolveException:
                # If it still fails, reraise original exception
                raise e
    elif isinstance(mesh, jax.sharding.Mesh):
        # Add constraints for existing mesh axes
        expr_mesh = einx.expr.stage1.Marker(
            einx.expr.stage1.List.maybe([
                einx.expr.stage1.NamedAxis(name) for name in mesh.axis_names
            ])
        )
        mesh_eq = einx.expr.Equation(expr_mesh, mesh.devices.shape)

        expr_in = solve([mesh_eq])
    elif isinstance(mesh, (list, tuple)):
        # Add constraint for number of devices
        expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0]))
        mesh_eq = einx.expr.Equation(expr_mesh, [len(mesh)])
        expr_in = solve([mesh_eq])

    expr_out = expr_in.__deepcopy__()

    return expr_in, expr_out


@einx.traceback_util.filter
@einx.jit(
    trace=lambda t, c: lambda description, tensor, mesh=None, backend=None, **kwargs: c(
        description, t(tensor), mesh=mesh, **kwargs
    )
)
def shard(
    description: str,
    tensor: einx.Tensor,
    mesh: Any = None,
    backend: Union[einx.Backend, str, None] = "jax",
    cse: bool = True,
    **parameters: npt.ArrayLike,
) -> einx.Tensor:
    """Shards a tensor over a mesh of devices.

    *This function is currently experimental and will likely change in future versions.*

    *This function is currently only supported for Jax: A sharding is created
    based on the given expression, and applied to the tensor using* ``jax.device_put``.

    The tensor is sharded across the marked axes in the input expression. The marked axes
    match the axis names and shape of the mesh:

    >>> x = jnp.ones((2, 4, 128))
    >>> x = einx.experimental.shard("[d1 d2] c")
    >>> x.sharding
    NamedSharding(mesh=Mesh('d1': 2, 'd2': 4), spec=PartitionSpec('d1', 'd2', None))

    Axis compositions can be used to apply the
    `sharding rules of Jax <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_,
    where tensor axes are evenly divided by the number of shards:

    >>> x = jnp.ones((128, 640, 480, 3))
    >>> x = einx.experimental.shard("([batch] _) ...", x)
    >>> x.sharding
    NamedSharding(mesh=Mesh('batch': 8), spec=PartitionSpec('batch',))

    If possible, the sharding is created over all devices. ``_`` is a regular axis name,
    and its value is determined by :doc:`einx's expression solver </faq/solver>`.

    Optionally, an existing mesh can be passed:

    >>> from jax.sharding import Mesh
    >>> devices = np.asarray(jax.devices()).reshape(4, 2)
    >>> mesh = Mesh(devices, axis_names=("d1", "d2"))
    >>> x = jnp.ones((4, 1024, 1024))
    >>> x = einx.experimental.shard("a ([d2] b) ([d1] c)", x, mesh=mesh)
    >>> x.sharding
    NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd2', 'd1'))

    The array is replicated over all mesh axes that are not part of the expression:

    >>> x = jnp.ones((1024, 1024))
    >>> x = einx.experimental.shard("a ([d1] b)", x, mesh=mesh)
    >>> x.sharding
    NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd1',))

    Args:
        description: Description string in Einstein notation (see above).
        tensor: Input tensor or tensor factory matching the description string.
        mesh: Mesh or list of devices to shard the tensor over. If not given, a new mesh over all
            available devices will be created matching the axes in the given expression.
            Defaults to ``None``.
        cse: Whether to apply common subexpression elimination to the expressions. Defaults
            to True.
        graph: Whether to return the graph representation of the operation instead of
            computing the result. Defaults to False.
        **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``.

    Returns:
        The sharded tensor if ``graph=False``, otherwise the graph
        representation of the operation.
    """
    if backend.name != "jax":
        raise NotImplementedError("einx.experimental.shard is currently only supported for Jax")
    expr_in, expr_out = parse(
        description, einx.tracer.get_shape(tensor), mesh=mesh, cse=cse, **parameters
    )
    tensor, expr_out = shard_stage3(expr_in, tensor, expr_out, mesh=mesh, backend=backend)
    return tensor


shard.parse = parse