1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
//===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the Cooperative Matrix operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "SPIRVParsingUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "llvm/ADT/STLExtras.h"
#include <cstdint>
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
static LogicalResult
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
spirv::MemoryAccessAttr memoryOperand) {
auto pointerType = cast<PointerType>(pointer);
Type pointeeType = pointerType.getPointeeType();
if (!isa<ScalarType, VectorType>(pointeeType)) {
return op->emitOpError(
"Pointer must point to a scalar or vector type but provided ")
<< pointeeType;
}
if (memoryOperand) {
spirv::MemoryAccess operandSet = memoryOperand.getValue();
if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
spirv::bitEnumContainsAll(operandSet,
spirv::MemoryAccess::MakePointerAvailable)) {
return op->emitOpError(
"not compatible with memory operand 'MakePointerAvailable'");
}
if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
spirv::bitEnumContainsAll(operandSet,
spirv::MemoryAccess::MakePointerVisible)) {
return op->emitOpError(
"not compatible with memory operand 'MakePointerVisible'");
}
// The 'Aligned' memory operand requires an alignment literal to follow,
// which needs to be implemented on the level of op parsing and
// (de-)serialization.
// TODO: Consider adding support for this attribute value.
if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
spirv::MemoryAccess::Aligned)) {
return op->emitOpError("has unhandled memory operand 'Aligned'");
}
}
// TODO: Verify the memory object behind the pointer:
// > If the Shader capability was declared, Pointer must point into an array
// > and any ArrayStride decoration on Pointer is ignored.
return success();
}
//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//
LogicalResult KHRCooperativeMatrixLoadOp::verify() {
return verifyCoopMatrixAccess(*this, getPointer().getType(),
getResult().getType(), getMemoryOperandAttr());
}
//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixStore
//===----------------------------------------------------------------------===//
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
return verifyCoopMatrixAccess(*this, getPointer().getType(),
getObject().getType(), getMemoryOperandAttr());
}
//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixMulAdd
//===----------------------------------------------------------------------===//
LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType());
auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType());
auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType());
// Check element types. ODS enforces that `type(c) == type(result)`, so no
// need to check it here.
// Check the 'use' part of the type against the operands and the result.
if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
return emitOpError("operand #0 must be of use 'MatrixA'");
if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
return emitOpError("operand #1 must be of use 'MatrixB'");
if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
return emitOpError("operand #2 must be of use 'MatrixAcc'");
// Check the 'scope' part of the type.
if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
return emitOpError("matrix scope mismatch");
// Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'.
if (typeA.getRows() != typeC.getRows())
return emitOpError("matrix size mismatch on dimension 'M'");
if (typeB.getColumns() != typeC.getColumns())
return emitOpError("matrix size mismatch on dimension 'N'");
if (typeA.getColumns() != typeB.getRows())
return emitOpError("matrix size mismatch on dimension 'K'");
// The spec does not restrict the element types:
// > A, B, C, and Result Type need not necessarily have the same component
// > type, this is defined by the client API.
// Check that if Cooperative Matrix Operands are provided, the element type
// is integer.
if (getMatrixOperands()) {
Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
typeC.getElementType()};
if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
return emitOpError("Matrix Operands require all matrix element types to "
"be Integer Types");
}
}
// Any further requirements need to be checked against VCE.
return success();
}
} // namespace mlir::spirv
|