File: shrink_tf_graph_weights.py

package info (click to toggle)
opencv 4.10.0%2Bdfsg-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 282,092 kB
  • sloc: cpp: 1,178,079; xml: 682,621; python: 49,092; lisp: 31,150; java: 25,469; ansic: 11,039; javascript: 6,085; sh: 1,214; cs: 601; perl: 494; objc: 210; makefile: 173
file content (62 lines) | stat: -rw-r--r-- 2,306 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# This file is part of OpenCV project.
# It is subject to the license terms in the LICENSE file found in the top-level directory
# of this distribution and at http://opencv.org/license.html.
#
# Copyright (C) 2017, Intel Corporation, all rights reserved.
# Third party copyrights are property of their respective owners.
import tensorflow as tf
import struct
import argparse
import numpy as np

parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.')
parser.add_argument('--input', required=True, help='Path to frozen graph.')
parser.add_argument('--output', required=True, help='Path to output graph.')
parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+',
                    help='List of ops which weights are converted.')
args = parser.parse_args()

DT_FLOAT = 1
DT_HALF = 19

# For the frozen graphs, an every node that uses weights connected to Const nodes
# through an Identity node. Usually they're called in the same way with '/read' suffix.
# We'll replace all of them to Cast nodes.

# Load the model
with tf.gfile.FastGFile(args.input) as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# Set of all inputs from desired nodes.
inputs = []
for node in graph_def.node:
    if node.op in args.ops:
        inputs += node.input

weightsNodes = []
for node in graph_def.node:
    # From the whole inputs we need to keep only an Identity nodes.
    if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT:
        weightsNodes.append(node.input[0])

        # Replace Identity to Cast.
        node.op = 'Cast'
        node.attr['DstT'].type = DT_FLOAT
        node.attr['SrcT'].type = DT_HALF
        del node.attr['T']
        del node.attr['_class']

# Convert weights to halfs.
for node in graph_def.node:
    if node.name in weightsNodes:
        node.attr['dtype'].type = DT_HALF
        node.attr['value'].tensor.dtype = DT_HALF

        floats = node.attr['value'].tensor.tensor_content

        floats = struct.unpack('f' * (len(floats) / 4), floats)
        halfs = np.array(floats).astype(np.float16).view(np.uint16)
        node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)

tf.train.write_graph(graph_def, "", args.output, as_text=False)