File: tensor-ops-to-spirv.mlir

package info (click to toggle)
llvm-toolchain-19 1%3A19.1.7-3~deb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,998,492 kB
  • sloc: cpp: 6,951,680; ansic: 1,486,157; asm: 913,598; python: 232,024; f90: 80,126; objc: 75,281; lisp: 37,276; pascal: 16,990; sh: 10,009; ml: 5,058; perl: 4,724; awk: 3,523; makefile: 3,167; javascript: 2,504; xml: 892; fortran: 664; cs: 573
file content (63 lines) | stat: -rw-r--r-- 2,609 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// RUN: mlir-opt --split-input-file --convert-tensor-to-spirv \
// RUN:   --verify-diagnostics %s | FileCheck %s

//===----------------------------------------------------------------------===//
// tensor.extract
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @tensor_extract_constant
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32, %[[C:.+]]: i32)
func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
  // CHECK: %[[CST:.+]] = spirv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]>
  %cst = arith.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32>
  // CHECK: %[[VAR:.+]] = spirv.Variable : !spirv.ptr<!spirv.array<12 x i32>, Function>
  // CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.array<12 x i32>
  // CHECK: %[[C0:.+]] = spirv.Constant 0 : i32
  // CHECK: %[[C6:.+]] = spirv.Constant 6 : i32
  // CHECK: %[[MUL0:.+]] = spirv.IMul %[[A]], %[[C6]] : i32
  // CHECK: %[[C3:.+]] = spirv.Constant 3 : i32
  // CHECK: %[[MUL1:.+]] = spirv.IMul %[[B]], %[[C3]] : i32
  // CHECK: %[[ADD1:.+]] = spirv.IAdd %[[MUL1]], %[[MUL0]] : i32
  // CHECK: %[[C1:.+]] = spirv.Constant 1 : i32
  // CHECK: %[[ADD2:.+]] = spirv.IAdd %[[C]], %[[ADD1]] : i32
  // CHECK: %[[AC:.+]] = spirv.AccessChain %[[VAR]][%[[ADD2]]]
  // CHECK: %[[VAL:.+]] = spirv.Load "Function" %[[AC]] : i32
  %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32>
  // CHECK: spirv.ReturnValue %[[VAL]]
  return %extract : i32
}

// -----

//===----------------------------------------------------------------------===//
// Type conversion
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @tensor_0d
// CHECK-NEXT:    spirv.Constant 1 : i32
func.func @tensor_0d() -> () {
  %x = arith.constant dense<1> : tensor<i32>
  return
}

// CHECK-LABEL: func @tensor_1d
// CHECK-NEXT:    spirv.Constant dense<[1, 2, 3]> : tensor<3xi32> : !spirv.array<3 x i32>
func.func @tensor_1d() -> () {
  %x = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
  return
}

// CHECK-LABEL: func @tensor_2d
// CHECK-NEXT:    spirv.Constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
func.func @tensor_2d() -> () {
  %x = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
  return
}

// We do not handle zero-element tensors yet. Just make we do not crash on them.
// CHECK-LABEL: func @tensor_2d_empty
// CHECK-NEXT:    arith.constant dense<>
func.func @tensor_2d_empty() -> () {
  %x = arith.constant dense<> : tensor<2x0xi32>
  return
}