File: NvGpuSupport.h

package info (click to toggle)
llvm-toolchain-15 1%3A15.0.6-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,554,644 kB
  • sloc: cpp: 5,922,452; ansic: 1,012,136; asm: 674,362; python: 191,568; objc: 73,855; f90: 42,327; lisp: 31,913; pascal: 11,973; javascript: 10,144; sh: 9,421; perl: 7,447; ml: 5,527; awk: 3,523; makefile: 2,520; xml: 885; cs: 573; fortran: 567
file content (100 lines) | stat: -rw-r--r-- 4,011 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides utilities to assist in the lowering of Vector operations
// to GPU dialect MMA operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"

namespace mlir {
namespace nvgpu {

enum class MatMulOperandRole : int32_t { A = 0, B, C };

/// Collects information about a warp-level matrix operand represented by a
/// VectorType.
struct WarpMatrixInfo {
  VectorType vectorType;
  MatMulOperandRole operandRole;
};

/// Given an op that operates on a VectorType representing a warp-level matrix
/// operand, the function returns a struct containing relevant type information.
FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);

/// Returns the number of bits in a single tile row. It is either 128, 256, or
/// 512 bits depending on the data type and` whether the operand is an
/// accumulator/result operand
int64_t inferTileWidthInBits(const WarpMatrixInfo &type);

/// Specifies information about the registers which compose a matrix fragment
/// according to the PTX documentation.
struct FragmentElementInfo {
  Type registerLLVMType;
  int64_t elementsPerRegister;
  int64_t registerWidthBits;
  int64_t numRegistersPerFragment;
};

/// Returns a FragmentElementInfo struct describing the register types for the
/// given matrix fragment type.
FailureOr<FragmentElementInfo>
getMmaSyncRegisterType(const WarpMatrixInfo &type);

/// Returns an AffineMap which maps a two dimensions representing (laneId,
/// logicalValueId) and returns two results representing offsets within a
/// matrix operand. The offsets point to the values the thread is responsible
/// for (AKA the matrix fragment values) during a warp-collective matrix
/// operation. For a visual reference of this LaneId -> (row, col) mapping,
/// please see NVIDIA's PTX documentation:
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
FailureOr<AffineMap>
getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
                                  const WarpMatrixInfo &fragmentType);

struct LdMatrixParams {
  VectorType fragmentType;
  bool isAccum;
  int64_t numTiles;
  IteratorType contiguousDimType;
  NVVM::MMALayout targetLayout;
};

FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
                                            bool transpose);
/// Returns an AffineMap which maps a single dimension representing the laneId
/// to two results representing offsets within the matrix operand that should
/// be the pointer locations a thread should pass to the ldmatrix instruction.
FailureOr<AffineMap>
getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
                               const LdMatrixParams &params);

// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
// to MMA matmul.
struct PrepareContractToGPUMMASync
    : public OpRewritePattern<vector::ContractionOp> {
  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::ContractionOp op,
                                PatternRewriter &rewriter) const override;
};

} // namespace nvgpu
} // namespace mlir

#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H