File: test_tfcolifilt.py

package info (click to toggle)
python-dtcwt 0.12.0-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 8,404 kB
  • sloc: python: 6,253; sh: 29; makefile: 13
file content (124 lines) | stat: -rw-r--r-- 3,291 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import pytest
from pytest import raises

import numpy as np
from dtcwt.coeffs import qshift
from dtcwt.numpy.lowlevel import colifilt as np_colifilt
from importlib import import_module

from tests.util import skip_if_no_tf
import tests.datasets as datasets


@skip_if_no_tf
def test_setup():
    global mandrill, mandrill_t, tf, colifilt
    tf = import_module('tensorflow')
    lowlevel = import_module('dtcwt.tf.lowlevel')
    colifilt = getattr(lowlevel, 'colifilt')

    mandrill = datasets.mandrill()
    mandrill_t = tf.expand_dims(tf.constant(mandrill, dtype=tf.float32),axis=0)


@skip_if_no_tf
def test_mandrill_loaded():
    assert mandrill.shape == (512, 512)
    assert mandrill.min() >= 0
    assert mandrill.max() <= 1
    assert mandrill.dtype == np.float32
    assert mandrill_t.get_shape() == (1, 512, 512)


@skip_if_no_tf
def test_odd_filter():
    with raises(ValueError):
        colifilt(mandrill_t, (-1,2,-1), (-1,2,1))


@skip_if_no_tf
def test_different_size_h():
    with raises(ValueError):
        colifilt(mandrill_t, (-1,2,1), (-0.5,-1,2,-1,0.5))


@skip_if_no_tf
def test_zero_input():
    Y = colifilt(mandrill_t, (-1,1), (1,-1))
    with tf.Session() as sess:
        y = sess.run(Y, {mandrill_t : [np.zeros_like(mandrill)]})[0]
    assert np.all(y[:0] == 0)


@skip_if_no_tf
def test_bad_input_size():
    with raises(ValueError):
        colifilt(mandrill_t[:,:511,:], (-1,1), (1,-1))


@skip_if_no_tf
def test_good_input_size():
    colifilt(mandrill_t[:,:,:511], (-1,1), (1,-1))


@skip_if_no_tf
def test_output_size():
    Y = colifilt(mandrill_t, (-1,1), (1,-1))
    assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])


@skip_if_no_tf
def test_non_orthogonal_input():
    Y = colifilt(mandrill_t, (1,1), (1,1))
    assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])


@skip_if_no_tf
def test_output_size_non_mult_4():
    Y = colifilt(mandrill_t, (-1,0,0,1), (1,0,0,-1))
    assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])


@skip_if_no_tf
def test_non_orthogonal_input_non_mult_4():
    Y = colifilt(mandrill_t, (1,0,0,1), (1,0,0,1))
    assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])


@skip_if_no_tf
def test_equal_small_in():
    ha = qshift('qshift_b')[0]
    hb = qshift('qshift_b')[1]
    im = mandrill[0:4,0:4]
    im_t = tf.expand_dims(tf.constant(im, tf.float32), axis=0)
    ref = np_colifilt(im, ha, hb)
    y_op = colifilt(im_t, ha, hb)
    with tf.Session() as sess:
        y = sess.run(y_op)
    np.testing.assert_array_almost_equal(y[0], ref, decimal=4)


@skip_if_no_tf
def test_equal_numpy_qshift1():
    ha = qshift('qshift_c')[0]
    hb = qshift('qshift_c')[1]
    ref = np_colifilt(mandrill, ha, hb)
    y_op = colifilt(mandrill_t, ha, hb)
    with tf.Session() as sess:
        y = sess.run(y_op)
    np.testing.assert_array_almost_equal(y[0], ref, decimal=4)


@skip_if_no_tf
def test_equal_numpy_qshift2():
    ha = qshift('qshift_c')[0]
    hb = qshift('qshift_c')[1]
    im = mandrill[:508, :502]
    im_t = tf.expand_dims(tf.constant(im, tf.float32), axis=0)
    ref = np_colifilt(im, ha, hb)
    y_op = colifilt(im_t, ha, hb)
    with tf.Session() as sess:
        y = sess.run(y_op)
    np.testing.assert_array_almost_equal(y[0], ref, decimal=4)

# vim:sw=4:sts=4:et