1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
|
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.
import numpy
from hypothesis import given, strategies as st
from hypothesis.extra.numpy import arrays, mutually_broadcastable_shapes
@given(
data=st.data(),
shapes=mutually_broadcastable_shapes(signature="(n?,k),(k,m?)->(n?,m?)"),
types=st.sampled_from([sig for sig in numpy.matmul.types if "O" not in sig]),
)
def test_gufunc_matmul(data, shapes, types):
input_shapes, expected_shape = shapes
input_dtypes, expected_dtype = types.split("->")
array_strats = [
arrays(dtype=dtp, shape=shp, elements={"allow_nan": True})
for dtp, shp in zip(input_dtypes, input_shapes)
]
a, b = data.draw(st.tuples(*array_strats))
result = numpy.matmul(a, b)
assert result.shape == expected_shape
assert result.dtype.char == expected_dtype
|