# RUN: %PYTHON %s | FileCheck %s

import gc
import io
import itertools
from mlir.ir import *


def run(f):
  print("\nTEST:", f.__name__)
  f()
  gc.collect()
  assert Context._get_live_count() == 0
  return f


# CHECK-LABEL: TEST: testSymbolTableInsert
@run
def testSymbolTableInsert():
  with Context() as ctx:
    ctx.allow_unregistered_dialects = True
    m1 = Module.parse("""
      func private @foo()
      func private @bar()""")
    m2 = Module.parse("""
      func private @qux()
      func private @foo()
      "foo.bar"() : () -> ()""")

    symbol_table = SymbolTable(m1.operation)

    # CHECK: func private @foo
    # CHECK: func private @bar
    assert "foo" in symbol_table
    print(symbol_table["foo"])
    assert "bar" in symbol_table
    bar = symbol_table["bar"]
    print(symbol_table["bar"])

    assert "qux" not in symbol_table

    del symbol_table["bar"]
    try:
      symbol_table.erase(symbol_table["bar"])
    except KeyError:
      pass
    else:
      assert False, "expected KeyError"

    # CHECK: module
    # CHECK:   func private @foo()
    print(m1)
    assert "bar" not in symbol_table

    try:
      print(bar)
    except RuntimeError as e:
      if "the operation has been invalidated" not in str(e):
        raise
    else:
      assert False, "expected RuntimeError due to invalidated operation"

    qux = m2.body.operations[0]
    m1.body.append(qux)
    symbol_table.insert(qux)
    assert "qux" in symbol_table

    # Check that insertion actually renames this symbol in the symbol table.
    foo2 = m2.body.operations[0]
    m1.body.append(foo2)
    updated_name = symbol_table.insert(foo2)
    assert foo2.name.value != "foo"
    assert foo2.name == updated_name

    # CHECK: module
    # CHECK:   func private @foo()
    # CHECK:   func private @qux()
    # CHECK:   func private @foo{{.*}}
    print(m1)

    try:
      symbol_table.insert(m2.body.operations[0])
    except ValueError as e:
      if "Expected operation to have a symbol name" not in str(e):
        raise
    else:
      assert False, "exepcted ValueError when adding a non-symbol"


# CHECK-LABEL: testSymbolTableRAUW
@run
def testSymbolTableRAUW():
  with Context() as ctx:
    m = Module.parse("""
      func private @foo() {
        call @bar() : () -> ()
        return
      }
      func private @bar()
      """)
    foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
    SymbolTable.set_symbol_name(bar, "bam")
    # Note that module.operation counts as a "nested symbol table" which won't
    # be traversed into, so it is necessary to traverse its children.
    SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
    # CHECK: call @bam()
    # CHECK: func private @bam
    print(m)
    # CHECK: Foo symbol: "foo"
    # CHECK: Bar symbol: "bam"
    print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
    print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")


# CHECK-LABEL: testSymbolTableVisibility
@run
def testSymbolTableVisibility():
  with Context() as ctx:
    m = Module.parse("""
      func private @foo() {
        return
      }
      """)
    foo = m.operation.regions[0].blocks[0].operations[0]
    # CHECK: Existing visibility: "private"
    print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
    SymbolTable.set_visibility(foo, "public")
    # CHECK: func public @foo
    print(m)


# CHECK: testWalkSymbolTables
@run
def testWalkSymbolTables():
  with Context() as ctx:
    m = Module.parse("""
      module @outer {
        module @inner{
        }
      }
      """)
    def callback(symbol_table_op, uses_visible):
      print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
    # CHECK: SYMBOL TABLE: True: module @inner
    # CHECK: SYMBOL TABLE: True: module @outer
    SymbolTable.walk_symbol_tables(m.operation, True, callback)

    # Make sure exceptions in the callback are handled.
    def error_callback(symbol_table_op, uses_visible):
      assert False, "Raised from python"
    try:
      SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
    except RuntimeError as e:
      # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
      print(f"GOT EXCEPTION: {e}")

