File: test_tfrowdfilt.py

package info (click to toggle)
python-dtcwt 0.14.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 8,588 kB
  • sloc: python: 6,287; sh: 29; makefile: 13
file content (102 lines) | stat: -rw-r--r-- 2,760 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pytest import raises

import numpy as np
from importlib import import_module
from dtcwt.coeffs import qshift
from dtcwt.numpy.lowlevel import coldfilt as np_coldfilt

from tests.util import skip_if_no_tf
import tests.datasets as datasets


@skip_if_no_tf
def test_setup():
    global mandrill, mandrill_t, rowdfilt, tf
    tf = import_module('tensorflow')
    lowlevel = import_module('dtcwt.tf.lowlevel')
    rowdfilt = getattr(lowlevel, 'rowdfilt')
    mandrill = datasets.mandrill()
    mandrill_t = tf.expand_dims(tf.constant(mandrill, dtype=tf.float32),axis=0)


@skip_if_no_tf
def test_mandrill_loaded():
    assert mandrill.shape == (512, 512)
    assert mandrill.min() >= 0
    assert mandrill.max() <= 1
    assert mandrill.dtype == np.float32
    assert mandrill_t.get_shape() == (1, 512, 512)


@skip_if_no_tf
def test_odd_filter():
    with raises(ValueError):
        rowdfilt(mandrill_t, (-1,2,-1), (-1,2,1))


@skip_if_no_tf
def test_different_size():
    with raises(ValueError):
        rowdfilt(mandrill_t, (-0.5,-1,2,1,0.5), (-1,2,-1))


@skip_if_no_tf
def test_bad_input_size():
    with raises(ValueError):
        rowdfilt(mandrill_t[:,:,:511], (-1,1), (1,-1))


@skip_if_no_tf
def test_good_input_size():
    rowdfilt(mandrill_t[:,:511,:], (-1,1), (1,-1))


@skip_if_no_tf
def test_good_input_size_non_orthogonal():
    rowdfilt(mandrill_t[:,:511,:], (1,1), (1,1))


@skip_if_no_tf
def test_output_size():
    y_op = rowdfilt(mandrill_t, (-1,1), (1,-1))
    assert y_op.shape[1:] == (mandrill.shape[0], mandrill.shape[1]/2)


@skip_if_no_tf
#  @pytest.mark.skip(reason='Cant pad by more than half the dimension of the input')
def test_equal_small_in():
    ha = qshift('qshift_b')[0]
    hb = qshift('qshift_b')[1]
    im = mandrill[0:4,0:4]
    im_t = tf.expand_dims(tf.constant(im, tf.float32), axis=0)
    ref = np_coldfilt(im.T, ha, hb).T
    y_op = rowdfilt(im_t, ha, hb)
    with tf.Session() as sess:
        y = sess.run(y_op)
    np.testing.assert_array_almost_equal(y[0], ref, decimal=4)


@skip_if_no_tf
def test_equal_numpy_qshift1():
    ha = qshift('qshift_c')[0]
    hb = qshift('qshift_c')[1]
    ref = np_coldfilt(mandrill.T, ha, hb).T
    y_op = rowdfilt(mandrill_t, ha, hb)
    with tf.Session() as sess:
        y = sess.run(y_op)
    np.testing.assert_array_almost_equal(y[0], ref, decimal=4)


@skip_if_no_tf
def test_equal_numpy_qshift2():
    ha = qshift('qshift_c')[0]
    hb = qshift('qshift_c')[1]
    im = mandrill[:508, :504]
    im_t = tf.expand_dims(tf.constant(im, tf.float32), axis=0)
    ref = np_coldfilt(im.T, ha, hb).T
    y_op = rowdfilt(im_t, ha, hb)
    with tf.Session() as sess:
        y = sess.run(y_op)
    np.testing.assert_array_almost_equal(y[0], ref, decimal=4)

# vim:sw=4:sts=4:et