File: IntegerDotProductOps.cpp

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (154 lines) | stat: -rw-r--r-- 6,521 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops  ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the Integer Dot Product operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"

#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"

#include "llvm/Support/FormatVariadic.h"

using namespace mlir::spirv::AttrNames;

namespace mlir::spirv {

//===----------------------------------------------------------------------===//
// Integer Dot Product ops
//===----------------------------------------------------------------------===//

template <typename IntegerDotProductOpTy>
static LogicalResult verifyIntegerDotProduct(Operation *op) {
  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
         "Not an integer dot product op?");
  assert(op->getNumResults() == 1 && "Expected a single result");

  // ODS enforces that vector 1 and vector 2, and result and the accumulator
  // have the same types.
  Type factorTy = op->getOperand(0).getType();
  StringAttr packedVectorFormatAttrName =
      IntegerDotProductOpTy::getFormatAttrName(op->getName());
  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
    auto packedVectorFormat =
        llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
            op->getAttr(packedVectorFormatAttrName));
    if (!packedVectorFormat)
      return op->emitOpError("requires Packed Vector Format attribute for "
                             "integer vector operands");

    assert(packedVectorFormat.getValue() ==
               spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
           "Unknown Packed Vector Format");
    if (intTy.getWidth() != 32)
      return op->emitOpError(
          llvm::formatv("with specified Packed Vector Format ({0}) requires "
                        "integer vector operands to be 32-bits wide",
                        packedVectorFormat.getValue()));
  } else {
    if (op->hasAttr(packedVectorFormatAttrName))
      return op->emitOpError(llvm::formatv(
          "with invalid format attribute for vector operands of type '{0}'",
          factorTy));
  }

  Type resultTy = op->getResultTypes().front();
  unsigned factorBitWidth = getBitWidth(factorTy);
  unsigned resultBitWidth = getBitWidth(resultTy);
  if (factorBitWidth > resultBitWidth)
    return op->emitOpError(
        llvm::formatv("result type has insufficient bit-width ({0} bits) "
                      "for the specified vector operand type ({1} bits)",
                      resultBitWidth, factorBitWidth));

  return success();
}

static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
}

static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
}

static SmallVector<ArrayRef<spirv::Extension>, 1>
getIntegerDotProductExtensions() {
  // Requires the SPV_KHR_integer_dot_product extension, specified either
  // explicitly or implied by target env's SPIR-V version >= 1.6.
  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
  return {extension};
}

template <typename IntegerDotProductOpTy>
static SmallVector<ArrayRef<spirv::Capability>, 1>
getIntegerDotProductCapabilities(Operation *op) {
  // Requires the the DotProduct capability and capabilities that depend on
  // exact op types.
  static const auto dotProductCap = spirv::Capability::DotProduct;
  static const auto dotProductInput4x8BitPackedCap =
      spirv::Capability::DotProductInput4x8BitPacked;
  static const auto dotProductInput4x8BitCap =
      spirv::Capability::DotProductInput4x8Bit;
  static const auto dotProductInputAllCap =
      spirv::Capability::DotProductInputAll;

  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};

  Type factorTy = op->getOperand(0).getType();
  StringAttr packedVectorFormatAttrName =
      IntegerDotProductOpTy::getFormatAttrName(op->getName());
  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
    auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
        op->getAttr(packedVectorFormatAttrName));
    if (formatAttr.getValue() ==
        spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
      capabilities.push_back(dotProductInput4x8BitPackedCap);

    return capabilities;
  }

  auto vecTy = llvm::cast<VectorType>(factorTy);
  if (vecTy.getElementTypeBitWidth() == 8) {
    capabilities.push_back(dotProductInput4x8BitCap);
    return capabilities;
  }

  capabilities.push_back(dotProductInputAllCap);
  return capabilities;
}

#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
  LogicalResult OpName::verify() {                                             \
    return verifyIntegerDotProduct<OpName>(*this);                             \
  }                                                                            \
  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
    return getIntegerDotProductExtensions();                                   \
  }                                                                            \
  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
    return getIntegerDotProductCapabilities<OpName>(*this);                    \
  }                                                                            \
  std::optional<spirv::Version> OpName::getMinVersion() {                      \
    return getIntegerDotProductMinVersion();                                   \
  }                                                                            \
  std::optional<spirv::Version> OpName::getMaxVersion() {                      \
    return getIntegerDotProductMaxVersion();                                   \
  }

SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp)

#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP

} // namespace mlir::spirv