File: vector-reduction-to-spirv-dot-prod.mlir

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (192 lines) | stat: -rw-r--r-- 9,391 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
// RUN: mlir-opt --split-input-file --verify-diagnostics \
// RUN:   --test-vector-reduction-to-spirv-dot-prod %s -o - | FileCheck %s

// Positive tests.

// CHECK-LABEL: func.func @to_sdot
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_sdot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_sdot_acc
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_sdot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_sdot_i64
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i64
//  CHECK-NEXT:   return [[DOT]] : i64
func.func @to_sdot_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i64 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
  %mul = arith.muli %lhs, %rhs : vector<4xi64>
  %red = vector.reduction <add>, %mul : vector<4xi64> into i64
  return %red : i64
}

// CHECK-LABEL: func.func @to_sdot_acc_i64
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i64)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i64) -> i64
//  CHECK-NEXT:   return [[DOT]] : i64
func.func @to_sdot_acc_i64(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i64) -> i64 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi64>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi64>
  %mul = arith.muli %lhs, %rhs : vector<4xi64>
  %red = vector.reduction <add>, %mul, %acc : vector<4xi64> into i64
  return %red : i64
}

// CHECK-LABEL: func.func @to_udot
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_udot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_udot_acc
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.UDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_udot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_signed_unsigned_dot
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG0]], [[ARG1]] : (vector<4xi8>, vector<4xi8>) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_signed_unsigned_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_signed_unsigned_dot_acc
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG0]], [[ARG1]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_signed_unsigned_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extui %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_unsigned_signed_dot
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDot [[ARG1]], [[ARG0]] : (vector<4xi8>, vector<4xi8>) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_unsigned_signed_dot(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_unsigned_signed_dot_acc
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>, [[ACC:%.+]]: i32)
//  CHECK-NEXT:   [[DOT:%.+]] = spirv.SUDotAccSat [[ARG1]], [[ARG0]], [[ACC]] : (vector<4xi8>, vector<4xi8>, i32) -> i32
//  CHECK-NEXT:   return [[DOT]] : i32
func.func @to_unsigned_signed_dot_acc(%arg0: vector<4xi8>, %arg1: vector<4xi8>, %acc: i32) -> i32 {
  %lhs = arith.extui %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <add>, %mul, %acc : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @to_sdot_vector3
//  CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<3xi8>)
//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 : i8
//       CHECK:   %[[LHS:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
//       CHECK:   %[[RHS:.+]] = spirv.CompositeConstruct %[[ARG1]], %[[ZERO]] : (vector<3xi8>, i8) -> vector<4xi8>
//       CHECK:   %[[SDOT:.+]] = spirv.SDot %[[LHS]], %[[RHS]] : (vector<4xi8>, vector<4xi8>) -> i32
//       CHECK:   return %[[SDOT]]
func.func @to_sdot_vector3(%arg0: vector<3xi8>, %arg1: vector<3xi8>) -> i32 {
  %lhs = arith.extsi %arg0 : vector<3xi8> to vector<3xi32>
  %rhs = arith.extsi %arg1 : vector<3xi8> to vector<3xi32>
  %mul = arith.muli %lhs, %rhs : vector<3xi32>
  %red = vector.reduction <add>, %mul : vector<3xi32> into i32
  return %red : i32
}

// -----

// Negative tests.

// CHECK-LABEL: func.func @too_short
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<2xi8>, [[ARG1:%.+]]: vector<2xi8>)
//  CHECK:        [[RED:%.+]] = vector.reduction
//  CHECK-NEXT:   return [[RED]] : i32
func.func @too_short(%arg0: vector<2xi8>, %arg1: vector<2xi8>) -> i32 {
  %lhs = arith.extsi %arg0 : vector<2xi8> to vector<2xi32>
  %rhs = arith.extsi %arg1 : vector<2xi8> to vector<2xi32>
  %mul = arith.muli %lhs, %rhs : vector<2xi32>
  %red = vector.reduction <add>, %mul : vector<2xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @too_long
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<6xi8>, [[ARG1:%.+]]: vector<6xi8>)
//  CHECK:        [[RED:%.+]] = vector.reduction
//  CHECK-NEXT:   return [[RED]] : i32
func.func @too_long(%arg0: vector<6xi8>, %arg1: vector<6xi8>) -> i32 {
  %lhs = arith.extsi %arg0 : vector<6xi8> to vector<6xi32>
  %rhs = arith.extsi %arg1 : vector<6xi8> to vector<6xi32>
  %mul = arith.muli %lhs, %rhs : vector<6xi32>
  %red = vector.reduction <add>, %mul : vector<6xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @wrong_reduction_kind
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
//  CHECK:        [[RED:%.+]] = vector.reduction <mul>
//  CHECK-NEXT:   return [[RED]] : i32
func.func @wrong_reduction_kind(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
  %mul = arith.muli %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <mul>, %mul : vector<4xi32> into i32
  return %red : i32
}

// CHECK-LABEL: func.func @wrong_arith_op
//  CHECK-SAME:   ([[ARG0:%.+]]: vector<4xi8>, [[ARG1:%.+]]: vector<4xi8>)
//  CHECK:        [[ADD:%.+]] = arith.addi
//  CHECK:        [[RED:%.+]] = vector.reduction <mul>, [[ADD]]
//  CHECK-NEXT:   return [[RED]] : i32
func.func @wrong_arith_op(%arg0: vector<4xi8>, %arg1: vector<4xi8>) -> i32 {
  %lhs = arith.extsi %arg0 : vector<4xi8> to vector<4xi32>
  %rhs = arith.extsi %arg1 : vector<4xi8> to vector<4xi32>
  %add = arith.addi %lhs, %rhs : vector<4xi32>
  %red = vector.reduction <mul>, %add : vector<4xi32> into i32
  return %red : i32
}