File: test_tensorflow.py

package info (click to toggle)
nanobind 2.10.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,100 kB
  • sloc: cpp: 12,131; python: 6,190; ansic: 4,785; makefile: 22; sh: 15
file content (97 lines) | stat: -rw-r--r-- 2,631 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import test_ndarray_ext as t
import test_tensorflow_ext as ttf
import pytest
import warnings
import importlib
from common import collect

try:
    import tensorflow as tf
    import tensorflow.config
    def needs_tensorflow(x):
        return x
except:
    needs_tensorflow = pytest.mark.skip(reason="TensorFlow is required")


@needs_tensorflow
def test01_constrain_order():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        try:
            c = tf.zeros((3, 5))
        except:
            pytest.skip('tensorflow is missing')

    assert t.check_order(c) == 'C'


@needs_tensorflow
def test02_implicit_conversion():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        try:
            c = tf.zeros((3, 5))
        except:
            pytest.skip('tensorflow is missing')

        t.implicit(tf.zeros((2, 2), dtype=tf.int32))
        t.implicit(tf.zeros((2, 2, 10), dtype=tf.float32)[:, :, 4])
        t.implicit(tf.zeros((2, 2, 10), dtype=tf.int32)[:, :, 4])
        t.implicit(tf.zeros((2, 2, 10), dtype=tf.bool)[:, :, 4])

        with pytest.raises(TypeError) as excinfo:
            t.noimplicit(tf.zeros((2, 2), dtype=tf.int32))

        with pytest.raises(TypeError) as excinfo:
            t.noimplicit(tf.zeros((2, 2), dtype=tf.bool))


@needs_tensorflow
def test03_return_tensorflow():
    collect()
    dc = ttf.destruct_count()
    x = ttf.ret_tensorflow()
    assert x.get_shape().as_list() == [2, 4]
    assert tf.math.reduce_all(
               x == tf.constant([[1,2,3,4], [5,6,7,8]], dtype=tf.float32))
    del x
    collect()
    assert ttf.destruct_count() - dc == 1


@needs_tensorflow
def test04_check():
    assert t.check(tf.zeros((1)))


@needs_tensorflow
def test05_passthrough():
    a = ttf.ret_tensorflow()
    b = t.passthrough(a)
    assert a is b

    a = tf.constant([1, 2, 3])
    b = t.passthrough(a)
    assert a is b

    a = None
    with pytest.raises(TypeError) as excinfo:
        b = t.passthrough(a)
    assert 'incompatible function arguments' in str(excinfo.value)
    b = t.passthrough_arg_none(a)
    assert a is b


@needs_tensorflow
def test06_ro_array():
    if tf.__version__ < '2.19':
        pytest.skip('tensorflow version is too old')
    a = tf.constant([1, 2], dtype=tf.float32)  # immutable
    assert t.accept_ro(a) == 1
    # If the next line fails, delete it, update the version above,
    # and uncomment the three lines below.
    assert t.accept_rw(a) == 1
    # with pytest.raises(TypeError) as excinfo:
    #     t.accept_rw(a)
    # assert 'incompatible function arguments' in str(excinfo.value)