1 //===- ir.c - Simple test of C APIs ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9
10 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
11 */
12
13 #include "mlir-c/IR.h"
14 #include "mlir-c/AffineExpr.h"
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/Registration.h"
20 #include "mlir-c/StandardDialect.h"
21
22 #include <assert.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27
populateLoopBody(MlirContext ctx,MlirBlock loopBody,MlirLocation location,MlirBlock funcBody)28 void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
29 MlirLocation location, MlirBlock funcBody) {
30 MlirValue iv = mlirBlockGetArgument(loopBody, 0);
31 MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
32 MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
33 MlirType f32Type =
34 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
35
36 MlirOperationState loadLHSState = mlirOperationStateGet(
37 mlirStringRefCreateFromCString("std.load"), location);
38 MlirValue loadLHSOperands[] = {funcArg0, iv};
39 mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
40 mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
41 MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
42 mlirBlockAppendOwnedOperation(loopBody, loadLHS);
43
44 MlirOperationState loadRHSState = mlirOperationStateGet(
45 mlirStringRefCreateFromCString("std.load"), location);
46 MlirValue loadRHSOperands[] = {funcArg1, iv};
47 mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
48 mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
49 MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
50 mlirBlockAppendOwnedOperation(loopBody, loadRHS);
51
52 MlirOperationState addState = mlirOperationStateGet(
53 mlirStringRefCreateFromCString("std.addf"), location);
54 MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
55 mlirOperationGetResult(loadRHS, 0)};
56 mlirOperationStateAddOperands(&addState, 2, addOperands);
57 mlirOperationStateAddResults(&addState, 1, &f32Type);
58 MlirOperation add = mlirOperationCreate(&addState);
59 mlirBlockAppendOwnedOperation(loopBody, add);
60
61 MlirOperationState storeState = mlirOperationStateGet(
62 mlirStringRefCreateFromCString("std.store"), location);
63 MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
64 mlirOperationStateAddOperands(&storeState, 3, storeOperands);
65 MlirOperation store = mlirOperationCreate(&storeState);
66 mlirBlockAppendOwnedOperation(loopBody, store);
67
68 MlirOperationState yieldState = mlirOperationStateGet(
69 mlirStringRefCreateFromCString("scf.yield"), location);
70 MlirOperation yield = mlirOperationCreate(&yieldState);
71 mlirBlockAppendOwnedOperation(loopBody, yield);
72 }
73
makeAndDumpAdd(MlirContext ctx,MlirLocation location)74 MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
75 MlirModule moduleOp = mlirModuleCreateEmpty(location);
76 MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
77
78 MlirType memrefType =
79 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
80 MlirType funcBodyArgTypes[] = {memrefType, memrefType};
81 MlirRegion funcBodyRegion = mlirRegionCreate();
82 MlirBlock funcBody = mlirBlockCreate(
83 sizeof(funcBodyArgTypes) / sizeof(MlirType), funcBodyArgTypes);
84 mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
85
86 MlirAttribute funcTypeAttr = mlirAttributeParseGet(
87 ctx,
88 mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
89 MlirAttribute funcNameAttr =
90 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
91 MlirNamedAttribute funcAttrs[] = {
92 mlirNamedAttributeGet(mlirStringRefCreateFromCString("type"),
93 funcTypeAttr),
94 mlirNamedAttributeGet(mlirStringRefCreateFromCString("sym_name"),
95 funcNameAttr)};
96 MlirOperationState funcState =
97 mlirOperationStateGet(mlirStringRefCreateFromCString("func"), location);
98 mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
99 mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
100 MlirOperation func = mlirOperationCreate(&funcState);
101 mlirBlockInsertOwnedOperation(moduleBody, 0, func);
102
103 MlirType indexType =
104 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
105 MlirAttribute indexZeroLiteral =
106 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
107 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
108 mlirStringRefCreateFromCString("value"), indexZeroLiteral);
109 MlirOperationState constZeroState = mlirOperationStateGet(
110 mlirStringRefCreateFromCString("std.constant"), location);
111 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
112 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
113 MlirOperation constZero = mlirOperationCreate(&constZeroState);
114 mlirBlockAppendOwnedOperation(funcBody, constZero);
115
116 MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
117 MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
118 MlirValue dimOperands[] = {funcArg0, constZeroValue};
119 MlirOperationState dimState = mlirOperationStateGet(
120 mlirStringRefCreateFromCString("std.dim"), location);
121 mlirOperationStateAddOperands(&dimState, 2, dimOperands);
122 mlirOperationStateAddResults(&dimState, 1, &indexType);
123 MlirOperation dim = mlirOperationCreate(&dimState);
124 mlirBlockAppendOwnedOperation(funcBody, dim);
125
126 MlirRegion loopBodyRegion = mlirRegionCreate();
127 MlirBlock loopBody = mlirBlockCreate(/*nArgs=*/1, &indexType);
128 mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
129
130 MlirAttribute indexOneLiteral =
131 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
132 MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
133 mlirStringRefCreateFromCString("value"), indexOneLiteral);
134 MlirOperationState constOneState = mlirOperationStateGet(
135 mlirStringRefCreateFromCString("std.constant"), location);
136 mlirOperationStateAddResults(&constOneState, 1, &indexType);
137 mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
138 MlirOperation constOne = mlirOperationCreate(&constOneState);
139 mlirBlockAppendOwnedOperation(funcBody, constOne);
140
141 MlirValue dimValue = mlirOperationGetResult(dim, 0);
142 MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
143 MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
144 MlirOperationState loopState = mlirOperationStateGet(
145 mlirStringRefCreateFromCString("scf.for"), location);
146 mlirOperationStateAddOperands(&loopState, 3, loopOperands);
147 mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
148 MlirOperation loop = mlirOperationCreate(&loopState);
149 mlirBlockAppendOwnedOperation(funcBody, loop);
150
151 populateLoopBody(ctx, loopBody, location, funcBody);
152
153 MlirOperationState retState = mlirOperationStateGet(
154 mlirStringRefCreateFromCString("std.return"), location);
155 MlirOperation ret = mlirOperationCreate(&retState);
156 mlirBlockAppendOwnedOperation(funcBody, ret);
157
158 MlirOperation module = mlirModuleGetOperation(moduleOp);
159 mlirOperationDump(module);
160 // clang-format off
161 // CHECK: module {
162 // CHECK: func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
163 // CHECK: %[[C0:.*]] = constant 0 : index
164 // CHECK: %[[DIM:.*]] = dim %[[ARG0]], %[[C0]] : memref<?xf32>
165 // CHECK: %[[C1:.*]] = constant 1 : index
166 // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
167 // CHECK: %[[LHS:.*]] = load %[[ARG0]][%[[I]]] : memref<?xf32>
168 // CHECK: %[[RHS:.*]] = load %[[ARG1]][%[[I]]] : memref<?xf32>
169 // CHECK: %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
170 // CHECK: store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
171 // CHECK: }
172 // CHECK: return
173 // CHECK: }
174 // CHECK: }
175 // clang-format on
176
177 return moduleOp;
178 }
179
180 struct OpListNode {
181 MlirOperation op;
182 struct OpListNode *next;
183 };
184 typedef struct OpListNode OpListNode;
185
186 struct ModuleStats {
187 unsigned numOperations;
188 unsigned numAttributes;
189 unsigned numBlocks;
190 unsigned numRegions;
191 unsigned numValues;
192 unsigned numBlockArguments;
193 unsigned numOpResults;
194 };
195 typedef struct ModuleStats ModuleStats;
196
collectStatsSingle(OpListNode * head,ModuleStats * stats)197 int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
198 MlirOperation operation = head->op;
199 stats->numOperations += 1;
200 stats->numValues += mlirOperationGetNumResults(operation);
201 stats->numAttributes += mlirOperationGetNumAttributes(operation);
202
203 unsigned numRegions = mlirOperationGetNumRegions(operation);
204
205 stats->numRegions += numRegions;
206
207 intptr_t numResults = mlirOperationGetNumResults(operation);
208 for (intptr_t i = 0; i < numResults; ++i) {
209 MlirValue result = mlirOperationGetResult(operation, i);
210 if (!mlirValueIsAOpResult(result))
211 return 1;
212 if (mlirValueIsABlockArgument(result))
213 return 2;
214 if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
215 return 3;
216 if (i != mlirOpResultGetResultNumber(result))
217 return 4;
218 ++stats->numOpResults;
219 }
220
221 for (unsigned i = 0; i < numRegions; ++i) {
222 MlirRegion region = mlirOperationGetRegion(operation, i);
223 for (MlirBlock block = mlirRegionGetFirstBlock(region);
224 !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
225 ++stats->numBlocks;
226 intptr_t numArgs = mlirBlockGetNumArguments(block);
227 stats->numValues += numArgs;
228 for (intptr_t j = 0; j < numArgs; ++j) {
229 MlirValue arg = mlirBlockGetArgument(block, j);
230 if (!mlirValueIsABlockArgument(arg))
231 return 5;
232 if (mlirValueIsAOpResult(arg))
233 return 6;
234 if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
235 return 7;
236 if (j != mlirBlockArgumentGetArgNumber(arg))
237 return 8;
238 ++stats->numBlockArguments;
239 }
240
241 for (MlirOperation child = mlirBlockGetFirstOperation(block);
242 !mlirOperationIsNull(child);
243 child = mlirOperationGetNextInBlock(child)) {
244 OpListNode *node = malloc(sizeof(OpListNode));
245 node->op = child;
246 node->next = head->next;
247 head->next = node;
248 }
249 }
250 }
251 return 0;
252 }
253
collectStats(MlirOperation operation)254 int collectStats(MlirOperation operation) {
255 OpListNode *head = malloc(sizeof(OpListNode));
256 head->op = operation;
257 head->next = NULL;
258
259 ModuleStats stats;
260 stats.numOperations = 0;
261 stats.numAttributes = 0;
262 stats.numBlocks = 0;
263 stats.numRegions = 0;
264 stats.numValues = 0;
265 stats.numBlockArguments = 0;
266 stats.numOpResults = 0;
267
268 do {
269 int retval = collectStatsSingle(head, &stats);
270 if (retval)
271 return retval;
272 OpListNode *next = head->next;
273 free(head);
274 head = next;
275 } while (head);
276
277 if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
278 return 100;
279
280 fprintf(stderr, "@stats\n");
281 fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
282 fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
283 fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
284 fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
285 fprintf(stderr, "Number of values: %u\n", stats.numValues);
286 fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
287 fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
288 // clang-format off
289 // CHECK-LABEL: @stats
290 // CHECK: Number of operations: 13
291 // CHECK: Number of attributes: 4
292 // CHECK: Number of blocks: 3
293 // CHECK: Number of regions: 3
294 // CHECK: Number of values: 9
295 // CHECK: Number of block arguments: 3
296 // CHECK: Number of op results: 6
297 // clang-format on
298 return 0;
299 }
300
printToStderr(MlirStringRef str,void * userData)301 static void printToStderr(MlirStringRef str, void *userData) {
302 (void)userData;
303 fwrite(str.data, 1, str.length, stderr);
304 }
305
printFirstOfEach(MlirContext ctx,MlirOperation operation)306 static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
307 // Assuming we are given a module, go to the first operation of the first
308 // function.
309 MlirRegion region = mlirOperationGetRegion(operation, 0);
310 MlirBlock block = mlirRegionGetFirstBlock(region);
311 operation = mlirBlockGetFirstOperation(block);
312 region = mlirOperationGetRegion(operation, 0);
313 MlirOperation parentOperation = operation;
314 block = mlirRegionGetFirstBlock(region);
315 operation = mlirBlockGetFirstOperation(block);
316
317 // Verify that parent operation and block report correctly.
318 fprintf(stderr, "Parent operation eq: %d\n",
319 mlirOperationEqual(mlirOperationGetParentOperation(operation),
320 parentOperation));
321 fprintf(stderr, "Block eq: %d\n",
322 mlirBlockEqual(mlirOperationGetBlock(operation), block));
323 // CHECK: Parent operation eq: 1
324 // CHECK: Block eq: 1
325
326 // In the module we created, the first operation of the first function is
327 // an "std.dim", which has an attribute and a single result that we can
328 // use to test the printing mechanism.
329 mlirBlockPrint(block, printToStderr, NULL);
330 fprintf(stderr, "\n");
331 fprintf(stderr, "First operation: ");
332 mlirOperationPrint(operation, printToStderr, NULL);
333 fprintf(stderr, "\n");
334 // clang-format off
335 // CHECK: %[[C0:.*]] = constant 0 : index
336 // CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
337 // CHECK: %[[C1:.*]] = constant 1 : index
338 // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
339 // CHECK: %[[LHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
340 // CHECK: %[[RHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
341 // CHECK: %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
342 // CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
343 // CHECK: }
344 // CHECK: return
345 // CHECK: First operation: {{.*}} = constant 0 : index
346 // clang-format on
347
348 // Get the operation name and print it.
349 MlirIdentifier ident = mlirOperationGetName(operation);
350 MlirStringRef identStr = mlirIdentifierStr(ident);
351 fprintf(stderr, "Operation name: '");
352 for (size_t i = 0; i < identStr.length; ++i)
353 fputc(identStr.data[i], stderr);
354 fprintf(stderr, "'\n");
355 // CHECK: Operation name: 'std.constant'
356
357 // Get the identifier again and verify equal.
358 MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
359 fprintf(stderr, "Identifier equal: %d\n",
360 mlirIdentifierEqual(ident, identAgain));
361 // CHECK: Identifier equal: 1
362
363 // Get the block terminator and print it.
364 MlirOperation terminator = mlirBlockGetTerminator(block);
365 fprintf(stderr, "Terminator: ");
366 mlirOperationPrint(terminator, printToStderr, NULL);
367 fprintf(stderr, "\n");
368 // CHECK: Terminator: return
369
370 // Get the attribute by index.
371 MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
372 fprintf(stderr, "Get attr 0: ");
373 mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
374 fprintf(stderr, "\n");
375 // CHECK: Get attr 0: 0 : index
376
377 // Now re-get the attribute by name.
378 MlirAttribute attr0ByName =
379 mlirOperationGetAttributeByName(operation, namedAttr0.name);
380 fprintf(stderr, "Get attr 0 by name: ");
381 mlirAttributePrint(attr0ByName, printToStderr, NULL);
382 fprintf(stderr, "\n");
383 // CHECK: Get attr 0 by name: 0 : index
384
385 // Get a non-existing attribute and assert that it is null (sanity).
386 fprintf(stderr, "does_not_exist is null: %d\n",
387 mlirAttributeIsNull(mlirOperationGetAttributeByName(
388 operation, mlirStringRefCreateFromCString("does_not_exist"))));
389 // CHECK: does_not_exist is null: 1
390
391 // Get result 0 and its type.
392 MlirValue value = mlirOperationGetResult(operation, 0);
393 fprintf(stderr, "Result 0: ");
394 mlirValuePrint(value, printToStderr, NULL);
395 fprintf(stderr, "\n");
396 fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
397 // CHECK: Result 0: {{.*}} = constant 0 : index
398 // CHECK: Value is null: 0
399
400 MlirType type = mlirValueGetType(value);
401 fprintf(stderr, "Result 0 type: ");
402 mlirTypePrint(type, printToStderr, NULL);
403 fprintf(stderr, "\n");
404 // CHECK: Result 0 type: index
405
406 // Set a custom attribute.
407 mlirOperationSetAttributeByName(operation,
408 mlirStringRefCreateFromCString("custom_attr"),
409 mlirBoolAttrGet(ctx, 1));
410 fprintf(stderr, "Op with set attr: ");
411 mlirOperationPrint(operation, printToStderr, NULL);
412 fprintf(stderr, "\n");
413 // CHECK: Op with set attr: {{.*}} {custom_attr = true}
414
415 // Remove the attribute.
416 fprintf(stderr, "Remove attr: %d\n",
417 mlirOperationRemoveAttributeByName(
418 operation, mlirStringRefCreateFromCString("custom_attr")));
419 fprintf(stderr, "Remove attr again: %d\n",
420 mlirOperationRemoveAttributeByName(
421 operation, mlirStringRefCreateFromCString("custom_attr")));
422 fprintf(stderr, "Removed attr is null: %d\n",
423 mlirAttributeIsNull(mlirOperationGetAttributeByName(
424 operation, mlirStringRefCreateFromCString("custom_attr"))));
425 // CHECK: Remove attr: 1
426 // CHECK: Remove attr again: 0
427 // CHECK: Removed attr is null: 1
428
429 // Add a large attribute to verify printing flags.
430 int64_t eltsShape[] = {4};
431 int32_t eltsData[] = {1, 2, 3, 4};
432 mlirOperationSetAttributeByName(
433 operation, mlirStringRefCreateFromCString("elts"),
434 mlirDenseElementsAttrInt32Get(
435 mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4,
436 eltsData));
437 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
438 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
439 mlirOpPrintingFlagsPrintGenericOpForm(flags);
440 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
441 mlirOpPrintingFlagsUseLocalScope(flags);
442 fprintf(stderr, "Op print with all flags: ");
443 mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
444 fprintf(stderr, "\n");
445 // clang-format off
446 // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
447 // clang-format on
448
449 mlirOpPrintingFlagsDestroy(flags);
450 }
451
constructAndTraverseIr(MlirContext ctx)452 static int constructAndTraverseIr(MlirContext ctx) {
453 MlirLocation location = mlirLocationUnknownGet(ctx);
454
455 MlirModule moduleOp = makeAndDumpAdd(ctx, location);
456 MlirOperation module = mlirModuleGetOperation(moduleOp);
457
458 int errcode = collectStats(module);
459 if (errcode)
460 return errcode;
461
462 printFirstOfEach(ctx, module);
463
464 mlirModuleDestroy(moduleOp);
465 return 0;
466 }
467
468 /// Creates an operation with a region containing multiple blocks with
469 /// operations and dumps it. The blocks and operations are inserted using
470 /// block/operation-relative API and their final order is checked.
buildWithInsertionsAndPrint(MlirContext ctx)471 static void buildWithInsertionsAndPrint(MlirContext ctx) {
472 MlirLocation loc = mlirLocationUnknownGet(ctx);
473
474 MlirRegion owningRegion = mlirRegionCreate();
475 MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
476 MlirOperationState state = mlirOperationStateGet(
477 mlirStringRefCreateFromCString("insertion.order.test"), loc);
478 mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
479 MlirOperation op = mlirOperationCreate(&state);
480 MlirRegion region = mlirOperationGetRegion(op, 0);
481
482 // Use integer types of different bitwidth as block arguments in order to
483 // differentiate blocks.
484 MlirType i1 = mlirIntegerTypeGet(ctx, 1);
485 MlirType i2 = mlirIntegerTypeGet(ctx, 2);
486 MlirType i3 = mlirIntegerTypeGet(ctx, 3);
487 MlirType i4 = mlirIntegerTypeGet(ctx, 4);
488 MlirBlock block1 = mlirBlockCreate(1, &i1);
489 MlirBlock block2 = mlirBlockCreate(1, &i2);
490 MlirBlock block3 = mlirBlockCreate(1, &i3);
491 MlirBlock block4 = mlirBlockCreate(1, &i4);
492 // Insert blocks so as to obtain the 1-2-3-4 order,
493 mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
494 mlirRegionInsertOwnedBlockBefore(region, block3, block2);
495 mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
496 mlirRegionInsertOwnedBlockAfter(region, block3, block4);
497
498 MlirOperationState op1State =
499 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
500 MlirOperationState op2State =
501 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
502 MlirOperationState op3State =
503 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
504 MlirOperationState op4State =
505 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
506 MlirOperationState op5State =
507 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
508 MlirOperationState op6State =
509 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
510 MlirOperationState op7State =
511 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
512 MlirOperation op1 = mlirOperationCreate(&op1State);
513 MlirOperation op2 = mlirOperationCreate(&op2State);
514 MlirOperation op3 = mlirOperationCreate(&op3State);
515 MlirOperation op4 = mlirOperationCreate(&op4State);
516 MlirOperation op5 = mlirOperationCreate(&op5State);
517 MlirOperation op6 = mlirOperationCreate(&op6State);
518 MlirOperation op7 = mlirOperationCreate(&op7State);
519
520 // Insert operations in the first block so as to obtain the 1-2-3-4 order.
521 MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
522 assert(mlirOperationIsNull(nullOperation));
523 mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
524 mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
525 mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
526 mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
527
528 // Append operations to the rest of blocks to make them non-empty and thus
529 // printable.
530 mlirBlockAppendOwnedOperation(block2, op5);
531 mlirBlockAppendOwnedOperation(block3, op6);
532 mlirBlockAppendOwnedOperation(block4, op7);
533
534 mlirOperationDump(op);
535 mlirOperationDestroy(op);
536 // clang-format off
537 // CHECK-LABEL: "insertion.order.test"
538 // CHECK: ^{{.*}}(%{{.*}}: i1
539 // CHECK: "dummy.op1"
540 // CHECK-NEXT: "dummy.op2"
541 // CHECK-NEXT: "dummy.op3"
542 // CHECK-NEXT: "dummy.op4"
543 // CHECK: ^{{.*}}(%{{.*}}: i2
544 // CHECK: "dummy.op5"
545 // CHECK: ^{{.*}}(%{{.*}}: i3
546 // CHECK: "dummy.op6"
547 // CHECK: ^{{.*}}(%{{.*}}: i4
548 // CHECK: "dummy.op7"
549 // clang-format on
550 }
551
552 /// Dumps instances of all builtin types to check that C API works correctly.
553 /// Additionally, performs simple identity checks that a builtin type
554 /// constructed with C API can be inspected and has the expected type. The
555 /// latter achieves full coverage of C API for builtin types. Returns 0 on
556 /// success and a non-zero error code on failure.
printBuiltinTypes(MlirContext ctx)557 static int printBuiltinTypes(MlirContext ctx) {
558 // Integer types.
559 MlirType i32 = mlirIntegerTypeGet(ctx, 32);
560 MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
561 MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
562 if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
563 return 1;
564 if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
565 return 2;
566 if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
567 return 3;
568 if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
569 return 4;
570 if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
571 return 5;
572 fprintf(stderr, "@types\n");
573 mlirTypeDump(i32);
574 fprintf(stderr, "\n");
575 mlirTypeDump(si32);
576 fprintf(stderr, "\n");
577 mlirTypeDump(ui32);
578 fprintf(stderr, "\n");
579 // CHECK-LABEL: @types
580 // CHECK: i32
581 // CHECK: si32
582 // CHECK: ui32
583
584 // Index type.
585 MlirType index = mlirIndexTypeGet(ctx);
586 if (!mlirTypeIsAIndex(index))
587 return 6;
588 mlirTypeDump(index);
589 fprintf(stderr, "\n");
590 // CHECK: index
591
592 // Floating-point types.
593 MlirType bf16 = mlirBF16TypeGet(ctx);
594 MlirType f16 = mlirF16TypeGet(ctx);
595 MlirType f32 = mlirF32TypeGet(ctx);
596 MlirType f64 = mlirF64TypeGet(ctx);
597 if (!mlirTypeIsABF16(bf16))
598 return 7;
599 if (!mlirTypeIsAF16(f16))
600 return 9;
601 if (!mlirTypeIsAF32(f32))
602 return 10;
603 if (!mlirTypeIsAF64(f64))
604 return 11;
605 mlirTypeDump(bf16);
606 fprintf(stderr, "\n");
607 mlirTypeDump(f16);
608 fprintf(stderr, "\n");
609 mlirTypeDump(f32);
610 fprintf(stderr, "\n");
611 mlirTypeDump(f64);
612 fprintf(stderr, "\n");
613 // CHECK: bf16
614 // CHECK: f16
615 // CHECK: f32
616 // CHECK: f64
617
618 // None type.
619 MlirType none = mlirNoneTypeGet(ctx);
620 if (!mlirTypeIsANone(none))
621 return 12;
622 mlirTypeDump(none);
623 fprintf(stderr, "\n");
624 // CHECK: none
625
626 // Complex type.
627 MlirType cplx = mlirComplexTypeGet(f32);
628 if (!mlirTypeIsAComplex(cplx) ||
629 !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
630 return 13;
631 mlirTypeDump(cplx);
632 fprintf(stderr, "\n");
633 // CHECK: complex<f32>
634
635 // Vector (and Shaped) type. ShapedType is a common base class for vectors,
636 // memrefs and tensors, one cannot create instances of this class so it is
637 // tested on an instance of vector type.
638 int64_t shape[] = {2, 3};
639 MlirType vector =
640 mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
641 if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
642 return 14;
643 if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
644 !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
645 mlirShapedTypeGetDimSize(vector, 0) != 2 ||
646 mlirShapedTypeIsDynamicDim(vector, 0) ||
647 mlirShapedTypeGetDimSize(vector, 1) != 3 ||
648 !mlirShapedTypeHasStaticShape(vector))
649 return 15;
650 mlirTypeDump(vector);
651 fprintf(stderr, "\n");
652 // CHECK: vector<2x3xf32>
653
654 // Ranked tensor type.
655 MlirType rankedTensor =
656 mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
657 if (!mlirTypeIsATensor(rankedTensor) ||
658 !mlirTypeIsARankedTensor(rankedTensor))
659 return 16;
660 mlirTypeDump(rankedTensor);
661 fprintf(stderr, "\n");
662 // CHECK: tensor<2x3xf32>
663
664 // Unranked tensor type.
665 MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
666 if (!mlirTypeIsATensor(unrankedTensor) ||
667 !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
668 mlirShapedTypeHasRank(unrankedTensor))
669 return 17;
670 mlirTypeDump(unrankedTensor);
671 fprintf(stderr, "\n");
672 // CHECK: tensor<*xf32>
673
674 // MemRef type.
675 MlirType memRef = mlirMemRefTypeContiguousGet(
676 f32, sizeof(shape) / sizeof(int64_t), shape, 2);
677 if (!mlirTypeIsAMemRef(memRef) ||
678 mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
679 mlirMemRefTypeGetMemorySpace(memRef) != 2)
680 return 18;
681 mlirTypeDump(memRef);
682 fprintf(stderr, "\n");
683 // CHECK: memref<2x3xf32, 2>
684
685 // Unranked MemRef type.
686 MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4);
687 if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
688 mlirTypeIsAMemRef(unrankedMemRef) ||
689 mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4)
690 return 19;
691 mlirTypeDump(unrankedMemRef);
692 fprintf(stderr, "\n");
693 // CHECK: memref<*xf32, 4>
694
695 // Tuple type.
696 MlirType types[] = {unrankedMemRef, f32};
697 MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
698 if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
699 !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
700 !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
701 return 20;
702 mlirTypeDump(tuple);
703 fprintf(stderr, "\n");
704 // CHECK: tuple<memref<*xf32, 4>, f32>
705
706 // Function type.
707 MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
708 MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
709 mlirIntegerTypeGet(ctx, 32),
710 mlirIntegerTypeGet(ctx, 64)};
711 MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
712 if (mlirFunctionTypeGetNumInputs(funcType) != 2)
713 return 21;
714 if (mlirFunctionTypeGetNumResults(funcType) != 3)
715 return 22;
716 if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
717 !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
718 return 23;
719 if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
720 !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
721 !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
722 return 24;
723 mlirTypeDump(funcType);
724 fprintf(stderr, "\n");
725 // CHECK: (index, i1) -> (i16, i32, i64)
726
727 return 0;
728 }
729
callbackSetFixedLengthString(const char * data,intptr_t len,void * userData)730 void callbackSetFixedLengthString(const char *data, intptr_t len,
731 void *userData) {
732 strncpy(userData, data, len);
733 }
734
stringIsEqual(const char * lhs,MlirStringRef rhs)735 bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
736 if (strlen(lhs) != rhs.length) {
737 return false;
738 }
739 return !strncmp(lhs, rhs.data, rhs.length);
740 }
741
printBuiltinAttributes(MlirContext ctx)742 int printBuiltinAttributes(MlirContext ctx) {
743 MlirAttribute floating =
744 mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
745 if (!mlirAttributeIsAFloat(floating) ||
746 fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
747 return 1;
748 fprintf(stderr, "@attrs\n");
749 mlirAttributeDump(floating);
750 // CHECK-LABEL: @attrs
751 // CHECK: 2.000000e+00 : f64
752
753 // Exercise mlirAttributeGetType() just for the first one.
754 MlirType floatingType = mlirAttributeGetType(floating);
755 mlirTypeDump(floatingType);
756 // CHECK: f64
757
758 MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
759 if (!mlirAttributeIsAInteger(integer) ||
760 mlirIntegerAttrGetValueInt(integer) != 42)
761 return 2;
762 mlirAttributeDump(integer);
763 // CHECK: 42 : i32
764
765 MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
766 if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
767 return 3;
768 mlirAttributeDump(boolean);
769 // CHECK: true
770
771 const char data[] = "abcdefghijklmnopqestuvwxyz";
772 MlirAttribute opaque =
773 mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("std"), 3, data,
774 mlirNoneTypeGet(ctx));
775 if (!mlirAttributeIsAOpaque(opaque) ||
776 !stringIsEqual("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
777 return 4;
778
779 MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
780 if (opaqueData.length != 3 ||
781 strncmp(data, opaqueData.data, opaqueData.length))
782 return 5;
783 mlirAttributeDump(opaque);
784 // CHECK: #std.abc
785
786 MlirAttribute string =
787 mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
788 if (!mlirAttributeIsAString(string))
789 return 6;
790
791 MlirStringRef stringValue = mlirStringAttrGetValue(string);
792 if (stringValue.length != 2 ||
793 strncmp(data + 3, stringValue.data, stringValue.length))
794 return 7;
795 mlirAttributeDump(string);
796 // CHECK: "de"
797
798 MlirAttribute flatSymbolRef =
799 mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
800 if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
801 return 8;
802
803 MlirStringRef flatSymbolRefValue =
804 mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
805 if (flatSymbolRefValue.length != 3 ||
806 strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
807 return 9;
808 mlirAttributeDump(flatSymbolRef);
809 // CHECK: @fgh
810
811 MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
812 MlirAttribute symbolRef =
813 mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
814 if (!mlirAttributeIsASymbolRef(symbolRef) ||
815 mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
816 !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
817 flatSymbolRef) ||
818 !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
819 flatSymbolRef))
820 return 10;
821
822 MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
823 MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
824 if (symbolRefLeaf.length != 3 ||
825 strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
826 symbolRefRoot.length != 2 ||
827 strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
828 return 11;
829 mlirAttributeDump(symbolRef);
830 // CHECK: @ij::@fgh::@fgh
831
832 MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
833 if (!mlirAttributeIsAType(type) ||
834 !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
835 return 12;
836 mlirAttributeDump(type);
837 // CHECK: f32
838
839 MlirAttribute unit = mlirUnitAttrGet(ctx);
840 if (!mlirAttributeIsAUnit(unit))
841 return 13;
842 mlirAttributeDump(unit);
843 // CHECK: unit
844
845 int64_t shape[] = {1, 2};
846
847 int bools[] = {0, 1};
848 uint32_t uints32[] = {0u, 1u};
849 int32_t ints32[] = {0, 1};
850 uint64_t uints64[] = {0u, 1u};
851 int64_t ints64[] = {0, 1};
852 float floats[] = {0.0f, 1.0f};
853 double doubles[] = {0.0, 1.0};
854 MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
855 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools);
856 MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
857 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2,
858 uints32);
859 MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
860 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2,
861 ints32);
862 MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
863 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2,
864 uints64);
865 MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
866 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2,
867 ints64);
868 MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
869 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats);
870 MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
871 mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles);
872
873 if (!mlirAttributeIsADenseElements(boolElements) ||
874 !mlirAttributeIsADenseElements(uint32Elements) ||
875 !mlirAttributeIsADenseElements(int32Elements) ||
876 !mlirAttributeIsADenseElements(uint64Elements) ||
877 !mlirAttributeIsADenseElements(int64Elements) ||
878 !mlirAttributeIsADenseElements(floatElements) ||
879 !mlirAttributeIsADenseElements(doubleElements))
880 return 14;
881
882 if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
883 mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
884 mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
885 mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
886 mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
887 fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
888 1E-6f ||
889 fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
890 return 15;
891
892 mlirAttributeDump(boolElements);
893 mlirAttributeDump(uint32Elements);
894 mlirAttributeDump(int32Elements);
895 mlirAttributeDump(uint64Elements);
896 mlirAttributeDump(int64Elements);
897 mlirAttributeDump(floatElements);
898 mlirAttributeDump(doubleElements);
899 // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
900 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
901 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
902 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
903 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
904 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
905 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
906
907 MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
908 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1);
909 MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
910 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
911 MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
912 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
913 MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
914 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
915 MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
916 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
917 MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
918 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f);
919 MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
920 mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0);
921
922 if (!mlirAttributeIsADenseElements(splatBool) ||
923 !mlirDenseElementsAttrIsSplat(splatBool) ||
924 !mlirAttributeIsADenseElements(splatUInt32) ||
925 !mlirDenseElementsAttrIsSplat(splatUInt32) ||
926 !mlirAttributeIsADenseElements(splatInt32) ||
927 !mlirDenseElementsAttrIsSplat(splatInt32) ||
928 !mlirAttributeIsADenseElements(splatUInt64) ||
929 !mlirDenseElementsAttrIsSplat(splatUInt64) ||
930 !mlirAttributeIsADenseElements(splatInt64) ||
931 !mlirDenseElementsAttrIsSplat(splatInt64) ||
932 !mlirAttributeIsADenseElements(splatFloat) ||
933 !mlirDenseElementsAttrIsSplat(splatFloat) ||
934 !mlirAttributeIsADenseElements(splatDouble) ||
935 !mlirDenseElementsAttrIsSplat(splatDouble))
936 return 16;
937
938 if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
939 mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
940 mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
941 mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
942 mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
943 fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
944 1E-6f ||
945 fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
946 return 17;
947
948 uint32_t *uint32RawData =
949 (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
950 int32_t *int32RawData =
951 (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
952 uint64_t *uint64RawData =
953 (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
954 int64_t *int64RawData =
955 (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
956 float *floatRawData =
957 (float *)mlirDenseElementsAttrGetRawData(floatElements);
958 double *doubleRawData =
959 (double *)mlirDenseElementsAttrGetRawData(doubleElements);
960 if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
961 int32RawData[0] != 0 || int32RawData[1] != 1 ||
962 uint64RawData[0] != 0u || uint64RawData[1] != 1u ||
963 int64RawData[0] != 0 || int64RawData[1] != 1 ||
964 floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
965 doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
966 return 18;
967
968 mlirAttributeDump(splatBool);
969 mlirAttributeDump(splatUInt32);
970 mlirAttributeDump(splatInt32);
971 mlirAttributeDump(splatUInt64);
972 mlirAttributeDump(splatInt64);
973 mlirAttributeDump(splatFloat);
974 mlirAttributeDump(splatDouble);
975 // CHECK: dense<true> : tensor<1x2xi1>
976 // CHECK: dense<1> : tensor<1x2xi32>
977 // CHECK: dense<1> : tensor<1x2xi32>
978 // CHECK: dense<1> : tensor<1x2xi64>
979 // CHECK: dense<1> : tensor<1x2xi64>
980 // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
981 // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
982
983 mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
984 mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
985 // CHECK: 1.000000e+00 : f32
986 // CHECK: 1.000000e+00 : f64
987
988 int64_t indices[] = {4, 7};
989 int64_t two = 2;
990 MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
991 mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2,
992 indices);
993 MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
994 mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats);
995 MlirAttribute sparseAttr = mlirSparseElementsAttribute(
996 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr,
997 valuesAttr);
998 mlirAttributeDump(sparseAttr);
999 // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
1000
1001 return 0;
1002 }
1003
printAffineMap(MlirContext ctx)1004 int printAffineMap(MlirContext ctx) {
1005 MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
1006 MlirAffineMap affineMap = mlirAffineMapGet(ctx, 3, 2);
1007 MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
1008 MlirAffineMap multiDimIdentityAffineMap =
1009 mlirAffineMapMultiDimIdentityGet(ctx, 3);
1010 MlirAffineMap minorIdentityAffineMap =
1011 mlirAffineMapMinorIdentityGet(ctx, 3, 2);
1012 unsigned permutation[] = {1, 2, 0};
1013 MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
1014 ctx, sizeof(permutation) / sizeof(unsigned), permutation);
1015
1016 fprintf(stderr, "@affineMap\n");
1017 mlirAffineMapDump(emptyAffineMap);
1018 mlirAffineMapDump(affineMap);
1019 mlirAffineMapDump(constAffineMap);
1020 mlirAffineMapDump(multiDimIdentityAffineMap);
1021 mlirAffineMapDump(minorIdentityAffineMap);
1022 mlirAffineMapDump(permutationAffineMap);
1023 // CHECK-LABEL: @affineMap
1024 // CHECK: () -> ()
1025 // CHECK: (d0, d1, d2)[s0, s1] -> ()
1026 // CHECK: () -> (2)
1027 // CHECK: (d0, d1, d2) -> (d0, d1, d2)
1028 // CHECK: (d0, d1, d2) -> (d1, d2)
1029 // CHECK: (d0, d1, d2) -> (d1, d2, d0)
1030
1031 if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
1032 mlirAffineMapIsIdentity(affineMap) ||
1033 mlirAffineMapIsIdentity(constAffineMap) ||
1034 !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
1035 mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
1036 mlirAffineMapIsIdentity(permutationAffineMap))
1037 return 1;
1038
1039 if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
1040 mlirAffineMapIsMinorIdentity(affineMap) ||
1041 !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
1042 !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
1043 mlirAffineMapIsMinorIdentity(permutationAffineMap))
1044 return 2;
1045
1046 if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
1047 mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
1048 mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
1049 mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
1050 mlirAffineMapIsEmpty(permutationAffineMap))
1051 return 3;
1052
1053 if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
1054 mlirAffineMapIsSingleConstant(affineMap) ||
1055 !mlirAffineMapIsSingleConstant(constAffineMap) ||
1056 mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
1057 mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
1058 mlirAffineMapIsSingleConstant(permutationAffineMap))
1059 return 4;
1060
1061 if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
1062 return 5;
1063
1064 if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
1065 mlirAffineMapGetNumDims(affineMap) != 3 ||
1066 mlirAffineMapGetNumDims(constAffineMap) != 0 ||
1067 mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
1068 mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
1069 mlirAffineMapGetNumDims(permutationAffineMap) != 3)
1070 return 6;
1071
1072 if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
1073 mlirAffineMapGetNumSymbols(affineMap) != 2 ||
1074 mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
1075 mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
1076 mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
1077 mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
1078 return 7;
1079
1080 if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
1081 mlirAffineMapGetNumResults(affineMap) != 0 ||
1082 mlirAffineMapGetNumResults(constAffineMap) != 1 ||
1083 mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
1084 mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
1085 mlirAffineMapGetNumResults(permutationAffineMap) != 3)
1086 return 8;
1087
1088 if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
1089 mlirAffineMapGetNumInputs(affineMap) != 5 ||
1090 mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
1091 mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
1092 mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
1093 mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
1094 return 9;
1095
1096 if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
1097 !mlirAffineMapIsPermutation(emptyAffineMap) ||
1098 mlirAffineMapIsProjectedPermutation(affineMap) ||
1099 mlirAffineMapIsPermutation(affineMap) ||
1100 mlirAffineMapIsProjectedPermutation(constAffineMap) ||
1101 mlirAffineMapIsPermutation(constAffineMap) ||
1102 !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
1103 !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
1104 !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
1105 mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
1106 !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
1107 !mlirAffineMapIsPermutation(permutationAffineMap))
1108 return 10;
1109
1110 intptr_t sub[] = {1};
1111
1112 MlirAffineMap subMap = mlirAffineMapGetSubMap(
1113 multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
1114 MlirAffineMap majorSubMap =
1115 mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
1116 MlirAffineMap minorSubMap =
1117 mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);
1118
1119 mlirAffineMapDump(subMap);
1120 mlirAffineMapDump(majorSubMap);
1121 mlirAffineMapDump(minorSubMap);
1122 // CHECK: (d0, d1, d2) -> (d1)
1123 // CHECK: (d0, d1, d2) -> (d0)
1124 // CHECK: (d0, d1, d2) -> (d2)
1125
1126 return 0;
1127 }
1128
printAffineExpr(MlirContext ctx)1129 int printAffineExpr(MlirContext ctx) {
1130 MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
1131 MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
1132 MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
1133 MlirAffineExpr affineAddExpr =
1134 mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
1135 MlirAffineExpr affineMulExpr =
1136 mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
1137 MlirAffineExpr affineModExpr =
1138 mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
1139 MlirAffineExpr affineFloorDivExpr =
1140 mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
1141 MlirAffineExpr affineCeilDivExpr =
1142 mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);
1143
1144 // Tests mlirAffineExprDump.
1145 fprintf(stderr, "@affineExpr\n");
1146 mlirAffineExprDump(affineDimExpr);
1147 mlirAffineExprDump(affineSymbolExpr);
1148 mlirAffineExprDump(affineConstantExpr);
1149 mlirAffineExprDump(affineAddExpr);
1150 mlirAffineExprDump(affineMulExpr);
1151 mlirAffineExprDump(affineModExpr);
1152 mlirAffineExprDump(affineFloorDivExpr);
1153 mlirAffineExprDump(affineCeilDivExpr);
1154 // CHECK-LABEL: @affineExpr
1155 // CHECK: d5
1156 // CHECK: s5
1157 // CHECK: 5
1158 // CHECK: d5 + s5
1159 // CHECK: d5 * s5
1160 // CHECK: d5 mod s5
1161 // CHECK: d5 floordiv s5
1162 // CHECK: d5 ceildiv s5
1163
1164 // Tests methods of affine binary operation expression, takes add expression
1165 // as an example.
1166 mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
1167 mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
1168 // CHECK: d5
1169 // CHECK: s5
1170
1171 // Tests methods of affine dimension expression.
1172 if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
1173 return 1;
1174
1175 // Tests methods of affine symbol expression.
1176 if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
1177 return 2;
1178
1179 // Tests methods of affine constant expression.
1180 if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
1181 return 3;
1182
1183 // Tests methods of affine expression.
1184 if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
1185 !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
1186 !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
1187 mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
1188 mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
1189 mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
1190 mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
1191 mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
1192 return 4;
1193
1194 if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
1195 !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
1196 !mlirAffineExprIsPureAffine(affineConstantExpr) ||
1197 !mlirAffineExprIsPureAffine(affineAddExpr) ||
1198 mlirAffineExprIsPureAffine(affineMulExpr) ||
1199 mlirAffineExprIsPureAffine(affineModExpr) ||
1200 mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
1201 mlirAffineExprIsPureAffine(affineCeilDivExpr))
1202 return 5;
1203
1204 if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
1205 mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
1206 mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
1207 mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
1208 mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
1209 mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
1210 mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
1211 mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
1212 return 6;
1213
1214 if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
1215 !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
1216 !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
1217 !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
1218 !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
1219 !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
1220 !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
1221 !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
1222 return 7;
1223
1224 if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
1225 mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
1226 mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
1227 !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
1228 !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
1229 !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
1230 !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
1231 !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
1232 return 8;
1233
1234 // Tests 'IsA' methods of affine binary operation expression.
1235 if (!mlirAffineExprIsAAdd(affineAddExpr))
1236 return 9;
1237
1238 if (!mlirAffineExprIsAMul(affineMulExpr))
1239 return 10;
1240
1241 if (!mlirAffineExprIsAMod(affineModExpr))
1242 return 11;
1243
1244 if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
1245 return 12;
1246
1247 if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
1248 return 13;
1249
1250 return 0;
1251 }
1252
registerOnlyStd()1253 int registerOnlyStd() {
1254 MlirContext ctx = mlirContextCreate();
1255 // The built-in dialect is always loaded.
1256 if (mlirContextGetNumLoadedDialects(ctx) != 1)
1257 return 1;
1258
1259 MlirDialect std =
1260 mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
1261 if (!mlirDialectIsNull(std))
1262 return 2;
1263
1264 mlirContextRegisterStandardDialect(ctx);
1265 if (mlirContextGetNumRegisteredDialects(ctx) != 1)
1266 return 3;
1267 if (mlirContextGetNumLoadedDialects(ctx) != 1)
1268 return 4;
1269
1270 std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
1271 if (mlirDialectIsNull(std))
1272 return 5;
1273 if (mlirContextGetNumLoadedDialects(ctx) != 2)
1274 return 6;
1275
1276 MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx);
1277 if (!mlirDialectEqual(std, alsoStd))
1278 return 7;
1279
1280 MlirStringRef stdNs = mlirDialectGetNamespace(std);
1281 MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace();
1282 if (stdNs.length != alsoStdNs.length ||
1283 strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
1284 return 8;
1285
1286 fprintf(stderr, "@registration\n");
1287 // CHECK-LABEL: @registration
1288
1289 return 0;
1290 }
1291
1292 // Wraps a diagnostic into additional text we can match against.
errorHandler(MlirDiagnostic diagnostic,void * userData)1293 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
1294 fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData);
1295 mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
1296 fprintf(stderr, "\n");
1297 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1298 mlirLocationPrint(loc, printToStderr, NULL);
1299 assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
1300 fprintf(stderr, ">> end of diagnostic (userData: %ld)\n", (long)userData);
1301 return mlirLogicalResultSuccess();
1302 }
1303
1304 // Logs when the delete user data callback is called
deleteUserData(void * userData)1305 static void deleteUserData(void *userData) {
1306 fprintf(stderr, "deleting user data (userData: %ld)\n", (long)userData);
1307 }
1308
testDiagnostics()1309 void testDiagnostics() {
1310 MlirContext ctx = mlirContextCreate();
1311 MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
1312 ctx, errorHandler, (void *)42, deleteUserData);
1313 MlirLocation loc = mlirLocationUnknownGet(ctx);
1314 fprintf(stderr, "@test_diagnostics\n");
1315 mlirEmitError(loc, "test diagnostics");
1316 mlirContextDetachDiagnosticHandler(ctx, id);
1317 mlirEmitError(loc, "more test diagnostics");
1318 // CHECK-LABEL: @test_diagnostics
1319 // CHECK: processing diagnostic (userData: 42) <<
1320 // CHECK: test diagnostics
1321 // CHECK: loc(unknown)
1322 // CHECK: >> end of diagnostic (userData: 42)
1323 // CHECK: deleting user data (userData: 42)
1324 // CHECK-NOT: processing diagnostic
1325 // CHECK: more test diagnostics
1326 }
1327
main()1328 int main() {
1329 MlirContext ctx = mlirContextCreate();
1330 mlirRegisterAllDialects(ctx);
1331 if (constructAndTraverseIr(ctx))
1332 return 1;
1333 buildWithInsertionsAndPrint(ctx);
1334
1335 if (printBuiltinTypes(ctx))
1336 return 2;
1337 if (printBuiltinAttributes(ctx))
1338 return 3;
1339 if (printAffineMap(ctx))
1340 return 4;
1341 if (printAffineExpr(ctx))
1342 return 5;
1343 if (registerOnlyStd())
1344 return 6;
1345
1346 mlirContextDestroy(ctx);
1347
1348 testDiagnostics();
1349 return 0;
1350 }
1351