File: classifiers.py

package info (click to toggle)
python-spectral 0.22.4-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 1,064 kB
  • sloc: python: 13,161; makefile: 7
file content (177 lines) | stat: -rw-r--r-- 7,461 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
'''
Runs unit tests for classification routines.

To run the unit tests, type the following from the system command line:

    # python -m spectral.tests.classifiers
'''

from __future__ import absolute_import, division, print_function, unicode_literals

import os
import numpy as np
import spectral as spy
from numpy.testing import assert_allclose
from .spytest import SpyTest, test_method
from spectral.tests import testdir

class ClassifierTest(SpyTest):
    '''Tests various classfication functions.'''

    def setup(self):
        if not os.path.isdir(testdir):
            os.mkdir(testdir)
        self.image = spy.open_image('92AV3C.lan')
        self.data = self.image.load()
        self.gt = spy.open_image('92AV3GT.GIS').read_band(0)
        self.ts = spy.create_training_classes(self.data, self.gt,
                                              calc_stats=True)
        self.class_filename = os.path.join(testdir, '92AV3C.classes')

    def test_save_training_sets(self):
        '''Test that TrainingClassSet data can be saved without exception.'''
        ts = spy.create_training_classes(self.data, self.gt, calc_stats=True)
        ts.save(self.class_filename)

    def test_load_training_sets(self):
        '''Test that the data loaded is the same as was saved.'''
        ts = spy.create_training_classes(self.data, self.gt, calc_stats=True)
        ts.save(self.class_filename)
        ts2 = spy.load_training_sets(self.class_filename, image=self.data)
        ids = list(ts.classes.keys())
        for id in ids:
            s1 = ts[id]
            s2 = ts2[id]
            assert(s1.index == s2.index)
            np.testing.assert_almost_equal(s1.class_prob, s2.class_prob)
            assert_allclose(s1.stats.mean, s2.stats.mean)
            assert_allclose(s1.stats.cov, s2.stats.cov)
            np.testing.assert_equal(s1.stats.nsamples, s2.stats.nsamples)

    def test_gmlc_spectrum_image_equal(self):
        '''Tests that classification of spectrum is same as from image.'''
        gmlc = spy.GaussianClassifier(self.ts, min_samples=600)
        data = self.data[20: 30, 30: 40, :]
        assert(gmlc.classify_spectrum(data[2, 2]) == \
               gmlc.classify_image(data)[2, 2])

    def test_gmlc_classify_spyfile_runs(self):
        '''Tests that GaussianClassifier classifies a SpyFile object.'''
        gmlc = spy.GaussianClassifier(self.ts, min_samples=600)
        ret = gmlc.classify_image(self.image)

    def test_gmlc_classify_transformedimage_runs(self):
        '''Tests that GaussianClassifier classifies a TransformedImage object.'''
        pc = spy.principal_components(self.data).reduce(num=3)
        ximg = pc.transform(self.image)
        ts = spy.create_training_classes(pc.transform(self.data), self.gt,
                                              calc_stats=True)
        gmlc = spy.GaussianClassifier(ts)
        ret = gmlc.classify_image(ximg)

    def test_gmlc_classify_ndarray_transformedimage_equal(self):
        '''Gaussian classification of an ndarray and TransformedImage are equal'''
        pc = spy.principal_components(self.data).reduce(num=3)
        ximg = pc.transform(self.image)
        ts = spy.create_training_classes(pc.transform(self.data), self.gt,
                                              calc_stats=True)
        gmlc = spy.GaussianClassifier(ts)
        cl_ximg = gmlc.classify_image(ximg)
        cl_ndarray = gmlc.classify_image(pc.transform(self.data))
        assert(np.all(cl_ximg == cl_ndarray))

    def test_mahalanobis_class_mean(self):
        '''Test that a class's mean spectrum is classified as that class.
        Note this assumes that class priors are equal.
        '''
        mdc = spy.MahalanobisDistanceClassifier(self.ts)
        cl = mdc.classes[0]
        assert(mdc.classify(cl.stats.mean) == cl.index)

    def test_mahalanobis_classify_spyfile_runs(self):
        '''Mahalanobis classifier works with a SpyFile object.'''
        mdc = spy.MahalanobisDistanceClassifier(self.ts)
        ret = mdc.classify_image(self.image)

    def test_mahalanobis_classify_transformedimage_runs(self):
        '''Mahalanobis classifier works with a TransformedImage object.'''
        pc = spy.principal_components(self.data).reduce(num=3)
        ximg = pc.transform(self.image)
        ts = spy.create_training_classes(pc.transform(self.data), self.gt,
                                              calc_stats=True)
        gmlc = spy.MahalanobisDistanceClassifier(ts)
        ret = gmlc.classify_image(ximg)

    def test_mahalanobis_classify_ndarray_transformedimage_equal(self):
        '''Mahalanobis classification of ndarray and TransformedImage are equal'''
        pc = spy.principal_components(self.data).reduce(num=3)
        ximg = pc.transform(self.image)
        ts = spy.create_training_classes(pc.transform(self.data), self.gt,
                                              calc_stats=True)
        mdc = spy.GaussianClassifier(ts)
        cl_ximg = mdc.classify_image(ximg)
        cl_ndarray = mdc.classify_image(pc.transform(self.data))
        assert(np.all(cl_ximg == cl_ndarray))

    def test_perceptron_learns_and(self):
        '''Test that 2x1 network can learn the logical AND function.'''
        from spectral.algorithms.perceptron import test_and
        (success, p) = test_and(stdout=None)
        assert(success)
        
    def test_perceptron_learns_xor(self):
        '''Test that 2x2x1 network can learn the logical XOR function.'''
        from spectral.algorithms.perceptron import test_xor231
        # XOR isn't guaranteed to converge so try at lease a few times
        for i in range(10):
            (success, p) = test_xor231(3000, stdout=None)
            if success is True:
                return
        assert(False)

    def test_perceptron_learns_xor_222(self):
        '''Test that 2x2x2 network can learn the logical XOR function.'''
        from spectral.algorithms.perceptron import test_xor222
        # XOR isn't guaranteed to converge so try at lease a few times
        for i in range(10):
            (success, p) = test_xor222(3000, stdout=None)
            if success is True:
                return
        assert(False)

    def test_perceptron_learns_image_classes(self):
        '''Test that perceptron can learn image class means.'''
        fld = spy.linear_discriminant(self.ts)
        xdata = fld.transform(self.data)
        classes = spy.create_training_classes(xdata, self.gt)
        nfeatures = xdata.shape[-1]
        nclasses = len(classes)
        for i in range(10):
            p = spy.PerceptronClassifier([nfeatures, 20, 8, nclasses])
            success = p.train(classes, 1, 5000, batch=1, momentum=0.3,
                              rate=0.3)
            if success is True:
                return
        assert(False)

    def test_mahalanobis_spectrum_image_equal(self):
        '''Tests that classification of spectrum is same as from image.'''
        mdc = spy.MahalanobisDistanceClassifier(self.ts)
        data = self.data[20: 30, 30: 40, :]
        assert(mdc.classify_spectrum(data[2, 2]) == \
               mdc.classify_image(data)[2, 2])


def run():
    print('\n' + '-' * 72)
    print('Running classifier tests.')
    print('-' * 72)
    test = ClassifierTest()
    test.run()

if __name__ == '__main__':
    from spectral.tests.run import parse_args, reset_stats, print_summary
    parse_args()
    reset_stats()
    run()
    print_summary()