1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
'''
Runs unit tests for linear transforms of spectral data & data files.
The unit tests in this module assume the example file "92AV3C.lan" is in the
spectral data path. After the file is opened, unit tests verify that
LinearTransform objects created with SpyFile and numpy.ndarray objects yield
the correct values for known image data values.
To run the unit tests, type the following from the system command line:
# python -m spectral.tests.transforms
'''
from __future__ import division, print_function, unicode_literals
import numpy as np
from numpy.testing import assert_almost_equal
import spectral as spy
from spectral.algorithms.transforms import LinearTransform
from spectral.io.spyfile import SpyFile
from spectral.tests.spytest import SpyTest
class LinearTransformTest(SpyTest):
'''Tests that LinearTransform objects produce correct values.'''
def __init__(self, file, datum, value):
'''
Arguments:
`file` (str or `SpyFile`):
The SpyFile to be tested. This can be either the name of the
file or a SpyFile object that has already been opened.
`datum` (3-tuple of ints):
(i, j, k) are the row, column and band of the datum to be
tested. 'i' and 'j' should be at least 10 pixels away from the
edge of the associated image and `k` should have at least 10
bands above and below it in the image.
`value` (int or float):
The scalar value associated with location (i, j, k) in
the image.
'''
self.file = file
self.datum = datum
self.value = value
def setup(self):
if isinstance(self.file, SpyFile):
self.image = self.file
elif isinstance(self.file, np.ndarray):
self.image = self.file
else:
self.image = spy.open_image(self.file)
self.scalar = 10.
self.matrix = self.scalar * np.identity(self.image.shape[2],
dtype='f8')
self.pre = 37.
self.post = 51.
def test_scalar_multiply(self):
(i, j, k) = self.datum
transform = LinearTransform(self.scalar)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * self.value)
def test_pre_scalar_multiply(self):
(i, j, k) = self.datum
transform = LinearTransform(self.scalar, pre=self.pre)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * (self.pre + self.value))
def test_scalar_multiply_post(self):
(i, j, k) = self.datum
transform = LinearTransform(self.scalar, post=self.post)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * self.value + self.post)
def test_pre_scalar_multiply_post(self):
(i, j, k) = self.datum
transform = LinearTransform(self.scalar, pre=self.pre,
post=self.post)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * (self.pre + self.value)
+ self.post)
def test_matrix_multiply(self):
(i, j, k) = self.datum
transform = LinearTransform(self.matrix)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * self.value)
def test_pre_matrix_multiply(self):
(i, j, k) = self.datum
transform = LinearTransform(self.matrix, pre=self.pre)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * (self.pre + self.value))
def test_matrix_multiply_post(self):
(i, j, k) = self.datum
transform = LinearTransform(self.matrix, post=self.post)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * self.value + self.post)
def test_pre_matrix_multiply_post(self):
(i, j, k) = self.datum
transform = LinearTransform(self.matrix, pre=self.pre,
post=self.post)
result = transform(self.image[i, j])[k]
assert_almost_equal(result,
self.scalar * (self.pre + self.value)
+ self.post)
def run():
(fname, datum, value) = ('92AV3C.lan', (99, 99, 99), 2057.0)
image = spy.open_image(fname)
print('\n' + '-' * 72)
print('Running LinearTransform tests on SpyFile object.')
print('-' * 72)
test = LinearTransformTest(image, datum, value)
test.run()
data = image.load()
print('\n' + '-' * 72)
print('Running LinearTransform tests on ImageArray object.')
print('-' * 72)
test = LinearTransformTest(data, datum, value)
test.run()
image.scale_factor = 10000.0
print('\n' + '-' * 72)
print('Running LinearTransform tests on SpyFile object with scale factor.')
print('-' * 72)
test = LinearTransformTest(image, datum, value / 10000.0)
test.run()
if __name__ == '__main__':
from spectral.tests.run import parse_args, reset_stats, print_summary
parse_args()
reset_stats()
run()
print_summary()
|