File: test_classifier.py

package info (click to toggle)
gamera 1:3.4.2+git20160808.1725654-1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 22,312 kB
  • ctags: 24,991
  • sloc: xml: 122,324; ansic: 52,869; cpp: 50,664; python: 35,034; makefile: 118; sh: 101
file content (129 lines) | stat: -rw-r--r-- 7,133 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from gamera.core import *
from gamera import knn, classify, gamera_xml
init_gamera()

correct_classes = ['latin.lower.letter.h', 'latin.lower.ligature.ft', 'latin.capital.letter.m', '_group._part.latin.capital.letter.m', '_group._part.latin.capital.letter.m', '_group._part.latin.lower.letter.i', 'latin.lower.letter.d', 'latin.capital.letter.m', 'latin.capital.letter.t', '_group._part.latin.lower.letter.h', 'latin.lower.ligature.fi', '_group._part.latin.lower.ligature.ft', '_group._part.latin.lower.letter.h', '_group._part.latin.lower.letter.i', 'latin.lower.letter.h', 'latin.lower.letter.d', 'latin.lower.letter.d', 'latin.capital.letter.m', 'latin.capital.letter.c', 'latin.lower.letter.t', 'latin.lower.letter.t', '_group._part.latin.lower.letter.n', 'latin.lower.letter.e', 'latin.lower.letter.a', 'latin.lower.letter.r', 'latin.lower.letter.r', 'latin.lower.letter.a', '_group._part.latin.lower.letter.n', '_group._part.latin.lower.letter.i', 'latin.lower.letter.a', 'latin.lower.letter.r', 'latin.lower.letter.r', '_group._part.latin.lower.letter.h', 'latin.lower.letter.e', 'latin.lower.letter.r', 'latin.lower.letter.e', 'latin.lower.letter.n', 'latin.lower.letter.n', 'latin.lower.letter.o', 'latin.lower.letter.e', 'latin.lower.letter.s', 'latin.lower.letter.e', '_group._part.latin.lower.letter.h', '_group._part.latin.lower.letter.i', '_group._part.latin.lower.letter.g', 'latin.lower.letter.a', 'latin.lower.letter.r', 'latin.lower.letter.r', 'latin.lower.letter.o-', 'latin.lower.letter.r', 'hyphen-minus', 'comma', 'full.stop', 'comma', '_group._part.latin.lower.ligature.ft', 'noise', '_group._part.latin.lower.letter.g']


results = [
      ['latin.lower.letter.n', 'latin.capital.letter.m', 'latin.lower.letter.h', 'latin.lower.ligature.ft', 'latin.capital.letter.t', 'latin.lower.letter.g'],
      ['latin.lower.letter.n', 'latin.capital.letter.m', 'latin.lower.letter.h', 'latin.lower.ligature.ft', 'latin.lower.letter.h', 'latin.lower.letter.g'],
      ['latin.lower.letter.n', 'latin.capital.letter.m', 'latin.capital.letter.t', 'latin.lower.letter.h', 'latin.lower.ligature.ft', 'latin.lower.letter.h', 'latin.lower.letter.i', 'latin.lower.letter.g'],
      ['latin.lower.letter.n', 'latin.capital.letter.m', 'latin.capital.letter.t', 'latin.lower.letter.h', 'latin.lower.ligature.ft', 'latin.lower.letter.h', 'latin.lower.letter.i', 'latin.lower.letter.g']
   ]

#featureset = ['area', 'aspect_ratio', 'black_area', 'compactness', 'moments', 'ncols_feature', 'nholes', 'nholes_extended', 'nrows_feature', 'skeleton_features', 'top_bottom', 'volume', 'volume16regions', 'volume64regions', 'zernike_moments']
featureset = ['area', 'aspect_ratio', 'black_area', 'moments', 'nholes_extended', 'skeleton_features', 'volume64regions']

