| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 
 | import einx
import einx.op.util as util
import numpy as np
from functools import partial
from typing import Callable, Union, Any
import numpy.typing as npt
tP = einx.tracer.import_("PartitionSpec", "P", from_="jax.sharding")
tNamedSharding = einx.tracer.import_("NamedSharding", from_="jax.sharding")
tMesh = einx.tracer.import_("Mesh", from_="jax.sharding")
tjax = einx.tracer.import_("jax")
tnp = einx.tracer.import_("numpy", as_="np")
def _is_composed(expr):
    node = expr
    while node is not None:
        if isinstance(node, einx.expr.stage3.Composition):
            return True
        node = node.parent
    return False
@einx.jit(
    trace=lambda t, c: lambda expr_in, tensor_in, expr_out, backend=None: c(
        expr_in,
        t(tensor_in),
        expr_out,
    )
)
def shard_stage3(expr_in, tensor_in, expr_out, mesh=None, backend=None):
    import jax
    for root in [expr_in, expr_out]:
        for expr in root.all():
            if isinstance(expr, einx.expr.stage3.Concatenation):
                raise ValueError("Concatenation not allowed")
            if isinstance(expr, einx.expr.stage3.Marker):
                child = expr
                while child.parent is not None:
                    if (
                        isinstance(child.parent, einx.expr.stage3.List)
                        and _is_composed(child.parent)
                        and child is not child.parent.children[0]
                    ):
                        raise ValueError(
                            "If device axes are used within a composition they "
                            "must appear as the left-most member of the composition"
                        )
                    child = child.parent
    # Call tensor factories
    tensor_in = einx.tracer.call_factory(tensor_in, expr_in.shape, backend=backend)
    (tensor_in,) = backend.all_to_tensor([tensor_in])
    # Flatten expressions
    (expr_in,), (tensor_in,) = util.flatten([expr_in], [tensor_in], backend=backend)
    marked_axes = tuple(
        axis
        for axis in expr_in
        if isinstance(axis, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(axis)
    )
    if mesh is None:
        # Construct new mesh
        devices = tnp.array(tjax.devices()).reshape(tuple(a.value for a in marked_axes))
        mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes))
    elif isinstance(mesh, jax.sharding.Mesh):
        # Got mesh -> check that marked axes match mesh
        marked_names = set(a.name for a in marked_axes)
        mesh_names = set(str(a) for a in mesh.axis_names)
        if not marked_names.issubset(mesh_names):
            raise ValueError(
                f"Marked axes must be subset of mesh axes. Got marked axes {marked_names} and mesh axes {mesh_names}"
            )
    else:
        # Got list of devices -> construct new mesh
        devices = tnp.array(mesh).reshape(tuple(a.value for a in marked_axes))
        mesh = tMesh(devices, axis_names=tuple(a.name for a in marked_axes))
    # Construct partition spec
    axes = tuple(axis for axis in expr_in if isinstance(axis, einx.expr.stage3.Axis))
    partition_spec = [axis.name if einx.expr.stage3.is_marked(axis) else None for axis in axes]
    partition_spec = tP(*partition_spec)
    # Shard tensor
    sharding = tNamedSharding(mesh, partition_spec)
    tensor_in = tjax.device_put(tensor_in, sharding)
    # Unflatten output expressions
    (tensor_in,) = util.unflatten([expr_in], [tensor_in], [expr_out], backend=backend)
    return tensor_in, expr_in
@einx.lru_cache
def parse(description, tensor_shape, cse=True, mesh=None, jax_devices=None, **parameters):
    import jax
    description, parameters = einx.op.util._clean_description_and_parameters(
        description, parameters
    )
    op = einx.expr.stage1.parse_op(description)
    if len(op) != 1:
        raise ValueError(f"Expected exactly one expression, got {len(op)}")
    def solve(eqs):
        return einx.expr.solve(
            [einx.expr.Equation(op[0][0], tensor_shape)]
            + eqs
            + [
                einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None)
                for k, v in parameters.items()
            ],
            cse=cse,
        )[0]
    if mesh is None:
        # If no mesh is given, create new mesh of all devices
        try:
            expr_in = solve([])
        except einx.expr.SolveException as e:
            # Try with additional constraint of total number of devices
            expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0]))
            mesh_eq = einx.expr.Equation(expr_mesh, [len(jax.devices())])
            try:
                expr_in = solve([mesh_eq])
            except einx.expr.SolveException:
                # If it still fails, reraise original exception
                raise e
    elif isinstance(mesh, jax.sharding.Mesh):
        # Add constraints for existing mesh axes
        expr_mesh = einx.expr.stage1.Marker(
            einx.expr.stage1.List.maybe([
                einx.expr.stage1.NamedAxis(name) for name in mesh.axis_names
            ])
        )
        mesh_eq = einx.expr.Equation(expr_mesh, mesh.devices.shape)
        expr_in = solve([mesh_eq])
    elif isinstance(mesh, (list, tuple)):
        # Add constraint for number of devices
        expr_mesh = einx.expr.stage1.Composition(einx.expr.stage1.get_marked(op[0][0]))
        mesh_eq = einx.expr.Equation(expr_mesh, [len(mesh)])
        expr_in = solve([mesh_eq])
    expr_out = expr_in.__deepcopy__()
    return expr_in, expr_out
