File: test_openclxfm2.py

package info (click to toggle)
python-dtcwt 0.14.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 8,588 kB
  • sloc: python: 6,287; sh: 29; makefile: 13
file content (92 lines) | stat: -rw-r--r-- 2,521 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from pytest import raises

import numpy as np
from dtcwt.coeffs import biort, qshift
from dtcwt.compat import dtwavexfm2 as dtwavexfm2_np, dtwaveifm2
from dtcwt.opencl.transform2d import dtwavexfm2 as dtwavexfm2_cl

from .util import assert_almost_equal, skip_if_no_cl
import tests.datasets as datasets

TOLERANCE = 1e-12
GOLD_TOLERANCE = 1e-5

def setup_module():
    global mandrill
    mandrill = datasets.mandrill()

def test_mandrill_loaded():
    assert mandrill.shape == (512, 512)
    assert mandrill.min() >= 0
    assert mandrill.max() <= 1
    assert mandrill.dtype == np.float32

def _compare_transforms(A, B):
    Yl_A, Yh_A = A
    Yl_B, Yh_B = B
    assert_almost_equal(Yl_A, Yl_B, tolerance=GOLD_TOLERANCE)
    for x, y in zip(Yh_A, Yh_B):
        assert_almost_equal(x, y, tolerance=GOLD_TOLERANCE)

@skip_if_no_cl
def test_simple():
    _compare_transforms(dtwavexfm2_np(mandrill), dtwavexfm2_cl(mandrill))

@skip_if_no_cl
def test_specific_wavelet():
    a = dtwavexfm2_np(mandrill, biort=biort('antonini'), qshift=qshift('qshift_06'))
    b = dtwavexfm2_cl(mandrill, biort=biort('antonini'), qshift=qshift('qshift_06'))
    _compare_transforms(a, b)

@skip_if_no_cl
def test_1d():
    a = dtwavexfm2_np(mandrill[0,:])
    b = dtwavexfm2_cl(mandrill[0,:])
    _compare_transforms(a, b)

@skip_if_no_cl
def test_3d():
    with raises(ValueError):
        Yl, Yh = dtwavexfm2_cl(np.dstack((mandrill, mandrill)))

@skip_if_no_cl
def test_simple_w_scale():
    Yl, Yh, Yscale = dtwavexfm2_cl(mandrill, include_scale=True)

    assert len(Yscale) > 0
    for x in Yscale:
        assert x is not None

@skip_if_no_cl
@skip_if_no_cl
def test_odd_rows():
    a = dtwavexfm2_np(mandrill[:509,:])
    b = dtwavexfm2_cl(mandrill[:509,:])
    _compare_transforms(a, b)

@skip_if_no_cl
def test_odd_cols():
    a = dtwavexfm2_np(mandrill[:,:509])
    b = dtwavexfm2_cl(mandrill[:,:509])
    _compare_transforms(a, b)

@skip_if_no_cl
def test_odd_rows_and_cols():
    a = dtwavexfm2_np(mandrill[:509,:509])
    b = dtwavexfm2_cl(mandrill[:509,:509])
    _compare_transforms(a, b)

@skip_if_no_cl
def test_0_levels():
    a = dtwavexfm2_np(mandrill, nlevels=0)
    b = dtwavexfm2_cl(mandrill, nlevels=0)
    _compare_transforms(a, b)

@skip_if_no_cl
def test_modified():
    a = dtwavexfm2_np(mandrill, biort=biort('near_sym_b_bp'), qshift=qshift('qshift_b_bp'))
    b = dtwavexfm2_cl(mandrill, biort=biort('near_sym_b_bp'), qshift=qshift('qshift_b_bp'))
    _compare_transforms(a, b)

# vim:sw=4:sts=4:et