File: test_tensorflow_integration.py

package info (click to toggle)
keras 2.3.1%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 9,288 kB
  • sloc: python: 48,266; javascript: 1,794; xml: 297; makefile: 36; sh: 30
file content (55 lines) | stat: -rw-r--r-- 1,652 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import print_function

import os
import tempfile
import pytest
import keras
from keras import layers
from keras.utils.test_utils import get_test_data


@pytest.mark.skipif(keras.backend.backend() != 'tensorflow',
                    reason='Requires TF backend')
def test_tf_optimizer():
    import tensorflow as tf

    num_hidden = 10
    output_dim = 2
    input_dim = 10
    target = 0.8

    if tf.__version__.startswith('1.'):
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=1., rho=0.95, epsilon=1e-08)
    else:
        optimizer = tf.keras.optimizers.Adadelta(
            learning_rate=1., rho=0.95, epsilon=1e-08)

    (x_train, y_train), (x_test, y_test) = get_test_data(
        num_train=1000, num_test=200,
        input_shape=(input_dim,),
        classification=True, num_classes=output_dim)

    model = keras.Sequential()
    model.add(layers.Dense(num_hidden,
                           activation='relu',
                           input_shape=(input_dim,)))
    model.add(layers.Dense(output_dim, activation='softmax'))

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['accuracy'])
    history = model.fit(x_train, y_train, epochs=8, batch_size=16,
                        validation_data=(x_test, y_test), verbose=2)
    assert history.history['val_accuracy'][-1] >= target

    # Test saving.
    _, fname = tempfile.mkstemp('.h5')
    model.save(fname)
    model = keras.models.load_model(fname)
    assert len(model.weights) == 4
    os.remove(fname)


if __name__ == '__main__':
    pytest.main([__file__])