1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Attribute wrapper to simplify using TableGen Record defining a MLIR 10 // Attribute. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/TableGen/Record.h" 17 18 using namespace mlir; 19 using namespace mlir::tblgen; 20 21 using llvm::DefInit; 22 using llvm::Init; 23 using llvm::Record; 24 using llvm::StringInit; 25 26 // Returns the initializer's value as string if the given TableGen initializer 27 // is a code or string initializer. Returns the empty StringRef otherwise. getValueAsString(const Init * init)28static StringRef getValueAsString(const Init *init) { 29 if (const auto *str = dyn_cast<StringInit>(init)) 30 return str->getValue().trim(); 31 return {}; 32 } 33 AttrConstraint(const Record * record)34AttrConstraint::AttrConstraint(const Record *record) 35 : Constraint(Constraint::CK_Attr, record) { 36 assert(isSubClassOf("AttrConstraint") && 37 "must be subclass of TableGen 'AttrConstraint' class"); 38 } 39 isSubClassOf(StringRef className) const40bool AttrConstraint::isSubClassOf(StringRef className) const { 41 return def->isSubClassOf(className); 42 } 43 Attribute(const Record * record)44Attribute::Attribute(const Record *record) : AttrConstraint(record) { 45 assert(record->isSubClassOf("Attr") && 46 "must be subclass of TableGen 'Attr' class"); 47 } 48 Attribute(const DefInit * init)49Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} 50 isDerivedAttr() const51bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); } 52 isTypeAttr() const53bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); } 54 isSymbolRefAttr() const55bool Attribute::isSymbolRefAttr() const { 56 StringRef defName = def->getName(); 57 if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr") 58 return true; 59 return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr"); 60 } 61 isEnumAttr() const62bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } 63 getStorageType() const64StringRef Attribute::getStorageType() const { 65 const auto *init = def->getValueInit("storageType"); 66 auto type = getValueAsString(init); 67 if (type.empty()) 68 return "Attribute"; 69 return type; 70 } 71 getReturnType() const72StringRef Attribute::getReturnType() const { 73 const auto *init = def->getValueInit("returnType"); 74 return getValueAsString(init); 75 } 76 77 // Return the type constraint corresponding to the type of this attribute, or 78 // None if this is not a TypedAttr. getValueType() const79llvm::Optional<Type> Attribute::getValueType() const { 80 if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType"))) 81 return Type(defInit->getDef()); 82 return llvm::None; 83 } 84 getConvertFromStorageCall() const85StringRef Attribute::getConvertFromStorageCall() const { 86 const auto *init = def->getValueInit("convertFromStorage"); 87 return getValueAsString(init); 88 } 89 isConstBuildable() const90bool Attribute::isConstBuildable() const { 91 const auto *init = def->getValueInit("constBuilderCall"); 92 return !getValueAsString(init).empty(); 93 } 94 getConstBuilderTemplate() const95StringRef Attribute::getConstBuilderTemplate() const { 96 const auto *init = def->getValueInit("constBuilderCall"); 97 return getValueAsString(init); 98 } 99 getBaseAttr() const100Attribute Attribute::getBaseAttr() const { 101 if (const auto *defInit = 102 llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) { 103 return Attribute(defInit).getBaseAttr(); 104 } 105 return *this; 106 } 107 hasDefaultValue() const108bool Attribute::hasDefaultValue() const { 109 const auto *init = def->getValueInit("defaultValue"); 110 return !getValueAsString(init).empty(); 111 } 112 getDefaultValue() const113StringRef Attribute::getDefaultValue() const { 114 const auto *init = def->getValueInit("defaultValue"); 115 return getValueAsString(init); 116 } 117 isOptional() const118bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); } 119 getAttrDefName() const120StringRef Attribute::getAttrDefName() const { 121 if (def->isAnonymous()) { 122 return getBaseAttr().def->getName(); 123 } 124 return def->getName(); 125 } 126 getDerivedCodeBody() const127StringRef Attribute::getDerivedCodeBody() const { 128 assert(isDerivedAttr() && "only derived attribute has 'body' field"); 129 return def->getValueAsString("body"); 130 } 131 getDialect() const132Dialect Attribute::getDialect() const { 133 const llvm::RecordVal *record = def->getValue("dialect"); 134 if (record && record->getValue()) { 135 if (DefInit *init = dyn_cast<DefInit>(record->getValue())) 136 return Dialect(init->getDef()); 137 } 138 return Dialect(nullptr); 139 } 140 ConstantAttr(const DefInit * init)141ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { 142 assert(def->isSubClassOf("ConstantAttr") && 143 "must be subclass of TableGen 'ConstantAttr' class"); 144 } 145 getAttribute() const146Attribute ConstantAttr::getAttribute() const { 147 return Attribute(def->getValueAsDef("attr")); 148 } 149 getConstantValue() const150StringRef ConstantAttr::getConstantValue() const { 151 return def->getValueAsString("value"); 152 } 153 EnumAttrCase(const llvm::Record * record)154EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) { 155 assert(isSubClassOf("EnumAttrCaseInfo") && 156 "must be subclass of TableGen 'EnumAttrInfo' class"); 157 } 158 EnumAttrCase(const llvm::DefInit * init)159EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) 160 : EnumAttrCase(init->getDef()) {} 161 isStrCase() const162bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); } 163 getSymbol() const164StringRef EnumAttrCase::getSymbol() const { 165 return def->getValueAsString("symbol"); 166 } 167 getStr() const168StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } 169 getValue() const170int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } 171 getDef() const172const llvm::Record &EnumAttrCase::getDef() const { return *def; } 173 EnumAttr(const llvm::Record * record)174EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { 175 assert(isSubClassOf("EnumAttrInfo") && 176 "must be subclass of TableGen 'EnumAttr' class"); 177 } 178 EnumAttr(const llvm::Record & record)179EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} 180 EnumAttr(const llvm::DefInit * init)181EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} 182 classof(const Attribute * attr)183bool EnumAttr::classof(const Attribute *attr) { 184 return attr->isSubClassOf("EnumAttrInfo"); 185 } 186 isBitEnum() const187bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } 188 getEnumClassName() const189StringRef EnumAttr::getEnumClassName() const { 190 return def->getValueAsString("className"); 191 } 192 getCppNamespace() const193StringRef EnumAttr::getCppNamespace() const { 194 return def->getValueAsString("cppNamespace"); 195 } 196 getUnderlyingType() const197StringRef EnumAttr::getUnderlyingType() const { 198 return def->getValueAsString("underlyingType"); 199 } 200 getUnderlyingToSymbolFnName() const201StringRef EnumAttr::getUnderlyingToSymbolFnName() const { 202 return def->getValueAsString("underlyingToSymbolFnName"); 203 } 204 getStringToSymbolFnName() const205StringRef EnumAttr::getStringToSymbolFnName() const { 206 return def->getValueAsString("stringToSymbolFnName"); 207 } 208 getSymbolToStringFnName() const209StringRef EnumAttr::getSymbolToStringFnName() const { 210 return def->getValueAsString("symbolToStringFnName"); 211 } 212 getSymbolToStringFnRetType() const213StringRef EnumAttr::getSymbolToStringFnRetType() const { 214 return def->getValueAsString("symbolToStringFnRetType"); 215 } 216 getMaxEnumValFnName() const217StringRef EnumAttr::getMaxEnumValFnName() const { 218 return def->getValueAsString("maxEnumValFnName"); 219 } 220 getAllCases() const221std::vector<EnumAttrCase> EnumAttr::getAllCases() const { 222 const auto *inits = def->getValueAsListInit("enumerants"); 223 224 std::vector<EnumAttrCase> cases; 225 cases.reserve(inits->size()); 226 227 for (const llvm::Init *init : *inits) { 228 cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init))); 229 } 230 231 return cases; 232 } 233 StructFieldAttr(const llvm::Record * record)234StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { 235 assert(def->isSubClassOf("StructFieldAttr") && 236 "must be subclass of TableGen 'StructFieldAttr' class"); 237 } 238 StructFieldAttr(const llvm::Record & record)239StructFieldAttr::StructFieldAttr(const llvm::Record &record) 240 : StructFieldAttr(&record) {} 241 StructFieldAttr(const llvm::DefInit * init)242StructFieldAttr::StructFieldAttr(const llvm::DefInit *init) 243 : StructFieldAttr(init->getDef()) {} 244 getName() const245StringRef StructFieldAttr::getName() const { 246 return def->getValueAsString("name"); 247 } 248 getType() const249Attribute StructFieldAttr::getType() const { 250 auto init = def->getValueInit("type"); 251 return Attribute(cast<llvm::DefInit>(init)); 252 } 253 StructAttr(const llvm::Record * record)254StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) { 255 assert(isSubClassOf("StructAttr") && 256 "must be subclass of TableGen 'StructAttr' class"); 257 } 258 StructAttr(const llvm::DefInit * init)259StructAttr::StructAttr(const llvm::DefInit *init) 260 : StructAttr(init->getDef()) {} 261 getStructClassName() const262StringRef StructAttr::getStructClassName() const { 263 return def->getValueAsString("className"); 264 } 265 getCppNamespace() const266StringRef StructAttr::getCppNamespace() const { 267 Dialect dialect(def->getValueAsDef("dialect")); 268 return dialect.getCppNamespace(); 269 } 270 getAllFields() const271std::vector<StructFieldAttr> StructAttr::getAllFields() const { 272 std::vector<StructFieldAttr> attributes; 273 274 const auto *inits = def->getValueAsListInit("fields"); 275 attributes.reserve(inits->size()); 276 277 for (const llvm::Init *init : *inits) { 278 attributes.emplace_back(cast<llvm::DefInit>(init)); 279 } 280 281 return attributes; 282 } 283 284 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; 285