File: parse_rnnrf.py

package info (click to toggle)
scrappie 1.4.2-9
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 61,724 kB
  • sloc: ansic: 114,526; python: 1,586; makefile: 160; sh: 122
file content (129 lines) | stat: -rwxr-xr-x 5,602 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
import argparse
import pickle
import math
import numpy as np
import re
import sys

parser = argparse.ArgumentParser()
parser.add_argument('--id', default='' , help='Identifier for model names')
parser.add_argument('model', help='Pickle to read model from')


trim_trailing_zeros = re.compile('0+p')

def small_hex(f):
    hf = float(f).hex()
    return trim_trailing_zeros.sub('p', hf)


def process_column(v, pad):
    """ process and pad """
    return [small_hex(f) for f in v] + [small_hex(0.0)] * pad


def cformatM(fh, name, X, nr=None, nc=None):
    nrq = int(math.ceil(X.shape[1] / 4.0))
    pad = nrq * 4 - X.shape[1]
    lines = map(lambda v: ', '.join(process_column(v, pad)), X)

    if nr is None:
        nr = X.shape[1]
    else:
        nrq = int(math.ceil(nr / 4.0))
    if nc is None:
        nc = X.shape[0]

    fh.write('float {}[] = {}\n'.format('__' + name, '{'))
    fh.write('\t' + ',\n\t'.join(lines))
    fh.write('};\n')
    fh.write('_Mat {} = {}\n\t.nr = {},\n\t.nrq = {},\n\t.nc = {},\n\t.stride = {},\n\t.data.f = {}\n{};\n'.format('_' + name, '{', nr, nrq, nc, nrq * 4, '__' + name, '}'))
    fh.write('const scrappie_matrix {} = &{};\n\n'.format(name, '_' + name))


def cformatV(fh, name, X):
    nrq = int(math.ceil(X.shape[0] / 4.0))
    pad = nrq * 4 - X.shape[0]
    lines = ', '.join(list(map(lambda f: small_hex(f), X)) + [small_hex(0.0)] * pad)
    fh.write('float {}[] = {}\n'.format( '__' + name, '{'))
    fh.write('\t' + lines)
    fh.write('};\n')
    fh.write('_Mat {} = {}\n\t.nr = {},\n\t.nrq = {},\n\t.nc = {},\n\t.stride = {},\n\t.data.f = {}\n{};\n'.format('_' + name, '{', X.shape[0], nrq, 1, nrq * 4, '__' + name, '}'))
    fh.write('const scrappie_matrix {} = &{};\n\n'.format(name, '_' + name))


if __name__ == '__main__':
    args = parser.parse_args()
    modelid = args.id + '_'

    with open(args.model, 'rb') as fh:
        network = pickle.load(fh, encoding='latin1')
    network_major_version = network.version[0] if isinstance(network.version, tuple) else network.version
    assert network_major_version >= 2, "Sloika model must be version >= 2 but model is {}.\nPerhaps you need to run Sloika's model_upgrade.py".format(network.version)

    sys.stdout.write("""#pragma once
    #ifndef NANONET_RNNRF_{}MODEL_H
    #define NANONET_RNNRF_{}MODEL_H
    #include <assert.h>
    #include "../util.h"
    """.format(modelid.upper(), modelid.upper()))

    """ Convolution layer
    """

    filterW =  network.sublayers[0].W.get_value()
    nfilter, _ , winlen = filterW.shape
    cformatM(sys.stdout, 'conv_rnnrf_{}W'.format(modelid), filterW.reshape(-1, 1), nr = winlen * 4 - 3, nc=nfilter)
    cformatV(sys.stdout, 'conv_rnnrf_{}b'.format(modelid), network.sublayers[0].b.get_value().reshape(-1))
    sys.stdout.write("const int conv_rnnrf_{}stride = {};\n".format(modelid, network.sublayers[0].stride))
    sys.stdout.write("""const size_t _conv_rnnrf_{}nfilter = {};
    const size_t _conv_rnnrf_{}winlen = {};
    """.format(modelid, nfilter, modelid, winlen))

    """  Backward GRU (first layer)
    """
    gru1 = network.sublayers[1].sublayers[0].sublayers[0]
    cformatM(sys.stdout, 'gruB1_rnnrf_{}iW'.format(modelid), gru1.iW.get_value())
    cformatM(sys.stdout, 'gruB1_rnnrf_{}sW'.format(modelid), gru1.sW.get_value())
    cformatM(sys.stdout, 'gruB1_rnnrf_{}sW2'.format(modelid), gru1.sW2.get_value())
    cformatV(sys.stdout, 'gruB1_rnnrf_{}b'.format(modelid), gru1.b.get_value().reshape(-1))

    """  Forward GRU (second layer)
    """
    gru2 = network.sublayers[2].sublayers[0]
    cformatM(sys.stdout, 'gruF2_rnnrf_{}iW'.format(modelid), gru2.iW.get_value())
    cformatM(sys.stdout, 'gruF2_rnnrf_{}sW'.format(modelid), gru2.sW.get_value())
    cformatM(sys.stdout, 'gruF2_rnnrf_{}sW2'.format(modelid), gru2.sW2.get_value())
    cformatV(sys.stdout, 'gruF2_rnnrf_{}b'.format(modelid), gru2.b.get_value().reshape(-1))

    """ backward GRU(third layer)
    """
    gru3 = network.sublayers[3].sublayers[0].sublayers[0]
    cformatM(sys.stdout, 'gruB3_rnnrf_{}iW'.format(modelid), gru3.iW.get_value())
    cformatM(sys.stdout, 'gruB3_rnnrf_{}sW'.format(modelid), gru3.sW.get_value())
    cformatM(sys.stdout, 'gruB3_rnnrf_{}sW2'.format(modelid), gru3.sW2.get_value())
    cformatV(sys.stdout, 'gruB3_rnnrf_{}b'.format(modelid), gru3.b.get_value().reshape(-1))

    """  Forward GRU (fourth layer)
    """
    gru4 = network.sublayers[4].sublayers[0]
    cformatM(sys.stdout, 'gruF4_rnnrf_{}iW'.format(modelid), gru4.iW.get_value())
    cformatM(sys.stdout, 'gruF4_rnnrf_{}sW'.format(modelid), gru4.sW.get_value())
    cformatM(sys.stdout, 'gruF4_rnnrf_{}sW2'.format(modelid), gru4.sW2.get_value())
    cformatV(sys.stdout, 'gruF4_rnnrf_{}b'.format(modelid), gru4.b.get_value().reshape(-1))

    """ backward GRU(fifth layer)
    """
    gru5 = network.sublayers[5].sublayers[0].sublayers[0]
    cformatM(sys.stdout, 'gruB5_rnnrf_{}iW'.format(modelid), gru5.iW.get_value())
    cformatM(sys.stdout, 'gruB5_rnnrf_{}sW'.format(modelid), gru5.sW.get_value())
    cformatM(sys.stdout, 'gruB5_rnnrf_{}sW2'.format(modelid), gru5.sW2.get_value())
    cformatV(sys.stdout, 'gruB5_rnnrf_{}b'.format(modelid), gru5.b.get_value().reshape(-1))
    """ Global norm layer
    """
    nstate = network.sublayers[6].W.get_value().shape[0]
    cformatM(sys.stdout, 'FF_rnnrf_{}W'.format(modelid), network.sublayers[6].W.get_value())
    cformatV(sys.stdout, 'FF_rnnrf_{}b'.format(modelid), network.sublayers[6].b.get_value())

    sys.stdout.write('#endif /* NANONET_RNNRF_{}MODEL_H */'.format(modelid.upper()))