• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3import gc, sys
4from mlir.ir import *
5from mlir.passmanager import *
6
7# Log everything to stderr and flush so that we have a unified stream to match
8# errors/info emitted by MLIR to stderr.
9def log(*args):
10  print(*args, file=sys.stderr)
11  sys.stderr.flush()
12
13def run(f):
14  log("\nTEST:", f.__name__)
15  f()
16  gc.collect()
17  assert Context._get_live_count() == 0
18
19# Verify capsule interop.
20# CHECK-LABEL: TEST: testCapsule
21def testCapsule():
22  with Context():
23    pm = PassManager()
24    pm_capsule = pm._CAPIPtr
25    assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
26    pm._testing_release()
27    pm1 = PassManager._CAPICreate(pm_capsule)
28    assert pm1 is not None  # And does not crash.
29run(testCapsule)
30
31
32# Verify successful round-trip.
33# CHECK-LABEL: TEST: testParseSuccess
34def testParseSuccess():
35  with Context():
36    # A first import is expected to fail because the pass isn't registered
37    # until we import mlir.transforms
38    try:
39      pm = PassManager.parse("module(func(print-op-stats))")
40      # TODO: this error should be propagate to Python but the C API does not help right now.
41      # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline
42    except ValueError as e:
43      # CHECK: ValueError exception: invalid pass pipeline 'module(func(print-op-stats))'.
44      log("ValueError exception:", e)
45    else:
46      log("Exception not produced")
47
48    # This will register the pass and round-trip should be possible now.
49    import mlir.transforms
50    pm = PassManager.parse("module(func(print-op-stats))")
51    # CHECK: Roundtrip: module(func(print-op-stats))
52    log("Roundtrip: ", pm)
53run(testParseSuccess)
54
55# Verify failure on unregistered pass.
56# CHECK-LABEL: TEST: testParseFail
57def testParseFail():
58  with Context():
59    try:
60      pm = PassManager.parse("unknown-pass")
61    except ValueError as e:
62      # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'.
63      log("ValueError exception:", e)
64    else:
65      log("Exception not produced")
66run(testParseFail)
67
68
69# Verify failure on incorrect level of nesting.
70# CHECK-LABEL: TEST: testInvalidNesting
71def testInvalidNesting():
72  with Context():
73    try:
74      pm = PassManager.parse("func(print-op-graph)")
75    except ValueError as e:
76      # CHECK: Can't add pass 'PrintOp' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
77      # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'.
78      log("ValueError exception:", e)
79    else:
80      log("Exception not produced")
81run(testInvalidNesting)
82
83
84# Verify that a pass manager can execute on IR
85# CHECK-LABEL: TEST: testRun
86def testRunPipeline():
87  with Context():
88    pm = PassManager.parse("print-op-stats")
89    module = Module.parse(r"""func @successfulParse() { return }""")
90    pm.run(module)
91# CHECK: Operations encountered:
92# CHECK: func              , 1
93# CHECK: module            , 1
94# CHECK: module_terminator , 1
95# CHECK: std.return        , 1
96run(testRunPipeline)
97