1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
|
"""
Directly tests various parser utility functions.
"""
from typing import Any, Tuple
import pytest
from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input
from opt_einsum.testing import build_arrays_from_tuples
def test_get_symbol() -> None:
assert get_symbol(2) == "c"
assert get_symbol(200000) == "\U00031540"
# Ensure we skip surrogates '[\uD800-\uDFFF]'
assert get_symbol(55295) == "\ud88b"
assert get_symbol(55296) == "\ue000"
assert get_symbol(57343) == "\ue7ff"
def test_parse_einsum_input() -> None:
eq = "ab,bc,cd"
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops])
assert input_subscripts == eq
assert output_subscript == "ad"
assert operands == ops
def test_parse_einsum_input_shapes_error() -> None:
eq = "ab,bc,cd"
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
with pytest.raises(ValueError):
_ = parse_einsum_input([eq, *ops], shapes=True)
def test_parse_einsum_input_shapes() -> None:
eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == eq
assert output_subscript == "ad"
assert shapes == operands
def test_parse_with_ellisis() -> None:
eq = "...a,ab"
shapes = [(2, 3), (3, 4)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == "da,ab"
assert output_subscript == "db"
assert shapes == operands
@pytest.mark.parametrize(
"array, shape",
[
[[5], (1,)],
[[5, 5], (2,)],
[(5, 5), (2,)],
[[[[[[5, 2]]]]], (1, 1, 1, 1, 2)],
[[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)],
["A", ()],
[b"A", ()],
[True, ()],
[5, ()],
[5.0, ()],
[5.0 + 0j, ()],
],
)
def test_get_shapes(array: Any, shape: Tuple[int]) -> None:
assert get_shape(array) == shape
|