1 //===- DialectImplementation.h ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains utilities classes for implementing dialect attributes and 10 // types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_IR_DIALECTIMPLEMENTATION_H 15 #define MLIR_IR_DIALECTIMPLEMENTATION_H 16 17 #include "mlir/IR/OpImplementation.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/SMLoc.h" 20 #include "llvm/Support/raw_ostream.h" 21 22 namespace mlir { 23 24 class Builder; 25 26 //===----------------------------------------------------------------------===// 27 // DialectAsmPrinter 28 //===----------------------------------------------------------------------===// 29 30 /// This is a pure-virtual base class that exposes the asmprinter hooks 31 /// necessary to implement a custom printAttribute/printType() method on a 32 /// dialect. 33 class DialectAsmPrinter { 34 public: DialectAsmPrinter()35 DialectAsmPrinter() {} 36 virtual ~DialectAsmPrinter(); 37 virtual raw_ostream &getStream() const = 0; 38 39 /// Print the given attribute to the stream. 40 virtual void printAttribute(Attribute attr) = 0; 41 42 /// Print the given floating point value in a stabilized form that can be 43 /// roundtripped through the IR. This is the companion to the 'parseFloat' 44 /// hook on the DialectAsmParser. 45 virtual void printFloat(const APFloat &value) = 0; 46 47 /// Print the given type to the stream. 48 virtual void printType(Type type) = 0; 49 50 private: 51 DialectAsmPrinter(const DialectAsmPrinter &) = delete; 52 void operator=(const DialectAsmPrinter &) = delete; 53 }; 54 55 // Make the implementations convenient to use. 56 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) { 57 p.printAttribute(attr); 58 return p; 59 } 60 61 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, 62 const APFloat &value) { 63 p.printFloat(value); 64 return p; 65 } 66 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) { 67 return p << APFloat(value); 68 } 69 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) { 70 return p << APFloat(value); 71 } 72 73 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) { 74 p.printType(type); 75 return p; 76 } 77 78 // Support printing anything that isn't convertible to one of the above types, 79 // even if it isn't exactly one of them. For example, we want to print 80 // FunctionType with the Type version above, not have it match this. 81 template <typename T, typename std::enable_if< 82 !std::is_convertible<T &, Attribute &>::value && 83 !std::is_convertible<T &, Type &>::value && 84 !std::is_convertible<T &, APFloat &>::value && 85 !llvm::is_one_of<T, double, float>::value, 86 T>::type * = nullptr> 87 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) { 88 p.getStream() << other; 89 return p; 90 } 91 92 //===----------------------------------------------------------------------===// 93 // DialectAsmParser 94 //===----------------------------------------------------------------------===// 95 96 /// The DialectAsmParser has methods for interacting with the asm parser: 97 /// parsing things from it, emitting errors etc. It has an intentionally 98 /// high-level API that is designed to reduce/constrain syntax innovation in 99 /// individual attributes or types. 100 class DialectAsmParser { 101 public: 102 virtual ~DialectAsmParser(); 103 104 /// Emit a diagnostic at the specified location and return failure. 105 virtual InFlightDiagnostic emitError(llvm::SMLoc loc, 106 const Twine &message = {}) = 0; 107 108 /// Return a builder which provides useful access to MLIRContext, global 109 /// objects like types and attributes. 110 virtual Builder &getBuilder() const = 0; 111 112 /// Get the location of the next token and store it into the argument. This 113 /// always succeeds. 114 virtual llvm::SMLoc getCurrentLocation() = 0; getCurrentLocation(llvm::SMLoc * loc)115 ParseResult getCurrentLocation(llvm::SMLoc *loc) { 116 *loc = getCurrentLocation(); 117 return success(); 118 } 119 120 /// Return the location of the original name token. 121 virtual llvm::SMLoc getNameLoc() const = 0; 122 123 /// Re-encode the given source location as an MLIR location and return it. 124 virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; 125 126 /// Returns the full specification of the symbol being parsed. This allows for 127 /// using a separate parser if necessary. 128 virtual StringRef getFullSymbolSpec() const = 0; 129 130 // These methods emit an error and return failure or success. This allows 131 // these to be chained together into a linear sequence of || expressions in 132 // many cases. 133 134 /// Parse a floating point value from the stream. 135 virtual ParseResult parseFloat(double &result) = 0; 136 137 /// Parse an integer value from the stream. parseInteger(IntT & result)138 template <typename IntT> ParseResult parseInteger(IntT &result) { 139 auto loc = getCurrentLocation(); 140 OptionalParseResult parseResult = parseOptionalInteger(result); 141 if (!parseResult.hasValue()) 142 return emitError(loc, "expected integer value"); 143 return *parseResult; 144 } 145 146 /// Parse an optional integer value from the stream. 147 virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; 148 149 template <typename IntT> parseOptionalInteger(IntT & result)150 OptionalParseResult parseOptionalInteger(IntT &result) { 151 auto loc = getCurrentLocation(); 152 153 // Parse the unsigned variant. 154 uint64_t uintResult; 155 OptionalParseResult parseResult = parseOptionalInteger(uintResult); 156 if (!parseResult.hasValue() || failed(*parseResult)) 157 return parseResult; 158 159 // Try to convert to the provided integer type. 160 result = IntT(uintResult); 161 if (uint64_t(result) != uintResult) 162 return emitError(loc, "integer value too large"); 163 return success(); 164 } 165 166 //===--------------------------------------------------------------------===// 167 // Token Parsing 168 //===--------------------------------------------------------------------===// 169 170 /// Parse a '->' token. 171 virtual ParseResult parseArrow() = 0; 172 173 /// Parse a '->' token if present 174 virtual ParseResult parseOptionalArrow() = 0; 175 176 /// Parse a '{' token. 177 virtual ParseResult parseLBrace() = 0; 178 179 /// Parse a '{' token if present 180 virtual ParseResult parseOptionalLBrace() = 0; 181 182 /// Parse a `}` token. 183 virtual ParseResult parseRBrace() = 0; 184 185 /// Parse a `}` token if present 186 virtual ParseResult parseOptionalRBrace() = 0; 187 188 /// Parse a `:` token. 189 virtual ParseResult parseColon() = 0; 190 191 /// Parse a `:` token if present. 192 virtual ParseResult parseOptionalColon() = 0; 193 194 /// Parse a `,` token. 195 virtual ParseResult parseComma() = 0; 196 197 /// Parse a `,` token if present. 198 virtual ParseResult parseOptionalComma() = 0; 199 200 /// Parse a `=` token. 201 virtual ParseResult parseEqual() = 0; 202 203 /// Parse a `=` token if present. 204 virtual ParseResult parseOptionalEqual() = 0; 205 206 /// Parse a quoted string token if present. 207 virtual ParseResult parseOptionalString(StringRef *string) = 0; 208 209 /// Parse a given keyword. 210 ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { 211 auto loc = getCurrentLocation(); 212 if (parseOptionalKeyword(keyword)) 213 return emitError(loc, "expected '") << keyword << "'" << msg; 214 return success(); 215 } 216 217 /// Parse a keyword into 'keyword'. parseKeyword(StringRef * keyword)218 ParseResult parseKeyword(StringRef *keyword) { 219 auto loc = getCurrentLocation(); 220 if (parseOptionalKeyword(keyword)) 221 return emitError(loc, "expected valid keyword"); 222 return success(); 223 } 224 225 /// Parse the given keyword if present. 226 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; 227 228 /// Parse a keyword, if present, into 'keyword'. 229 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; 230 231 /// Parse a '<' token. 232 virtual ParseResult parseLess() = 0; 233 234 /// Parse a `<` token if present. 235 virtual ParseResult parseOptionalLess() = 0; 236 237 /// Parse a '>' token. 238 virtual ParseResult parseGreater() = 0; 239 240 /// Parse a `>` token if present. 241 virtual ParseResult parseOptionalGreater() = 0; 242 243 /// Parse a `(` token. 244 virtual ParseResult parseLParen() = 0; 245 246 /// Parse a `(` token if present. 247 virtual ParseResult parseOptionalLParen() = 0; 248 249 /// Parse a `)` token. 250 virtual ParseResult parseRParen() = 0; 251 252 /// Parse a `)` token if present. 253 virtual ParseResult parseOptionalRParen() = 0; 254 255 /// Parse a `[` token. 256 virtual ParseResult parseLSquare() = 0; 257 258 /// Parse a `[` token if present. 259 virtual ParseResult parseOptionalLSquare() = 0; 260 261 /// Parse a `]` token. 262 virtual ParseResult parseRSquare() = 0; 263 264 /// Parse a `]` token if present. 265 virtual ParseResult parseOptionalRSquare() = 0; 266 267 /// Parse a `...` token if present; 268 virtual ParseResult parseOptionalEllipsis() = 0; 269 270 /// Parse a `?` token. 271 virtual ParseResult parseOptionalQuestion() = 0; 272 273 /// Parse a `*` token. 274 virtual ParseResult parseOptionalStar() = 0; 275 276 //===--------------------------------------------------------------------===// 277 // Attribute Parsing 278 //===--------------------------------------------------------------------===// 279 280 /// Parse an arbitrary attribute and return it in result. 281 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; 282 283 /// Parse an attribute of a specific kind and type. 284 template <typename AttrType> 285 ParseResult parseAttribute(AttrType &result, Type type = {}) { 286 llvm::SMLoc loc = getCurrentLocation(); 287 288 // Parse any kind of attribute. 289 Attribute attr; 290 if (parseAttribute(attr, type)) 291 return failure(); 292 293 // Check for the right kind of attribute. 294 result = attr.dyn_cast<AttrType>(); 295 if (!result) 296 return emitError(loc, "invalid kind of attribute specified"); 297 return success(); 298 } 299 300 /// Parse an affine map instance into 'map'. 301 virtual ParseResult parseAffineMap(AffineMap &map) = 0; 302 303 /// Parse an integer set instance into 'set'. 304 virtual ParseResult printIntegerSet(IntegerSet &set) = 0; 305 306 //===--------------------------------------------------------------------===// 307 // Type Parsing 308 //===--------------------------------------------------------------------===// 309 310 /// Parse a type. 311 virtual ParseResult parseType(Type &result) = 0; 312 313 /// Parse a type of a specific kind, e.g. a FunctionType. parseType(TypeType & result)314 template <typename TypeType> ParseResult parseType(TypeType &result) { 315 llvm::SMLoc loc = getCurrentLocation(); 316 317 // Parse any kind of type. 318 Type type; 319 if (parseType(type)) 320 return failure(); 321 322 // Check for the right kind of attribute. 323 result = type.dyn_cast<TypeType>(); 324 if (!result) 325 return emitError(loc, "invalid kind of type specified"); 326 return success(); 327 } 328 329 /// Parse a type if present. 330 virtual OptionalParseResult parseOptionalType(Type &result) = 0; 331 332 /// Parse a 'x' separated dimension list. This populates the dimension list, 333 /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on 334 /// `?` otherwise. 335 /// 336 /// dimension-list ::= (dimension `x`)* 337 /// dimension ::= `?` | integer 338 /// 339 /// When `allowDynamic` is not set, this is used to parse: 340 /// 341 /// static-dimension-list ::= (integer `x`)* 342 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, 343 bool allowDynamic = true) = 0; 344 }; 345 346 } // end namespace mlir 347 348 #endif 349