1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
|
from __future__ import annotations
import pytest
keras = pytest.importorskip("keras")
np = pytest.importorskip("numpy")
from distributed.protocol import deserialize, dumps, loads, serialize, to_serialize
def test_serialize_deserialize_model():
from numpy.testing import assert_allclose
model = keras.models.Sequential()
model.add(keras.layers.Dense(5, input_dim=3))
model.add(keras.layers.Dense(2))
model.compile(optimizer="sgd", loss="mse")
x = np.random.random((1, 3))
y = np.random.random((1, 2))
model.train_on_batch(x, y)
loaded = deserialize(*serialize(model))
assert_allclose(loaded.predict(x), model.predict(x))
data = {"model": to_serialize(model)}
frames = dumps(data)
result = loads(frames)
assert_allclose(result["model"].predict(x), model.predict(x))
|