File: memref_abi.c

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (173 lines) | stat: -rw-r--r-- 5,819 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
// Split the MLIR string: this will produce %t/input.mlir
// RUN: split-file %s %t

// Compile the MLIR file to LLVM:
// RUN: mlir-opt %t/input.mlir \
// RUN:  -lower-affine  -convert-scf-to-cf  -finalize-memref-to-llvm \
// RUN:  -convert-func-to-llvm -reconcile-unrealized-casts \
// RUN: | mlir-translate --mlir-to-llvmir -o %t.ll

// Generate an object file for the MLIR code
// RUN: llc %t.ll -o %t.o -filetype=obj

// Compile the current C file and link it to the MLIR code:
// RUN: %host_cc %s %t.o -o %t.exe

// Exec
// RUN: %t.exe | FileCheck %s

/* MLIR_BEGIN
//--- input.mlir
// Performs: arg0[i, j] = arg0[i, j] + arg1[i, j]
func.func private @add_memref(%arg0: memref<?x?xf64>, %arg1: memref<?x?xf64>) -> i64
   attributes {llvm.emit_c_interface} {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %dimI = memref.dim %arg0, %c0 : memref<?x?xf64>
  %dimJ = memref.dim %arg0, %c1 : memref<?x?xf64>
  affine.for %i = 0 to %dimI {
    affine.for %j = 0 to %dimJ {
      %load0 = memref.load %arg0[%i, %j] : memref<?x?xf64>
      %load1 = memref.load %arg1[%i, %j] : memref<?x?xf64>
      %add = arith.addf %load0, %load1 : f64
      affine.store %add, %arg0[%i, %j] : memref<?x?xf64>
    }
  }
  %c42 = arith.constant 42 : i64
  return %c42 : i64
}

//--- end_input.mlir

MLIR_END */

#include <stdint.h>
#include <stdio.h>

// Define the API for the MLIR function, see
// https://mlir.llvm.org/docs/TargetLLVMIR/#calling-conventions for details.
//
// The function takes two 2D memref, the signature in MLIR LLVM dialect will be:
// llvm.func @add_memref(
//   // First Memref (%arg0)
//      %allocated_ptr0: !llvm.ptr<f64>, %aligned_ptr0: !llvm.ptr<f64>,
//      %offset0: i64, %size0_d0: i64, %size0_d1: i64, %stride0_d0: i64,
//      %stride0_d1: i64,
//   // Second Memref (%arg1)
//      %allocated_ptr1: !llvm.ptr<f64>, %aligned_ptr1: !llvm.ptr<f64>,
//      %offset1: i64, %size1_d0: i64, %size1_d1: i64, %stride1_d0: i64,
//      %stride1_d1: i64,
//
long long add_memref(double *allocated_ptr0, double *aligned_ptr0,
                     intptr_t offset0, intptr_t size0_d0, intptr_t size0_d1,
                     intptr_t stride0_d0, intptr_t stride0_d1,
                     // Second Memref (%arg1)
                     double *allocated_ptr1, double *aligned_ptr1,
                     intptr_t offset1, intptr_t size1_d0, intptr_t size1_d1,
                     intptr_t stride1_d0, intptr_t stride1_d1);

// The llvm.emit_c_interface will also trigger emission of another wrapper:
// llvm.func @_mlir_ciface_add_memref(
//   %arg0: !llvm.ptr<struct<(ptr<f64>, ptr<f64>, i64,
//                            array<2 x i64>, array<2 x i64>)>>,
//   %arg1: !llvm.ptr<struct<(ptr<f64>, ptr<f64>, i64,
//                            array<2 x i64>, array<2 x i64>)>>)
// -> i64
typedef struct {
  double *allocated;
  double *aligned;
  intptr_t offset;
  intptr_t size[2];
  intptr_t stride[2];
} memref_2d_descriptor;
long long _mlir_ciface_add_memref(memref_2d_descriptor *arg0,
                                  memref_2d_descriptor *arg1);

