1 //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestDialect.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/Pass/Pass.h"
12
13 using namespace mlir;
14
15 namespace {
16 /// This is a symbol test pass that tests the symbol uselist functionality
17 /// provided by the symbol table along with erasing from the symbol table.
18 struct SymbolUsesPass
19 : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> {
operateOnSymbol__anon64d9922f0111::SymbolUsesPass20 WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
21 SmallVectorImpl<FuncOp> &deadFunctions) {
22 // Test computing uses on a non symboltable op.
23 Optional<SymbolTable::UseRange> symbolUses =
24 SymbolTable::getSymbolUses(symbol);
25
26 // Test the conservative failure case.
27 if (!symbolUses) {
28 symbol->emitRemark()
29 << "symbol contains an unknown nested operation that "
30 "'may' define a new symbol table";
31 return WalkResult::interrupt();
32 }
33 if (unsigned numUses = llvm::size(*symbolUses))
34 symbol->emitRemark() << "symbol contains " << numUses
35 << " nested references";
36
37 // Test the functionality of symbolKnownUseEmpty.
38 if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) {
39 FuncOp funcSymbol = dyn_cast<FuncOp>(symbol);
40 if (funcSymbol && funcSymbol.isExternal())
41 deadFunctions.push_back(funcSymbol);
42
43 symbol->emitRemark() << "symbol has no uses";
44 return WalkResult::advance();
45 }
46
47 // Test the functionality of getSymbolUses.
48 symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion());
49 assert(symbolUses.hasValue() && "expected no unknown operations");
50 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
51 // Check that we can resolve back to our symbol.
52 if (SymbolTable::lookupNearestSymbolFrom(
53 symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) {
54 symbolUse.getUser()->emitRemark()
55 << "found use of symbol : " << symbolUse.getSymbolRef() << " : "
56 << symbol->getAttr(SymbolTable::getSymbolAttrName());
57 }
58 }
59 symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses";
60 return WalkResult::advance();
61 }
62
runOnOperation__anon64d9922f0111::SymbolUsesPass63 void runOnOperation() override {
64 auto module = getOperation();
65
66 // Walk nested symbols.
67 SmallVector<FuncOp, 4> deadFunctions;
68 module.getBodyRegion().walk([&](Operation *nestedOp) {
69 if (isa<SymbolOpInterface>(nestedOp))
70 return operateOnSymbol(nestedOp, module, deadFunctions);
71 return WalkResult::advance();
72 });
73
74 SymbolTable table(module);
75 for (Operation *op : deadFunctions) {
76 // In order to test the SymbolTable::erase method, also erase completely
77 // useless functions.
78 auto name = SymbolTable::getSymbolName(op);
79 assert(table.lookup(name) && "expected no unknown operations");
80 table.erase(op);
81 assert(!table.lookup(name) &&
82 "expected erased operation to be unknown now");
83 module.emitRemark() << name << " function successfully erased";
84 }
85 }
86 };
87
88 /// This is a symbol test pass that tests the symbol use replacement
89 /// functionality provided by the symbol table.
90 struct SymbolReplacementPass
91 : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
runOnOperation__anon64d9922f0111::SymbolReplacementPass92 void runOnOperation() override {
93 auto module = getOperation();
94
95 // Walk nested functions and modules.
96 module.getBodyRegion().walk([&](Operation *nestedOp) {
97 StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
98 if (!newName)
99 return;
100 if (succeeded(SymbolTable::replaceAllSymbolUses(
101 nestedOp, newName.getValue(), &module.getBodyRegion())))
102 SymbolTable::setSymbolName(nestedOp, newName.getValue());
103 });
104 }
105 };
106 } // end anonymous namespace
107
108 namespace mlir {
registerSymbolTestPasses()109 void registerSymbolTestPasses() {
110 PassRegistration<SymbolUsesPass>("test-symbol-uses",
111 "Test detection of symbol uses");
112
113 PassRegistration<SymbolReplacementPass>("test-symbol-rauw",
114 "Test replacement of symbol uses");
115 }
116 } // namespace mlir
117