1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
|
//===- pass.c - Simple test of C APIs -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
*/
#include "mlir-c/Pass.h"
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "mlir-c/Transforms.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
void testRunPassOnModule() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
"func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module)) {
fprintf(stderr, "Unexpected failure parsing module.\n");
exit(EXIT_FAILURE);
}
// Run the print-op-stats pass on the top-level module:
// CHECK-LABEL: Operations encountered:
// CHECK: arith.addi , 1
// CHECK: builtin.func , 1
// CHECK: std.return , 1
{
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirPassManagerAddOwnedPass(pm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running pass manager.\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
}
mlirModuleDestroy(module);
mlirContextDestroy(ctx);
}
void testRunPassOnNestedModule() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
MlirModule module =
mlirModuleCreateParse(ctx,
// clang-format off
mlirStringRefCreateFromCString(
"func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"} \n"
"module { \n"
" func @bar(%arg0 : f32) -> f32 { \n"
" %res = arith.addf %arg0, %arg0 : f32 \n"
" return %res : f32 \n"
" } \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module))
exit(1);
// Run the print-op-stats pass on functions under the top-level module:
// CHECK-LABEL: Operations encountered:
// CHECK: arith.addi , 1
// CHECK: builtin.func , 1
// CHECK: std.return , 1
{
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
pm, mlirStringRefCreateFromCString("builtin.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
}
// Run the print-op-stats pass on functions under the nested module:
// CHECK-LABEL: Operations encountered:
// CHECK: arith.addf , 1
// CHECK: builtin.func , 1
// CHECK: std.return , 1
{
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
pm, mlirStringRefCreateFromCString("builtin.module"));
MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
nestedModulePm, mlirStringRefCreateFromCString("builtin.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
}
mlirModuleDestroy(module);
mlirContextDestroy(ctx);
}
static void printToStderr(MlirStringRef str, void *userData) {
(void)userData;
fwrite(str.data, 1, str.length, stderr);
}
void testPrintPassPipeline() {
MlirContext ctx = mlirContextCreate();
MlirPassManager pm = mlirPassManagerCreate(ctx);
// Populate the pass-manager
MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
pm, mlirStringRefCreateFromCString("builtin.module"));
MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
nestedModulePm, mlirStringRefCreateFromCString("builtin.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
// Print the top level pass manager
// CHECK: Top-level: builtin.module(builtin.func(print-op-stats))
fprintf(stderr, "Top-level: ");
mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
NULL);
fprintf(stderr, "\n");
// Print the pipeline nested one level down
// CHECK: Nested Module: builtin.func(print-op-stats)
fprintf(stderr, "Nested Module: ");
mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
fprintf(stderr, "\n");
// Print the pipeline nested two levels down
// CHECK: Nested Module>Func: print-op-stats
fprintf(stderr, "Nested Module>Func: ");
mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
fprintf(stderr, "\n");
mlirPassManagerDestroy(pm);
mlirContextDestroy(ctx);
}
void testParsePassPipeline() {
MlirContext ctx = mlirContextCreate();
MlirPassManager pm = mlirPassManagerCreate(ctx);
// Try parse a pipeline.
MlirLogicalResult status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(pm),
mlirStringRefCreateFromCString(
"builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsSuccess(status)) {
fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
exit(EXIT_FAILURE);
}
// Try again after registrating the pass.
mlirRegisterTransformsPrintOpStats();
status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(pm),
mlirStringRefCreateFromCString(
"builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
exit(EXIT_FAILURE);
}
// CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))
fprintf(stderr, "Round-trip: ");
mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
NULL);
fprintf(stderr, "\n");
mlirPassManagerDestroy(pm);
mlirContextDestroy(ctx);
}
int main() {
testRunPassOnModule();
testRunPassOnNestedModule();
testPrintPassPipeline();
testParsePassPipeline();
return 0;
}
|