@einx.traceback_util.filter
@einx.jit(
    trace=lambda t, c: lambda description, tensor, mesh=None, backend=None, **kwargs: c(
        description, t(tensor), mesh=mesh, **kwargs
    )
)
def shard(
    description: str,
    tensor: einx.Tensor,
    mesh: Any = None,
    backend: Union[einx.Backend, str, None] = "jax",
    cse: bool = True,
    **parameters: npt.ArrayLike,
) -> einx.Tensor:
    """Shards a tensor over a mesh of devices.
    *This function is currently experimental and will likely change in future versions.*
    *This function is currently only supported for Jax: A sharding is created
    based on the given expression, and applied to the tensor using* ``jax.device_put``.
    The tensor is sharded across the marked axes in the input expression. The marked axes
    match the axis names and shape of the mesh:
    >>> x = jnp.ones((2, 4, 128))
    >>> x = einx.experimental.shard("[d1 d2] c")
    >>> x.sharding
    NamedSharding(mesh=Mesh('d1': 2, 'd2': 4), spec=PartitionSpec('d1', 'd2', None))
    Axis compositions can be used to apply the
    `sharding rules of Jax <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_,
    where tensor axes are evenly divided by the number of shards:
    >>> x = jnp.ones((128, 640, 480, 3))
    >>> x = einx.experimental.shard("([batch] _) ...", x)
    >>> x.sharding
    NamedSharding(mesh=Mesh('batch': 8), spec=PartitionSpec('batch',))
    If possible, the sharding is created over all devices. ``_`` is a regular axis name,
    and its value is determined by :doc:`einx's expression solver </faq/solver>`.
    Optionally, an existing mesh can be passed:
    >>> from jax.sharding import Mesh
    >>> devices = np.asarray(jax.devices()).reshape(4, 2)
    >>> mesh = Mesh(devices, axis_names=("d1", "d2"))
    >>> x = jnp.ones((4, 1024, 1024))
    >>> x = einx.experimental.shard("a ([d2] b) ([d1] c)", x, mesh=mesh)
    >>> x.sharding
    NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd2', 'd1'))
    The array is replicated over all mesh axes that are not part of the expression:
    >>> x = jnp.ones((1024, 1024))
    >>> x = einx.experimental.shard("a ([d1] b)", x, mesh=mesh)
    >>> x.sharding
    NamedSharding(mesh=Mesh('d1': 4, 'd2': 2), spec=PartitionSpec(None, 'd1',))
    Args:
        description: Description string in Einstein notation (see above).
        tensor: Input tensor or tensor factory matching the description string.
        mesh: Mesh or list of devices to shard the tensor over. If not given, a new mesh over all
            available devices will be created matching the axes in the given expression.
            Defaults to ``None``.
        cse: Whether to apply common subexpression elimination to the expressions. Defaults
            to True.
        graph: Whether to return the graph representation of the operation instead of
            computing the result. Defaults to False.
        **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``.
    Returns:
        The sharded tensor if ``graph=False``, otherwise the graph
        representation of the operation.
    """
    if backend.name != "jax":
        raise NotImplementedError("einx.experimental.shard is currently only supported for Jax")
    expr_in, expr_out = parse(
        description, einx.tracer.get_shape(tensor), mesh=mesh, cse=cse, **parameters
    )
    tensor, expr_out = shard_stage3(expr_in, tensor, expr_out, mesh=mesh, backend=backend)
    return tensor
shard.parse = parse
 |