1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 gc.collect() 12 assert Context._get_live_count() == 0 13 14 15# Verify iterator based traversal of the op/region/block hierarchy. 16# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 17def testTraverseOpRegionBlockIterators(): 18 ctx = Context() 19 ctx.allow_unregistered_dialects = True 20 module = Module.parse(r""" 21 func @f1(%arg0: i32) -> i32 { 22 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 23 return %1 : i32 24 } 25 """, ctx) 26 op = module.operation 27 assert op.context is ctx 28 # Get the block using iterators off of the named collections. 29 regions = list(op.regions) 30 blocks = list(regions[0].blocks) 31 # CHECK: MODULE REGIONS=1 BLOCKS=1 32 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 33 34 # Get the regions and blocks from the default collections. 35 default_regions = list(op) 36 default_blocks = list(default_regions[0]) 37 # They should compare equal regardless of how obtained. 38 assert default_regions == regions 39 assert default_blocks == blocks 40 41 # Should be able to get the operations from either the named collection 42 # or the block. 43 operations = list(blocks[0].operations) 44 default_operations = list(blocks[0]) 45 assert default_operations == operations 46 47 def walk_operations(indent, op): 48 for i, region in enumerate(op): 49 print(f"{indent}REGION {i}:") 50 for j, block in enumerate(region): 51 print(f"{indent} BLOCK {j}:") 52 for k, child_op in enumerate(block): 53 print(f"{indent} OP {k}: {child_op}") 54 walk_operations(indent + " ", child_op) 55 56 # CHECK: REGION 0: 57 # CHECK: BLOCK 0: 58 # CHECK: OP 0: func 59 # CHECK: REGION 0: 60 # CHECK: BLOCK 0: 61 # CHECK: OP 0: %0 = "custom.addi" 62 # CHECK: OP 1: return 63 # CHECK: OP 1: module_terminator 64 walk_operations("", op) 65 66run(testTraverseOpRegionBlockIterators) 67 68 69# Verify index based traversal of the op/region/block hierarchy. 70# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 71def testTraverseOpRegionBlockIndices(): 72 ctx = Context() 73 ctx.allow_unregistered_dialects = True 74 module = Module.parse(r""" 75 func @f1(%arg0: i32) -> i32 { 76 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 77 return %1 : i32 78 } 79 """, ctx) 80 81 def walk_operations(indent, op): 82 for i in range(len(op.regions)): 83 region = op.regions[i] 84 print(f"{indent}REGION {i}:") 85 for j in range(len(region.blocks)): 86 block = region.blocks[j] 87 print(f"{indent} BLOCK {j}:") 88 for k in range(len(block.operations)): 89 child_op = block.operations[k] 90 print(f"{indent} OP {k}: {child_op}") 91 walk_operations(indent + " ", child_op) 92 93 # CHECK: REGION 0: 94 # CHECK: BLOCK 0: 95 # CHECK: OP 0: func 96 # CHECK: REGION 0: 97 # CHECK: BLOCK 0: 98 # CHECK: OP 0: %0 = "custom.addi" 99 # CHECK: OP 1: return 100 # CHECK: OP 1: module_terminator 101 walk_operations("", module.operation) 102 103run(testTraverseOpRegionBlockIndices) 104 105 106# CHECK-LABEL: TEST: testBlockArgumentList 107def testBlockArgumentList(): 108 with Context() as ctx: 109 module = Module.parse(r""" 110 func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 111 return 112 } 113 """, ctx) 114 func = module.body.operations[0] 115 entry_block = func.regions[0].blocks[0] 116 assert len(entry_block.arguments) == 3 117 # CHECK: Argument 0, type i32 118 # CHECK: Argument 1, type f64 119 # CHECK: Argument 2, type index 120 for arg in entry_block.arguments: 121 print(f"Argument {arg.arg_number}, type {arg.type}") 122 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 123 arg.set_type(new_type) 124 125 # CHECK: Argument 0, type i8 126 # CHECK: Argument 1, type i16 127 # CHECK: Argument 2, type i24 128 for arg in entry_block.arguments: 129 print(f"Argument {arg.arg_number}, type {arg.type}") 130 131 132run(testBlockArgumentList) 133 134 135# CHECK-LABEL: TEST: testOperationOperands 136def testOperationOperands(): 137 with Context() as ctx: 138 ctx.allow_unregistered_dialects = True 139 module = Module.parse(r""" 140 func @f1(%arg0: i32) { 141 %0 = "test.producer"() : () -> i64 142 "test.consumer"(%arg0, %0) : (i32, i64) -> () 143 return 144 }""") 145 func = module.body.operations[0] 146 entry_block = func.regions[0].blocks[0] 147 consumer = entry_block.operations[1] 148 assert len(consumer.operands) == 2 149 # CHECK: Operand 0, type i32 150 # CHECK: Operand 1, type i64 151 for i, operand in enumerate(consumer.operands): 152 print(f"Operand {i}, type {operand.type}") 153 154 155run(testOperationOperands) 156 157 158# CHECK-LABEL: TEST: testOperationOperandsSlice 159def testOperationOperandsSlice(): 160 with Context() as ctx: 161 ctx.allow_unregistered_dialects = True 162 module = Module.parse(r""" 163 func @f1() { 164 %0 = "test.producer0"() : () -> i64 165 %1 = "test.producer1"() : () -> i64 166 %2 = "test.producer2"() : () -> i64 167 %3 = "test.producer3"() : () -> i64 168 %4 = "test.producer4"() : () -> i64 169 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 170 return 171 }""") 172 func = module.body.operations[0] 173 entry_block = func.regions[0].blocks[0] 174 consumer = entry_block.operations[5] 175 assert len(consumer.operands) == 5 176 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 177 assert left == right 178 179 # CHECK: test.producer0 180 # CHECK: test.producer1 181 # CHECK: test.producer2 182 # CHECK: test.producer3 183 # CHECK: test.producer4 184 full_slice = consumer.operands[:] 185 for operand in full_slice: 186 print(operand) 187 188 # CHECK: test.producer0 189 # CHECK: test.producer1 190 first_two = consumer.operands[0:2] 191 for operand in first_two: 192 print(operand) 193 194 # CHECK: test.producer3 195 # CHECK: test.producer4 196 last_two = consumer.operands[3:] 197 for operand in last_two: 198 print(operand) 199 200 # CHECK: test.producer0 201 # CHECK: test.producer2 202 # CHECK: test.producer4 203 even = consumer.operands[::2] 204 for operand in even: 205 print(operand) 206 207 # CHECK: test.producer2 208 fourth = consumer.operands[::2][1::2] 209 for operand in fourth: 210 print(operand) 211 212 213run(testOperationOperandsSlice) 214 215 216# CHECK-LABEL: TEST: testDetachedOperation 217def testDetachedOperation(): 218 ctx = Context() 219 ctx.allow_unregistered_dialects = True 220 with Location.unknown(ctx): 221 i32 = IntegerType.get_signed(32) 222 op1 = Operation.create( 223 "custom.op1", results=[i32, i32], regions=1, attributes={ 224 "foo": StringAttr.get("foo_value"), 225 "bar": StringAttr.get("bar_value"), 226 }) 227 # CHECK: %0:2 = "custom.op1"() ( { 228 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 229 print(op1) 230 231 # TODO: Check successors once enough infra exists to do it properly. 232 233run(testDetachedOperation) 234 235 236# CHECK-LABEL: TEST: testOperationInsertionPoint 237def testOperationInsertionPoint(): 238 ctx = Context() 239 ctx.allow_unregistered_dialects = True 240 module = Module.parse(r""" 241 func @f1(%arg0: i32) -> i32 { 242 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 243 return %1 : i32 244 } 245 """, ctx) 246 247 # Create test op. 248 with Location.unknown(ctx): 249 op1 = Operation.create("custom.op1") 250 op2 = Operation.create("custom.op2") 251 252 func = module.body.operations[0] 253 entry_block = func.regions[0].blocks[0] 254 ip = InsertionPoint.at_block_begin(entry_block) 255 ip.insert(op1) 256 ip.insert(op2) 257 # CHECK: func @f1 258 # CHECK: "custom.op1"() 259 # CHECK: "custom.op2"() 260 # CHECK: %0 = "custom.addi" 261 print(module) 262 263 # Trying to add a previously added op should raise. 264 try: 265 ip.insert(op1) 266 except ValueError: 267 pass 268 else: 269 assert False, "expected insert of attached op to raise" 270 271run(testOperationInsertionPoint) 272 273 274# CHECK-LABEL: TEST: testOperationWithRegion 275def testOperationWithRegion(): 276 ctx = Context() 277 ctx.allow_unregistered_dialects = True 278 with Location.unknown(ctx): 279 i32 = IntegerType.get_signed(32) 280 op1 = Operation.create("custom.op1", regions=1) 281 block = op1.regions[0].blocks.append(i32, i32) 282 # CHECK: "custom.op1"() ( { 283 # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors 284 # CHECK: "custom.terminator"() : () -> () 285 # CHECK: }) : () -> () 286 terminator = Operation.create("custom.terminator") 287 ip = InsertionPoint(block) 288 ip.insert(terminator) 289 print(op1) 290 291 # Now add the whole operation to another op. 292 # TODO: Verify lifetime hazard by nulling out the new owning module and 293 # accessing op1. 294 # TODO: Also verify accessing the terminator once both parents are nulled 295 # out. 296 module = Module.parse(r""" 297 func @f1(%arg0: i32) -> i32 { 298 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 299 return %1 : i32 300 } 301 """) 302 func = module.body.operations[0] 303 entry_block = func.regions[0].blocks[0] 304 ip = InsertionPoint.at_block_begin(entry_block) 305 ip.insert(op1) 306 # CHECK: func @f1 307 # CHECK: "custom.op1"() 308 # CHECK: "custom.terminator" 309 # CHECK: %0 = "custom.addi" 310 print(module) 311 312run(testOperationWithRegion) 313 314 315# CHECK-LABEL: TEST: testOperationResultList 316def testOperationResultList(): 317 ctx = Context() 318 module = Module.parse(r""" 319 func @f1() { 320 %0:3 = call @f2() : () -> (i32, f64, index) 321 return 322 } 323 func private @f2() -> (i32, f64, index) 324 """, ctx) 325 caller = module.body.operations[0] 326 call = caller.regions[0].blocks[0].operations[0] 327 assert len(call.results) == 3 328 # CHECK: Result 0, type i32 329 # CHECK: Result 1, type f64 330 # CHECK: Result 2, type index 331 for res in call.results: 332 print(f"Result {res.result_number}, type {res.type}") 333 334 335run(testOperationResultList) 336 337 338# CHECK-LABEL: TEST: testOperationResultListSlice 339def testOperationResultListSlice(): 340 with Context() as ctx: 341 ctx.allow_unregistered_dialects = True 342 module = Module.parse(r""" 343 func @f1() { 344 "some.op"() : () -> (i1, i2, i3, i4, i5) 345 return 346 } 347 """) 348 func = module.body.operations[0] 349 entry_block = func.regions[0].blocks[0] 350 producer = entry_block.operations[0] 351 352 assert len(producer.results) == 5 353 for left, right in zip(producer.results, producer.results[::-1][::-1]): 354 assert left == right 355 assert left.result_number == right.result_number 356 357 # CHECK: Result 0, type i1 358 # CHECK: Result 1, type i2 359 # CHECK: Result 2, type i3 360 # CHECK: Result 3, type i4 361 # CHECK: Result 4, type i5 362 full_slice = producer.results[:] 363 for res in full_slice: 364 print(f"Result {res.result_number}, type {res.type}") 365 366 # CHECK: Result 1, type i2 367 # CHECK: Result 2, type i3 368 # CHECK: Result 3, type i4 369 middle = producer.results[1:4] 370 for res in middle: 371 print(f"Result {res.result_number}, type {res.type}") 372 373 # CHECK: Result 1, type i2 374 # CHECK: Result 3, type i4 375 odd = producer.results[1::2] 376 for res in odd: 377 print(f"Result {res.result_number}, type {res.type}") 378 379 # CHECK: Result 3, type i4 380 # CHECK: Result 1, type i2 381 inverted_middle = producer.results[-2:0:-2] 382 for res in inverted_middle: 383 print(f"Result {res.result_number}, type {res.type}") 384 385 386run(testOperationResultListSlice) 387 388 389# CHECK-LABEL: TEST: testOperationAttributes 390def testOperationAttributes(): 391 ctx = Context() 392 ctx.allow_unregistered_dialects = True 393 module = Module.parse(r""" 394 "some.op"() { some.attribute = 1 : i8, 395 other.attribute = 3.0, 396 dependent = "text" } : () -> () 397 """, ctx) 398 op = module.body.operations[0] 399 assert len(op.attributes) == 3 400 iattr = IntegerAttr(op.attributes["some.attribute"]) 401 fattr = FloatAttr(op.attributes["other.attribute"]) 402 sattr = StringAttr(op.attributes["dependent"]) 403 # CHECK: Attribute type i8, value 1 404 print(f"Attribute type {iattr.type}, value {iattr.value}") 405 # CHECK: Attribute type f64, value 3.0 406 print(f"Attribute type {fattr.type}, value {fattr.value}") 407 # CHECK: Attribute value text 408 print(f"Attribute value {sattr.value}") 409 410 # We don't know in which order the attributes are stored. 411 # CHECK-DAG: NamedAttribute(dependent="text") 412 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 413 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 414 for attr in op.attributes: 415 print(str(attr)) 416 417 # Check that exceptions are raised as expected. 418 try: 419 op.attributes["does_not_exist"] 420 except KeyError: 421 pass 422 else: 423 assert False, "expected KeyError on accessing a non-existent attribute" 424 425 try: 426 op.attributes[42] 427 except IndexError: 428 pass 429 else: 430 assert False, "expected IndexError on accessing an out-of-bounds attribute" 431 432 433run(testOperationAttributes) 434 435 436# CHECK-LABEL: TEST: testOperationPrint 437def testOperationPrint(): 438 ctx = Context() 439 module = Module.parse(r""" 440 func @f1(%arg0: i32) -> i32 { 441 %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> 442 return %arg0 : i32 443 } 444 """, ctx) 445 446 # Test print to stdout. 447 # CHECK: return %arg0 : i32 448 module.operation.print() 449 450 # Test print to text file. 451 f = io.StringIO() 452 # CHECK: <class 'str'> 453 # CHECK: return %arg0 : i32 454 module.operation.print(file=f) 455 str_value = f.getvalue() 456 print(str_value.__class__) 457 print(f.getvalue()) 458 459 # Test print to binary file. 460 f = io.BytesIO() 461 # CHECK: <class 'bytes'> 462 # CHECK: return %arg0 : i32 463 module.operation.print(file=f, binary=True) 464 bytes_value = f.getvalue() 465 print(bytes_value.__class__) 466 print(bytes_value) 467 468 # Test get_asm with options. 469 # CHECK: value = opaque<"", "0xDEADBEEF"> : tensor<4xi32> 470 # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7 471 module.operation.print(large_elements_limit=2, enable_debug_info=True, 472 pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True) 473 474run(testOperationPrint) 475 476 477# CHECK-LABEL: TEST: testKnownOpView 478def testKnownOpView(): 479 with Context(), Location.unknown(): 480 Context.current.allow_unregistered_dialects = True 481 module = Module.parse(r""" 482 %1 = "custom.f32"() : () -> f32 483 %2 = "custom.f32"() : () -> f32 484 %3 = addf %1, %2 : f32 485 """) 486 print(module) 487 488 # addf should map to a known OpView class in the std dialect. 489 # We know the OpView for it defines an 'lhs' attribute. 490 addf = module.body.operations[2] 491 # CHECK: <mlir.dialects.std._AddFOp object 492 print(repr(addf)) 493 # CHECK: "custom.f32"() 494 print(addf.lhs) 495 496 # One of the custom ops should resolve to the default OpView. 497 custom = module.body.operations[0] 498 # CHECK: <_mlir.ir.OpView object 499 print(repr(custom)) 500 501 # Check again to make sure negative caching works. 502 custom = module.body.operations[0] 503 # CHECK: <_mlir.ir.OpView object 504 print(repr(custom)) 505 506run(testKnownOpView) 507 508 509# CHECK-LABEL: TEST: testSingleResultProperty 510def testSingleResultProperty(): 511 with Context(), Location.unknown(): 512 Context.current.allow_unregistered_dialects = True 513 module = Module.parse(r""" 514 "custom.no_result"() : () -> () 515 %0:2 = "custom.two_result"() : () -> (f32, f32) 516 %1 = "custom.one_result"() : () -> f32 517 """) 518 print(module) 519 520 try: 521 module.body.operations[0].result 522 except ValueError as e: 523 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 524 print(e) 525 else: 526 assert False, "Expected exception" 527 528 try: 529 module.body.operations[1].result 530 except ValueError as e: 531 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 532 print(e) 533 else: 534 assert False, "Expected exception" 535 536 # CHECK: %1 = "custom.one_result"() : () -> f32 537 print(module.body.operations[2]) 538 539run(testSingleResultProperty) 540 541# CHECK-LABEL: TEST: testPrintInvalidOperation 542def testPrintInvalidOperation(): 543 ctx = Context() 544 with Location.unknown(ctx): 545 module = Operation.create("module", regions=1) 546 # This block does not have a terminator, it may crash the custom printer. 547 # Verify that we fallback to the generic printer for safety. 548 block = module.regions[0].blocks.append() 549 print(module) 550 # CHECK: // Verification failed, printing generic form 551 # CHECK: "module"() ( { 552 # CHECK: }) : () -> () 553run(testPrintInvalidOperation) 554 555 556# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 557def testCreateWithInvalidAttributes(): 558 ctx = Context() 559 with Location.unknown(ctx): 560 try: 561 Operation.create("module", attributes={None:StringAttr.get("name")}) 562 except Exception as e: 563 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "module" (Unable to cast Python instance of type <class 'NoneType'> to C++ type 564 print(e) 565 try: 566 Operation.create("module", attributes={42:StringAttr.get("name")}) 567 except Exception as e: 568 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "module" (Unable to cast Python instance of type <class 'int'> to C++ type 569 print(e) 570 try: 571 Operation.create("module", attributes={"some_key":ctx}) 572 except Exception as e: 573 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "module" (Unable to cast Python instance of type <class '_mlir.ir.Context'> to C++ type 'mlir::python::PyAttribute') 574 print(e) 575 try: 576 Operation.create("module", attributes={"some_key":None}) 577 except Exception as e: 578 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "module" 579 print(e) 580run(testCreateWithInvalidAttributes) 581