File: normalize-memrefs-ops.mlir

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (151 lines) | stat: -rw-r--r-- 7,610 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s

// For all these cases, we test if MemRefs Normalization works with the test
// operations.
// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests
//   that include this operation are constructed so that the normalization should
//   happen.
// * test_op_nonnorm: this operation does not have the MemRefsNormalization
//   attribute. The tests that include this operation are constructed so that the
//    normalization should not happen.

#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)>

// Test with op_norm and maps in arguments and in the operations in the function.

// CHECK-LABEL: test_norm
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
func.func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>

    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x64xf32>
    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
    return
}

// Same test with op_nonnorm, with maps in the arguments and the operations in the function.

// CHECK-LABEL: test_nonnorm
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #[[MAP:.*]]>)
func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
    "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>

    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #[[MAP]]>
    // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #[[MAP]]>, memref<1x16x14x14xf32, #[[MAP]]>) -> ()
    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #[[MAP]]>
    return
}

// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm
// does not block the normalization of other operations.

// CHECK-LABEL: test_nonnorm_identity_layout
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
    %0 = memref.alloc() : memref<1x16x14x14xf32>
    "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> ()
    memref.dealloc %0 :  memref<1x16x14x14xf32>

    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32>
    // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> ()
    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32>
    return
}

// Test with op_norm, with maps in the operations in the function.

// CHECK-LABEL: test_norm_mix
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>
func.func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
    "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>

    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x64xf32>
    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
    return
}

// Test with maps in load and store ops.

#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)>

// CHECK-LABEL: test_load_store
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32>
func.func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
    %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile>
    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
    %1 = memref.alloc() : memref<1x16x14x14xf32>
    // CHECK: %[[v1:.*]] = memref.alloc() : memref<1x16x14x14xf32>
    "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> ()
    // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> ()
    %cst = arith.constant 3.0 : f32
    affine.for %i = 0 to 1 {
      affine.for %j = 0 to 16 {
        affine.for %k = 0 to 14 {
          affine.for %l = 0 to 14 {
            %2 = memref.load %1[%i, %j, %k, %l] : memref<1x16x14x14xf32>
            // CHECK: memref<1x16x14x14xf32>
            %3 = arith.addf %2, %cst : f32
            memref.store %3, %arg0[%i, %j, %k, %l] : memref<1x16x14x14xf32>
            // CHECK: memref<1x16x14x14xf32>
          }
        }
      }
    }
    memref.dealloc %0 :  memref<1x16x14x14xf32, #map_tile>
    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
    memref.dealloc %1 :  memref<1x16x14x14xf32>
    // CHECK: memref.dealloc %[[v1]] : memref<1x16x14x14xf32>
    return
}

// Test with op_norm_ret, with maps in the results of normalizable operation.

// CHECK-LABEL: test_norm_ret
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) {
func.func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) {
    %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile>
    // CHECK-NEXT: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
    %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>)
    // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret"
    // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>)
    "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32, #map_tile>) -> ()
    // CHECK-NEXT: "test.op_norm"
    // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> ()
    memref.dealloc %0 : memref<1x16x14x14xf32, #map_tile>
    // CHECK-NEXT: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
    return %1, %2 : memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>
    // CHECK-NEXT: return %[[v1]], %[[v2]] : memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>
}

// Test with an arbitrary op that references the function symbol.

"test.op_funcref"() {func = @test_norm_mix} : () -> ()


// -----

#map_1d_tile = affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>

// Test with memref.reinterpret_cast

// CHECK-LABEL: test_norm_reinterpret_cast
// CHECK-SAME: (%[[ARG0:.*]]: memref<1x32xf32>) -> memref<3x1x1xf32> {
func.func @test_norm_reinterpret_cast(%arg0 : memref<3xf32, #map_1d_tile>) -> (memref<3x1x1xf32>) {
    %0 = memref.alloc() : memref<3xf32>
    "test.op_norm"(%arg0, %0) : (memref<3xf32, #map_1d_tile>, memref<3xf32>) -> ()
    %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32>
    // CHECK: %[[v0:.*]] = memref.alloc() : memref<3xf32>
    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x32xf32>, memref<3xf32>) -> ()
    // CHECK: memref.reinterpret_cast %[[v0]] to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32>
    return %1 : memref<3x1x1xf32>
}