def _test_grouping(classifier, ccs):
   classifier.change_feature_set(featureset)
   cases = [(classify.ShapedGroupingFunction(4), 'min'),
             (classify.ShapedGroupingFunction(4), 'avg'),
             (None, 'min'),
             (None, 'avg')
            ]

   for (i, (func, criterion)) in enumerate(cases):
        if func == None:
            added,removed = classifier.group_list_automatic(ccs, criterion=criterion)
        else:
            added, removed = classifier.group_list_automatic(
               ccs,
               grouping_function = func,
               max_parts_per_group = 10,
               max_graph_size = 64,
               criterion=criterion)
   
        added.sort(lambda a,b: cmp(a.offset_x,b.offset_x))
        assert [cc.get_main_id() for cc in added] == results[i]


def _test_classification(classifier, ccs):
   (id_name, confidence) = classifier.guess_glyph_automatic(ccs[0])
   assert id_name == [(1.0, 'latin.lower.letter.h')]

   classifier.classify_glyph_automatic(ccs[1])
   assert ccs[1].id_name == [(1.0, 'latin.lower.ligature.ft')]
   
   added, removed = classifier.classify_list_automatic(ccs)
   assert [cc.get_main_id() for cc in ccs] == correct_classes
   assert added == [] and removed == []

   classifier.classify_and_update_list_automatic(ccs)
   assert [cc.get_main_id() for cc in ccs] == correct_classes

   classifier.change_feature_set(['area'])
   assert len(list(classifier.database)[0].features) == 1

   added, removed = classifier.classify_list_automatic(ccs)
   assert [cc.get_main_id() for cc in ccs] != correct_classes

   _test_grouping(classifier, ccs)


def _test_training(classifier, ccs):
   length = len(classifier.get_glyphs())
   classifier.classify_glyph_manual(ccs[0], "dummy")
   assert len(classifier.get_glyphs()) == length + 1
   added, removed = classifier.classify_list_manual(ccs, "dummy")
   assert len(classifier.get_glyphs()) == length + len(ccs)
   assert added == [] and removed == []
   classifier.classify_and_update_list_manual(ccs, "dummy")
   assert len(classifier.get_glyphs()) == length + len(ccs)
   classifier.add_to_database(ccs)
   assert len(classifier.get_glyphs()) == length + len(ccs)
   classifier.remove_from_database(ccs)
   assert len(classifier.get_glyphs()) == length
   classifier.add_to_database(ccs)
   assert len(classifier.get_glyphs()) == length + len(ccs)

def test_interactive_classifier():
   # We assume the XML reading/writing itself is fine (given
   # test_xml), but we should test the wrappers in classify anyway
   image = load_image("data/testline.png")
   ccs = image.cc_analysis()

   classifier = knn.kNNInteractive([],features=featureset)
   assert classifier.is_interactive()
   assert len(classifier.get_glyphs()) == 0
   
   classifier.from_xml_filename("data/testline.xml")
   assert len(classifier.get_glyphs()) == 66
   _test_classification(classifier, ccs)
   _test_training(classifier, ccs)
   length = len(classifier.get_glyphs())

   # subtract len(group_parts) because to_xml_filename() does 
   # not save "_group._part"
   group_parts = [x for x in classifier.get_glyphs() 
                  if x.get_main_id().startswith("_group._part")]
   length = length - len(group_parts)

   classifier.to_xml_filename("tmp/testline_classifier.xml")
   classifier.from_xml_filename("tmp/testline_classifier.xml")
   assert len(classifier.get_glyphs()) == length
   classifier.merge_from_xml_filename("data/testline.xml")
   assert len(classifier.get_glyphs()) == length + 66
   classifier.clear_glyphs()
   assert len(classifier.get_glyphs()) == 0
   classifier.from_xml_filename("data/testline.xml")
   assert len(classifier.get_glyphs()) == 66
   
def test_noninteractive_classifier():
   # We assume the XML reading/writing itself is fine (given
   # test_xml), but we should test the wrappers in classify anyway
   image = load_image("data/testline.png")
   ccs = image.cc_analysis()

   database = gamera_xml.glyphs_from_xml("data/testline.xml")
   classifier = knn.kNNNonInteractive(database,features=featureset,normalize=False)
   assert not classifier.is_interactive()
   assert len(classifier.get_glyphs()) == 66
   
   _test_classification(classifier, ccs)

   classifier.serialize("tmp/serialized.knn")
   classifier.clear_glyphs()
   assert len(classifier.get_glyphs()) == 0
   classifier.unserialize("tmp/serialized.knn")