# RUN: %PYTHON %s | FileCheck %s

import gc
import io
import itertools
from mlir.ir import *

def run(f):
  print("\nTEST:", f.__name__)
  f()
  gc.collect()
  assert Context._get_live_count() == 0


# CHECK-LABEL: TEST: test_insert_at_block_end
def test_insert_at_block_end():
  ctx = Context()
  ctx.allow_unregistered_dialects = True
  with Location.unknown(ctx):
    module = Module.parse(r"""
      func.func @foo() -> () {
        "custom.op1"() : () -> ()
      }
    """)
    entry_block = module.body.operations[0].regions[0].blocks[0]
    ip = InsertionPoint(entry_block)
    ip.insert(Operation.create("custom.op2"))
    # CHECK: "custom.op1"
    # CHECK: "custom.op2"
    module.operation.print()

run(test_insert_at_block_end)


# CHECK-LABEL: TEST: test_insert_before_operation
def test_insert_before_operation():
  ctx = Context()
  ctx.allow_unregistered_dialects = True
  with Location.unknown(ctx):
    module = Module.parse(r"""
      func.func @foo() -> () {
        "custom.op1"() : () -> ()
        "custom.op2"() : () -> ()
      }
    """)
    entry_block = module.body.operations[0].regions[0].blocks[0]
    ip = InsertionPoint(entry_block.operations[1])
    ip.insert(Operation.create("custom.op3"))
    # CHECK: "custom.op1"
    # CHECK: "custom.op3"
    # CHECK: "custom.op2"
    module.operation.print()

run(test_insert_before_operation)


# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
  ctx = Context()
  ctx.allow_unregistered_dialects = True
  with Location.unknown(ctx):
    module = Module.parse(r"""
      func.func @foo() -> () {
        "custom.op2"() : () -> ()
      }
    """)
    entry_block = module.body.operations[0].regions[0].blocks[0]
    ip = InsertionPoint.at_block_begin(entry_block)
    ip.insert(Operation.create("custom.op1"))
    # CHECK: "custom.op1"
    # CHECK: "custom.op2"
    module.operation.print()

run(test_insert_at_block_begin)


# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
def test_insert_at_block_begin_empty():
  # TODO: Write this test case when we can create such a situation.
  pass

run(test_insert_at_block_begin_empty)


# CHECK-LABEL: TEST: test_insert_at_terminator
def test_insert_at_terminator():
  ctx = Context()
  ctx.allow_unregistered_dialects = True
  with Location.unknown(ctx):
    module = Module.parse(r"""
      func.func @foo() -> () {
        "custom.op1"() : () -> ()
        return
      }
    """)
    entry_block = module.body.operations[0].regions[0].blocks[0]
    ip = InsertionPoint.at_block_terminator(entry_block)
    ip.insert(Operation.create("custom.op2"))
    # CHECK: "custom.op1"
    # CHECK: "custom.op2"
    module.operation.print()

run(test_insert_at_terminator)


# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
def test_insert_at_block_terminator_missing():
  ctx = Context()
  ctx.allow_unregistered_dialects = True
  with ctx:
    module = Module.parse(r"""
      func.func @foo() -> () {
        "custom.op1"() : () -> ()
      }
    """)
    entry_block = module.body.operations[0].regions[0].blocks[0]
    try:
      ip = InsertionPoint.at_block_terminator(entry_block)
    except ValueError as e:
      # CHECK: Block has no terminator
      print(e)
    else:
      assert False, "Expected exception"

run(test_insert_at_block_terminator_missing)


# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
def test_insert_at_end_with_terminator_errors():
  with Context() as ctx, Location.unknown():
    ctx.allow_unregistered_dialects = True
    module = Module.parse(r"""
      func.func @foo() -> () {
        return
      }
    """)
    entry_block = module.body.operations[0].regions[0].blocks[0]
    with InsertionPoint(entry_block):
      try:
        Operation.create("custom.op1", results=[], operands=[])
      except IndexError as e:
        # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
        print(f"ERROR: {e}")

run(test_insert_at_end_with_terminator_errors)


# CHECK-LABEL: TEST: test_insertion_point_context
def test_insertion_point_context():
  ctx = Context()
  ctx.allow_unregistered_dialects = True
  with Location.unknown(ctx):
    module = Module.parse(r"""
      func.func @foo() -> () {
        "custom.op1"() : () -> ()
      }
    """)
    entry_block = module.body.operations[0].regions[0].blocks[0]
    with InsertionPoint(entry_block):
      Operation.create("custom.op2")
      with InsertionPoint.at_block_begin(entry_block):
        Operation.create("custom.opa")
        Operation.create("custom.opb")
      Operation.create("custom.op3")
    # CHECK: "custom.opa"
    # CHECK: "custom.opb"
    # CHECK: "custom.op1"
    # CHECK: "custom.op2"
    # CHECK: "custom.op3"
    module.operation.print()

run(test_insertion_point_context)
