File: huffman.py

package info (click to toggle)
python-bitarray 3.6.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,288 kB
  • sloc: python: 11,456; ansic: 7,657; makefile: 73; sh: 6
file content (195 lines) | stat: -rw-r--r-- 5,293 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
This library contains useful functionality for working with Huffman trees
and codes.

Note:
There is a function for directly creating a Huffman code from a frequency
map in the bitarray library itself: bitarray.util.huffman_code()
"""
from bitarray import bitarray


class Node(object):
    def __init__(self):
        self.child = [None, None]
        self.freq = None

    def __lt__(self, other):
        # heapq needs to be able to compare the nodes
        return self.freq < other.freq


def huff_code(tree):
    """
    Given a Huffman tree, traverse the tree and return the Huffman code, i.e.
    a dictionary mapping symbols to bitarrays.
    """
    result = {}

    def traverse(nd, prefix=bitarray()):
        try:  # leaf
            result[nd.symbol] = prefix
        except AttributeError:
            traverse(nd.child[0], prefix + bitarray([0]))
            traverse(nd.child[1], prefix + bitarray([1]))

    traverse(tree)
    return result


def insert_symbol(tree, ba, sym):
    """
    Insert symbol into a tree at the position described by the bitarray,
    creating nodes as necessary.
    """
    nd = tree
    for k in ba:
        prev = nd
        nd = nd.child[k]

        if hasattr(nd, 'symbol'):
            raise ValueError("ambiguity")

        if nd is None:
            nd = Node()
            prev.child[k] = nd

    if hasattr(nd, 'symbol') or nd.child[0] or nd.child[1]:
        raise ValueError("ambiguity")

    nd.symbol = sym


def make_tree(codedict):
    """
    Create a tree from the given code dictionary, and return its root node.
    Unlike trees created by huff_tree, all nodes will have .freq set to None.
    """
    tree = Node()
    for sym, ba in codedict.items():
        insert_symbol(tree, ba, sym)
    return tree


def traverse(tree, it):
    """
    Traverse tree until a leaf node is reached, and return its symbol.
    This function consumes an iterator on which next() is called during each
    step of traversing.
    """
    nd = tree
    while 1:
        nd = nd.child[next(it)]
        if nd is None:
            raise ValueError("prefix code does not match data in bitarray")

        try:
            return nd.symbol
        except AttributeError:
            pass

    if nd != tree:
        raise ValueError("decoding not terminated")


def iterdecode(tree, bitsequence):
    """
    Given a tree and a bitsequence, decode the bitsequence and generate
    the symbols.
    """
    it = iter(bitsequence)
    while True:
        try:
            yield traverse(tree, it)
        except StopIteration:
            return


def write_dot(tree, fn, binary=False):
    """
    Given a tree (which may or may not contain frequencies), write
    a graphviz '.dot' file with a visual representation of the tree.
    """
    special_ascii = {' ': 'SPACE', '\n': 'LF', '\r': 'CR', '\t': 'TAB',
                     '\\': r'\\', '"': r'\"'}
    def disp_sym(i):
        if binary:
            return '0x%02x' % i
        else:
            c = chr(i)
            res = special_ascii.get(c, c)
            assert res.strip(), repr(c)
            return res

    def disp_freq(f):
        if f is None:
            return ''
        return '%d' % f

    def write_nd(fo, nd):
        if hasattr(nd, 'symbol'):  # leaf node
            a, b = disp_freq(nd.freq), disp_sym(nd.symbol)
            fo.write('  %d  [label="%s%s%s"];\n' %
                     (id(nd), a, ': ' if a and b else '', b))
            return

        assert hasattr(nd, 'child')
        fo.write('  %d  [shape=circle, style=filled, '
                 'fillcolor=grey, label="%s"];\n' %
                 (id(nd), disp_freq(nd.freq)))

        for k in range(2):
            if nd.child[k]:
                fo.write('  %d->%d;\n' % (id(nd), id(nd.child[k])))

        for k in range(2):
            if nd.child[k]:
                write_nd(fo, nd.child[k])

    with open(fn, 'w') as fo:    # dot -Tpng tree.dot -O
        fo.write('digraph BT {\n')
        fo.write('  node [shape=box, fontsize=20, fontname="Arial"];\n')
        write_nd(fo, tree)
        fo.write('}\n')


def print_code(freq, codedict):
    """
    Given a frequency map (dictionary mapping symbols to their frequency)
    and a codedict, print them in a readable form.
    """
    special_ascii = {0: 'NUL', 9: 'TAB', 10: 'LF', 13: 'CR', 127: 'DEL'}
    def disp_char(i):
        if 32 <= i < 127:
            return repr(chr(i))
        return special_ascii.get(i, '')

    print(' symbol     char    hex   frequency     Huffman code')
    print(70 * '-')
    for i in sorted(codedict, key=lambda c: (freq[c], c), reverse=True):
        print('%7r     %-4s    0x%02x %10i     %s' % (
            i, disp_char(i), i, freq[i], codedict[i].to01()))


def test():
    from bitarray.util import _huffman_tree

    freq = {'a': 10, 'b': 2, 'c': 1}
    tree = _huffman_tree(freq)
    code = huff_code(tree)
    assert len(code['a']) == 1
    assert len(code['b']) == len(code['c']) == 2

    code = {'a': bitarray('0'),
            'b': bitarray('10'),
            'c': bitarray('11')}
    tree = make_tree(code)
    txt = 'abca'
    a = bitarray()
    a.encode(code, txt)
    assert a == bitarray('010110')
    assert ''.join(iterdecode(tree, a)) == txt


if __name__ == '__main__':
    test()