# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.std as std


def run(f):
  print("\nTEST:", f.__name__)
  f()
  return f


# CHECK-LABEL: TEST: testFromPyFunc
@run
def testFromPyFunc():
  with Context() as ctx, Location.unknown() as loc:
    ctx.allow_unregistered_dialects = True
    m = builtin.ModuleOp()
    f32 = F32Type.get()
    f64 = F64Type.get()
    with InsertionPoint(m.body):
      # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
      # CHECK: return %arg0 : f64
      @builtin.FuncOp.from_py_func(f64)
      def unary_return(a):
        return a

      # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
      # CHECK: return %arg0, %arg1 : f32, f64
      @builtin.FuncOp.from_py_func(f32, f64)
      def binary_return(a, b):
        return a, b

      # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
      # CHECK: return
      @builtin.FuncOp.from_py_func(f32, f64)
      def none_return(a, b):
        pass

      # CHECK-LABEL: func @call_unary
      # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
      # CHECK: return %0 : f64
      @builtin.FuncOp.from_py_func(f64)
      def call_unary(a):
        return unary_return(a)

      # CHECK-LABEL: func @call_binary
      # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
      # CHECK: return %0#0, %0#1 : f32, f64
      @builtin.FuncOp.from_py_func(f32, f64)
      def call_binary(a, b):
        return binary_return(a, b)

      # We expect coercion of a single result operation to a returned value.
      # CHECK-LABEL: func @single_result_op
      # CHECK: %0 = "custom.op1"() : () -> f32
      # CHECK: return %0 : f32
      @builtin.FuncOp.from_py_func()
      def single_result_op():
        return Operation.create("custom.op1", results=[f32])

      # CHECK-LABEL: func @call_none
      # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
      # CHECK: return
      @builtin.FuncOp.from_py_func(f32, f64)
      def call_none(a, b):
        return none_return(a, b)

      ## Variants and optional feature tests.
      # CHECK-LABEL: func @from_name_arg
      @builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg")
      def explicit_name(a, b):
        return b

      @builtin.FuncOp.from_py_func(f32, f64)
      def positional_func_op(a, b, func_op):
        assert isinstance(func_op, builtin.FuncOp)
        return b

      @builtin.FuncOp.from_py_func(f32, f64)
      def kw_func_op(a, b=None, func_op=None):
        assert isinstance(func_op, builtin.FuncOp)
        return b

      @builtin.FuncOp.from_py_func(f32, f64)
      def kwargs_func_op(a, b=None, **kwargs):
        assert isinstance(kwargs["func_op"], builtin.FuncOp)
        return b

      # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
      # CHECK: return %arg1 : f64
      @builtin.FuncOp.from_py_func(f32, f64, results=[f64])
      def explicit_results(a, b):
        std.ReturnOp([b])

  print(m)


# CHECK-LABEL: TEST: testFromPyFuncErrors
@run
def testFromPyFuncErrors():
  with Context() as ctx, Location.unknown() as loc:
    m = builtin.ModuleOp()
    f32 = F32Type.get()
    f64 = F64Type.get()
    with InsertionPoint(m.body):
      try:

        @builtin.FuncOp.from_py_func(f64, results=[f64])
        def unary_return(a):
          return a
      except AssertionError as e:
        # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
        print(e)


# CHECK-LABEL: TEST: testBuildFuncOp
@run
def testBuildFuncOp():
  ctx = Context()
  with Location.unknown(ctx) as loc:
    m = builtin.ModuleOp()

    f32 = F32Type.get()
    tensor_type = RankedTensorType.get((2, 3, 4), f32)
    with InsertionPoint.at_block_begin(m.body):
      func = builtin.FuncOp(name="some_func",
                            type=FunctionType.get(
                                inputs=[tensor_type, tensor_type],
                                results=[tensor_type]),
                            visibility="nested")
      # CHECK: Name is: "some_func"
      print("Name is: ", func.name)

      # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
      print("Type is: ", func.type)

      # CHECK: Visibility is: "nested"
      print("Visibility is: ", func.visibility)

      try:
        entry_block = func.entry_block
      except IndexError as e:
        # CHECK: External function does not have a body
        print(e)

      with InsertionPoint(func.add_entry_block()):
        std.ReturnOp([func.entry_block.arguments[0]])
        pass

      try:
        func.add_entry_block()
      except IndexError as e:
        # CHECK: The function already has an entry block!
        print(e)

      # Try the callback builder and passing type as tuple.
      func = builtin.FuncOp(name="some_other_func",
                            type=([tensor_type, tensor_type], [tensor_type]),
                            visibility="nested",
                            body_builder=lambda func: std.ReturnOp(
                                [func.entry_block.arguments[0]]))

  # CHECK: module  {
  # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
  # CHECK:   return %arg0 : tensor<2x3x4xf32>
  # CHECK:  }
  # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
  # CHECK:   return %arg0 : tensor<2x3x4xf32>
  # CHECK:  }
  print(m)


# CHECK-LABEL: TEST: testFuncArgumentAccess
@run
def testFuncArgumentAccess():
  with Context() as ctx, Location.unknown():
    ctx.allow_unregistered_dialects = True
    module = Module.create()
    f32 = F32Type.get()
    f64 = F64Type.get()
    with InsertionPoint(module.body):
      func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32]))
      with InsertionPoint(func.add_entry_block()):
        std.ReturnOp(func.arguments)
      func.arg_attrs = ArrayAttr.get([
          DictAttr.get({
              "custom_dialect.foo": StringAttr.get("bar"),
              "custom_dialect.baz": UnitAttr.get()
          }),
          DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
      ])
      func.result_attrs = ArrayAttr.get([
          DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
          DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
      ])

      other = builtin.FuncOp("other_func", ([f32, f32], []))
      with InsertionPoint(other.add_entry_block()):
        std.ReturnOp([])
      other.arg_attrs = [
          DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
          DictAttr.get()
      ]

  # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
  print(func.arg_attrs)

  # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
  print(func.result_attrs)

  # CHECK: func @some_func(
  # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
  # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
  # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
  # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
  # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
  #
  # CHECK: func @other_func(
  # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
  # CHECK: %{{.*}}: f32)
  print(module)
