File: OrtInferenceSession.cs

package info (click to toggle)
onnxruntime 1.21.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 333,732 kB
  • sloc: cpp: 3,153,079; python: 179,219; ansic: 109,131; asm: 37,791; cs: 34,424; perl: 13,070; java: 11,047; javascript: 6,330; pascal: 4,126; sh: 3,277; xml: 598; objc: 281; makefile: 59
file content (115 lines) | stat: -rw-r--r-- 4,307 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
using Microsoft.ML.OnnxRuntime;
using System.Diagnostics;

namespace MauiModelTester
{
    public enum ExecutionProviders
    {
        CPU,    // CPU execution provider is always available by default
        NNAPI,  // NNAPI is available on Android
        CoreML, // CoreML is available on iOS/macOS
        XNNPACK // XNNPACK is available on ARM/ARM64 platforms and benefits 32-bit float models
    }

    // An inference session runs an ONNX model
    internal class OrtInferenceSession
    {
        public OrtInferenceSession(ExecutionProviders provider = ExecutionProviders.CPU)
        {
            _sessionOptions = new SessionOptions();
            switch (_executionProvider)
            {
                case ExecutionProviders.CPU:
                    break;
                case ExecutionProviders.NNAPI:
                    _sessionOptions.AppendExecutionProvider_Nnapi();
                    break;
                case ExecutionProviders.CoreML:
                    _sessionOptions.AppendExecutionProvider_CoreML();
                    break;
                case ExecutionProviders.XNNPACK:
                    _sessionOptions.AppendExecutionProvider("XNNPACK");
                    break;
            }

            // enable pre/post processing custom operators from onnxruntime-extensions
            _sessionOptions.RegisterOrtExtensions();

            _perfStats = new PerfStats();
        }

        // async task to create the inference session which is an expensive operation.
        public async Task Create()
        {
            // create the InferenceSession. this is an expensive operation so only do this when necessary.
            // the InferenceSession supports multiple calls to Run, including concurrent calls.
            var modelBytes = await Utils.LoadResource("test_data/model.onnx");

            var stopwatch = new Stopwatch();
            stopwatch.Start();
            _inferenceSession = new InferenceSession(modelBytes, _sessionOptions);
            stopwatch.Stop();
            _perfStats.LoadTime = stopwatch.Elapsed;

            (_inputs, _expectedOutputs) = await Utils.LoadTestData();

            // warmup
            Run(1, true);
        }

        public void Run(int iterations = 1, bool isWarmup = false)
        {
            // do all setup outside of the timing
            var runOptions = new RunOptions();
            var outputNames = _inferenceSession.OutputNames;

            _perfStats.ClearRunTimes();

            // var stopwatch = new Stopwatch();

            for (var i = 0; i < iterations; i++)
            {
                // stopwatch.Restart();
                var stopwatch = new Stopwatch();
                stopwatch.Start();

                using IDisposableReadOnlyCollection<OrtValue> results =
                    _inferenceSession.Run(runOptions, _inputs, outputNames);

                stopwatch.Stop();

                if (isWarmup)
                {
                    _perfStats.WarmupTime = stopwatch.Elapsed;

                    // validate the expected output on the first Run only
                    if (_expectedOutputs.Count > 0)
                    {
                        // create dictionary of output name to results
                        var actual = outputNames.Zip(results).ToDictionary(x => x.First, x => x.Second);

                        foreach (var expectedOutput in _expectedOutputs)
                        {
                            var outputName = expectedOutput.Key;
                            Utils.TensorComparer.VerifyTensorResults(outputName, expectedOutput.Value,
                                                                     actual[outputName]);
                        }
                    }
                }
                else
                {
                    _perfStats.AddRunTime(stopwatch.Elapsed);
                }
            }
        }

        public PerfStats PerfStats => _perfStats;

        private SessionOptions _sessionOptions;
        private InferenceSession _inferenceSession;
        private ExecutionProviders _executionProvider = ExecutionProviders.CPU;
        private Dictionary<string, OrtValue> _inputs;
        private Dictionary<string, OrtValue> _expectedOutputs;
        private PerfStats _perfStats;
    }
}