1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GEMMLOWP_META_BASE_H_
#define GEMMLOWP_META_BASE_H_
#include <cassert>
#include <cstdint>
#include "../internal/common.h"
namespace gemmlowp {
namespace meta {
template <int align>
inline int AlignTo(int value) {
return ((value + align - 1) / align) * align;
}
inline int AlignTo(int align, int value) {
return ((value + align - 1) / align) * align;
}
template <typename Kernel_, typename OutputStream_>
struct FusedKernelParams {
public:
typedef Kernel_ Kernel;
typedef OutputStream_ OutputStream;
Kernel kernel;
OutputStream output_stream;
};
template <typename InType_, typename OutType_, typename LeftStream_,
typename RightStream_, typename Kernel_, typename OutputStream_>
struct GemmParams {
public:
typedef InType_ InType;
typedef OutType_ OutType;
typedef LeftStream_ LeftStream;
typedef RightStream_ RightStream;
typedef Kernel_ Kernel;
typedef OutputStream_ OutputStream;
typedef FusedKernelParams<Kernel, OutputStream> FusedKernel;
// Common parameters.
int m;
int n;
int k;
const InType* lhs;
const InType* rhs;
OutType* result;
std::uint8_t* scratch;
// Specialized parameters.
LeftStream left_stream;
RightStream right_stream;
FusedKernel fused_kernel;
};
template <typename InType, int lanes_count, int pack_size, int leftovers,
typename StreamParams>
class Stream {
public:
static void Pack(const InType* in, const StreamParams& params, InType* out);
static int UnpackedAdvance(const StreamParams& params);
static int PackedAdvance(const StreamParams& params);
static int UnpackedStride(const StreamParams& params);
static int PackedStride(const StreamParams& params);
};
template <typename InType, typename StreamType>
class StreamUtil {
public:
static const InType* Offset(const StreamType& params, const InType* source,
int offset_stride, int offset_advance);
static int Scratch(const StreamType& params, int lanes);
};
template <typename InType, typename OutType, typename Kernel,
typename OutputStream, int kernel_m, int kernel_n, int pack_size>
class MulKernel {
public:
static void Multiply(const InType* lhs, const InType* rhs,
const FusedKernelParams<Kernel, OutputStream>& params,
OutType* result);
};
template <typename InType_, typename OutType_, typename Kernel_>
struct Transform1DParams {
typedef InType_ InType;
typedef OutType_ OutType;
typedef Kernel_ Kernel;
const InType* input;
OutType* output;
std::uint8_t* scratch;
Kernel kernel;
};
template <typename InType, typename OutType, typename Kernel, int kernel_size,
int leftovers>
class Transform1DKernel {
public:
static void Transform(const InType* input, const Kernel& params,
OutType* output);
};
template <typename InType, typename OutType, typename Transform>
class Transform1DUtil {
public:
static int EstimateComputeCost(const Transform& params);
static const InType* OffsetInput(const Transform& params, const InType* input,
int offset);
static OutType* OffsetOutput(const Transform& params, OutType* output,
int offset);
};
} // namespace meta
} // namespace gemmlowp
#endif // GEMMLOWP_META_BASE_H_
|