1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
|
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is copied and adapted from the following git repository -
// https://github.com/dotnet/corefx
// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1
// Path: /src/System.Numerics.Tensors/tests/TensorTestsBase.cs
// Original license statement below -
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System;
namespace Microsoft.ML.OnnxRuntime.Tensors.Tests
{
public class TensorTestsBase
{
public enum TensorType
{
Dense
};
public class TensorConstructor
{
public TensorType TensorType { get; set; }
public bool IsReversedStride { get; set; }
public Tensor<T> CreateFromArray<T>(Array array)
{
switch (TensorType)
{
case TensorType.Dense:
return array.ToTensor<T>(IsReversedStride);
}
throw new ArgumentException(nameof(TensorType));
}
public Tensor<T> CreateFromDimensions<T>(ReadOnlySpan<int> dimensions)
{
switch (TensorType)
{
case TensorType.Dense:
return new DenseTensor<T>(dimensions, IsReversedStride);
}
throw new ArgumentException(nameof(TensorType));
}
public override string ToString()
{
return $"{TensorType}, {nameof(IsReversedStride)} = {IsReversedStride}";
}
}
private static TensorType[] s_tensorTypes = new[]
{
TensorType.Dense
};
private static bool[] s_reverseStrideValues = new[]
{
false,
true
};
public static IEnumerable<object[]> GetSingleTensorConstructors()
{
foreach (TensorType tensorType in s_tensorTypes)
{
foreach (bool isReversedStride in s_reverseStrideValues)
{
yield return new[]
{
new TensorConstructor()
{
TensorType = tensorType,
IsReversedStride = isReversedStride
}
};
}
}
}
public static IEnumerable<object[]> GetDualTensorConstructors()
{
foreach (TensorType leftTensorType in s_tensorTypes)
{
foreach (TensorType rightTensorType in s_tensorTypes)
{
foreach (bool isLeftReversedStride in s_reverseStrideValues)
{
foreach (bool isRightReversedStride in s_reverseStrideValues)
{
yield return new[]
{
new TensorConstructor()
{
TensorType = leftTensorType,
IsReversedStride = isLeftReversedStride
},
new TensorConstructor()
{
TensorType = rightTensorType,
IsReversedStride = isRightReversedStride
}
};
}
}
}
}
}
public static IEnumerable<object[]> GetTensorAndResultConstructor()
{
foreach (TensorType leftTensorType in s_tensorTypes)
{
foreach (TensorType rightTensorType in s_tensorTypes)
{
foreach (bool isReversedStride in s_reverseStrideValues)
{
yield return new[]
{
new TensorConstructor()
{
TensorType = leftTensorType,
IsReversedStride = isReversedStride
},
new TensorConstructor()
{
TensorType = rightTensorType,
IsReversedStride = isReversedStride
}
};
}
}
}
}
public static NativeMemory<T> NativeMemoryFromArray<T>(T[] array)
{
return NativeMemoryFromArray<T>((Array)array);
}
public static NativeMemory<T> NativeMemoryFromArray<T>(Array array)
{
// this silly method takes a managed array and copies it over to unmanaged memory,
// **only for test purposes**
var memory = NativeMemory<T>.Allocate(array.Length);
var span = memory.GetSpan();
int index = 0;
foreach (T item in array)
{
span[index++] = item;
}
return memory;
}
}
}
|