File: IntegerDotProductOps.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (158 lines) | stat: -rw-r--r-- 6,516 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops  ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the Integer Dot Product operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"

#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"

#include "llvm/Support/FormatVariadic.h"

using namespace mlir::spirv::AttrNames;

namespace mlir::spirv {

//===----------------------------------------------------------------------===//
// Integer Dot Product ops
//===----------------------------------------------------------------------===//

static LogicalResult verifyIntegerDotProduct(Operation *op) {
  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
         "Not an integer dot product op?");
  assert(op->getNumResults() == 1 && "Expected a single result");

  Type factorTy = op->getOperand(0).getType();
  if (op->getOperand(1).getType() != factorTy)
    return op->emitOpError("requires the same type for both vector operands");

  unsigned expectedNumAttrs = 0;
  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
    ++expectedNumAttrs;
    auto packedVectorFormat =
        llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
            op->getAttr(kPackedVectorFormatAttrName));
    if (!packedVectorFormat)
      return op->emitOpError("requires Packed Vector Format attribute for "
                             "integer vector operands");

    assert(packedVectorFormat.getValue() ==
               spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
           "Unknown Packed Vector Format");
    if (intTy.getWidth() != 32)
      return op->emitOpError(
          llvm::formatv("with specified Packed Vector Format ({0}) requires "
                        "integer vector operands to be 32-bits wide",
                        packedVectorFormat.getValue()));
  } else {
    if (op->hasAttr(kPackedVectorFormatAttrName))
      return op->emitOpError(llvm::formatv(
          "with invalid format attribute for vector operands of type '{0}'",
          factorTy));
  }

  if (op->getAttrs().size() > expectedNumAttrs)
    return op->emitError(
        "op only supports the 'format' #spirv.packed_vector_format attribute");

  Type resultTy = op->getResultTypes().front();
  bool hasAccumulator = op->getNumOperands() == 3;
  if (hasAccumulator && op->getOperand(2).getType() != resultTy)
    return op->emitOpError(
        "requires the same accumulator operand and result types");

  unsigned factorBitWidth = getBitWidth(factorTy);
  unsigned resultBitWidth = getBitWidth(resultTy);
  if (factorBitWidth > resultBitWidth)
    return op->emitOpError(
        llvm::formatv("result type has insufficient bit-width ({0} bits) "
                      "for the specified vector operand type ({1} bits)",
                      resultBitWidth, factorBitWidth));

  return success();
}

static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
}

static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
}

static SmallVector<ArrayRef<spirv::Extension>, 1>
getIntegerDotProductExtensions() {
  // Requires the SPV_KHR_integer_dot_product extension, specified either
  // explicitly or implied by target env's SPIR-V version >= 1.6.
  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
  return {extension};
}

static SmallVector<ArrayRef<spirv::Capability>, 1>
getIntegerDotProductCapabilities(Operation *op) {
  // Requires the the DotProduct capability and capabilities that depend on
  // exact op types.
  static const auto dotProductCap = spirv::Capability::DotProduct;
  static const auto dotProductInput4x8BitPackedCap =
      spirv::Capability::DotProductInput4x8BitPacked;
  static const auto dotProductInput4x8BitCap =
      spirv::Capability::DotProductInput4x8Bit;
  static const auto dotProductInputAllCap =
      spirv::Capability::DotProductInputAll;

  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};

  Type factorTy = op->getOperand(0).getType();
  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
    auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
        op->getAttr(kPackedVectorFormatAttrName));
    if (formatAttr.getValue() ==
        spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
      capabilities.push_back(dotProductInput4x8BitPackedCap);

    return capabilities;
  }

  auto vecTy = llvm::cast<VectorType>(factorTy);
  if (vecTy.getElementTypeBitWidth() == 8) {
    capabilities.push_back(dotProductInput4x8BitCap);
    return capabilities;
  }

  capabilities.push_back(dotProductInputAllCap);
  return capabilities;
}

#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); }    \
  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
    return getIntegerDotProductExtensions();                                   \
  }                                                                            \
  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
    return getIntegerDotProductCapabilities(*this);                            \
  }                                                                            \
  std::optional<spirv::Version> OpName::getMinVersion() {                      \
    return getIntegerDotProductMinVersion();                                   \
  }                                                                            \
  std::optional<spirv::Version> OpName::getMaxVersion() {                      \
    return getIntegerDotProductMaxVersion();                                   \
  }

SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp)
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp)

#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP

} // namespace mlir::spirv