File: TensorExtensions.cs

package info (click to toggle)
onnxruntime 1.21.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 333,732 kB
  • sloc: cpp: 3,153,079; python: 179,219; ansic: 109,131; asm: 37,791; cs: 34,424; perl: 13,070; java: 11,047; javascript: 6,330; pascal: 4,126; sh: 3,277; xml: 598; objc: 281; makefile: 59
file content (42 lines) | stat: -rw-r--r-- 1,640 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// This file is copied and adapted from the following git repository -
// https://github.com/dotnet/corefx
// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1
// Path: /src/System.Numerics.Tensors/tests/TensorExtensions.cs
// Original license statement below -

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;

namespace Microsoft.ML.OnnxRuntime.Tensors
{
    public static partial class TensorExtensions
    {
        private static int[] s_zeroArray = new[] { 0 };
        private static int[] s_oneArray = new[] { 1 };

        internal static Tensor<T> MatrixMultiply<T>(this Tensor<T> left, Tensor<T> right)
        {
            if (left.Rank != 2)
            {
                throw new InvalidOperationException($"{nameof(MatrixMultiply)} is only valid for a {nameof(Tensor<T>)} of {nameof(left.Rank)} 2.");
            }

            if (right.Rank != 2)
            {
                throw new ArgumentException($"{nameof(Tensor<T>)} {nameof(right)} must have {nameof(left.Rank)} 2.", nameof(right));
            }

            if (left.Dimensions[1] != right.Dimensions[0])
            {
                throw new ArgumentException($"{nameof(Tensor<T>)} {nameof(right)} must have first dimension of {left.Dimensions[1]}.", nameof(right));
            }

            return TensorOperations.Contract(left, right, s_oneArray, s_zeroArray);
        }
    }
}