File: tf_frozen_model_extractor.py

package info (click to toggle)
arm-compute-library 20.08%2Bdfsg-5
  • links: PTS
  • area: main
  • in suites: bullseye
  • size: 40,020 kB
  • sloc: cpp: 430,531; lisp: 44,077; ansic: 25,855; cs: 5,724; python: 1,030; sh: 348; makefile: 42
file content (62 lines) | stat: -rw-r--r-- 2,843 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
""" Extract trainable parameters from a frozen model and stores them in numpy arrays.
Usage:
    python tf_frozen_model_extractor -m path_to_frozem_model -d path_to_store_the_parameters

Saves each variable to a {variable_name}.npy binary file.

Note that the script permutes the trainable parameters to NCHW format. This is a pretty manual step thus it's not thoroughly tested.
"""
import argparse
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

strings_to_remove=["read", "/:0"]
permutations = { 1 : [0], 2 : [1, 0], 3 : [2, 1, 0], 4 : [3, 2, 0, 1]}

if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser('Extract TensorFlow net parameters')
    parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to TensorFlow frozen graph file (.pb)')
    parser.add_argument('-d', dest='dumpPath', type=str, required=False, default='./', help='Path to store the resulting files.')
    parser.add_argument('--nostore', dest='storeRes', action='store_false', help='Specify if files should not be stored. Used for debugging.')
    parser.set_defaults(storeRes=True)
    args = parser.parse_args()

    # Create directory if not present
    if not os.path.exists(args.dumpPath):
        os.makedirs(args.dumpPath)

    # Extract parameters
    with tf.Graph().as_default() as graph:
        with tf.Session() as sess:
            print("Loading model.")
            with gfile.FastGFile(args.modelFile, 'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                sess.graph.as_default()

                tf.import_graph_def(graph_def, input_map=None, return_elements=None, name="", op_dict=None, producer_op_list=None)

                for op in graph.get_operations():
                    for op_val in op.values():
                        varname = op_val.name

                        # Skip non-const values
                        if "read" in varname:
                            t  = op_val.eval()
                            tT = t.transpose(permutations[len(t.shape)])
                            t  = np.ascontiguousarray(tT)

                            for s in strings_to_remove:
                                varname = varname.replace(s, "")
                            if os.path.sep in varname:
                                varname = varname.replace(os.path.sep, '_')
                                print("Renaming variable {0} to {1}".format(op_val.name, varname))

                            # Store files
                            if args.storeRes:
                                print("Saving variable {0} with shape {1} ...".format(varname, t.shape))
                                np.save(os.path.join(args.dumpPath, varname), t)