# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects.pdl import *


def constructAndPrintInModule(f):
  print("\nTEST:", f.__name__)
  with Context(), Location.unknown():
    module = Module.create()
    with InsertionPoint(module.body):
      f()
    print(module)
  return f


# CHECK: module  {
# CHECK:   pdl.pattern @operations : benefit(1)  {
# CHECK:     %0 = attribute
# CHECK:     %1 = type
# CHECK:     %2 = operation  {"attr" = %0} -> (%1 : !pdl.type)
# CHECK:     %3 = result 0 of %2
# CHECK:     %4 = operand
# CHECK:     %5 = operation(%3, %4 : !pdl.value, !pdl.value)
# CHECK:     rewrite %5 with "rewriter"
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_operations():
  pattern = PatternOp(1, "operations")
  with InsertionPoint(pattern.body):
    attr = AttributeOp()
    ty = TypeOp()
    op0 = OperationOp(attributes={"attr": attr}, types=[ty])
    op0_result = ResultOp(op0, 0)
    input = OperandOp()
    root = OperationOp(args=[op0_result, input])
    RewriteOp(root, "rewriter")


# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_with_args : benefit(1)  {
# CHECK:     %0 = operand
# CHECK:     %1 = operation(%0 : !pdl.value)
# CHECK:     rewrite %1 with "rewriter"(%0 : !pdl.value)
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_args():
  pattern = PatternOp(1, "rewrite_with_args")
  with InsertionPoint(pattern.body):
    input = OperandOp()
    root = OperationOp(args=[input])
    RewriteOp(root, "rewriter", args=[input])

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_with_params : benefit(1)  {
# CHECK:     %0 = operation
# CHECK:     rewrite %0 with "rewriter" ["I am param"]
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_params():
  pattern = PatternOp(1, "rewrite_with_params")
  with InsertionPoint(pattern.body):
    op = OperationOp()
    RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")])

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_with_args_and_params : benefit(1)  {
# CHECK:     %0 = operand
# CHECK:     %1 = operation(%0 : !pdl.value)
# CHECK:     rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value)
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_args_and_params():
  pattern = PatternOp(1, "rewrite_with_args_and_params")
  with InsertionPoint(pattern.body):
    input = OperandOp()
    root = OperationOp(args=[input])
    RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input])

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
# CHECK:     %0 = operand
# CHECK:     %1 = operand
# CHECK:     %2 = type
# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %4 = result 0 of %3
# CHECK:     %5 = operation(%4 : !pdl.value)
# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %7 = result 0 of %6
# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
# CHECK:     rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation)
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_optimal():
  pattern = PatternOp(1, "rewrite_multi_root_optimal")
  with InsertionPoint(pattern.body):
    input1 = OperandOp()
    input2 = OperandOp()
    ty = TypeOp()
    op1 = OperationOp(args=[input1], types=[ty])
    val1 = ResultOp(op1, 0)
    root1 = OperationOp(args=[val1])
    op2 = OperationOp(args=[input2], types=[ty])
    val2 = ResultOp(op2, 0)
    root2 = OperationOp(args=[val1, val2])
    RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2])

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
# CHECK:     %0 = operand
# CHECK:     %1 = operand
# CHECK:     %2 = type
# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %4 = result 0 of %3
# CHECK:     %5 = operation(%4 : !pdl.value)
# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
# CHECK:     %7 = result 0 of %6
# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
# CHECK:     rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation)
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_forced():
  pattern = PatternOp(1, "rewrite_multi_root_forced")
  with InsertionPoint(pattern.body):
    input1 = OperandOp()
    input2 = OperandOp()
    ty = TypeOp()
    op1 = OperationOp(args=[input1], types=[ty])
    val1 = ResultOp(op1, 0)
    root1 = OperationOp(args=[val1])
    op2 = OperationOp(args=[input2], types=[ty])
    val2 = ResultOp(op2, 0)
    root2 = OperationOp(args=[val1, val2])
    RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2])

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
# CHECK:     %0 = type : i32
# CHECK:     %1 = type
# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK:     rewrite %2  {
# CHECK:       %3 = type
# CHECK:       %4 = operation "foo.op"  -> (%0, %3 : !pdl.type, !pdl.type)
# CHECK:       replace %2 with %4
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_add_body():
  pattern = PatternOp(1, "rewrite_add_body")
  with InsertionPoint(pattern.body):
    ty1 = TypeOp(IntegerType.get_signless(32))
    ty2 = TypeOp()
    root = OperationOp(types=[ty1, ty2])
    rewrite = RewriteOp(root)
    with InsertionPoint(rewrite.add_body()):
      ty3 = TypeOp()
      newOp = OperationOp(name="foo.op", types=[ty1, ty3])
      ReplaceOp(root, with_op=newOp)

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
# CHECK:     %0 = type : i32
# CHECK:     %1 = type
# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK:     rewrite %2  {
# CHECK:       %3 = operation "foo.op"  -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_type():
  pattern = PatternOp(1, "rewrite_type")
  with InsertionPoint(pattern.body):
    ty1 = TypeOp(IntegerType.get_signless(32))
    ty2 = TypeOp()
    root = OperationOp(types=[ty1, ty2])
    rewrite = RewriteOp(root)
    with InsertionPoint(rewrite.add_body()):
      newOp = OperationOp(name="foo.op", types=[ty1, ty2])

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
# CHECK:     %0 = types
# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
# CHECK:     rewrite %1  {
# CHECK:       %2 = types : [i32, i64]
# CHECK:       %3 = operation "foo.op"  -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_types():
  pattern = PatternOp(1, "rewrite_types")
  with InsertionPoint(pattern.body):
    types = TypesOp()
    root = OperationOp(types=[types])
    rewrite = RewriteOp(root)
    with InsertionPoint(rewrite.add_body()):
      otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
      newOp = OperationOp(name="foo.op", types=[types, otherTypes])

