File: fuse_sparse_convert_into_producer.mlir

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (122 lines) | stat: -rw-r--r-- 5,241 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map  | FileCheck %s --check-prefix=CHECK-FOLD
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s

#trait = {
  indexing_maps = [
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  ],
  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}>
#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>

// CHECK-LABEL:   func.func @fold_convert(
// CHECK:           scf.for
// CHECK:             scf.for
// CHECK:               scf.for
// CHECK:                 scf.if
// CHECK-NEXT:               tensor.insert
// CHECK-NEXT:               scf.yield
// CHECK-NEXT:             else
// CHECK-NEXT:               scf.yield
// CHECK:                 scf.yield
// CHECK:               scf.yield
// CHECK:             scf.yield
// CHECK:           sparse_tensor.load

// CHECK-FOLD-LABEL:   func.func @fold_convert(
// CHECK-FOLD-NOT:     sparse_tensor.convert
func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 1.000000e+00 : f32
  %cst_1 = arith.constant 1.000000e+00 : f32
  %0 = tensor.empty() : tensor<128x32x32x1xf32>
  %1 = linalg.generic #trait
  ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
  outs(%0 : tensor<128x32x32x1xf32>) {
    ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
      %3 = arith.subf %cst_0, %in_2 : f32
      %4 = arith.mulf %in, %3 : f32
      %5 = arith.mulf %4, %cst_1 : f32
      %6 = arith.addf %5, %in_3 : f32
      %7 = arith.subf %6, %cst_0 : f32
      %8 = arith.cmpf uge, %7, %cst : f32
      %9 = arith.uitofp %8 : i1 to f32
      linalg.yield %9 : f32
    } -> tensor<128x32x32x1xf32>
  %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
  return %2 : tensor<128x32x32x1xf32, #CCCD>
}

#trait_bin = {
  indexing_maps = [
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  ],
  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}

// CHECK-FOLD-LABEL:   func.func @fold_convert_multi_use(
// CHECK-FOLD:           tensor.empty() : tensor<128x32x32x1xf32>
// CHECK-FOLD:           linalg.generic
// CHECK-FOLD:           tensor.empty() : tensor<128x32x32x1xf32, #sparse>
// CHECK-FOLD:           linalg.generic
// CHECK-FOLD-NOT:       sparse_tensor.convert
func.func @fold_convert_multi_use(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
                        %arg2: tensor<128x32x32x1xf32>, %arg3: tensor<128x32x32x1xf32>) -> (tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>) {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 1.000000e+00 : f32
  %cst_1 = arith.constant 1.000000e+00 : f32

  %0 = tensor.empty() : tensor<128x32x32x1xf32>
  %1 = linalg.generic #trait_bin
  ins(%arg0, %arg1 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
  outs(%0 : tensor<128x32x32x1xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      linalg.yield %3 : f32
    } -> tensor<128x32x32x1xf32>

  // A second kernel that uses %0 as the init operand.
  %3 = linalg.generic #trait_bin
  ins(%arg2, %arg3 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
  outs(%0 : tensor<128x32x32x1xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      linalg.yield %3 : f32
    } -> tensor<128x32x32x1xf32>
  %4 = sparse_tensor.convert %3 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>

  return %1, %4 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>
}



// FIXME: The following kernel is not sparsifiable because `arith.select`
// operations is not handled by the sparse compiler at the moment.
//
// CHECK-FOLD-LABEL:   func.func @fold_cast(
// CHECK-FOLD-NOT:     sparse_tensor.convert
func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> {
  %cst = arith.constant 0.000000e+00 : f64
  %1 = tensor.empty() : tensor<10x20x30xf64>
  %2 = linalg.generic { indexing_maps = [#map, #map],
                        iterator_types = ["parallel", "parallel", "parallel"]
                      }
  ins (%0 : tensor<10x20x30xf64, #COO>)
  outs(%1 : tensor<10x20x30xf64>) {
      ^bb0(%in: f64, %out: f64):
        %4 = arith.cmpf ugt, %in, %cst : f64
        %5 = arith.select %4, %in, %cst : f64
        linalg.yield %5 : f64
  } -> tensor<10x20x30xf64>
  %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO>
  return %cast : tensor<10x20x30xf64, #COO>
}