File: CooperativeMatrixOps.cpp

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (138 lines) | stat: -rw-r--r-- 5,705 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops  -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the Cooperative Matrix operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "SPIRVParsingUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "llvm/ADT/STLExtras.h"
#include <cstdint>

using namespace mlir::spirv::AttrNames;

namespace mlir::spirv {

static LogicalResult
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
                       spirv::MemoryAccessAttr memoryOperand) {
  auto pointerType = cast<PointerType>(pointer);
  Type pointeeType = pointerType.getPointeeType();
  if (!isa<ScalarType, VectorType>(pointeeType)) {
    return op->emitOpError(
               "Pointer must point to a scalar or vector type but provided ")
           << pointeeType;
  }

  if (memoryOperand) {
    spirv::MemoryAccess operandSet = memoryOperand.getValue();

    if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
        spirv::bitEnumContainsAll(operandSet,
                                  spirv::MemoryAccess::MakePointerAvailable)) {
      return op->emitOpError(
          "not compatible with memory operand 'MakePointerAvailable'");
    }

    if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
        spirv::bitEnumContainsAll(operandSet,
                                  spirv::MemoryAccess::MakePointerVisible)) {
      return op->emitOpError(
          "not compatible with memory operand 'MakePointerVisible'");
    }

    // The 'Aligned' memory operand requires an alignment literal to follow,
    // which needs to be implemented on the level of op parsing and
    // (de-)serialization.
    // TODO: Consider adding support for this attribute value.
    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
                                  spirv::MemoryAccess::Aligned)) {
      return op->emitOpError("has unhandled memory operand 'Aligned'");
    }
  }

  // TODO: Verify the memory object behind the pointer:
  // > If the Shader capability was declared, Pointer must point into an array
  // > and any ArrayStride decoration on Pointer is ignored.

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//

LogicalResult KHRCooperativeMatrixLoadOp::verify() {
  return verifyCoopMatrixAccess(*this, getPointer().getType(),
                                getResult().getType(), getMemoryOperandAttr());
}

//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixStore
//===----------------------------------------------------------------------===//

LogicalResult KHRCooperativeMatrixStoreOp::verify() {
  return verifyCoopMatrixAccess(*this, getPointer().getType(),
                                getObject().getType(), getMemoryOperandAttr());
}

//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixMulAdd
//===----------------------------------------------------------------------===//

LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
  auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType());
  auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType());
  auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType());

  // Check element types. ODS enforces that `type(c) == type(result)`, so no
  // need to check it here.

  // Check the 'use' part of the type against the operands and the result.
  if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
    return emitOpError("operand #0 must be of use 'MatrixA'");
  if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
    return emitOpError("operand #1 must be of use 'MatrixB'");
  if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
    return emitOpError("operand #2 must be of use 'MatrixAcc'");

  // Check the 'scope' part of the type.
  if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
    return emitOpError("matrix scope mismatch");

  // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'.
  if (typeA.getRows() != typeC.getRows())
    return emitOpError("matrix size mismatch on dimension 'M'");
  if (typeB.getColumns() != typeC.getColumns())
    return emitOpError("matrix size mismatch on dimension 'N'");
  if (typeA.getColumns() != typeB.getRows())
    return emitOpError("matrix size mismatch on dimension 'K'");

  // The spec does not restrict the element types:
  //  > A, B, C, and Result Type need not necessarily have the same component
  //  > type, this is defined by the client API.

  // Check that if Cooperative Matrix Operands are provided, the element type
  // is integer.
  if (getMatrixOperands()) {
    Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
                           typeC.getElementType()};
    if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
      return emitOpError("Matrix Operands require all matrix element types to "
                         "be Integer Types");
    }
  }

  // Any further requirements need to be checked against VCE.
  return success();
}

} // namespace mlir::spirv