# CHECK: module  {
# CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
# CHECK:     %0 = types
# CHECK:     %1 = operands : %0
# CHECK:     %2 = operation(%1 : !pdl.range<value>)
# CHECK:     rewrite %2  {
# CHECK:       %3 = operation "foo.op"  -> (%0 : !pdl.range<type>)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_operands():
  pattern = PatternOp(1, "rewrite_operands")
  with InsertionPoint(pattern.body):
    types = TypesOp()
    operands = OperandsOp(types)
    root = OperationOp(args=[operands])
    rewrite = RewriteOp(root)
    with InsertionPoint(rewrite.add_body()):
      newOp = OperationOp(name="foo.op", types=[types])

# CHECK: module  {
# CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
# CHECK:     %0 = operation
# CHECK:     rewrite %0  {
# CHECK:       apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_native_rewrite():
  pattern = PatternOp(1, "native_rewrite")
  with InsertionPoint(pattern.body):
    root = OperationOp()
    rewrite = RewriteOp(root)
    with InsertionPoint(rewrite.add_body()):
      ApplyNativeRewriteOp([], "NativeRewrite", args=[root])

# CHECK: module  {
# CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
# CHECK:     %0 = operation
# CHECK:     rewrite %0  {
# CHECK:       %1 = attribute "value"
# CHECK:       apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_attribute_with_value():
  pattern = PatternOp(1, "attribute_with_value")
  with InsertionPoint(pattern.body):
    root = OperationOp()
    rewrite = RewriteOp(root)
    with InsertionPoint(rewrite.add_body()):
      attr = AttributeOp(value=Attribute.parse('"value"'))
      ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])

# CHECK: module  {
# CHECK:   pdl.pattern @erase : benefit(1)  {
# CHECK:     %0 = operation
# CHECK:     rewrite %0  {
# CHECK:       erase %0
# CHECK:     }
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_erase():
  pattern = PatternOp(1, "erase")
  with InsertionPoint(pattern.body):
    root = OperationOp()
    rewrite = RewriteOp(root)
    with InsertionPoint(rewrite.add_body()):
      EraseOp(root)

# CHECK: module  {
# CHECK:   pdl.pattern @operation_results : benefit(1)  {
# CHECK:     %0 = types
# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
# CHECK:     %2 = results of %1
# CHECK:     %3 = operation(%2 : !pdl.range<value>)
# CHECK:     rewrite %3 with "rewriter"
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_operation_results():
  valueRange = RangeType.get(ValueType.get())
  pattern = PatternOp(1, "operation_results")
  with InsertionPoint(pattern.body):
    types = TypesOp()
    inputOp = OperationOp(types=[types])
    results = ResultsOp(valueRange, inputOp)
    root = OperationOp(args=[results])
    RewriteOp(root, name="rewriter")

# CHECK: module  {
# CHECK:   pdl.pattern : benefit(1)  {
# CHECK:     %0 = type
# CHECK:     apply_native_constraint "typeConstraint" [](%0 : !pdl.type)
# CHECK:     %1 = operation  -> (%0 : !pdl.type)
# CHECK:     rewrite %1 with "rewrite"
# CHECK:   }
# CHECK: }
@constructAndPrintInModule
def test_apply_native_constraint():
  pattern = PatternOp(1)
  with InsertionPoint(pattern.body):
    resultType = TypeOp()
    ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[])
    root = OperationOp(types=[resultType])
    RewriteOp(root, name="rewrite")
