1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
import pytest
from pytest import raises
import numpy as np
from dtcwt.coeffs import qshift
from dtcwt.numpy.lowlevel import colifilt as np_colifilt
from importlib import import_module
from tests.util import skip_if_no_tf
import tests.datasets as datasets
@skip_if_no_tf
def test_setup():
global mandrill, mandrill_t, tf, colifilt
tf = import_module('tensorflow')
lowlevel = import_module('dtcwt.tf.lowlevel')
colifilt = getattr(lowlevel, 'colifilt')
mandrill = datasets.mandrill()
mandrill_t = tf.expand_dims(tf.constant(mandrill, dtype=tf.float32),axis=0)
@skip_if_no_tf
def test_mandrill_loaded():
assert mandrill.shape == (512, 512)
assert mandrill.min() >= 0
assert mandrill.max() <= 1
assert mandrill.dtype == np.float32
assert mandrill_t.get_shape() == (1, 512, 512)
@skip_if_no_tf
def test_odd_filter():
with raises(ValueError):
colifilt(mandrill_t, (-1,2,-1), (-1,2,1))
@skip_if_no_tf
def test_different_size_h():
with raises(ValueError):
colifilt(mandrill_t, (-1,2,1), (-0.5,-1,2,-1,0.5))
@skip_if_no_tf
def test_zero_input():
Y = colifilt(mandrill_t, (-1,1), (1,-1))
with tf.Session() as sess:
y = sess.run(Y, {mandrill_t : [np.zeros_like(mandrill)]})[0]
assert np.all(y[:0] == 0)
@skip_if_no_tf
def test_bad_input_size():
with raises(ValueError):
colifilt(mandrill_t[:,:511,:], (-1,1), (1,-1))
@skip_if_no_tf
def test_good_input_size():
colifilt(mandrill_t[:,:,:511], (-1,1), (1,-1))
@skip_if_no_tf
def test_output_size():
Y = colifilt(mandrill_t, (-1,1), (1,-1))
assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])
@skip_if_no_tf
def test_non_orthogonal_input():
Y = colifilt(mandrill_t, (1,1), (1,1))
assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])
@skip_if_no_tf
def test_output_size_non_mult_4():
Y = colifilt(mandrill_t, (-1,0,0,1), (1,0,0,-1))
assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])
@skip_if_no_tf
def test_non_orthogonal_input_non_mult_4():
Y = colifilt(mandrill_t, (1,0,0,1), (1,0,0,1))
assert Y.shape[1:] == (mandrill.shape[0]*2, mandrill.shape[1])
@skip_if_no_tf
def test_equal_small_in():
ha = qshift('qshift_b')[0]
hb = qshift('qshift_b')[1]
im = mandrill[0:4,0:4]
im_t = tf.expand_dims(tf.constant(im, tf.float32), axis=0)
ref = np_colifilt(im, ha, hb)
y_op = colifilt(im_t, ha, hb)
with tf.Session() as sess:
y = sess.run(y_op)
np.testing.assert_array_almost_equal(y[0], ref, decimal=4)
@skip_if_no_tf
def test_equal_numpy_qshift1():
ha = qshift('qshift_c')[0]
hb = qshift('qshift_c')[1]
ref = np_colifilt(mandrill, ha, hb)
y_op = colifilt(mandrill_t, ha, hb)
with tf.Session() as sess:
y = sess.run(y_op)
np.testing.assert_array_almost_equal(y[0], ref, decimal=4)
@skip_if_no_tf
def test_equal_numpy_qshift2():
ha = qshift('qshift_c')[0]
hb = qshift('qshift_c')[1]
im = mandrill[:508, :502]
im_t = tf.expand_dims(tf.constant(im, tf.float32), axis=0)
ref = np_colifilt(im, ha, hb)
y_op = colifilt(im_t, ha, hb)
with tf.Session() as sess:
y = sess.run(y_op)
np.testing.assert_array_almost_equal(y[0], ref, decimal=4)
# vim:sw=4:sts=4:et
|