• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ir.c - Simple test of C APIs ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/IR.h"
14 #include "mlir-c/AffineExpr.h"
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/Registration.h"
20 #include "mlir-c/StandardDialect.h"
21 
22 #include <assert.h>
23 #include <math.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27 
populateLoopBody(MlirContext ctx,MlirBlock loopBody,MlirLocation location,MlirBlock funcBody)28 void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
29                       MlirLocation location, MlirBlock funcBody) {
30   MlirValue iv = mlirBlockGetArgument(loopBody, 0);
31   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
32   MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
33   MlirType f32Type =
34       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
35 
36   MlirOperationState loadLHSState = mlirOperationStateGet(
37       mlirStringRefCreateFromCString("std.load"), location);
38   MlirValue loadLHSOperands[] = {funcArg0, iv};
39   mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
40   mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
41   MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
42   mlirBlockAppendOwnedOperation(loopBody, loadLHS);
43 
44   MlirOperationState loadRHSState = mlirOperationStateGet(
45       mlirStringRefCreateFromCString("std.load"), location);
46   MlirValue loadRHSOperands[] = {funcArg1, iv};
47   mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
48   mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
49   MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
50   mlirBlockAppendOwnedOperation(loopBody, loadRHS);
51 
52   MlirOperationState addState = mlirOperationStateGet(
53       mlirStringRefCreateFromCString("std.addf"), location);
54   MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
55                              mlirOperationGetResult(loadRHS, 0)};
56   mlirOperationStateAddOperands(&addState, 2, addOperands);
57   mlirOperationStateAddResults(&addState, 1, &f32Type);
58   MlirOperation add = mlirOperationCreate(&addState);
59   mlirBlockAppendOwnedOperation(loopBody, add);
60 
61   MlirOperationState storeState = mlirOperationStateGet(
62       mlirStringRefCreateFromCString("std.store"), location);
63   MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
64   mlirOperationStateAddOperands(&storeState, 3, storeOperands);
65   MlirOperation store = mlirOperationCreate(&storeState);
66   mlirBlockAppendOwnedOperation(loopBody, store);
67 
68   MlirOperationState yieldState = mlirOperationStateGet(
69       mlirStringRefCreateFromCString("scf.yield"), location);
70   MlirOperation yield = mlirOperationCreate(&yieldState);
71   mlirBlockAppendOwnedOperation(loopBody, yield);
72 }
73 
makeAndDumpAdd(MlirContext ctx,MlirLocation location)74 MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
75   MlirModule moduleOp = mlirModuleCreateEmpty(location);
76   MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
77 
78   MlirType memrefType =
79       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
80   MlirType funcBodyArgTypes[] = {memrefType, memrefType};
81   MlirRegion funcBodyRegion = mlirRegionCreate();
82   MlirBlock funcBody = mlirBlockCreate(
83       sizeof(funcBodyArgTypes) / sizeof(MlirType), funcBodyArgTypes);
84   mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
85 
86   MlirAttribute funcTypeAttr = mlirAttributeParseGet(
87       ctx,
88       mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
89   MlirAttribute funcNameAttr =
90       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
91   MlirNamedAttribute funcAttrs[] = {
92       mlirNamedAttributeGet(mlirStringRefCreateFromCString("type"),
93                             funcTypeAttr),
94       mlirNamedAttributeGet(mlirStringRefCreateFromCString("sym_name"),
95                             funcNameAttr)};
96   MlirOperationState funcState =
97       mlirOperationStateGet(mlirStringRefCreateFromCString("func"), location);
98   mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
99   mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
100   MlirOperation func = mlirOperationCreate(&funcState);
101   mlirBlockInsertOwnedOperation(moduleBody, 0, func);
102 
103   MlirType indexType =
104       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
105   MlirAttribute indexZeroLiteral =
106       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
107   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
108       mlirStringRefCreateFromCString("value"), indexZeroLiteral);
109   MlirOperationState constZeroState = mlirOperationStateGet(
110       mlirStringRefCreateFromCString("std.constant"), location);
111   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
112   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
113   MlirOperation constZero = mlirOperationCreate(&constZeroState);
114   mlirBlockAppendOwnedOperation(funcBody, constZero);
115 
116   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
117   MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
118   MlirValue dimOperands[] = {funcArg0, constZeroValue};
119   MlirOperationState dimState = mlirOperationStateGet(
120       mlirStringRefCreateFromCString("std.dim"), location);
121   mlirOperationStateAddOperands(&dimState, 2, dimOperands);
122   mlirOperationStateAddResults(&dimState, 1, &indexType);
123   MlirOperation dim = mlirOperationCreate(&dimState);
124   mlirBlockAppendOwnedOperation(funcBody, dim);
125 
126   MlirRegion loopBodyRegion = mlirRegionCreate();
127   MlirBlock loopBody = mlirBlockCreate(/*nArgs=*/1, &indexType);
128   mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
129 
130   MlirAttribute indexOneLiteral =
131       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
132   MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
133       mlirStringRefCreateFromCString("value"), indexOneLiteral);
134   MlirOperationState constOneState = mlirOperationStateGet(
135       mlirStringRefCreateFromCString("std.constant"), location);
136   mlirOperationStateAddResults(&constOneState, 1, &indexType);
137   mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
138   MlirOperation constOne = mlirOperationCreate(&constOneState);
139   mlirBlockAppendOwnedOperation(funcBody, constOne);
140 
141   MlirValue dimValue = mlirOperationGetResult(dim, 0);
142   MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
143   MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
144   MlirOperationState loopState = mlirOperationStateGet(
145       mlirStringRefCreateFromCString("scf.for"), location);
146   mlirOperationStateAddOperands(&loopState, 3, loopOperands);
147   mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
148   MlirOperation loop = mlirOperationCreate(&loopState);
149   mlirBlockAppendOwnedOperation(funcBody, loop);
150 
151   populateLoopBody(ctx, loopBody, location, funcBody);
152 
153   MlirOperationState retState = mlirOperationStateGet(
154       mlirStringRefCreateFromCString("std.return"), location);
155   MlirOperation ret = mlirOperationCreate(&retState);
156   mlirBlockAppendOwnedOperation(funcBody, ret);
157 
158   MlirOperation module = mlirModuleGetOperation(moduleOp);
159   mlirOperationDump(module);
160   // clang-format off
161   // CHECK: module {
162   // CHECK:   func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
163   // CHECK:     %[[C0:.*]] = constant 0 : index
164   // CHECK:     %[[DIM:.*]] = dim %[[ARG0]], %[[C0]] : memref<?xf32>
165   // CHECK:     %[[C1:.*]] = constant 1 : index
166   // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
167   // CHECK:       %[[LHS:.*]] = load %[[ARG0]][%[[I]]] : memref<?xf32>
168   // CHECK:       %[[RHS:.*]] = load %[[ARG1]][%[[I]]] : memref<?xf32>
169   // CHECK:       %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
170   // CHECK:       store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
171   // CHECK:     }
172   // CHECK:     return
173   // CHECK:   }
174   // CHECK: }
175   // clang-format on
176 
177   return moduleOp;
178 }
179 
180 struct OpListNode {
181   MlirOperation op;
182   struct OpListNode *next;
183 };
184 typedef struct OpListNode OpListNode;
185 
186 struct ModuleStats {
187   unsigned numOperations;
188   unsigned numAttributes;
189   unsigned numBlocks;
190   unsigned numRegions;
191   unsigned numValues;
192   unsigned numBlockArguments;
193   unsigned numOpResults;
194 };
195 typedef struct ModuleStats ModuleStats;
196 
collectStatsSingle(OpListNode * head,ModuleStats * stats)197 int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
198   MlirOperation operation = head->op;
199   stats->numOperations += 1;
200   stats->numValues += mlirOperationGetNumResults(operation);
201   stats->numAttributes += mlirOperationGetNumAttributes(operation);
202 
203   unsigned numRegions = mlirOperationGetNumRegions(operation);
204 
205   stats->numRegions += numRegions;
206 
207   intptr_t numResults = mlirOperationGetNumResults(operation);
208   for (intptr_t i = 0; i < numResults; ++i) {
209     MlirValue result = mlirOperationGetResult(operation, i);
210     if (!mlirValueIsAOpResult(result))
211       return 1;
212     if (mlirValueIsABlockArgument(result))
213       return 2;
214     if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
215       return 3;
216     if (i != mlirOpResultGetResultNumber(result))
217       return 4;
218     ++stats->numOpResults;
219   }
220 
221   for (unsigned i = 0; i < numRegions; ++i) {
222     MlirRegion region = mlirOperationGetRegion(operation, i);
223     for (MlirBlock block = mlirRegionGetFirstBlock(region);
224          !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
225       ++stats->numBlocks;
226       intptr_t numArgs = mlirBlockGetNumArguments(block);
227       stats->numValues += numArgs;
228       for (intptr_t j = 0; j < numArgs; ++j) {
229         MlirValue arg = mlirBlockGetArgument(block, j);
230         if (!mlirValueIsABlockArgument(arg))
231           return 5;
232         if (mlirValueIsAOpResult(arg))
233           return 6;
234         if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
235           return 7;
236         if (j != mlirBlockArgumentGetArgNumber(arg))
237           return 8;
238         ++stats->numBlockArguments;
239       }
240 
241       for (MlirOperation child = mlirBlockGetFirstOperation(block);
242            !mlirOperationIsNull(child);
243            child = mlirOperationGetNextInBlock(child)) {
244         OpListNode *node = malloc(sizeof(OpListNode));
245         node->op = child;
246         node->next = head->next;
247         head->next = node;
248       }
249     }
250   }
251   return 0;
252 }
253 
collectStats(MlirOperation operation)254 int collectStats(MlirOperation operation) {
255   OpListNode *head = malloc(sizeof(OpListNode));
256   head->op = operation;
257   head->next = NULL;
258 
259   ModuleStats stats;
260   stats.numOperations = 0;
261   stats.numAttributes = 0;
262   stats.numBlocks = 0;
263   stats.numRegions = 0;
264   stats.numValues = 0;
265   stats.numBlockArguments = 0;
266   stats.numOpResults = 0;
267 
268   do {
269     int retval = collectStatsSingle(head, &stats);
270     if (retval)
271       return retval;
272     OpListNode *next = head->next;
273     free(head);
274     head = next;
275   } while (head);
276 
277   if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
278     return 100;
279 
280   fprintf(stderr, "@stats\n");
281   fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
282   fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
283   fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
284   fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
285   fprintf(stderr, "Number of values: %u\n", stats.numValues);
286   fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
287   fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
288   // clang-format off
289   // CHECK-LABEL: @stats
290   // CHECK: Number of operations: 13
291   // CHECK: Number of attributes: 4
292   // CHECK: Number of blocks: 3
293   // CHECK: Number of regions: 3
294   // CHECK: Number of values: 9
295   // CHECK: Number of block arguments: 3
296   // CHECK: Number of op results: 6
297   // clang-format on
298   return 0;
299 }
300 
printToStderr(MlirStringRef str,void * userData)301 static void printToStderr(MlirStringRef str, void *userData) {
302   (void)userData;
303   fwrite(str.data, 1, str.length, stderr);
304 }
305 
printFirstOfEach(MlirContext ctx,MlirOperation operation)306 static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
307   // Assuming we are given a module, go to the first operation of the first
308   // function.
309   MlirRegion region = mlirOperationGetRegion(operation, 0);
310   MlirBlock block = mlirRegionGetFirstBlock(region);
311   operation = mlirBlockGetFirstOperation(block);
312   region = mlirOperationGetRegion(operation, 0);
313   MlirOperation parentOperation = operation;
314   block = mlirRegionGetFirstBlock(region);
315   operation = mlirBlockGetFirstOperation(block);
316 
317   // Verify that parent operation and block report correctly.
318   fprintf(stderr, "Parent operation eq: %d\n",
319           mlirOperationEqual(mlirOperationGetParentOperation(operation),
320                              parentOperation));
321   fprintf(stderr, "Block eq: %d\n",
322           mlirBlockEqual(mlirOperationGetBlock(operation), block));
323   // CHECK: Parent operation eq: 1
324   // CHECK: Block eq: 1
325 
326   // In the module we created, the first operation of the first function is
327   // an "std.dim", which has an attribute and a single result that we can
328   // use to test the printing mechanism.
329   mlirBlockPrint(block, printToStderr, NULL);
330   fprintf(stderr, "\n");
331   fprintf(stderr, "First operation: ");
332   mlirOperationPrint(operation, printToStderr, NULL);
333   fprintf(stderr, "\n");
334   // clang-format off
335   // CHECK:   %[[C0:.*]] = constant 0 : index
336   // CHECK:   %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
337   // CHECK:   %[[C1:.*]] = constant 1 : index
338   // CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
339   // CHECK:     %[[LHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
340   // CHECK:     %[[RHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
341   // CHECK:     %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
342   // CHECK:     store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
343   // CHECK:   }
344   // CHECK: return
345   // CHECK: First operation: {{.*}} = constant 0 : index
346   // clang-format on
347 
348   // Get the operation name and print it.
349   MlirIdentifier ident = mlirOperationGetName(operation);
350   MlirStringRef identStr = mlirIdentifierStr(ident);
351   fprintf(stderr, "Operation name: '");
352   for (size_t i = 0; i < identStr.length; ++i)
353     fputc(identStr.data[i], stderr);
354   fprintf(stderr, "'\n");
355   // CHECK: Operation name: 'std.constant'
356 
357   // Get the identifier again and verify equal.
358   MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
359   fprintf(stderr, "Identifier equal: %d\n",
360           mlirIdentifierEqual(ident, identAgain));
361   // CHECK: Identifier equal: 1
362 
363   // Get the block terminator and print it.
364   MlirOperation terminator = mlirBlockGetTerminator(block);
365   fprintf(stderr, "Terminator: ");
366   mlirOperationPrint(terminator, printToStderr, NULL);
367   fprintf(stderr, "\n");
368   // CHECK: Terminator: return
369 
370   // Get the attribute by index.
371   MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
372   fprintf(stderr, "Get attr 0: ");
373   mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
374   fprintf(stderr, "\n");
375   // CHECK: Get attr 0: 0 : index
376 
377   // Now re-get the attribute by name.
378   MlirAttribute attr0ByName =
379       mlirOperationGetAttributeByName(operation, namedAttr0.name);
380   fprintf(stderr, "Get attr 0 by name: ");
381   mlirAttributePrint(attr0ByName, printToStderr, NULL);
382   fprintf(stderr, "\n");
383   // CHECK: Get attr 0 by name: 0 : index
384 
385   // Get a non-existing attribute and assert that it is null (sanity).
386   fprintf(stderr, "does_not_exist is null: %d\n",
387           mlirAttributeIsNull(mlirOperationGetAttributeByName(
388               operation, mlirStringRefCreateFromCString("does_not_exist"))));
389   // CHECK: does_not_exist is null: 1
390 
391   // Get result 0 and its type.
392   MlirValue value = mlirOperationGetResult(operation, 0);
393   fprintf(stderr, "Result 0: ");
394   mlirValuePrint(value, printToStderr, NULL);
395   fprintf(stderr, "\n");
396   fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
397   // CHECK: Result 0: {{.*}} = constant 0 : index
398   // CHECK: Value is null: 0
399 
400   MlirType type = mlirValueGetType(value);
401   fprintf(stderr, "Result 0 type: ");
402   mlirTypePrint(type, printToStderr, NULL);
403   fprintf(stderr, "\n");
404   // CHECK: Result 0 type: index
405 
406   // Set a custom attribute.
407   mlirOperationSetAttributeByName(operation,
408                                   mlirStringRefCreateFromCString("custom_attr"),
409                                   mlirBoolAttrGet(ctx, 1));
410   fprintf(stderr, "Op with set attr: ");
411   mlirOperationPrint(operation, printToStderr, NULL);
412   fprintf(stderr, "\n");
413   // CHECK: Op with set attr: {{.*}} {custom_attr = true}
414 
415   // Remove the attribute.
416   fprintf(stderr, "Remove attr: %d\n",
417           mlirOperationRemoveAttributeByName(
418               operation, mlirStringRefCreateFromCString("custom_attr")));
419   fprintf(stderr, "Remove attr again: %d\n",
420           mlirOperationRemoveAttributeByName(
421               operation, mlirStringRefCreateFromCString("custom_attr")));
422   fprintf(stderr, "Removed attr is null: %d\n",
423           mlirAttributeIsNull(mlirOperationGetAttributeByName(
424               operation, mlirStringRefCreateFromCString("custom_attr"))));
425   // CHECK: Remove attr: 1
426   // CHECK: Remove attr again: 0
427   // CHECK: Removed attr is null: 1
428 
429   // Add a large attribute to verify printing flags.
430   int64_t eltsShape[] = {4};
431   int32_t eltsData[] = {1, 2, 3, 4};
432   mlirOperationSetAttributeByName(
433       operation, mlirStringRefCreateFromCString("elts"),
434       mlirDenseElementsAttrInt32Get(
435           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4,
436           eltsData));
437   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
438   mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
439   mlirOpPrintingFlagsPrintGenericOpForm(flags);
440   mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
441   mlirOpPrintingFlagsUseLocalScope(flags);
442   fprintf(stderr, "Op print with all flags: ");
443   mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
444   fprintf(stderr, "\n");
445   // clang-format off
446   // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
447   // clang-format on
448 
449   mlirOpPrintingFlagsDestroy(flags);
450 }
451 
constructAndTraverseIr(MlirContext ctx)452 static int constructAndTraverseIr(MlirContext ctx) {
453   MlirLocation location = mlirLocationUnknownGet(ctx);
454 
455   MlirModule moduleOp = makeAndDumpAdd(ctx, location);
456   MlirOperation module = mlirModuleGetOperation(moduleOp);
457 
458   int errcode = collectStats(module);
459   if (errcode)
460     return errcode;
461 
462   printFirstOfEach(ctx, module);
463 
464   mlirModuleDestroy(moduleOp);
465   return 0;
466 }
467 
468 /// Creates an operation with a region containing multiple blocks with
469 /// operations and dumps it. The blocks and operations are inserted using
470 /// block/operation-relative API and their final order is checked.
buildWithInsertionsAndPrint(MlirContext ctx)471 static void buildWithInsertionsAndPrint(MlirContext ctx) {
472   MlirLocation loc = mlirLocationUnknownGet(ctx);
473 
474   MlirRegion owningRegion = mlirRegionCreate();
475   MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
476   MlirOperationState state = mlirOperationStateGet(
477       mlirStringRefCreateFromCString("insertion.order.test"), loc);
478   mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
479   MlirOperation op = mlirOperationCreate(&state);
480   MlirRegion region = mlirOperationGetRegion(op, 0);
481 
482   // Use integer types of different bitwidth as block arguments in order to
483   // differentiate blocks.
484   MlirType i1 = mlirIntegerTypeGet(ctx, 1);
485   MlirType i2 = mlirIntegerTypeGet(ctx, 2);
486   MlirType i3 = mlirIntegerTypeGet(ctx, 3);
487   MlirType i4 = mlirIntegerTypeGet(ctx, 4);
488   MlirBlock block1 = mlirBlockCreate(1, &i1);
489   MlirBlock block2 = mlirBlockCreate(1, &i2);
490   MlirBlock block3 = mlirBlockCreate(1, &i3);
491   MlirBlock block4 = mlirBlockCreate(1, &i4);
492   // Insert blocks so as to obtain the 1-2-3-4 order,
493   mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
494   mlirRegionInsertOwnedBlockBefore(region, block3, block2);
495   mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
496   mlirRegionInsertOwnedBlockAfter(region, block3, block4);
497 
498   MlirOperationState op1State =
499       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
500   MlirOperationState op2State =
501       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
502   MlirOperationState op3State =
503       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
504   MlirOperationState op4State =
505       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
506   MlirOperationState op5State =
507       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
508   MlirOperationState op6State =
509       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
510   MlirOperationState op7State =
511       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
512   MlirOperation op1 = mlirOperationCreate(&op1State);
513   MlirOperation op2 = mlirOperationCreate(&op2State);
514   MlirOperation op3 = mlirOperationCreate(&op3State);
515   MlirOperation op4 = mlirOperationCreate(&op4State);
516   MlirOperation op5 = mlirOperationCreate(&op5State);
517   MlirOperation op6 = mlirOperationCreate(&op6State);
518   MlirOperation op7 = mlirOperationCreate(&op7State);
519 
520   // Insert operations in the first block so as to obtain the 1-2-3-4 order.
521   MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
522   assert(mlirOperationIsNull(nullOperation));
523   mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
524   mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
525   mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
526   mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
527 
528   // Append operations to the rest of blocks to make them non-empty and thus
529   // printable.
530   mlirBlockAppendOwnedOperation(block2, op5);
531   mlirBlockAppendOwnedOperation(block3, op6);
532   mlirBlockAppendOwnedOperation(block4, op7);
533 
534   mlirOperationDump(op);
535   mlirOperationDestroy(op);
536   // clang-format off
537   // CHECK-LABEL:  "insertion.order.test"
538   // CHECK:      ^{{.*}}(%{{.*}}: i1
539   // CHECK:        "dummy.op1"
540   // CHECK-NEXT:   "dummy.op2"
541   // CHECK-NEXT:   "dummy.op3"
542   // CHECK-NEXT:   "dummy.op4"
543   // CHECK:      ^{{.*}}(%{{.*}}: i2
544   // CHECK:        "dummy.op5"
545   // CHECK:      ^{{.*}}(%{{.*}}: i3
546   // CHECK:        "dummy.op6"
547   // CHECK:      ^{{.*}}(%{{.*}}: i4
548   // CHECK:        "dummy.op7"
549   // clang-format on
550 }
551 
552 /// Dumps instances of all builtin types to check that C API works correctly.
553 /// Additionally, performs simple identity checks that a builtin type
554 /// constructed with C API can be inspected and has the expected type. The
555 /// latter achieves full coverage of C API for builtin types. Returns 0 on
556 /// success and a non-zero error code on failure.
printBuiltinTypes(MlirContext ctx)557 static int printBuiltinTypes(MlirContext ctx) {
558   // Integer types.
559   MlirType i32 = mlirIntegerTypeGet(ctx, 32);
560   MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
561   MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
562   if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
563     return 1;
564   if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
565     return 2;
566   if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
567     return 3;
568   if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
569     return 4;
570   if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
571     return 5;
572   fprintf(stderr, "@types\n");
573   mlirTypeDump(i32);
574   fprintf(stderr, "\n");
575   mlirTypeDump(si32);
576   fprintf(stderr, "\n");
577   mlirTypeDump(ui32);
578   fprintf(stderr, "\n");
579   // CHECK-LABEL: @types
580   // CHECK: i32
581   // CHECK: si32
582   // CHECK: ui32
583 
584   // Index type.
585   MlirType index = mlirIndexTypeGet(ctx);
586   if (!mlirTypeIsAIndex(index))
587     return 6;
588   mlirTypeDump(index);
589   fprintf(stderr, "\n");
590   // CHECK: index
591 
592   // Floating-point types.
593   MlirType bf16 = mlirBF16TypeGet(ctx);
594   MlirType f16 = mlirF16TypeGet(ctx);
595   MlirType f32 = mlirF32TypeGet(ctx);
596   MlirType f64 = mlirF64TypeGet(ctx);
597   if (!mlirTypeIsABF16(bf16))
598     return 7;
599   if (!mlirTypeIsAF16(f16))
600     return 9;
601   if (!mlirTypeIsAF32(f32))
602     return 10;
603   if (!mlirTypeIsAF64(f64))
604     return 11;
605   mlirTypeDump(bf16);
606   fprintf(stderr, "\n");
607   mlirTypeDump(f16);
608   fprintf(stderr, "\n");
609   mlirTypeDump(f32);
610   fprintf(stderr, "\n");
611   mlirTypeDump(f64);
612   fprintf(stderr, "\n");
613   // CHECK: bf16
614   // CHECK: f16
615   // CHECK: f32
616   // CHECK: f64
617 
618   // None type.
619   MlirType none = mlirNoneTypeGet(ctx);
620   if (!mlirTypeIsANone(none))
621     return 12;
622   mlirTypeDump(none);
623   fprintf(stderr, "\n");
624   // CHECK: none
625 
626   // Complex type.
627   MlirType cplx = mlirComplexTypeGet(f32);
628   if (!mlirTypeIsAComplex(cplx) ||
629       !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
630     return 13;
631   mlirTypeDump(cplx);
632   fprintf(stderr, "\n");
633   // CHECK: complex<f32>
634 
635   // Vector (and Shaped) type. ShapedType is a common base class for vectors,
636   // memrefs and tensors, one cannot create instances of this class so it is
637   // tested on an instance of vector type.
638   int64_t shape[] = {2, 3};
639   MlirType vector =
640       mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
641   if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
642     return 14;
643   if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
644       !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
645       mlirShapedTypeGetDimSize(vector, 0) != 2 ||
646       mlirShapedTypeIsDynamicDim(vector, 0) ||
647       mlirShapedTypeGetDimSize(vector, 1) != 3 ||
648       !mlirShapedTypeHasStaticShape(vector))
649     return 15;
650   mlirTypeDump(vector);
651   fprintf(stderr, "\n");
652   // CHECK: vector<2x3xf32>
653 
654   // Ranked tensor type.
655   MlirType rankedTensor =
656       mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
657   if (!mlirTypeIsATensor(rankedTensor) ||
658       !mlirTypeIsARankedTensor(rankedTensor))
659     return 16;
660   mlirTypeDump(rankedTensor);
661   fprintf(stderr, "\n");
662   // CHECK: tensor<2x3xf32>
663 
664   // Unranked tensor type.
665   MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
666   if (!mlirTypeIsATensor(unrankedTensor) ||
667       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
668       mlirShapedTypeHasRank(unrankedTensor))
669     return 17;
670   mlirTypeDump(unrankedTensor);
671   fprintf(stderr, "\n");
672   // CHECK: tensor<*xf32>
673 
674   // MemRef type.
675   MlirType memRef = mlirMemRefTypeContiguousGet(
676       f32, sizeof(shape) / sizeof(int64_t), shape, 2);
677   if (!mlirTypeIsAMemRef(memRef) ||
678       mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
679       mlirMemRefTypeGetMemorySpace(memRef) != 2)
680     return 18;
681   mlirTypeDump(memRef);
682   fprintf(stderr, "\n");
683   // CHECK: memref<2x3xf32, 2>
684 
685   // Unranked MemRef type.
686   MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4);
687   if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
688       mlirTypeIsAMemRef(unrankedMemRef) ||
689       mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4)
690     return 19;
691   mlirTypeDump(unrankedMemRef);
692   fprintf(stderr, "\n");
693   // CHECK: memref<*xf32, 4>
694 
695   // Tuple type.
696   MlirType types[] = {unrankedMemRef, f32};
697   MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
698   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
699       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
700       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
701     return 20;
702   mlirTypeDump(tuple);
703   fprintf(stderr, "\n");
704   // CHECK: tuple<memref<*xf32, 4>, f32>
705 
706   // Function type.
707   MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
708   MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
709                              mlirIntegerTypeGet(ctx, 32),
710                              mlirIntegerTypeGet(ctx, 64)};
711   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
712   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
713     return 21;
714   if (mlirFunctionTypeGetNumResults(funcType) != 3)
715     return 22;
716   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
717       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
718     return 23;
719   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
720       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
721       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
722     return 24;
723   mlirTypeDump(funcType);
724   fprintf(stderr, "\n");
725   // CHECK: (index, i1) -> (i16, i32, i64)
726 
727   return 0;
728 }
729 
callbackSetFixedLengthString(const char * data,intptr_t len,void * userData)730 void callbackSetFixedLengthString(const char *data, intptr_t len,
731                                   void *userData) {
732   strncpy(userData, data, len);
733 }
734 
stringIsEqual(const char * lhs,MlirStringRef rhs)735 bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
736   if (strlen(lhs) != rhs.length) {
737     return false;
738   }
739   return !strncmp(lhs, rhs.data, rhs.length);
740 }
741 
printBuiltinAttributes(MlirContext ctx)742 int printBuiltinAttributes(MlirContext ctx) {
743   MlirAttribute floating =
744       mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
745   if (!mlirAttributeIsAFloat(floating) ||
746       fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
747     return 1;
748   fprintf(stderr, "@attrs\n");
749   mlirAttributeDump(floating);
750   // CHECK-LABEL: @attrs
751   // CHECK: 2.000000e+00 : f64
752 
753   // Exercise mlirAttributeGetType() just for the first one.
754   MlirType floatingType = mlirAttributeGetType(floating);
755   mlirTypeDump(floatingType);
756   // CHECK: f64
757 
758   MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
759   if (!mlirAttributeIsAInteger(integer) ||
760       mlirIntegerAttrGetValueInt(integer) != 42)
761     return 2;
762   mlirAttributeDump(integer);
763   // CHECK: 42 : i32
764 
765   MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
766   if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
767     return 3;
768   mlirAttributeDump(boolean);
769   // CHECK: true
770 
771   const char data[] = "abcdefghijklmnopqestuvwxyz";
772   MlirAttribute opaque =
773       mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("std"), 3, data,
774                         mlirNoneTypeGet(ctx));
775   if (!mlirAttributeIsAOpaque(opaque) ||
776       !stringIsEqual("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
777     return 4;
778 
779   MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
780   if (opaqueData.length != 3 ||
781       strncmp(data, opaqueData.data, opaqueData.length))
782     return 5;
783   mlirAttributeDump(opaque);
784   // CHECK: #std.abc
785 
786   MlirAttribute string =
787       mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
788   if (!mlirAttributeIsAString(string))
789     return 6;
790 
791   MlirStringRef stringValue = mlirStringAttrGetValue(string);
792   if (stringValue.length != 2 ||
793       strncmp(data + 3, stringValue.data, stringValue.length))
794     return 7;
795   mlirAttributeDump(string);
796   // CHECK: "de"
797 
798   MlirAttribute flatSymbolRef =
799       mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
800   if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
801     return 8;
802 
803   MlirStringRef flatSymbolRefValue =
804       mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
805   if (flatSymbolRefValue.length != 3 ||
806       strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
807     return 9;
808   mlirAttributeDump(flatSymbolRef);
809   // CHECK: @fgh
810 
811   MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
812   MlirAttribute symbolRef =
813       mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
814   if (!mlirAttributeIsASymbolRef(symbolRef) ||
815       mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
816       !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
817                           flatSymbolRef) ||
818       !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
819                           flatSymbolRef))
820     return 10;
821 
822   MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
823   MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
824   if (symbolRefLeaf.length != 3 ||
825       strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
826       symbolRefRoot.length != 2 ||
827       strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
828     return 11;
829   mlirAttributeDump(symbolRef);
830   // CHECK: @ij::@fgh::@fgh
831 
832   MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
833   if (!mlirAttributeIsAType(type) ||
834       !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
835     return 12;
836   mlirAttributeDump(type);
837   // CHECK: f32
838 
839   MlirAttribute unit = mlirUnitAttrGet(ctx);
840   if (!mlirAttributeIsAUnit(unit))
841     return 13;
842   mlirAttributeDump(unit);
843   // CHECK: unit
844 
845   int64_t shape[] = {1, 2};
846 
847   int bools[] = {0, 1};
848   uint32_t uints32[] = {0u, 1u};
849   int32_t ints32[] = {0, 1};
850   uint64_t uints64[] = {0u, 1u};
851   int64_t ints64[] = {0, 1};
852   float floats[] = {0.0f, 1.0f};
853   double doubles[] = {0.0, 1.0};
854   MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
855       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools);
856   MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
857       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2,
858       uints32);
859   MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
860       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2,
861       ints32);
862   MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
863       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2,
864       uints64);
865   MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
866       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2,
867       ints64);
868   MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
869       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats);
870   MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
871       mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles);
872 
873   if (!mlirAttributeIsADenseElements(boolElements) ||
874       !mlirAttributeIsADenseElements(uint32Elements) ||
875       !mlirAttributeIsADenseElements(int32Elements) ||
876       !mlirAttributeIsADenseElements(uint64Elements) ||
877       !mlirAttributeIsADenseElements(int64Elements) ||
878       !mlirAttributeIsADenseElements(floatElements) ||
879       !mlirAttributeIsADenseElements(doubleElements))
880     return 14;
881 
882   if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
883       mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
884       mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
885       mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
886       mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
887       fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
888           1E-6f ||
889       fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
890     return 15;
891 
892   mlirAttributeDump(boolElements);
893   mlirAttributeDump(uint32Elements);
894   mlirAttributeDump(int32Elements);
895   mlirAttributeDump(uint64Elements);
896   mlirAttributeDump(int64Elements);
897   mlirAttributeDump(floatElements);
898   mlirAttributeDump(doubleElements);
899   // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
900   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
901   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
902   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
903   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
904   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
905   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
906 
907   MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
908       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1);
909   MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
910       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
911   MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
912       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
913   MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
914       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
915   MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
916       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
917   MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
918       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f);
919   MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
920       mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0);
921 
922   if (!mlirAttributeIsADenseElements(splatBool) ||
923       !mlirDenseElementsAttrIsSplat(splatBool) ||
924       !mlirAttributeIsADenseElements(splatUInt32) ||
925       !mlirDenseElementsAttrIsSplat(splatUInt32) ||
926       !mlirAttributeIsADenseElements(splatInt32) ||
927       !mlirDenseElementsAttrIsSplat(splatInt32) ||
928       !mlirAttributeIsADenseElements(splatUInt64) ||
929       !mlirDenseElementsAttrIsSplat(splatUInt64) ||
930       !mlirAttributeIsADenseElements(splatInt64) ||
931       !mlirDenseElementsAttrIsSplat(splatInt64) ||
932       !mlirAttributeIsADenseElements(splatFloat) ||
933       !mlirDenseElementsAttrIsSplat(splatFloat) ||
934       !mlirAttributeIsADenseElements(splatDouble) ||
935       !mlirDenseElementsAttrIsSplat(splatDouble))
936     return 16;
937 
938   if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
939       mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
940       mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
941       mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
942       mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
943       fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
944           1E-6f ||
945       fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
946     return 17;
947 
948   uint32_t *uint32RawData =
949       (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
950   int32_t *int32RawData =
951       (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
952   uint64_t *uint64RawData =
953       (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
954   int64_t *int64RawData =
955       (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
956   float *floatRawData =
957       (float *)mlirDenseElementsAttrGetRawData(floatElements);
958   double *doubleRawData =
959       (double *)mlirDenseElementsAttrGetRawData(doubleElements);
960   if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
961       int32RawData[0] != 0 || int32RawData[1] != 1 ||
962       uint64RawData[0] != 0u || uint64RawData[1] != 1u ||
963       int64RawData[0] != 0 || int64RawData[1] != 1 ||
964       floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
965       doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
966     return 18;
967 
968   mlirAttributeDump(splatBool);
969   mlirAttributeDump(splatUInt32);
970   mlirAttributeDump(splatInt32);
971   mlirAttributeDump(splatUInt64);
972   mlirAttributeDump(splatInt64);
973   mlirAttributeDump(splatFloat);
974   mlirAttributeDump(splatDouble);
975   // CHECK: dense<true> : tensor<1x2xi1>
976   // CHECK: dense<1> : tensor<1x2xi32>
977   // CHECK: dense<1> : tensor<1x2xi32>
978   // CHECK: dense<1> : tensor<1x2xi64>
979   // CHECK: dense<1> : tensor<1x2xi64>
980   // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
981   // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
982 
983   mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
984   mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
985   // CHECK: 1.000000e+00 : f32
986   // CHECK: 1.000000e+00 : f64
987 
988   int64_t indices[] = {4, 7};
989   int64_t two = 2;
990   MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
991       mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2,
992       indices);
993   MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
994       mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats);
995   MlirAttribute sparseAttr = mlirSparseElementsAttribute(
996       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr,
997       valuesAttr);
998   mlirAttributeDump(sparseAttr);
999   // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
1000 
1001   return 0;
1002 }
1003 
printAffineMap(MlirContext ctx)1004 int printAffineMap(MlirContext ctx) {
1005   MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
1006   MlirAffineMap affineMap = mlirAffineMapGet(ctx, 3, 2);
1007   MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
1008   MlirAffineMap multiDimIdentityAffineMap =
1009       mlirAffineMapMultiDimIdentityGet(ctx, 3);
1010   MlirAffineMap minorIdentityAffineMap =
1011       mlirAffineMapMinorIdentityGet(ctx, 3, 2);
1012   unsigned permutation[] = {1, 2, 0};
1013   MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
1014       ctx, sizeof(permutation) / sizeof(unsigned), permutation);
1015 
1016   fprintf(stderr, "@affineMap\n");
1017   mlirAffineMapDump(emptyAffineMap);
1018   mlirAffineMapDump(affineMap);
1019   mlirAffineMapDump(constAffineMap);
1020   mlirAffineMapDump(multiDimIdentityAffineMap);
1021   mlirAffineMapDump(minorIdentityAffineMap);
1022   mlirAffineMapDump(permutationAffineMap);
1023   // CHECK-LABEL: @affineMap
1024   // CHECK: () -> ()
1025   // CHECK: (d0, d1, d2)[s0, s1] -> ()
1026   // CHECK: () -> (2)
1027   // CHECK: (d0, d1, d2) -> (d0, d1, d2)
1028   // CHECK: (d0, d1, d2) -> (d1, d2)
1029   // CHECK: (d0, d1, d2) -> (d1, d2, d0)
1030 
1031   if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
1032       mlirAffineMapIsIdentity(affineMap) ||
1033       mlirAffineMapIsIdentity(constAffineMap) ||
1034       !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
1035       mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
1036       mlirAffineMapIsIdentity(permutationAffineMap))
1037     return 1;
1038 
1039   if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
1040       mlirAffineMapIsMinorIdentity(affineMap) ||
1041       !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
1042       !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
1043       mlirAffineMapIsMinorIdentity(permutationAffineMap))
1044     return 2;
1045 
1046   if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
1047       mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
1048       mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
1049       mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
1050       mlirAffineMapIsEmpty(permutationAffineMap))
1051     return 3;
1052 
1053   if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
1054       mlirAffineMapIsSingleConstant(affineMap) ||
1055       !mlirAffineMapIsSingleConstant(constAffineMap) ||
1056       mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
1057       mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
1058       mlirAffineMapIsSingleConstant(permutationAffineMap))
1059     return 4;
1060 
1061   if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
1062     return 5;
1063 
1064   if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
1065       mlirAffineMapGetNumDims(affineMap) != 3 ||
1066       mlirAffineMapGetNumDims(constAffineMap) != 0 ||
1067       mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
1068       mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
1069       mlirAffineMapGetNumDims(permutationAffineMap) != 3)
1070     return 6;
1071 
1072   if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
1073       mlirAffineMapGetNumSymbols(affineMap) != 2 ||
1074       mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
1075       mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
1076       mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
1077       mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
1078     return 7;
1079 
1080   if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
1081       mlirAffineMapGetNumResults(affineMap) != 0 ||
1082       mlirAffineMapGetNumResults(constAffineMap) != 1 ||
1083       mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
1084       mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
1085       mlirAffineMapGetNumResults(permutationAffineMap) != 3)
1086     return 8;
1087 
1088   if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
1089       mlirAffineMapGetNumInputs(affineMap) != 5 ||
1090       mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
1091       mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
1092       mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
1093       mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
1094     return 9;
1095 
1096   if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
1097       !mlirAffineMapIsPermutation(emptyAffineMap) ||
1098       mlirAffineMapIsProjectedPermutation(affineMap) ||
1099       mlirAffineMapIsPermutation(affineMap) ||
1100       mlirAffineMapIsProjectedPermutation(constAffineMap) ||
1101       mlirAffineMapIsPermutation(constAffineMap) ||
1102       !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
1103       !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
1104       !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
1105       mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
1106       !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
1107       !mlirAffineMapIsPermutation(permutationAffineMap))
1108     return 10;
1109 
1110   intptr_t sub[] = {1};
1111 
1112   MlirAffineMap subMap = mlirAffineMapGetSubMap(
1113       multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
1114   MlirAffineMap majorSubMap =
1115       mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
1116   MlirAffineMap minorSubMap =
1117       mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);
1118 
1119   mlirAffineMapDump(subMap);
1120   mlirAffineMapDump(majorSubMap);
1121   mlirAffineMapDump(minorSubMap);
1122   // CHECK: (d0, d1, d2) -> (d1)
1123   // CHECK: (d0, d1, d2) -> (d0)
1124   // CHECK: (d0, d1, d2) -> (d2)
1125 
1126   return 0;
1127 }
1128 
printAffineExpr(MlirContext ctx)1129 int printAffineExpr(MlirContext ctx) {
1130   MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
1131   MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
1132   MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
1133   MlirAffineExpr affineAddExpr =
1134       mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
1135   MlirAffineExpr affineMulExpr =
1136       mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
1137   MlirAffineExpr affineModExpr =
1138       mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
1139   MlirAffineExpr affineFloorDivExpr =
1140       mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
1141   MlirAffineExpr affineCeilDivExpr =
1142       mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);
1143 
1144   // Tests mlirAffineExprDump.
1145   fprintf(stderr, "@affineExpr\n");
1146   mlirAffineExprDump(affineDimExpr);
1147   mlirAffineExprDump(affineSymbolExpr);
1148   mlirAffineExprDump(affineConstantExpr);
1149   mlirAffineExprDump(affineAddExpr);
1150   mlirAffineExprDump(affineMulExpr);
1151   mlirAffineExprDump(affineModExpr);
1152   mlirAffineExprDump(affineFloorDivExpr);
1153   mlirAffineExprDump(affineCeilDivExpr);
1154   // CHECK-LABEL: @affineExpr
1155   // CHECK: d5
1156   // CHECK: s5
1157   // CHECK: 5
1158   // CHECK: d5 + s5
1159   // CHECK: d5 * s5
1160   // CHECK: d5 mod s5
1161   // CHECK: d5 floordiv s5
1162   // CHECK: d5 ceildiv s5
1163 
1164   // Tests methods of affine binary operation expression, takes add expression
1165   // as an example.
1166   mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
1167   mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
1168   // CHECK: d5
1169   // CHECK: s5
1170 
1171   // Tests methods of affine dimension expression.
1172   if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
1173     return 1;
1174 
1175   // Tests methods of affine symbol expression.
1176   if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
1177     return 2;
1178 
1179   // Tests methods of affine constant expression.
1180   if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
1181     return 3;
1182 
1183   // Tests methods of affine expression.
1184   if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
1185       !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
1186       !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
1187       mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
1188       mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
1189       mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
1190       mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
1191       mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
1192     return 4;
1193 
1194   if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
1195       !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
1196       !mlirAffineExprIsPureAffine(affineConstantExpr) ||
1197       !mlirAffineExprIsPureAffine(affineAddExpr) ||
1198       mlirAffineExprIsPureAffine(affineMulExpr) ||
1199       mlirAffineExprIsPureAffine(affineModExpr) ||
1200       mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
1201       mlirAffineExprIsPureAffine(affineCeilDivExpr))
1202     return 5;
1203 
1204   if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
1205       mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
1206       mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
1207       mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
1208       mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
1209       mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
1210       mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
1211       mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
1212     return 6;
1213 
1214   if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
1215       !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
1216       !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
1217       !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
1218       !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
1219       !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
1220       !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
1221       !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
1222     return 7;
1223 
1224   if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
1225       mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
1226       mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
1227       !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
1228       !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
1229       !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
1230       !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
1231       !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
1232     return 8;
1233 
1234   // Tests 'IsA' methods of affine binary operation expression.
1235   if (!mlirAffineExprIsAAdd(affineAddExpr))
1236     return 9;
1237 
1238   if (!mlirAffineExprIsAMul(affineMulExpr))
1239     return 10;
1240 
1241   if (!mlirAffineExprIsAMod(affineModExpr))
1242     return 11;
1243 
1244   if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
1245     return 12;
1246 
1247   if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
1248     return 13;
1249 
1250   return 0;
1251 }
1252 
registerOnlyStd()1253 int registerOnlyStd() {
1254   MlirContext ctx = mlirContextCreate();
1255   // The built-in dialect is always loaded.
1256   if (mlirContextGetNumLoadedDialects(ctx) != 1)
1257     return 1;
1258 
1259   MlirDialect std =
1260       mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
1261   if (!mlirDialectIsNull(std))
1262     return 2;
1263 
1264   mlirContextRegisterStandardDialect(ctx);
1265   if (mlirContextGetNumRegisteredDialects(ctx) != 1)
1266     return 3;
1267   if (mlirContextGetNumLoadedDialects(ctx) != 1)
1268     return 4;
1269 
1270   std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
1271   if (mlirDialectIsNull(std))
1272     return 5;
1273   if (mlirContextGetNumLoadedDialects(ctx) != 2)
1274     return 6;
1275 
1276   MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx);
1277   if (!mlirDialectEqual(std, alsoStd))
1278     return 7;
1279 
1280   MlirStringRef stdNs = mlirDialectGetNamespace(std);
1281   MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace();
1282   if (stdNs.length != alsoStdNs.length ||
1283       strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
1284     return 8;
1285 
1286   fprintf(stderr, "@registration\n");
1287   // CHECK-LABEL: @registration
1288 
1289   return 0;
1290 }
1291 
1292 // Wraps a diagnostic into additional text we can match against.
errorHandler(MlirDiagnostic diagnostic,void * userData)1293 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
1294   fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData);
1295   mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
1296   fprintf(stderr, "\n");
1297   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1298   mlirLocationPrint(loc, printToStderr, NULL);
1299   assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
1300   fprintf(stderr, ">> end of diagnostic (userData: %ld)\n", (long)userData);
1301   return mlirLogicalResultSuccess();
1302 }
1303 
1304 // Logs when the delete user data callback is called
deleteUserData(void * userData)1305 static void deleteUserData(void *userData) {
1306   fprintf(stderr, "deleting user data (userData: %ld)\n", (long)userData);
1307 }
1308 
testDiagnostics()1309 void testDiagnostics() {
1310   MlirContext ctx = mlirContextCreate();
1311   MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
1312       ctx, errorHandler, (void *)42, deleteUserData);
1313   MlirLocation loc = mlirLocationUnknownGet(ctx);
1314   fprintf(stderr, "@test_diagnostics\n");
1315   mlirEmitError(loc, "test diagnostics");
1316   mlirContextDetachDiagnosticHandler(ctx, id);
1317   mlirEmitError(loc, "more test diagnostics");
1318   // CHECK-LABEL: @test_diagnostics
1319   // CHECK: processing diagnostic (userData: 42) <<
1320   // CHECK:   test diagnostics
1321   // CHECK:   loc(unknown)
1322   // CHECK: >> end of diagnostic (userData: 42)
1323   // CHECK: deleting user data (userData: 42)
1324   // CHECK-NOT: processing diagnostic
1325   // CHECK:     more test diagnostics
1326 }
1327 
main()1328 int main() {
1329   MlirContext ctx = mlirContextCreate();
1330   mlirRegisterAllDialects(ctx);
1331   if (constructAndTraverseIr(ctx))
1332     return 1;
1333   buildWithInsertionsAndPrint(ctx);
1334 
1335   if (printBuiltinTypes(ctx))
1336     return 2;
1337   if (printBuiltinAttributes(ctx))
1338     return 3;
1339   if (printAffineMap(ctx))
1340     return 4;
1341   if (printAffineExpr(ctx))
1342     return 5;
1343   if (registerOnlyStd())
1344     return 6;
1345 
1346   mlirContextDestroy(ctx);
1347 
1348   testDiagnostics();
1349   return 0;
1350 }
1351