File: test_gapi_infer_onnx.py

package info (click to toggle)
opencv 4.10.0%2Bdfsg-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 282,092 kB
  • sloc: cpp: 1,178,079; xml: 682,621; python: 49,092; lisp: 31,150; java: 25,469; ansic: 11,039; javascript: 6,085; sh: 1,214; cs: 601; perl: 494; objc: 210; makefile: 173
file content (68 lines) | stat: -rw-r--r-- 2,037 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python

import numpy as np
import cv2 as cv
import os
import sys
import unittest

from tests_common import NewOpenCVTests


try:

    if sys.version_info[:2] < (3, 0):
        raise unittest.SkipTest('Python 2.x is not supported')

    CLASSIFICATION_MODEL_PATH = "vision/classification/squeezenet/model/squeezenet1.0-9.onnx"

    class test_gapi_infer(NewOpenCVTests):
        def find_dnn_file(self, filename):
            return self.find_file(filename, [os.environ.get('OPENCV_GAPI_ONNX_MODEL_PATH')], False)

        def test_onnx_classification(self):
            model_path = self.find_dnn_file(CLASSIFICATION_MODEL_PATH)
            if model_path is None:
                raise unittest.SkipTest("Missing DNN test file")

            in_mat = cv.imread(
                self.find_file("cv/dpm/cat.png",
                [os.environ.get('OPENCV_TEST_DATA_PATH')]))

            g_in = cv.GMat()
            g_infer_inputs = cv.GInferInputs()
            g_infer_inputs.setInput("data_0", g_in)
            g_infer_out = cv.gapi.infer("squeeze-net", g_infer_inputs)
            g_out = g_infer_out.at("softmaxout_1")

            comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out))

            net = cv.gapi.onnx.params("squeeze-net", model_path)
            net.cfgNormalize("data_0", False)
            try:
                out_gapi = comp.apply(cv.gin(in_mat), cv.gapi.compile_args(cv.gapi.networks(net)))
            except cv.error as err:
                if err.args[0] == "G-API has been compiled without ONNX support":
                    raise unittest.SkipTest("G-API has been compiled without ONNX support")
                else:
                    raise

            self.assertEqual((1, 1000, 1, 1), out_gapi.shape)


except unittest.SkipTest as e:

    message = str(e)

    class TestSkip(unittest.TestCase):
        def setUp(self):
            self.skipTest('Skip tests: ' + message)

        def test_skip():
            pass

    pass


if __name__ == '__main__':
    NewOpenCVTests.bootstrap()