1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
namespace Microsoft.ML.OnnxRuntime.InferenceSample
{
public class InferenceSampleApi : IDisposable
{
public InferenceSampleApi()
{
_model = LoadModelFromEmbeddedResource("TestData.squeezenet.onnx");
// this is the data for only one input tensor for this model
var inputData = LoadTensorFromEmbeddedResource("TestData.bench.in");
// create default session with default session options
// Creating an InferenceSession and loading the model is an expensive operation, so generally you would
// do this once. InferenceSession.Run can be called multiple times, and concurrently.
CreateInferenceSession();
// setup sample input data
var inputMeta = _inferenceSession.InputMetadata;
_inputData = new List<OrtValue>(inputMeta.Count);
_orderedInputNames = new List<string>(inputMeta.Count);
foreach (var name in inputMeta.Keys)
{
// We create an OrtValue in this case over the buffer of potentially different shapes.
// It is Okay as long as the specified shape does not exceed the actual length of the buffer
var shape = Array.ConvertAll<int, long>(inputMeta[name].Dimensions, Convert.ToInt64);
Debug.Assert(ShapeUtils.GetSizeForShape(shape) <= inputData.LongLength);
var ortValue = OrtValue.CreateTensorValueFromMemory(inputData, shape);
_inputData.Add(ortValue);
_orderedInputNames.Add(name);
}
}
public void CreateInferenceSession(SessionOptions options = null)
{
// Optional : Create session options and set any relevant values.
// If an additional execution provider is needed it should be added to the SessionOptions prior to
// creating the InferenceSession. The CPU Execution Provider is always added by default.
if (options == null)
{
options = new SessionOptions { LogId = "Sample" };
}
_inferenceSession = new InferenceSession(_model, options);
}
public void Execute()
{
// Run the inference
// 'results' is an IDisposableReadOnlyCollection<OrtValue> container
using (var results = _inferenceSession.Run(null, _orderedInputNames, _inputData, _inferenceSession.OutputNames))
{
// dump the results
for (int i = 0; i < results.Count; ++i)
{
var name = _inferenceSession.OutputNames[i];
Console.WriteLine("Output for {0}", name);
// We can now access the native buffer directly from the OrtValue, no copy is involved.
// Spans are structs and are stack allocated. They do not add any GC pressure.
ReadOnlySpan<float> span = results[i].GetTensorDataAsSpan<float>();
Console.Write($"Input {i} results:");
for(int k = 0; k < span.Length; ++k)
{
Console.Write($" {span[k]}");
}
Console.WriteLine();
}
}
}
protected virtual void Dispose(bool disposing)
{
if (disposing && !_disposed)
{
_inferenceSession?.Dispose();
_inferenceSession = null;
if (_inputData != null)
foreach(var v in _inputData)
{
v?.Dispose();
}
_disposed = true;
}
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
static float[] LoadTensorFromEmbeddedResource(string path)
{
var tensorData = new List<float>();
var assembly = typeof(InferenceSampleApi).Assembly;
using (var inputFile =
new StreamReader(assembly.GetManifestResourceStream($"{assembly.GetName().Name}.{path}")))
{
inputFile.ReadLine(); // skip the input name
string[] dataStr = inputFile.ReadLine().Split(new char[] { ',', '[', ']' },
StringSplitOptions.RemoveEmptyEntries);
for (int i = 0; i < dataStr.Length; i++)
{
tensorData.Add(Single.Parse(dataStr[i]));
}
}
return tensorData.ToArray();
}
static byte[] LoadModelFromEmbeddedResource(string path)
{
var assembly = typeof(InferenceSampleApi).Assembly;
byte[] model = null;
using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.{path}"))
{
using (MemoryStream memoryStream = new MemoryStream())
{
stream.CopyTo(memoryStream);
model = memoryStream.ToArray();
}
}
return model;
}
private bool _disposed = false;
private readonly byte[] _model;
private readonly List<string> _orderedInputNames;
private readonly List<OrtValue> _inputData;
private InferenceSession _inferenceSession;
}
}
|