1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
|
from mock import MagicMock
from thinc.api import Linear, with_debug
def test_with_debug():
on_init = MagicMock()
on_forward = MagicMock()
on_backprop = MagicMock()
model = with_debug(
Linear(), on_init=on_init, on_forward=on_forward, on_backprop=on_backprop
)
on_init.assert_not_called()
on_forward.assert_not_called()
on_backprop.assert_not_called()
X = model.ops.alloc2f(1, 1)
Y = model.ops.alloc2f(1, 1)
model.initialize(X=X, Y=Y)
on_init.assert_called_once_with(model, X, Y)
on_forward.assert_not_called()
on_backprop.assert_not_called()
Yh, backprop = model(X, is_train=True)
on_forward.assert_called_once_with(model, X, True)
on_backprop.assert_not_called()
backprop(Y)
on_backprop.assert_called_once_with(Y)
|