File: multiple.mlir

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (131 lines) | stat: -rw-r--r-- 6,116 bytes parent folder | download | duplicates (8)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics

// Matmul+ReLU.
func.func @fc_relu_operands_00(
    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
    -> tensor<512x512xf32> {
  // Matrix-matrix multiplication.
  // expected-remark @below {{matmul # 0}}
  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>

  // Elementwise addition.
  // expected-remark @below {{add # 0}}
  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>

  // Elementwise max with 0 (ReLU).
  %c0f = arith.constant 0.0 : f32
  // expected-remark @below {{max # 0}}
  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
    ins(%biased, %c0f : tensor<512x512xf32>, f32)
    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
  func.return %relued : tensor<512x512xf32>
}

// Matmul+ReLU with swapped operands.
func.func @fc_relu_operands_01(
    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
    -> tensor<512x512xf32> {
  // Matrix-matrix multiplication.
  // expected-remark @below {{matmul # 1}}
  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>

  // Elementwise addition.
  // expected-remark @below {{add # 1}}
  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>

  // Elementwise max with 0 (ReLU).
  %c0f = arith.constant 0.0 : f32
  // expected-remark @below {{max # 1}}
  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
    ins(%c0f, %biased : f32, tensor<512x512xf32>)
    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
  func.return %relued : tensor<512x512xf32>
}

// The module containing named sequences must have an attribute allowing them
// to enable verification.
module @transforms attributes { transform.with_named_sequence } {
  // Entry point. This takes as the only argument the root operation (typically
  // pass root) given to the transform interpreter.
  transform.named_sequence @__transform_main(
      %root: !transform.any_op {transform.consumed}) {

    // Traverses the payload IR associated with the operand handle, invoking
    // @match_matmul_elemwise on each of the operations. If the named sequence
    // succeeds, i.e., if none of the nested match (transform) operations
    // produced a silenceable failure, invokes @print_matmul_elemwise and
    // forwards the values yielded as arguments of the new invocation. If the
    // named sequence fails with a silenceable failure, silences it (the message
    // is forwarded to the debug stream). Definite failures are propagated
    // immediately and unconditionally, as usual.
    transform.foreach_match in %root
      @match_matmul_elemwise -> @print_matmul_elemwise
      : (!transform.any_op) -> !transform.any_op

    transform.yield
  }

  // This is an action sequence.
  transform.named_sequence @print_matmul_elemwise(
      %matmul: !transform.any_op {transform.readonly},
      %add: !transform.any_op {transform.readonly},
      %max: !transform.any_op {transform.readonly},
      %pos: !transform.param<i32> {transform.readonly}) {
    transform.debug.emit_param_as_remark %pos, "matmul #" at %matmul
      : !transform.param<i32>, !transform.any_op
    transform.debug.emit_param_as_remark %pos, "add #" at %add
      : !transform.param<i32>, !transform.any_op
    transform.debug.emit_param_as_remark %pos, "max #" at %max
      : !transform.param<i32>, !transform.any_op
    transform.yield
  }

  // This is also a matcher sequence. It is similarly given an operation to
  // match and nested operations must succeed in order for a match to be deemed
  // successful. It starts matching from the last operation in the use-def chain
  // and goes back because each operand (use) has exactly one definition.
  transform.named_sequence @match_matmul_elemwise(
      %last: !transform.any_op {transform.readonly}) 
      -> (!transform.any_op, !transform.any_op, !transform.any_op,
          !transform.param<i32>) {
    // The last operation must be an elementwise binary.
    transform.match.operation_name %last ["linalg.elemwise_binary"]
      : !transform.any_op

    // One of its operands must be defined by another operation, to which we
    // will get a handle here. This is achieved thanks to a newly defined
    // operation that tries to match operands one by one using the match
    // operations nested in its region.
    %pos, %middle = transform.match.my.has_operand_satisfying %last
        : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) {
    ^bb0(%operand: !transform.any_value):
      // The operand must be defined by an operation.
      %def = transform.get_defining_op %operand 
        : (!transform.any_value) -> !transform.any_op
      // The defining operation must itself be an elementwise binary.
      transform.match.operation_name %def ["linalg.elemwise_binary"]
        : !transform.any_op
      transform.yield %def : !transform.any_op
    }
    
    // And the first operand of that operation must be defined by yet another
    // operation.
    %matmul = transform.get_producer_of_operand %middle[0]
      : (!transform.any_op) -> !transform.any_op
    // And that operation is a matmul.
    transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
    // We will yield the handles to the matmul and the two elementwise
    // operations separately. 
    transform.yield %matmul, %middle, %last, %pos
      : !transform.any_op, !transform.any_op, !transform.any_op,
        !transform.param<i32>
  }
}