File: features.mlir

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (123 lines) | stat: -rw-r--r-- 5,610 bytes parent folder | download | duplicates (11)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics

// Matmul as a named operation.
func.func @named(
    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
    -> tensor<512x512xf32> {
  // expected-remark @below {{matmul}}
  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
  func.return %matmul : tensor<512x512xf32>
}

// Matmul as a generic operation.
func.func @generic(
    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
    -> tensor<512x512xf32> {
  // expected-remark @below {{matmul}}
  %matmul = linalg.generic {
    iterator_types = ["parallel", "parallel", "reduction"],
    indexing_maps = [
      affine_map<(d0, d1, d2) -> (d0, d2)>,
      affine_map<(d0, d1, d2) -> (d2, d1)>,
      affine_map<(d0, d1, d2) -> (d0, d1)>]
  } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
    outs(%output: tensor<512x512xf32>) {
  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
    %0 = arith.mulf %arg0, %arg1 : f32
    %1 = arith.addf %0, %arg2 : f32
    linalg.yield %1 : f32
  } -> tensor<512x512xf32>
  return %matmul : tensor<512x512xf32>
}

// The module containing named sequences must have an attribute allowing them
// to enable verification.
module @transforms attributes { transform.with_named_sequence } {
  // Entry point. This takes as the only argument the root operation (typically
  // pass root) given to the transform interpreter.
  transform.named_sequence @__transform_main(
      %root: !transform.any_op {transform.consumed}) {

    // Traverses the payload IR associated with the operand handle, invoking
    // @match_matmul_elemwise on each of the operations. If the named sequence
    // succeeds, i.e., if none of the nested match (transform) operations
    // produced a silenceable failure, invokes @print_matmul_elemwise and
    // forwards the values yielded as arguments of the new invocation. If the
    // named sequence fails with a silenceable failure, silences it (the message
    // is forwarded to the debug stream). Definite failures are propagated
    // immediately and unconditionally, as usual.
    transform.foreach_match in %root
      @match_generic_matmul -> @print_generic_matmul
      : (!transform.any_op) -> !transform.any_op

    transform.yield
  }

  // This is an action sequence.
  transform.named_sequence @print_generic_matmul(
      %matmul: !transform.any_op {transform.readonly}) {
    transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op
    transform.yield
  }

  transform.named_sequence @match_generic_matmul(
      %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
    // Match a structured linear algebra operation.
    transform.match.structured %candidate : !transform.any_op {
    ^bb0(%c: !transform.any_op):
      // With a rank equal to 3.
      %rank = transform.match.structured.rank %c
        : (!transform.any_op) -> !transform.param<i64>
      %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
      transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>

      // With 2 inputs.
      %n_ins = transform.match.structured.num_inputs %c
        : (!transform.any_op) -> !transform.param<i64>
      %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
      transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>

      // With 1 output (note that structured ops in destination passing style
      // has as many inits as outputs).
      %n_inits = transform.match.structured.num_inits %c
        : (!transform.any_op) -> !transform.param<i64>
      %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
      transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>

      // All inputs and inits are accessed with a projected permutation.
      transform.match.structured.input %c[all] {projected_permutation}
        : !transform.any_op
      transform.match.structured.init %c[0] {projected_permutation}
        : !transform.any_op

      // The body is a mulf/addf contraction with appropriate dimensions.
      transform.match.structured.body %c 
        { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
      %batch, %lhs, %rhs, %reduction =
      transform.match.structured.classify_contraction_dims %c
        : (!transform.any_op)
        -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
            !transform.param<i64>)

      // There is one of lhs, rhs and reduction dimensions and zero batch
      // dimensions.
      %n_batch = transform.num_associations %batch
        : (!transform.param<i64>) -> !transform.param<i64>
      %n_lhs = transform.num_associations %lhs
        : (!transform.param<i64>) -> !transform.param<i64>
      %n_rhs = transform.num_associations %rhs
        : (!transform.param<i64>) -> !transform.param<i64>
      %n_reduction = transform.num_associations %reduction
        : (!transform.param<i64>) -> !transform.param<i64>
      %c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
      transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
      transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
      transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
      transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
    }
    transform.yield %candidate : !transform.any_op
  }
}