1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
|
import numpy
import pytest
from thinc.api import (
Adam,
ArgsKwargs,
Linear,
Model,
TensorFlowWrapper,
get_current_ops,
keras_subclass,
tensorflow2xp,
xp2tensorflow,
)
from thinc.compat import has_cupy_gpu, has_tensorflow
from thinc.util import to_categorical
from ..util import check_input_converters, make_tempdir
@pytest.fixture
def n_hidden():
return 12
@pytest.fixture
def input_size():
return 784
@pytest.fixture
def n_classes():
return 10
@pytest.fixture
def answer():
return 1
@pytest.fixture
def X(input_size):
ops = get_current_ops()
return ops.alloc(shape=(1, input_size))
@pytest.fixture
def Y(answer, n_classes):
ops = get_current_ops()
return to_categorical(ops.asarray1i([answer]), n_classes=n_classes)
@pytest.fixture
def tf_model(n_hidden, input_size):
import tensorflow as tf
tf_model = tf.keras.Sequential(
[
tf.keras.layers.Dense(n_hidden, input_shape=(input_size,)),
tf.keras.layers.LayerNormalization(),
tf.keras.layers.Dense(n_hidden, activation="relu"),
tf.keras.layers.LayerNormalization(),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
return tf_model
@pytest.fixture
def model(tf_model):
return TensorFlowWrapper(tf_model)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_roundtrip_conversion():
import tensorflow as tf
ops = get_current_ops()
xp_tensor = ops.alloc2f(2, 3, zeros=True)
tf_tensor = xp2tensorflow(xp_tensor)
assert isinstance(tf_tensor, tf.Tensor)
new_xp_tensor = tensorflow2xp(tf_tensor, ops=ops)
assert ops.xp.array_equal(xp_tensor, new_xp_tensor)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_construction_requires_keras_model():
import tensorflow as tf
keras_model = tf.keras.Sequential([tf.keras.layers.Dense(12, input_shape=(12,))])
assert isinstance(TensorFlowWrapper(keras_model), Model)
with pytest.raises(ValueError):
TensorFlowWrapper(Linear(2, 3))
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_built_model(model, X, Y):
# built models are validated more and can perform useful operations:
assert model.predict(X) is not None
# Can print a keras summary
assert str(model.shims[0]) != ""
# They can de/serialized
assert model.from_bytes(model.to_bytes()) is not None
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_predict(model, X):
model.predict(X)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_train_overfits(model, X, Y, answer):
optimizer = Adam()
ops = get_current_ops()
for i in range(100):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)
d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
predicted = model.predict(X).argmax()
assert predicted == answer
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_accumulate_gradients(model, X, Y, answer):
import tensorflow as tf
optimizer = Adam()
gradients = []
ops = get_current_ops()
for i in range(3):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)
d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
shim_grads = [tf.identity(var) for var in model.shims[0].gradients]
gradients.append(shim_grads)
# Apply the gradients
model.finish_update(optimizer)
assert model.shims[0].gradients is None
# Compare prev/next pairs and ensure their gradients have changed
for i in range(len(gradients)):
# Skip the first one
if i == 0:
continue
found_diff = False
curr_grads = gradients[i]
prev_grads = gradients[i - 1]
for curr, prev in zip(curr_grads, prev_grads):
if (prev != curr).numpy().any():
found_diff = True
assert found_diff is True
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_serialize_model_subclass(
X, Y, input_size, n_classes, answer
):
import tensorflow as tf
input_shape = (1, input_size)
ops = get_current_ops()
@keras_subclass(
"foo.v1",
X=ops.alloc2f(*input_shape),
Y=to_categorical(ops.asarray1i([1]), n_classes=n_classes),
input_shape=input_shape,
)
class CustomKerasModel(tf.keras.Model):
def __init__(self, **kwargs):
super(CustomKerasModel, self).__init__(**kwargs)
self.in_dense = tf.keras.layers.Dense(
12, name="in_dense", input_shape=input_shape
)
self.out_dense = tf.keras.layers.Dense(
n_classes, name="out_dense", activation="softmax"
)
def call(self, inputs) -> tf.Tensor:
x = self.in_dense(inputs)
return self.out_dense(x)
model = TensorFlowWrapper(CustomKerasModel())
# Train the model to predict the right single answer
optimizer = Adam()
for i in range(50):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)
d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
predicted = model.predict(X).argmax()
assert predicted == answer
# Save then Load the model from bytes
model.from_bytes(model.to_bytes())
# The from_bytes model gets the same answer
assert model.predict(X).argmax() == answer
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_keras_subclass_decorator_compile_args():
import tensorflow as tf
class UndecoratedModel(tf.keras.Model):
def call(self, inputs):
return inputs
# Can't wrap an undecorated keras subclass model
with pytest.raises(ValueError):
TensorFlowWrapper(UndecoratedModel())
@keras_subclass(
"TestModel",
X=numpy.array([0.0, 0.0]),
Y=numpy.array([0.5]),
input_shape=(2,),
compile_args={"loss": "binary_crossentropy"},
)
class TestModel(tf.keras.Model):
def call(self, inputs):
return inputs
model = TensorFlowWrapper(TestModel())
model = model.from_bytes(model.to_bytes())
assert model.shims[0]._model.loss == "binary_crossentropy"
assert isinstance(model, Model)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_keras_subclass_decorator():
import tensorflow as tf
class UndecoratedModel(tf.keras.Model):
def call(self, inputs):
return inputs
# Can't wrap an undecorated keras subclass model
with pytest.raises(ValueError):
TensorFlowWrapper(UndecoratedModel())
@keras_subclass(
"TestModel", X=numpy.array([0.0, 0.0]), Y=numpy.array([0.5]), input_shape=(2,)
)
class TestModel(tf.keras.Model):
def call(self, inputs):
return inputs
# Can wrap an decorated keras subclass model
assert isinstance(TensorFlowWrapper(TestModel()), Model)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_keras_subclass_decorator_capture_args_kwargs(
X, Y, input_size, n_classes, answer
):
import tensorflow as tf
@keras_subclass(
"TestModel", X=numpy.array([0.0, 0.0]), Y=numpy.array([0.5]), input_shape=(2,)
)
class TestModel(tf.keras.Model):
def __init__(self, custom=False, **kwargs):
super().__init__(self)
# This is to force the mode to pass the captured arguments
# or fail.
assert custom is True
assert kwargs.get("other", None) is not None
def call(self, inputs):
return inputs
# Can wrap an decorated keras subclass model
model = TensorFlowWrapper(TestModel(True, other=1337))
assert hasattr(model.shims[0]._model, "eg_args")
args_kwargs = model.shims[0]._model.eg_args
assert True in args_kwargs.args
assert "other" in args_kwargs.kwargs
# Raises an error if the args/kwargs is not serializable
obj = {}
obj["key"] = obj
with pytest.raises(ValueError):
TensorFlowWrapper(TestModel(True, other=obj))
# Provides the same arguments when copying a capture model
model = model.from_bytes(model.to_bytes())
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_can_copy_model(model):
copy = model.copy()
assert copy is not None
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_print_summary(model, X):
summary = str(model.shims[0])
# Summary includes the layers of our model
assert "layer_normalization" in summary
assert "dense" in summary
# And counts of params
assert "Total params" in summary
assert "Trainable params" in summary
assert "Non-trainable params" in summary
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_to_bytes(model, X):
# And can be serialized
model_bytes = model.to_bytes()
assert model_bytes is not None
model.from_bytes(model_bytes)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_to_from_disk(model, X, Y, answer):
with make_tempdir() as tmp_path:
model_file = tmp_path / "model.h5"
model.to_disk(model_file)
another_model = model.from_disk(model_file)
assert another_model is not None
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_from_bytes(model, X):
model.predict(X)
model_bytes = model.to_bytes()
another_model = model.from_bytes(model_bytes)
assert another_model is not None
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_use_params(model, X, Y, answer):
optimizer = Adam()
ops = get_current_ops()
with model.use_params(optimizer.averages):
assert model.predict(X).argmax() is not None
for i in range(10):
guesses, backprop = model.begin_update(X)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)
d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
with model.use_params(optimizer.averages):
predicted = model.predict(X).argmax()
assert predicted == answer
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_to_cpu(tf_model):
model = TensorFlowWrapper(tf_model)
model.to_cpu()
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
@pytest.mark.skipif(not has_cupy_gpu, reason="needs GPU/cupy")
def test_tensorflow_wrapper_to_gpu(model, X):
model.to_gpu(0)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
@pytest.mark.parametrize(
"data,n_args,kwargs_keys",
[
# fmt: off
(numpy.zeros((2, 3), dtype="f"), 1, []),
([numpy.zeros((2, 3), dtype="f"), numpy.zeros((2, 3), dtype="f")], 2, []),
((numpy.zeros((2, 3), dtype="f"), numpy.zeros((2, 3), dtype="f")), 2, []),
({"a": numpy.zeros((2, 3), dtype="f"), "b": numpy.zeros((2, 3), dtype="f")}, 0, ["a", "b"]),
(ArgsKwargs((numpy.zeros((2, 3), dtype="f"), numpy.zeros((2, 3), dtype="f")), {"c": numpy.zeros((2, 3), dtype="f")}), 2, ["c"]),
# fmt: on
],
)
def test_tensorflow_wrapper_convert_inputs(data, n_args, kwargs_keys):
import tensorflow as tf
keras_model = tf.keras.Sequential([tf.keras.layers.Dense(12, input_shape=(12,))])
model = TensorFlowWrapper(keras_model)
convert_inputs = model.attrs["convert_inputs"]
Y, backprop = convert_inputs(model, data, is_train=True)
check_input_converters(Y, backprop, data, n_args, kwargs_keys, tf.Tensor)
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_thinc_model_subclass(tf_model):
class CustomModel(Model):
def fn(self):
return 1337
model = TensorFlowWrapper(tf_model, model_class=CustomModel)
assert isinstance(model, CustomModel)
assert model.fn() == 1337
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_thinc_set_model_name(tf_model):
model = TensorFlowWrapper(tf_model, model_name="cool")
assert model.name == "cool"
|