File: ShapeCanonicalization.td

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (52 lines) | stat: -rw-r--r-- 1,737 bytes parent folder | download | duplicates (10)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"

def AllInputShapesEq : Constraint<CPred< [{
  llvm::all_equal($0)
}]>>;

def HasSingleElement : Constraint<CPred< [{
  $0.size() == 1
}]>>;

def HasStaticShape : Constraint<CPred< [{
  ::llvm::dyn_cast<ShapedType>($0.getType()).hasStaticShape()
}]>>;

// Helper that takes the first element of a range.
def TakeFront : NativeCodeCall<"$0.front()">;

// Canonicalization patterns.

def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
                           (replaceWithValue $args),
                           [(HasSingleElement $args)]>;

def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $shapes),
  (Shape_ConstWitnessOp ConstBoolAttrTrue),
  [(AllInputShapesEq $shapes)]>;

def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
  (Shape_ConstWitnessOp ConstBoolAttrTrue),
  [(AllInputShapesEq $shapes)]>;

def IndexToSizeToIndexCanonicalization : Pat<
  (Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)),
  (replaceWithValue $arg)>;

def SizeToIndexToSizeCanonicalization : Pat<
  (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
  (replaceWithValue $arg)>;

// Fold tensor.cast(const_shape) to const_shape. This changes the type of
// const_shape to the destination type of the cast.
def TensorCastConstShape : Pat <
  (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
  [(HasStaticShape $res)]>;

// tensor.extract from shape_of -> tensor.dim. We can take the first index
// because shape_of always returns a 1D tensor.
def ExtractFromShapeOfExtentTensor : Pat<
  (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
  (Tensor_DimOp $arg, (TakeFront $indices))>;