1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
|
//===-- runtime/dot-product.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "cpp-type.h"
#include "descriptor.h"
#include "reduction.h"
#include "terminator.h"
#include "tools.h"
#include <cinttypes>
namespace Fortran::runtime {
template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
class Accumulator {
public:
using Result = RESULT;
Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
if constexpr (XCAT == TypeCategory::Complex) {
sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
static_cast<Result>(*y_.Element<YT>(&yAt));
} else if constexpr (XCAT == TypeCategory::Logical) {
sum_ = sum_ ||
(IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
} else {
sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
static_cast<Result>(*y_.Element<YT>(&yAt));
}
}
Result GetResult() const { return sum_; }
private:
const Descriptor &x_, &y_;
Result sum_{};
};
template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
static inline RESULT DoDotProduct(
const Descriptor &x, const Descriptor &y, Terminator &terminator) {
RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
SubscriptValue n{x.GetDimension(0).Extent()};
if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
terminator.Crash(
"DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
}
if constexpr (std::is_same_v<XT, YT>) {
if constexpr (std::is_same_v<XT, float>) {
// TODO: call BLAS-1 SDOT or SDSDOT
} else if constexpr (std::is_same_v<XT, double>) {
// TODO: call BLAS-1 DDOT
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
// TODO: call BLAS-1 CDOTC
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
// TODO: call BLAS-1 ZDOTC
}
}
SubscriptValue xAt{x.GetDimension(0).LowerBound()};
SubscriptValue yAt{y.GetDimension(0).LowerBound()};
Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
for (SubscriptValue j{0}; j < n; ++j) {
accumulator.Accumulate(xAt++, yAt++);
}
return accumulator.GetResult();
}
template <TypeCategory RCAT, int RKIND> struct DotProduct {
using Result = CppTypeFor<RCAT, RKIND>;
template <TypeCategory XCAT, int XKIND> struct DP1 {
template <TypeCategory YCAT, int YKIND> struct DP2 {
Result operator()(const Descriptor &x, const Descriptor &y,
Terminator &terminator) const {
if constexpr (constexpr auto resultType{
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
if constexpr (resultType->first == RCAT &&
resultType->second <= RKIND) {
return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
CppTypeFor<YCAT, YKIND>>(x, y, terminator);
}
}
terminator.Crash(
"DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
static_cast<int>(YCAT), YKIND);
}
};
Result operator()(const Descriptor &x, const Descriptor &y,
Terminator &terminator, TypeCategory yCat, int yKind) const {
return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
}
};
Result operator()(const Descriptor &x, const Descriptor &y,
const char *source, int line) const {
Terminator terminator{source, line};
auto xCatKind{x.type().GetCategoryAndKind()};
auto yCatKind{y.type().GetCategoryAndKind()};
RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
x, y, terminator, yCatKind->first, yCatKind->second);
}
};
extern "C" {
std::int8_t RTNAME(DotProductInteger1)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int16_t RTNAME(DotProductInteger2)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int32_t RTNAME(DotProductInteger4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int64_t RTNAME(DotProductInteger8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
#ifdef __SIZEOF_INT128__
common::int128_t RTNAME(DotProductInteger16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
}
#endif
// TODO: REAL/COMPLEX(2 & 3)
float RTNAME(DotProductReal4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
double RTNAME(DotProductReal8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
#if LONG_DOUBLE == 80
long double RTNAME(DotProductReal10)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
}
#elif LONG_DOUBLE == 128
long double RTNAME(DotProductReal16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
}
#endif
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
result = std::complex<float>{
static_cast<float>(z.real()), static_cast<float>(z.imag())};
}
void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
}
#if LONG_DOUBLE == 80
void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
}
#elif LONG_DOUBLE == 128
void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
}
#endif
bool RTNAME(DotProductLogical)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
}
} // extern "C"
} // namespace Fortran::runtime
|