File: InferenceSampleApi.cs

package info (click to toggle)
onnxruntime 1.23.2%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 340,744 kB
  • sloc: cpp: 3,222,135; python: 188,267; ansic: 114,318; asm: 37,927; cs: 36,849; java: 10,962; javascript: 6,811; pascal: 4,126; sh: 2,996; xml: 705; objc: 281; makefile: 67
file content (145 lines) | stat: -rw-r--r-- 5,640 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;

namespace Microsoft.ML.OnnxRuntime.InferenceSample
{
    public class InferenceSampleApi : IDisposable
    {
        public InferenceSampleApi()
        {
            _model = LoadModelFromEmbeddedResource("TestData.squeezenet.onnx");

            // this is the data for only one input tensor for this model
            var inputData = LoadTensorFromEmbeddedResource("TestData.bench.in");

            // create default session with default session options
            // Creating an InferenceSession and loading the model is an expensive operation, so generally you would
            // do this once. InferenceSession.Run can be called multiple times, and concurrently.
            CreateInferenceSession();

            // setup sample input data
            var inputMeta = _inferenceSession.InputMetadata;
            _inputData = new List<OrtValue>(inputMeta.Count);
            _orderedInputNames = new List<string>(inputMeta.Count);

            foreach (var name in inputMeta.Keys)
            {
                // We create an OrtValue in this case over the buffer of potentially different shapes.
                // It is Okay as long as the specified shape does not exceed the actual length of the buffer
                var shape = Array.ConvertAll<int, long>(inputMeta[name].Dimensions, Convert.ToInt64);
                Debug.Assert(ShapeUtils.GetSizeForShape(shape) <= inputData.LongLength);

                var ortValue = OrtValue.CreateTensorValueFromMemory(inputData, shape);
                _inputData.Add(ortValue);

                _orderedInputNames.Add(name);
            }
        }

        public void CreateInferenceSession(SessionOptions options = null)
        {
            // Optional : Create session options and set any relevant values.
            // If an additional execution provider is needed it should be added to the SessionOptions prior to
            // creating the InferenceSession. The CPU Execution Provider is always added by default.
            if (options == null)
            {
                options = new SessionOptions { LogId = "Sample" };
            }

            _inferenceSession = new InferenceSession(_model, options);
        }

        public void Execute()
        {
            // Run the inference
            // 'results' is an IDisposableReadOnlyCollection<OrtValue> container
            using (var results = _inferenceSession.Run(null, _orderedInputNames, _inputData, _inferenceSession.OutputNames))
            {
                // dump the results
                for (int i = 0; i < results.Count; ++i)
                {
                    var name = _inferenceSession.OutputNames[i];
                    Console.WriteLine("Output for {0}", name);
                    // We can now access the native buffer directly from the OrtValue, no copy is involved.
                    // Spans are structs and are stack allocated. They do not add any GC pressure.
                    ReadOnlySpan<float> span = results[i].GetTensorDataAsSpan<float>();
                    Console.Write($"Input {i} results:");
                    for(int k = 0; k < span.Length; ++k)
                    {
                        Console.Write($" {span[k]}");
                    }
                    Console.WriteLine();
                }
            }
        }

        protected virtual void Dispose(bool disposing)
        {
            if (disposing && !_disposed)
            {
                _inferenceSession?.Dispose();
                _inferenceSession = null;

                if (_inputData != null)
                    foreach(var v in _inputData)
                    {
                        v?.Dispose();
                    }

                _disposed = true;
            }
        }

        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }

        static float[] LoadTensorFromEmbeddedResource(string path)
        {
            var tensorData = new List<float>();
            var assembly = typeof(InferenceSampleApi).Assembly;

            using (var inputFile = 
                new StreamReader(assembly.GetManifestResourceStream($"{assembly.GetName().Name}.{path}")))
            {
                inputFile.ReadLine(); // skip the input name
                string[] dataStr = inputFile.ReadLine().Split(new char[] { ',', '[', ']' }, 
                                                              StringSplitOptions.RemoveEmptyEntries);
                for (int i = 0; i < dataStr.Length; i++)
                {
                    tensorData.Add(Single.Parse(dataStr[i]));
                }
            }

            return tensorData.ToArray();
        }

        static byte[] LoadModelFromEmbeddedResource(string path)
        {
            var assembly = typeof(InferenceSampleApi).Assembly;
            byte[] model = null;

            using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.{path}"))
            {
                using (MemoryStream memoryStream = new MemoryStream())
                {
                    stream.CopyTo(memoryStream);
                    model = memoryStream.ToArray();
                }
            }

            return model;
        }

        private bool _disposed = false;
        private readonly byte[] _model;
        private readonly List<string> _orderedInputNames;
        private readonly List<OrtValue> _inputData;
        private InferenceSession _inferenceSession;
    }
}