File: tf_model.py

package info (click to toggle)
opencv 4.10.0%2Bdfsg-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 282,092 kB
  • sloc: cpp: 1,178,079; xml: 682,621; python: 49,092; lisp: 31,150; java: 25,469; ansic: 11,039; javascript: 6,085; sh: 1,214; cs: 601; perl: 494; objc: 210; makefile: 173
file content (112 lines) | stat: -rw-r--r-- 3,762 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import cv2
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from ..common.abstract_model import AbstractModel, Framework
from ..common.utils import DNN_LIB, get_full_model_path

CURRENT_LIB = "TF"
MODEL_FORMAT = ".pb"


class TFModelPreparer(AbstractModel):
    """ Class for the preparation of the TF models: original and converted OpenCV Net.

    Args:
        model_name: TF model name
        original_model: TF configured model object or session
        is_ready_graph: indicates whether ready .pb file already exists
        tf_model_graph_path: path to the existing frozen TF graph
    """

    def __init__(
            self,
            model_name="default",
            original_model=None,
            is_ready_graph=False,
            tf_model_graph_path=""
    ):
        self._model_name = model_name
        self._original_model = original_model
        self._model_to_save = ""

        self._is_ready_to_transfer_graph = is_ready_graph
        self.model_path = self._set_model_path(tf_model_graph_path)
        self._dnn_model = self._set_dnn_model()

    def _set_dnn_model(self):
        if not self._is_ready_to_transfer_graph:
            # get model TF graph
            tf_model_graph = tf.function(lambda x: self._original_model(x))

            tf_model_graph = tf_model_graph.get_concrete_function(
                tf.TensorSpec(self._original_model.inputs[0].shape, self._original_model.inputs[0].dtype))

            # obtain frozen concrete function
            frozen_tf_func = convert_variables_to_constants_v2(tf_model_graph)
            frozen_tf_func.graph.as_graph_def()

            # save full TF model
            tf.io.write_graph(graph_or_graph_def=frozen_tf_func.graph,
                              logdir=self.model_path["path"],
                              name=self._model_to_save,
                              as_text=False)

        return cv2.dnn.readNetFromTensorflow(self.model_path["full_path"])

    def _set_model_path(self, tf_pb_file_path):
        """ Method for setting model paths.

        Args:
            tf_pb_file_path: path to the existing TF .pb

        Returns:
            dictionary, where full_path key means saved model path and its full name.
        """
        model_paths_dict = {
            "path": "",
            "full_path": tf_pb_file_path
        }

        if not self._is_ready_to_transfer_graph:
            self._model_to_save = self._model_name + MODEL_FORMAT
            model_paths_dict = get_full_model_path(CURRENT_LIB.lower(), self._model_to_save)

        return model_paths_dict

    def get_prepared_models(self):
        original_lib_name = CURRENT_LIB + " " + self._model_name
        configured_model_dict = {
            original_lib_name: self._original_model,
            DNN_LIB + " " + self._model_name: self._dnn_model
        }
        return configured_model_dict


class TFModelProcessor(Framework):
    def __init__(self, prepared_model, model_name):
        self._prepared_model = prepared_model
        self._name = model_name

    def get_output(self, input_blob):
        assert len(input_blob.shape) == 4
        batch_tf = input_blob.transpose(0, 2, 3, 1)
        out = self._prepared_model(batch_tf)
        return out

    def get_name(self):
        return CURRENT_LIB


class TFDnnModelProcessor(Framework):
    def __init__(self, prepared_dnn_model, model_name):
        self._prepared_dnn_model = prepared_dnn_model
        self._name = model_name

    def get_output(self, input_blob):
        self._prepared_dnn_model.setInput(input_blob)
        ret_val = self._prepared_dnn_model.forward()
        return ret_val

    def get_name(self):
        return DNN_LIB