1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
# RUN: %PYTHON %s | FileCheck %s
from mlir import ir
from mlir.dialects.transform import interpreter as interp
def test_in_context(f):
with ir.Context(), ir.Location.unknown():
f()
return f
print_root_module = """
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.print %root { name = \"from interpreter\" }: !transform.any_op
transform.yield
}
}"""
@test_in_context
def print_self():
m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
interp.apply_named_sequence(m, m.body.operations[0], m)
# CHECK-LABEL: print_self
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.print
# CHECK: transform.yield
@test_in_context
def print_other():
transform = ir.Module.parse(
print_root_module.replace("from interpreter", "print_other")
)
payload = ir.Module.parse("module attributes { this.is.payload } {}")
interp.apply_named_sequence(payload, transform.body.operations[0], transform)
# CHECK-LABEL: print_other
# CHECK-NOT: transform
# CHECK: this.is.payload
@test_in_context
def transform_options():
options = interp.TransformOptions()
options.expensive_checks = False
options.enforce_single_top_level_transform_op = True
m = ir.Module.parse(
print_root_module.replace("from interpreter", "transform_options")
)
payload = ir.Module.parse("module attributes { this.is.payload } {}")
interp.apply_named_sequence(payload, m.body.operations[0], m, options)
# CHECK-LABEL: transform_options
@test_in_context
def failed():
payload = ir.Module.parse("module attributes { this.is.payload } {}")
try:
interp.apply_named_sequence(payload, payload, payload)
except ValueError as e:
assert (
"must implement TransformOpInterface to be used as transform root" in str(e)
)
print_root_via_include_module = """
module @print_root_via_include_module attributes {transform.with_named_sequence} {
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
transform.named_sequence @__transform_main(%root: !transform.any_op) {
transform.include @callee2 failures(propagate)
(%root) : (!transform.any_op) -> ()
transform.yield
}
}"""
callee2_definition = """
module attributes {transform.with_named_sequence} {
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
transform.include @callee1 failures(propagate)
(%root) : (!transform.any_op) -> ()
transform.yield
}
}
"""
callee1_definition = """
module attributes {transform.with_named_sequence} {
transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
transform.print %root { name = \"from interpreter\" }: !transform.any_op
transform.yield
}
}
"""
@test_in_context
def include():
main = ir.Module.parse(print_root_via_include_module)
callee1 = ir.Module.parse(callee1_definition)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee1)
interp.copy_symbols_and_merge_into(main, callee2)
# CHECK: @print_root_via_include_module
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.include @callee2
#
# CHECK: transform.named_sequence @callee1
# CHECK: transform.print
#
# CHECK: transform.named_sequence @callee2
# CHECK: transform.include @callee1
interp.apply_named_sequence(main, main.body.operations[0], main)
@test_in_context
def partial_include():
main = ir.Module.parse(print_root_via_include_module)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee2)
try:
interp.apply_named_sequence(main, main.body.operations[0], main)
except ValueError as e:
assert "Failed to apply" in str(e)
@test_in_context
def repeated_include():
main = ir.Module.parse(print_root_via_include_module)
callee2 = ir.Module.parse(callee2_definition)
interp.copy_symbols_and_merge_into(main, callee2)
try:
interp.copy_symbols_and_merge_into(main, callee2)
except ValueError as e:
assert "doubly defined symbol @callee2" in str(e)
|