# RUN: %PYTHON %s | FileCheck %s

import gc
from mlir.ir import *

def run(f):
  print("\nTEST:", f.__name__)
  f()
  gc.collect()
  assert Context._get_live_count() == 0


def add_dummy_value():
  return Operation.create(
      "custom.value",
      results=[IntegerType.get_signless(32)]).result


def testOdsBuildDefaultImplicitRegions():

  class TestFixedRegionsOp(OpView):
    OPERATION_NAME = "custom.test_op"
    _ODS_REGIONS = (2, True)

  class TestVariadicRegionsOp(OpView):
    OPERATION_NAME = "custom.test_any_regions_op"
    _ODS_REGIONS = (2, False)

  with Context() as ctx, Location.unknown():
    ctx.allow_unregistered_dialects = True
    m = Module.create()
    with InsertionPoint(m.body):
      op = TestFixedRegionsOp.build_generic(results=[], operands=[])
      # CHECK: NUM_REGIONS: 2
      print(f"NUM_REGIONS: {len(op.regions)}")
      # Including a regions= that matches should be fine.
      op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
      print(f"NUM_REGIONS: {len(op.regions)}")
      # Reject greater than.
      try:
        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3)
      except ValueError as e:
        # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
        print(f"ERROR:{e}")
      # Reject less than.
      try:
        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1)
      except ValueError as e:
        # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
        print(f"ERROR:{e}")

      # If no regions specified for a variadic region op, build the minimum.
      op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
      # CHECK: DEFAULT_NUM_REGIONS: 2
      print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
      # Should also accept an explicit regions= that matches the minimum.
      op = TestVariadicRegionsOp.build_generic(
          results=[], operands=[], regions=2)
      # CHECK: EQ_NUM_REGIONS: 2
      print(f"EQ_NUM_REGIONS: {len(op.regions)}")
      # And accept greater than minimum.
      # Should also accept an explicit regions= that matches the minimum.
      op = TestVariadicRegionsOp.build_generic(
          results=[], operands=[], regions=3)
      # CHECK: GT_NUM_REGIONS: 3
      print(f"GT_NUM_REGIONS: {len(op.regions)}")
      # Should reject less than minimum.
      try:
        op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1)
      except ValueError as e:
        # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
        print(f"ERROR:{e}")



run(testOdsBuildDefaultImplicitRegions)


def testOdsBuildDefaultNonVariadic():

  class TestOp(OpView):
    OPERATION_NAME = "custom.test_op"

  with Context() as ctx, Location.unknown():
    ctx.allow_unregistered_dialects = True
    m = Module.create()
    with InsertionPoint(m.body):
      v0 = add_dummy_value()
      v1 = add_dummy_value()
      t0 = IntegerType.get_signless(8)
      t1 = IntegerType.get_signless(16)
      op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
      # CHECK: %[[V0:.+]] = "custom.value"
      # CHECK: %[[V1:.+]] = "custom.value"
      # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
      # CHECK-NOT: operand_segment_sizes
      # CHECK-NOT: result_segment_sizes
      # CHECK-SAME: : (i32, i32) -> (i8, i16)
      print(m)

run(testOdsBuildDefaultNonVariadic)


def testOdsBuildDefaultSizedVariadic():

  class TestOp(OpView):
    OPERATION_NAME = "custom.test_op"
    _ODS_OPERAND_SEGMENTS = [1, -1, 0]
    _ODS_RESULT_SEGMENTS = [-1, 0, 1]

  with Context() as ctx, Location.unknown():
    ctx.allow_unregistered_dialects = True
    m = Module.create()
    with InsertionPoint(m.body):
      v0 = add_dummy_value()
      v1 = add_dummy_value()
      v2 = add_dummy_value()
      v3 = add_dummy_value()
      t0 = IntegerType.get_signless(8)
      t1 = IntegerType.get_signless(16)
      t2 = IntegerType.get_signless(32)
      t3 = IntegerType.get_signless(64)
      # CHECK: %[[V0:.+]] = "custom.value"
      # CHECK: %[[V1:.+]] = "custom.value"
      # CHECK: %[[V2:.+]] = "custom.value"
      # CHECK: %[[V3:.+]] = "custom.value"
      # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
      # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi32>
      # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi32>
      # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
      op = TestOp.build_generic(
          results=[[t0, t1], t2, t3],
          operands=[v0, [v1, v2], v3])

      # Now test with optional omitted.
      # CHECK: "custom.test_op"(%[[V0]])
      # CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]>
      # CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]>
      # CHECK-SAME: (i32) -> i64
      op = TestOp.build_generic(
          results=[None, None, t3],
          operands=[v0, None, None])
      print(m)

      # And verify that errors are raised for None in a required operand.
      try:
        op = TestOp.build_generic(
            results=[None, None, t3],
            operands=[None, None, None])
      except ValueError as e:
        # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
        print(f"OPERAND_CAST_ERROR:{e}")

      # And verify that errors are raised for None in a required result.
      try:
        op = TestOp.build_generic(
            results=[None, None, None],
            operands=[v0, None, None])
      except ValueError as e:
        # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
        print(f"RESULT_CAST_ERROR:{e}")

      # Variadic lists with None elements should reject.
      try:
        op = TestOp.build_generic(
            results=[None, None, t3],
            operands=[v0, [None], None])
      except ValueError as e:
        # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
        print(f"OPERAND_LIST_CAST_ERROR:{e}")
      try:
        op = TestOp.build_generic(
            results=[[None], None, t3],
            operands=[v0, None, None])
      except ValueError as e:
        # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
        print(f"RESULT_LIST_CAST_ERROR:{e}")

run(testOdsBuildDefaultSizedVariadic)


def testOdsBuildDefaultCastError():

  class TestOp(OpView):
    OPERATION_NAME = "custom.test_op"

  with Context() as ctx, Location.unknown():
    ctx.allow_unregistered_dialects = True
    m = Module.create()
    with InsertionPoint(m.body):
      v0 = add_dummy_value()
      v1 = add_dummy_value()
      t0 = IntegerType.get_signless(8)
      t1 = IntegerType.get_signless(16)
      try:
        op = TestOp.build_generic(
            results=[t0, t1],
            operands=[None, v1])
      except ValueError as e:
        # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
        print(f"ERROR: {e}")
      try:
        op = TestOp.build_generic(
            results=[t0, None],
            operands=[v0, v1])
      except ValueError as e:
        # CHECK: Result 1 of operation "custom.test_op" must be a Type
        print(f"ERROR: {e}")

run(testOdsBuildDefaultCastError)