#define N 4
#define M 8
double arg0[N][M];
double arg1[N][M];

void dump() {
  for (int i = 0; i < N; i++) {
    printf("[");
    for (int j = 0; j < M; j++)
      printf("%d,\t", (int)arg0[i][j]);
    printf("] [");
    for (int j = 0; j < M; j++)
      printf("%d,\t", (int)arg1[i][j]);
    printf("]\n");
  }
}

int main() {
  int count = 0;
  for (int i = 0; i < N; i++) {
    for (int j = 0; j < M; j++) {
      arg0[i][j] = count++;
      arg1[i][j] = count++;
    }
  }
  printf("Before:\n");
  dump();
  // clang-format off
  // CHECK-LABEL: Before:
  // CHECK: [0,	  2,	4,	6,	8,	10,	12,	14,	] [1,	  3,	5, 7, 9,	11,	13,	15,	]
  // CHECK: [16,	18,	20,	22, 24, 26,	28,	30,	] [17,	19,	21,	23,	25,	27,	29, 31, ]
  // CHECK: [32,	34,	36,	38,	40,	42,	44,	46,	] [33,	35, 37, 39,	41,	43,	45,	47,	]
  // CHECK: [48,	50,	52, 54, 56,	58,	60,	62,	] [49,	51,	53,	55,	57,	59, 61, 63,	]
  // clang-format on

  // Call into MLIR.
  long long result = add_memref((double *)arg0, (double *)arg0, 0, N, M, M, 0,
                                //
                                (double *)arg1, (double *)arg1, 0, N, M, M, 0);

  // CHECK-LABEL: Result:
  // CHECK: 42
  printf("Result: %d\n", (int)result);

  printf("After:\n");
  dump();

  // clang-format off
  // CHECK-LABEL: After:
  // CHECK: [1,	  5,	  9,	  13,	 17,	21,	  25,	  29,	  ] [1, 3,	5,	7,	9,	11,	13,	15,	] 
  // CHECK: [33,	37,  41,	  45,	 49,	53,	  57,	  61,	  ] [17,	19,	21, 23, 25,	27,	29,	31,	]
  // CHECK: [65,	69,	  73,   77,	 81,	85,	  89,	  93,	  ] [33,	35,	37,	39, 41, 43,	45,	47,	]
  // CHECK: [97,	101,	105,	109, 113,	117,	121,	125,	] [49,	51,	53,	55,	57,	59, 61, 63,	]
  // clang-format on

  // Reset the input and re-apply the same function use the C API wrapper.
  count = 0;
  for (int i = 0; i < N; i++) {
    for (int j = 0; j < M; j++) {
      arg0[i][j] = count++;
      arg1[i][j] = count++;
    }
  }

  // Call into MLIR.
  memref_2d_descriptor arg0_descriptor = {
      (double *)arg0, (double *)arg0, 0, N, M, M, 0};
  memref_2d_descriptor arg1_descriptor = {
      (double *)arg1, (double *)arg1, 0, N, M, M, 0};
  result = _mlir_ciface_add_memref(&arg0_descriptor, &arg1_descriptor);

  // CHECK-LABEL: Result2:
  // CHECK: 42
  printf("Result2: %d\n", (int)result);

  printf("After2:\n");
  dump();

  // clang-format off
  // CHECK-LABEL: After2:
  // CHECK: [1,	  5,	  9,	  13,	 17,	21,	  25,	  29,	  ] [1, 3,	5,	7,	9,	11,	13,	15,	] 
  // CHECK: [33,	37,  41,	  45,	 49,	53,	  57,	  61,	  ] [17,	19,	21, 23, 25,	27,	29,	31,	]
  // CHECK: [65,	69,	  73,   77,	 81,	85,	  89,	  93,	  ] [33,	35,	37,	39, 41, 43,	45,	47,	]
  // CHECK: [97,	101,	105,	109, 113,	117,	121,	125,	] [49,	51,	53,	55,	57,	59, 61, 63,	]
  // clang-format on

  return 0;
}