1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_SYMBOLTABLE_H 10 #define MLIR_IR_SYMBOLTABLE_H 11 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/OpDefinition.h" 14 #include "llvm/ADT/StringMap.h" 15 16 namespace mlir { 17 class Identifier; 18 class Operation; 19 20 /// This class allows for representing and managing the symbol table used by 21 /// operations with the 'SymbolTable' trait. Inserting into and erasing from 22 /// this SymbolTable will also insert and erase from the Operation given to it 23 /// at construction. 24 class SymbolTable { 25 public: 26 /// Build a symbol table with the symbols within the given operation. 27 SymbolTable(Operation *symbolTableOp); 28 29 /// Look up a symbol with the specified name, returning null if no such 30 /// name exists. Names never include the @ on them. 31 Operation *lookup(StringRef name) const; lookup(StringRef name)32 template <typename T> T lookup(StringRef name) const { 33 return dyn_cast_or_null<T>(lookup(name)); 34 } 35 36 /// Erase the given symbol from the table. 37 void erase(Operation *symbol); 38 39 /// Insert a new symbol into the table, and rename it as necessary to avoid 40 /// collisions. Also insert at the specified location in the body of the 41 /// associated operation. 42 void insert(Operation *symbol, Block::iterator insertPt = {}); 43 44 /// Return the name of the attribute used for symbol names. getSymbolAttrName()45 static StringRef getSymbolAttrName() { return "sym_name"; } 46 47 /// Returns the associated operation. getOp()48 Operation *getOp() const { return symbolTableOp; } 49 50 /// Return the name of the attribute used for symbol visibility. getVisibilityAttrName()51 static StringRef getVisibilityAttrName() { return "sym_visibility"; } 52 53 //===--------------------------------------------------------------------===// 54 // Symbol Utilities 55 //===--------------------------------------------------------------------===// 56 57 /// An enumeration detailing the different visibility types that a symbol may 58 /// have. 59 enum class Visibility { 60 /// The symbol is public and may be referenced anywhere internal or external 61 /// to the visible references in the IR. 62 Public, 63 64 /// The symbol is private and may only be referenced by SymbolRefAttrs local 65 /// to the operations within the current symbol table. 66 Private, 67 68 /// The symbol is visible to the current IR, which may include operations in 69 /// symbol tables above the one that owns the current symbol. `Nested` 70 /// visibility allows for referencing a symbol outside of its current symbol 71 /// table, while retaining the ability to observe all uses. 72 Nested, 73 }; 74 75 /// Returns the name of the given symbol operation. 76 static StringRef getSymbolName(Operation *symbol); 77 /// Sets the name of the given symbol operation. 78 static void setSymbolName(Operation *symbol, StringRef name); 79 80 /// Returns the visibility of the given symbol operation. 81 static Visibility getSymbolVisibility(Operation *symbol); 82 /// Sets the visibility of the given symbol operation. 83 static void setSymbolVisibility(Operation *symbol, Visibility vis); 84 85 /// Returns the nearest symbol table from a given operation `from`. Returns 86 /// nullptr if no valid parent symbol table could be found. 87 static Operation *getNearestSymbolTable(Operation *from); 88 89 /// Walks all symbol table operations nested within, and including, `op`. For 90 /// each symbol table operation, the provided callback is invoked with the op 91 /// and a boolean signifying if the symbols within that symbol table can be 92 /// treated as if all uses within the IR are visible to the caller. 93 /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols 94 /// within `op` are visible. 95 static void walkSymbolTables(Operation *op, bool allSymUsesVisible, 96 function_ref<void(Operation *, bool)> callback); 97 98 /// Returns the operation registered with the given symbol name with the 99 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation 100 /// with the 'OpTrait::SymbolTable' trait. 101 static Operation *lookupSymbolIn(Operation *op, StringRef symbol); 102 static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); 103 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 104 /// by a given SymbolRefAttr. Returns failure if any of the nested references 105 /// could not be resolved. 106 static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, 107 SmallVectorImpl<Operation *> &symbols); 108 109 /// Returns the operation registered with the given symbol name within the 110 /// closest parent operation of, or including, 'from' with the 111 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 112 /// found. 113 static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); 114 static Operation *lookupNearestSymbolFrom(Operation *from, 115 SymbolRefAttr symbol); 116 template <typename T> lookupNearestSymbolFrom(Operation * from,StringRef symbol)117 static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { 118 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 119 } 120 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)121 static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 122 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 123 } 124 125 /// This class represents a specific symbol use. 126 class SymbolUse { 127 public: SymbolUse(Operation * op,SymbolRefAttr symbolRef)128 SymbolUse(Operation *op, SymbolRefAttr symbolRef) 129 : owner(op), symbolRef(symbolRef) {} 130 131 /// Return the operation user of this symbol reference. getUser()132 Operation *getUser() const { return owner; } 133 134 /// Return the symbol reference that this use represents. getSymbolRef()135 SymbolRefAttr getSymbolRef() const { return symbolRef; } 136 137 private: 138 /// The operation that this access is held by. 139 Operation *owner; 140 141 /// The symbol reference that this use represents. 142 SymbolRefAttr symbolRef; 143 }; 144 145 /// This class implements a range of SymbolRef uses. 146 class UseRange { 147 public: UseRange(std::vector<SymbolUse> && uses)148 UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} 149 150 using iterator = std::vector<SymbolUse>::const_iterator; begin()151 iterator begin() const { return uses.begin(); } end()152 iterator end() const { return uses.end(); } empty()153 bool empty() const { return uses.empty(); } 154 155 private: 156 std::vector<SymbolUse> uses; 157 }; 158 159 /// Get an iterator range for all of the uses, for any symbol, that are nested 160 /// within the given operation 'from'. This does not traverse into any nested 161 /// symbol tables. This function returns None if there are any unknown 162 /// operations that may potentially be symbol tables. 163 static Optional<UseRange> getSymbolUses(Operation *from); 164 static Optional<UseRange> getSymbolUses(Region *from); 165 166 /// Get all of the uses of the given symbol that are nested within the given 167 /// operation 'from'. This does not traverse into any nested symbol tables. 168 /// This function returns None if there are any unknown operations that may 169 /// potentially be symbol tables. 170 static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from); 171 static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from); 172 static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from); 173 static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from); 174 175 /// Return if the given symbol is known to have no uses that are nested 176 /// within the given operation 'from'. This does not traverse into any nested 177 /// symbol tables. This function will also return false if there are any 178 /// unknown operations that may potentially be symbol tables. This doesn't 179 /// necessarily mean that there are no uses, we just can't conservatively 180 /// prove it. 181 static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); 182 static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); 183 static bool symbolKnownUseEmpty(StringRef symbol, Region *from); 184 static bool symbolKnownUseEmpty(Operation *symbol, Region *from); 185 186 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the 187 /// provided symbol 'newSymbol' that are nested within the given operation 188 /// 'from'. This does not traverse into any nested symbol tables. If there are 189 /// any unknown operations that may potentially be symbol tables, no uses are 190 /// replaced and failure is returned. 191 LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, 192 StringRef newSymbol, 193 Operation *from); 194 LLVM_NODISCARD static LogicalResult 195 replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, 196 Operation *from); 197 LLVM_NODISCARD static LogicalResult 198 replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Region *from); 199 LLVM_NODISCARD static LogicalResult 200 replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, 201 Region *from); 202 203 private: 204 Operation *symbolTableOp; 205 206 /// This is a mapping from a name to the symbol with that name. 207 llvm::StringMap<Operation *> symbolTable; 208 209 /// This is used when name conflicts are detected. 210 unsigned uniquingCounter = 0; 211 }; 212 213 //===----------------------------------------------------------------------===// 214 // SymbolTableCollection 215 //===----------------------------------------------------------------------===// 216 217 /// This class represents a collection of `SymbolTable`s. This simplifies 218 /// certain algorithms that run recursively on nested symbol tables. Symbol 219 /// tables are constructed lazily to reduce the upfront cost of constructing 220 /// unnecessary tables. 221 class SymbolTableCollection { 222 public: 223 /// Look up a symbol with the specified name within the specified symbol table 224 /// operation, returning null if no such name exists. 225 Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol); 226 Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); 227 template <typename T, typename NameT> lookupSymbolIn(Operation * symbolTableOp,NameT && name)228 T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const { 229 return dyn_cast_or_null<T>( 230 lookupSymbolIn(symbolTableOp, std::forward<NameT>(name))); 231 } 232 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 233 /// by a given SymbolRefAttr when resolved within the provided symbol table 234 /// operation. Returns failure if any of the nested references could not be 235 /// resolved. 236 LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, 237 SmallVectorImpl<Operation *> &symbols); 238 239 /// Returns the operation registered with the given symbol name within the 240 /// closest parent operation of, or including, 'from' with the 241 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 242 /// found. 243 Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); 244 Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); 245 template <typename T> lookupNearestSymbolFrom(Operation * from,StringRef symbol)246 T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { 247 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 248 } 249 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)250 T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 251 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 252 } 253 254 /// Lookup, or create, a symbol table for an operation. 255 SymbolTable &getSymbolTable(Operation *op); 256 257 private: 258 /// The constructed symbol tables nested within this table. 259 DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables; 260 }; 261 262 //===----------------------------------------------------------------------===// 263 // SymbolTable Trait Types 264 //===----------------------------------------------------------------------===// 265 266 namespace detail { 267 LogicalResult verifySymbolTable(Operation *op); 268 LogicalResult verifySymbol(Operation *op); 269 } // namespace detail 270 271 namespace OpTrait { 272 /// A trait used to provide symbol table functionalities to a region operation. 273 /// This operation must hold exactly 1 region. Once attached, all operations 274 /// that are directly within the region, i.e not including those within child 275 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will 276 /// be verified to ensure that the names are uniqued. These operations must also 277 /// adhere to the constraints defined by the `Symbol` trait, even if they do not 278 /// inherit from it. 279 template <typename ConcreteType> 280 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { 281 public: verifyTrait(Operation * op)282 static LogicalResult verifyTrait(Operation *op) { 283 return ::mlir::detail::verifySymbolTable(op); 284 } 285 286 /// Look up a symbol with the specified name, returning null if no such 287 /// name exists. Symbol names never include the @ on them. Note: This 288 /// performs a linear scan of held symbols. lookupSymbol(StringRef name)289 Operation *lookupSymbol(StringRef name) { 290 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); 291 } lookupSymbol(StringRef name)292 template <typename T> T lookupSymbol(StringRef name) { 293 return dyn_cast_or_null<T>(lookupSymbol(name)); 294 } lookupSymbol(SymbolRefAttr symbol)295 Operation *lookupSymbol(SymbolRefAttr symbol) { 296 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); 297 } 298 template <typename T> lookupSymbol(SymbolRefAttr symbol)299 T lookupSymbol(SymbolRefAttr symbol) { 300 return dyn_cast_or_null<T>(lookupSymbol(symbol)); 301 } 302 }; 303 304 } // end namespace OpTrait 305 306 //===----------------------------------------------------------------------===// 307 // Visibility parsing implementation. 308 //===----------------------------------------------------------------------===// 309 310 namespace impl { 311 /// Parse an optional visibility attribute keyword (i.e., public, private, or 312 /// nested) without quotes in a string attribute named 'attrName'. 313 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, 314 NamedAttrList &attrs); 315 } // end namespace impl 316 317 } // end namespace mlir 318 319 /// Include the generated symbol interfaces. 320 #include "mlir/IR/SymbolInterfaces.h.inc" 321 322 #endif // MLIR_IR_SYMBOLTABLE_H 323