1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
|
import mockito
from mockito import when, patch
import pytest
import numpy as np
from . import module
pytestmark = pytest.mark.usefixtures("unstub")
def xcompare(a, b):
if isinstance(a, mockito.matchers.Matcher):
return a.matches(b)
return np.array_equal(a, b)
class TestEnsureNumpyWorks:
def testEnsureNumpyArrayAllowedWhenStubbing(self):
array = np.array([1, 2, 3])
when(module).one_arg(array).thenReturn('yep')
with patch(mockito.invocation.MatchingInvocation.compare, xcompare):
assert module.one_arg(array) == 'yep'
def testEnsureNumpyArrayAllowedWhenCalling(self):
array = np.array([1, 2, 3])
when(module).one_arg(Ellipsis).thenReturn('yep')
assert module.one_arg(array) == 'yep'
|