• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 /// This is a symbol test pass that tests the symbol uselist functionality
17 /// provided by the symbol table along with erasing from the symbol table.
18 struct SymbolUsesPass
19     : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> {
operateOnSymbol__anon64d9922f0111::SymbolUsesPass20   WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
21                              SmallVectorImpl<FuncOp> &deadFunctions) {
22     // Test computing uses on a non symboltable op.
23     Optional<SymbolTable::UseRange> symbolUses =
24         SymbolTable::getSymbolUses(symbol);
25 
26     // Test the conservative failure case.
27     if (!symbolUses) {
28       symbol->emitRemark()
29           << "symbol contains an unknown nested operation that "
30              "'may' define a new symbol table";
31       return WalkResult::interrupt();
32     }
33     if (unsigned numUses = llvm::size(*symbolUses))
34       symbol->emitRemark() << "symbol contains " << numUses
35                            << " nested references";
36 
37     // Test the functionality of symbolKnownUseEmpty.
38     if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) {
39       FuncOp funcSymbol = dyn_cast<FuncOp>(symbol);
40       if (funcSymbol && funcSymbol.isExternal())
41         deadFunctions.push_back(funcSymbol);
42 
43       symbol->emitRemark() << "symbol has no uses";
44       return WalkResult::advance();
45     }
46 
47     // Test the functionality of getSymbolUses.
48     symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion());
49     assert(symbolUses.hasValue() && "expected no unknown operations");
50     for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
51       // Check that we can resolve back to our symbol.
52       if (SymbolTable::lookupNearestSymbolFrom(
53               symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) {
54         symbolUse.getUser()->emitRemark()
55             << "found use of symbol : " << symbolUse.getSymbolRef() << " : "
56             << symbol->getAttr(SymbolTable::getSymbolAttrName());
57       }
58     }
59     symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses";
60     return WalkResult::advance();
61   }
62 
runOnOperation__anon64d9922f0111::SymbolUsesPass63   void runOnOperation() override {
64     auto module = getOperation();
65 
66     // Walk nested symbols.
67     SmallVector<FuncOp, 4> deadFunctions;
68     module.getBodyRegion().walk([&](Operation *nestedOp) {
69       if (isa<SymbolOpInterface>(nestedOp))
70         return operateOnSymbol(nestedOp, module, deadFunctions);
71       return WalkResult::advance();
72     });
73 
74     SymbolTable table(module);
75     for (Operation *op : deadFunctions) {
76       // In order to test the SymbolTable::erase method, also erase completely
77       // useless functions.
78       auto name = SymbolTable::getSymbolName(op);
79       assert(table.lookup(name) && "expected no unknown operations");
80       table.erase(op);
81       assert(!table.lookup(name) &&
82              "expected erased operation to be unknown now");
83       module.emitRemark() << name << " function successfully erased";
84     }
85   }
86 };
87 
88 /// This is a symbol test pass that tests the symbol use replacement
89 /// functionality provided by the symbol table.
90 struct SymbolReplacementPass
91     : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
runOnOperation__anon64d9922f0111::SymbolReplacementPass92   void runOnOperation() override {
93     auto module = getOperation();
94 
95     // Walk nested functions and modules.
96     module.getBodyRegion().walk([&](Operation *nestedOp) {
97       StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
98       if (!newName)
99         return;
100       if (succeeded(SymbolTable::replaceAllSymbolUses(
101               nestedOp, newName.getValue(), &module.getBodyRegion())))
102         SymbolTable::setSymbolName(nestedOp, newName.getValue());
103     });
104   }
105 };
106 } // end anonymous namespace
107 
108 namespace mlir {
registerSymbolTestPasses()109 void registerSymbolTestPasses() {
110   PassRegistration<SymbolUsesPass>("test-symbol-uses",
111                                    "Test detection of symbol uses");
112 
113   PassRegistration<SymbolReplacementPass>("test-symbol-rauw",
114                                           "Test replacement of symbol uses");
115 }
116 } // namespace mlir
117