• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  gc.collect()
12  assert Context._get_live_count() == 0
13
14
15# Verify iterator based traversal of the op/region/block hierarchy.
16# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
17def testTraverseOpRegionBlockIterators():
18  ctx = Context()
19  ctx.allow_unregistered_dialects = True
20  module = Module.parse(r"""
21    func @f1(%arg0: i32) -> i32 {
22      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
23      return %1 : i32
24    }
25  """, ctx)
26  op = module.operation
27  assert op.context is ctx
28  # Get the block using iterators off of the named collections.
29  regions = list(op.regions)
30  blocks = list(regions[0].blocks)
31  # CHECK: MODULE REGIONS=1 BLOCKS=1
32  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
33
34  # Get the regions and blocks from the default collections.
35  default_regions = list(op)
36  default_blocks = list(default_regions[0])
37  # They should compare equal regardless of how obtained.
38  assert default_regions == regions
39  assert default_blocks == blocks
40
41  # Should be able to get the operations from either the named collection
42  # or the block.
43  operations = list(blocks[0].operations)
44  default_operations = list(blocks[0])
45  assert default_operations == operations
46
47  def walk_operations(indent, op):
48    for i, region in enumerate(op):
49      print(f"{indent}REGION {i}:")
50      for j, block in enumerate(region):
51        print(f"{indent}  BLOCK {j}:")
52        for k, child_op in enumerate(block):
53          print(f"{indent}    OP {k}: {child_op}")
54          walk_operations(indent + "      ", child_op)
55
56  # CHECK: REGION 0:
57  # CHECK:   BLOCK 0:
58  # CHECK:     OP 0: func
59  # CHECK:       REGION 0:
60  # CHECK:         BLOCK 0:
61  # CHECK:           OP 0: %0 = "custom.addi"
62  # CHECK:           OP 1: return
63  # CHECK:    OP 1: module_terminator
64  walk_operations("", op)
65
66run(testTraverseOpRegionBlockIterators)
67
68
69# Verify index based traversal of the op/region/block hierarchy.
70# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
71def testTraverseOpRegionBlockIndices():
72  ctx = Context()
73  ctx.allow_unregistered_dialects = True
74  module = Module.parse(r"""
75    func @f1(%arg0: i32) -> i32 {
76      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
77      return %1 : i32
78    }
79  """, ctx)
80
81  def walk_operations(indent, op):
82    for i in range(len(op.regions)):
83      region = op.regions[i]
84      print(f"{indent}REGION {i}:")
85      for j in range(len(region.blocks)):
86        block = region.blocks[j]
87        print(f"{indent}  BLOCK {j}:")
88        for k in range(len(block.operations)):
89          child_op = block.operations[k]
90          print(f"{indent}    OP {k}: {child_op}")
91          walk_operations(indent + "      ", child_op)
92
93  # CHECK: REGION 0:
94  # CHECK:   BLOCK 0:
95  # CHECK:     OP 0: func
96  # CHECK:       REGION 0:
97  # CHECK:         BLOCK 0:
98  # CHECK:           OP 0: %0 = "custom.addi"
99  # CHECK:           OP 1: return
100  # CHECK:    OP 1: module_terminator
101  walk_operations("", module.operation)
102
103run(testTraverseOpRegionBlockIndices)
104
105
106# CHECK-LABEL: TEST: testBlockArgumentList
107def testBlockArgumentList():
108  with Context() as ctx:
109    module = Module.parse(r"""
110      func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
111        return
112      }
113    """, ctx)
114    func = module.body.operations[0]
115    entry_block = func.regions[0].blocks[0]
116    assert len(entry_block.arguments) == 3
117    # CHECK: Argument 0, type i32
118    # CHECK: Argument 1, type f64
119    # CHECK: Argument 2, type index
120    for arg in entry_block.arguments:
121      print(f"Argument {arg.arg_number}, type {arg.type}")
122      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
123      arg.set_type(new_type)
124
125    # CHECK: Argument 0, type i8
126    # CHECK: Argument 1, type i16
127    # CHECK: Argument 2, type i24
128    for arg in entry_block.arguments:
129      print(f"Argument {arg.arg_number}, type {arg.type}")
130
131
132run(testBlockArgumentList)
133
134
135# CHECK-LABEL: TEST: testOperationOperands
136def testOperationOperands():
137  with Context() as ctx:
138    ctx.allow_unregistered_dialects = True
139    module = Module.parse(r"""
140      func @f1(%arg0: i32) {
141        %0 = "test.producer"() : () -> i64
142        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
143        return
144      }""")
145    func = module.body.operations[0]
146    entry_block = func.regions[0].blocks[0]
147    consumer = entry_block.operations[1]
148    assert len(consumer.operands) == 2
149    # CHECK: Operand 0, type i32
150    # CHECK: Operand 1, type i64
151    for i, operand in enumerate(consumer.operands):
152      print(f"Operand {i}, type {operand.type}")
153
154
155run(testOperationOperands)
156
157
158# CHECK-LABEL: TEST: testOperationOperandsSlice
159def testOperationOperandsSlice():
160  with Context() as ctx:
161    ctx.allow_unregistered_dialects = True
162    module = Module.parse(r"""
163      func @f1() {
164        %0 = "test.producer0"() : () -> i64
165        %1 = "test.producer1"() : () -> i64
166        %2 = "test.producer2"() : () -> i64
167        %3 = "test.producer3"() : () -> i64
168        %4 = "test.producer4"() : () -> i64
169        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
170        return
171      }""")
172    func = module.body.operations[0]
173    entry_block = func.regions[0].blocks[0]
174    consumer = entry_block.operations[5]
175    assert len(consumer.operands) == 5
176    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
177      assert left == right
178
179    # CHECK: test.producer0
180    # CHECK: test.producer1
181    # CHECK: test.producer2
182    # CHECK: test.producer3
183    # CHECK: test.producer4
184    full_slice = consumer.operands[:]
185    for operand in full_slice:
186      print(operand)
187
188    # CHECK: test.producer0
189    # CHECK: test.producer1
190    first_two = consumer.operands[0:2]
191    for operand in first_two:
192      print(operand)
193
194    # CHECK: test.producer3
195    # CHECK: test.producer4
196    last_two = consumer.operands[3:]
197    for operand in last_two:
198      print(operand)
199
200    # CHECK: test.producer0
201    # CHECK: test.producer2
202    # CHECK: test.producer4
203    even = consumer.operands[::2]
204    for operand in even:
205      print(operand)
206
207    # CHECK: test.producer2
208    fourth = consumer.operands[::2][1::2]
209    for operand in fourth:
210      print(operand)
211
212
213run(testOperationOperandsSlice)
214
215
216# CHECK-LABEL: TEST: testDetachedOperation
217def testDetachedOperation():
218  ctx = Context()
219  ctx.allow_unregistered_dialects = True
220  with Location.unknown(ctx):
221    i32 = IntegerType.get_signed(32)
222    op1 = Operation.create(
223        "custom.op1", results=[i32, i32], regions=1, attributes={
224            "foo": StringAttr.get("foo_value"),
225            "bar": StringAttr.get("bar_value"),
226        })
227    # CHECK: %0:2 = "custom.op1"() ( {
228    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
229    print(op1)
230
231  # TODO: Check successors once enough infra exists to do it properly.
232
233run(testDetachedOperation)
234
235
236# CHECK-LABEL: TEST: testOperationInsertionPoint
237def testOperationInsertionPoint():
238  ctx = Context()
239  ctx.allow_unregistered_dialects = True
240  module = Module.parse(r"""
241    func @f1(%arg0: i32) -> i32 {
242      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
243      return %1 : i32
244    }
245  """, ctx)
246
247  # Create test op.
248  with Location.unknown(ctx):
249    op1 = Operation.create("custom.op1")
250    op2 = Operation.create("custom.op2")
251
252    func = module.body.operations[0]
253    entry_block = func.regions[0].blocks[0]
254    ip = InsertionPoint.at_block_begin(entry_block)
255    ip.insert(op1)
256    ip.insert(op2)
257    # CHECK: func @f1
258    # CHECK: "custom.op1"()
259    # CHECK: "custom.op2"()
260    # CHECK: %0 = "custom.addi"
261    print(module)
262
263  # Trying to add a previously added op should raise.
264  try:
265    ip.insert(op1)
266  except ValueError:
267    pass
268  else:
269    assert False, "expected insert of attached op to raise"
270
271run(testOperationInsertionPoint)
272
273
274# CHECK-LABEL: TEST: testOperationWithRegion
275def testOperationWithRegion():
276  ctx = Context()
277  ctx.allow_unregistered_dialects = True
278  with Location.unknown(ctx):
279    i32 = IntegerType.get_signed(32)
280    op1 = Operation.create("custom.op1", regions=1)
281    block = op1.regions[0].blocks.append(i32, i32)
282    # CHECK: "custom.op1"() ( {
283    # CHECK: ^bb0(%arg0: si32, %arg1: si32):  // no predecessors
284    # CHECK:   "custom.terminator"() : () -> ()
285    # CHECK: }) : () -> ()
286    terminator = Operation.create("custom.terminator")
287    ip = InsertionPoint(block)
288    ip.insert(terminator)
289    print(op1)
290
291    # Now add the whole operation to another op.
292    # TODO: Verify lifetime hazard by nulling out the new owning module and
293    # accessing op1.
294    # TODO: Also verify accessing the terminator once both parents are nulled
295    # out.
296    module = Module.parse(r"""
297      func @f1(%arg0: i32) -> i32 {
298        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
299        return %1 : i32
300      }
301    """)
302    func = module.body.operations[0]
303    entry_block = func.regions[0].blocks[0]
304    ip = InsertionPoint.at_block_begin(entry_block)
305    ip.insert(op1)
306    # CHECK: func @f1
307    # CHECK: "custom.op1"()
308    # CHECK:   "custom.terminator"
309    # CHECK: %0 = "custom.addi"
310    print(module)
311
312run(testOperationWithRegion)
313
314
315# CHECK-LABEL: TEST: testOperationResultList
316def testOperationResultList():
317  ctx = Context()
318  module = Module.parse(r"""
319    func @f1() {
320      %0:3 = call @f2() : () -> (i32, f64, index)
321      return
322    }
323    func private @f2() -> (i32, f64, index)
324  """, ctx)
325  caller = module.body.operations[0]
326  call = caller.regions[0].blocks[0].operations[0]
327  assert len(call.results) == 3
328  # CHECK: Result 0, type i32
329  # CHECK: Result 1, type f64
330  # CHECK: Result 2, type index
331  for res in call.results:
332    print(f"Result {res.result_number}, type {res.type}")
333
334
335run(testOperationResultList)
336
337
338# CHECK-LABEL: TEST: testOperationResultListSlice
339def testOperationResultListSlice():
340  with Context() as ctx:
341    ctx.allow_unregistered_dialects = True
342    module = Module.parse(r"""
343      func @f1() {
344        "some.op"() : () -> (i1, i2, i3, i4, i5)
345        return
346      }
347    """)
348    func = module.body.operations[0]
349    entry_block = func.regions[0].blocks[0]
350    producer = entry_block.operations[0]
351
352    assert len(producer.results) == 5
353    for left, right in zip(producer.results, producer.results[::-1][::-1]):
354      assert left == right
355      assert left.result_number == right.result_number
356
357    # CHECK: Result 0, type i1
358    # CHECK: Result 1, type i2
359    # CHECK: Result 2, type i3
360    # CHECK: Result 3, type i4
361    # CHECK: Result 4, type i5
362    full_slice = producer.results[:]
363    for res in full_slice:
364      print(f"Result {res.result_number}, type {res.type}")
365
366    # CHECK: Result 1, type i2
367    # CHECK: Result 2, type i3
368    # CHECK: Result 3, type i4
369    middle = producer.results[1:4]
370    for res in middle:
371      print(f"Result {res.result_number}, type {res.type}")
372
373    # CHECK: Result 1, type i2
374    # CHECK: Result 3, type i4
375    odd = producer.results[1::2]
376    for res in odd:
377      print(f"Result {res.result_number}, type {res.type}")
378
379    # CHECK: Result 3, type i4
380    # CHECK: Result 1, type i2
381    inverted_middle = producer.results[-2:0:-2]
382    for res in inverted_middle:
383      print(f"Result {res.result_number}, type {res.type}")
384
385
386run(testOperationResultListSlice)
387
388
389# CHECK-LABEL: TEST: testOperationAttributes
390def testOperationAttributes():
391  ctx = Context()
392  ctx.allow_unregistered_dialects = True
393  module = Module.parse(r"""
394    "some.op"() { some.attribute = 1 : i8,
395                  other.attribute = 3.0,
396                  dependent = "text" } : () -> ()
397  """, ctx)
398  op = module.body.operations[0]
399  assert len(op.attributes) == 3
400  iattr = IntegerAttr(op.attributes["some.attribute"])
401  fattr = FloatAttr(op.attributes["other.attribute"])
402  sattr = StringAttr(op.attributes["dependent"])
403  # CHECK: Attribute type i8, value 1
404  print(f"Attribute type {iattr.type}, value {iattr.value}")
405  # CHECK: Attribute type f64, value 3.0
406  print(f"Attribute type {fattr.type}, value {fattr.value}")
407  # CHECK: Attribute value text
408  print(f"Attribute value {sattr.value}")
409
410  # We don't know in which order the attributes are stored.
411  # CHECK-DAG: NamedAttribute(dependent="text")
412  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
413  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
414  for attr in op.attributes:
415    print(str(attr))
416
417  # Check that exceptions are raised as expected.
418  try:
419    op.attributes["does_not_exist"]
420  except KeyError:
421    pass
422  else:
423    assert False, "expected KeyError on accessing a non-existent attribute"
424
425  try:
426    op.attributes[42]
427  except IndexError:
428    pass
429  else:
430    assert False, "expected IndexError on accessing an out-of-bounds attribute"
431
432
433run(testOperationAttributes)
434
435
436# CHECK-LABEL: TEST: testOperationPrint
437def testOperationPrint():
438  ctx = Context()
439  module = Module.parse(r"""
440    func @f1(%arg0: i32) -> i32 {
441      %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
442      return %arg0 : i32
443    }
444  """, ctx)
445
446  # Test print to stdout.
447  # CHECK: return %arg0 : i32
448  module.operation.print()
449
450  # Test print to text file.
451  f = io.StringIO()
452  # CHECK: <class 'str'>
453  # CHECK: return %arg0 : i32
454  module.operation.print(file=f)
455  str_value = f.getvalue()
456  print(str_value.__class__)
457  print(f.getvalue())
458
459  # Test print to binary file.
460  f = io.BytesIO()
461  # CHECK: <class 'bytes'>
462  # CHECK: return %arg0 : i32
463  module.operation.print(file=f, binary=True)
464  bytes_value = f.getvalue()
465  print(bytes_value.__class__)
466  print(bytes_value)
467
468  # Test get_asm with options.
469  # CHECK: value = opaque<"", "0xDEADBEEF"> : tensor<4xi32>
470  # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
471  module.operation.print(large_elements_limit=2, enable_debug_info=True,
472      pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True)
473
474run(testOperationPrint)
475
476
477# CHECK-LABEL: TEST: testKnownOpView
478def testKnownOpView():
479  with Context(), Location.unknown():
480    Context.current.allow_unregistered_dialects = True
481    module = Module.parse(r"""
482      %1 = "custom.f32"() : () -> f32
483      %2 = "custom.f32"() : () -> f32
484      %3 = addf %1, %2 : f32
485    """)
486    print(module)
487
488    # addf should map to a known OpView class in the std dialect.
489    # We know the OpView for it defines an 'lhs' attribute.
490    addf = module.body.operations[2]
491    # CHECK: <mlir.dialects.std._AddFOp object
492    print(repr(addf))
493    # CHECK: "custom.f32"()
494    print(addf.lhs)
495
496    # One of the custom ops should resolve to the default OpView.
497    custom = module.body.operations[0]
498    # CHECK: <_mlir.ir.OpView object
499    print(repr(custom))
500
501    # Check again to make sure negative caching works.
502    custom = module.body.operations[0]
503    # CHECK: <_mlir.ir.OpView object
504    print(repr(custom))
505
506run(testKnownOpView)
507
508
509# CHECK-LABEL: TEST: testSingleResultProperty
510def testSingleResultProperty():
511  with Context(), Location.unknown():
512    Context.current.allow_unregistered_dialects = True
513    module = Module.parse(r"""
514      "custom.no_result"() : () -> ()
515      %0:2 = "custom.two_result"() : () -> (f32, f32)
516      %1 = "custom.one_result"() : () -> f32
517    """)
518    print(module)
519
520  try:
521    module.body.operations[0].result
522  except ValueError as e:
523    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
524    print(e)
525  else:
526    assert False, "Expected exception"
527
528  try:
529    module.body.operations[1].result
530  except ValueError as e:
531    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
532    print(e)
533  else:
534    assert False, "Expected exception"
535
536  # CHECK: %1 = "custom.one_result"() : () -> f32
537  print(module.body.operations[2])
538
539run(testSingleResultProperty)
540
541# CHECK-LABEL: TEST: testPrintInvalidOperation
542def testPrintInvalidOperation():
543  ctx = Context()
544  with Location.unknown(ctx):
545    module = Operation.create("module", regions=1)
546    # This block does not have a terminator, it may crash the custom printer.
547    # Verify that we fallback to the generic printer for safety.
548    block = module.regions[0].blocks.append()
549    print(module)
550    # CHECK: // Verification failed, printing generic form
551    # CHECK: "module"() ( {
552    # CHECK: }) : () -> ()
553run(testPrintInvalidOperation)
554
555
556# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
557def testCreateWithInvalidAttributes():
558  ctx = Context()
559  with Location.unknown(ctx):
560    try:
561      Operation.create("module", attributes={None:StringAttr.get("name")})
562    except Exception as e:
563      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "module" (Unable to cast Python instance of type <class 'NoneType'> to C++ type
564      print(e)
565    try:
566      Operation.create("module", attributes={42:StringAttr.get("name")})
567    except Exception as e:
568      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "module" (Unable to cast Python instance of type <class 'int'> to C++ type
569      print(e)
570    try:
571      Operation.create("module", attributes={"some_key":ctx})
572    except Exception as e:
573      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "module" (Unable to cast Python instance of type <class '_mlir.ir.Context'> to C++ type 'mlir::python::PyAttribute')
574      print(e)
575    try:
576      Operation.create("module", attributes={"some_key":None})
577    except Exception as e:
578      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "module"
579      print(e)
580run(testCreateWithInvalidAttributes)
581