• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- DialectImplementation.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains utilities classes for implementing dialect attributes and
10 // types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_DIALECTIMPLEMENTATION_H
15 #define MLIR_IR_DIALECTIMPLEMENTATION_H
16 
17 #include "mlir/IR/OpImplementation.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/SMLoc.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 namespace mlir {
23 
24 class Builder;
25 
26 //===----------------------------------------------------------------------===//
27 // DialectAsmPrinter
28 //===----------------------------------------------------------------------===//
29 
30 /// This is a pure-virtual base class that exposes the asmprinter hooks
31 /// necessary to implement a custom printAttribute/printType() method on a
32 /// dialect.
33 class DialectAsmPrinter {
34 public:
DialectAsmPrinter()35   DialectAsmPrinter() {}
36   virtual ~DialectAsmPrinter();
37   virtual raw_ostream &getStream() const = 0;
38 
39   /// Print the given attribute to the stream.
40   virtual void printAttribute(Attribute attr) = 0;
41 
42   /// Print the given floating point value in a stabilized form that can be
43   /// roundtripped through the IR. This is the companion to the 'parseFloat'
44   /// hook on the DialectAsmParser.
45   virtual void printFloat(const APFloat &value) = 0;
46 
47   /// Print the given type to the stream.
48   virtual void printType(Type type) = 0;
49 
50 private:
51   DialectAsmPrinter(const DialectAsmPrinter &) = delete;
52   void operator=(const DialectAsmPrinter &) = delete;
53 };
54 
55 // Make the implementations convenient to use.
56 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) {
57   p.printAttribute(attr);
58   return p;
59 }
60 
61 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p,
62                                      const APFloat &value) {
63   p.printFloat(value);
64   return p;
65 }
66 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) {
67   return p << APFloat(value);
68 }
69 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) {
70   return p << APFloat(value);
71 }
72 
73 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) {
74   p.printType(type);
75   return p;
76 }
77 
78 // Support printing anything that isn't convertible to one of the above types,
79 // even if it isn't exactly one of them.  For example, we want to print
80 // FunctionType with the Type version above, not have it match this.
81 template <typename T, typename std::enable_if<
82                           !std::is_convertible<T &, Attribute &>::value &&
83                               !std::is_convertible<T &, Type &>::value &&
84                               !std::is_convertible<T &, APFloat &>::value &&
85                               !llvm::is_one_of<T, double, float>::value,
86                           T>::type * = nullptr>
87 inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) {
88   p.getStream() << other;
89   return p;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // DialectAsmParser
94 //===----------------------------------------------------------------------===//
95 
96 /// The DialectAsmParser has methods for interacting with the asm parser:
97 /// parsing things from it, emitting errors etc.  It has an intentionally
98 /// high-level API that is designed to reduce/constrain syntax innovation in
99 /// individual attributes or types.
100 class DialectAsmParser {
101 public:
102   virtual ~DialectAsmParser();
103 
104   /// Emit a diagnostic at the specified location and return failure.
105   virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
106                                        const Twine &message = {}) = 0;
107 
108   /// Return a builder which provides useful access to MLIRContext, global
109   /// objects like types and attributes.
110   virtual Builder &getBuilder() const = 0;
111 
112   /// Get the location of the next token and store it into the argument.  This
113   /// always succeeds.
114   virtual llvm::SMLoc getCurrentLocation() = 0;
getCurrentLocation(llvm::SMLoc * loc)115   ParseResult getCurrentLocation(llvm::SMLoc *loc) {
116     *loc = getCurrentLocation();
117     return success();
118   }
119 
120   /// Return the location of the original name token.
121   virtual llvm::SMLoc getNameLoc() const = 0;
122 
123   /// Re-encode the given source location as an MLIR location and return it.
124   virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
125 
126   /// Returns the full specification of the symbol being parsed. This allows for
127   /// using a separate parser if necessary.
128   virtual StringRef getFullSymbolSpec() const = 0;
129 
130   // These methods emit an error and return failure or success. This allows
131   // these to be chained together into a linear sequence of || expressions in
132   // many cases.
133 
134   /// Parse a floating point value from the stream.
135   virtual ParseResult parseFloat(double &result) = 0;
136 
137   /// Parse an integer value from the stream.
parseInteger(IntT & result)138   template <typename IntT> ParseResult parseInteger(IntT &result) {
139     auto loc = getCurrentLocation();
140     OptionalParseResult parseResult = parseOptionalInteger(result);
141     if (!parseResult.hasValue())
142       return emitError(loc, "expected integer value");
143     return *parseResult;
144   }
145 
146   /// Parse an optional integer value from the stream.
147   virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0;
148 
149   template <typename IntT>
parseOptionalInteger(IntT & result)150   OptionalParseResult parseOptionalInteger(IntT &result) {
151     auto loc = getCurrentLocation();
152 
153     // Parse the unsigned variant.
154     uint64_t uintResult;
155     OptionalParseResult parseResult = parseOptionalInteger(uintResult);
156     if (!parseResult.hasValue() || failed(*parseResult))
157       return parseResult;
158 
159     // Try to convert to the provided integer type.
160     result = IntT(uintResult);
161     if (uint64_t(result) != uintResult)
162       return emitError(loc, "integer value too large");
163     return success();
164   }
165 
166   //===--------------------------------------------------------------------===//
167   // Token Parsing
168   //===--------------------------------------------------------------------===//
169 
170   /// Parse a '->' token.
171   virtual ParseResult parseArrow() = 0;
172 
173   /// Parse a '->' token if present
174   virtual ParseResult parseOptionalArrow() = 0;
175 
176   /// Parse a '{' token.
177   virtual ParseResult parseLBrace() = 0;
178 
179   /// Parse a '{' token if present
180   virtual ParseResult parseOptionalLBrace() = 0;
181 
182   /// Parse a `}` token.
183   virtual ParseResult parseRBrace() = 0;
184 
185   /// Parse a `}` token if present
186   virtual ParseResult parseOptionalRBrace() = 0;
187 
188   /// Parse a `:` token.
189   virtual ParseResult parseColon() = 0;
190 
191   /// Parse a `:` token if present.
192   virtual ParseResult parseOptionalColon() = 0;
193 
194   /// Parse a `,` token.
195   virtual ParseResult parseComma() = 0;
196 
197   /// Parse a `,` token if present.
198   virtual ParseResult parseOptionalComma() = 0;
199 
200   /// Parse a `=` token.
201   virtual ParseResult parseEqual() = 0;
202 
203   /// Parse a `=` token if present.
204   virtual ParseResult parseOptionalEqual() = 0;
205 
206   /// Parse a quoted string token if present.
207   virtual ParseResult parseOptionalString(StringRef *string) = 0;
208 
209   /// Parse a given keyword.
210   ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
211     auto loc = getCurrentLocation();
212     if (parseOptionalKeyword(keyword))
213       return emitError(loc, "expected '") << keyword << "'" << msg;
214     return success();
215   }
216 
217   /// Parse a keyword into 'keyword'.
parseKeyword(StringRef * keyword)218   ParseResult parseKeyword(StringRef *keyword) {
219     auto loc = getCurrentLocation();
220     if (parseOptionalKeyword(keyword))
221       return emitError(loc, "expected valid keyword");
222     return success();
223   }
224 
225   /// Parse the given keyword if present.
226   virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
227 
228   /// Parse a keyword, if present, into 'keyword'.
229   virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
230 
231   /// Parse a '<' token.
232   virtual ParseResult parseLess() = 0;
233 
234   /// Parse a `<` token if present.
235   virtual ParseResult parseOptionalLess() = 0;
236 
237   /// Parse a '>' token.
238   virtual ParseResult parseGreater() = 0;
239 
240   /// Parse a `>` token if present.
241   virtual ParseResult parseOptionalGreater() = 0;
242 
243   /// Parse a `(` token.
244   virtual ParseResult parseLParen() = 0;
245 
246   /// Parse a `(` token if present.
247   virtual ParseResult parseOptionalLParen() = 0;
248 
249   /// Parse a `)` token.
250   virtual ParseResult parseRParen() = 0;
251 
252   /// Parse a `)` token if present.
253   virtual ParseResult parseOptionalRParen() = 0;
254 
255   /// Parse a `[` token.
256   virtual ParseResult parseLSquare() = 0;
257 
258   /// Parse a `[` token if present.
259   virtual ParseResult parseOptionalLSquare() = 0;
260 
261   /// Parse a `]` token.
262   virtual ParseResult parseRSquare() = 0;
263 
264   /// Parse a `]` token if present.
265   virtual ParseResult parseOptionalRSquare() = 0;
266 
267   /// Parse a `...` token if present;
268   virtual ParseResult parseOptionalEllipsis() = 0;
269 
270   /// Parse a `?` token.
271   virtual ParseResult parseOptionalQuestion() = 0;
272 
273   /// Parse a `*` token.
274   virtual ParseResult parseOptionalStar() = 0;
275 
276   //===--------------------------------------------------------------------===//
277   // Attribute Parsing
278   //===--------------------------------------------------------------------===//
279 
280   /// Parse an arbitrary attribute and return it in result.
281   virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
282 
283   /// Parse an attribute of a specific kind and type.
284   template <typename AttrType>
285   ParseResult parseAttribute(AttrType &result, Type type = {}) {
286     llvm::SMLoc loc = getCurrentLocation();
287 
288     // Parse any kind of attribute.
289     Attribute attr;
290     if (parseAttribute(attr, type))
291       return failure();
292 
293     // Check for the right kind of attribute.
294     result = attr.dyn_cast<AttrType>();
295     if (!result)
296       return emitError(loc, "invalid kind of attribute specified");
297     return success();
298   }
299 
300   /// Parse an affine map instance into 'map'.
301   virtual ParseResult parseAffineMap(AffineMap &map) = 0;
302 
303   /// Parse an integer set instance into 'set'.
304   virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
305 
306   //===--------------------------------------------------------------------===//
307   // Type Parsing
308   //===--------------------------------------------------------------------===//
309 
310   /// Parse a type.
311   virtual ParseResult parseType(Type &result) = 0;
312 
313   /// Parse a type of a specific kind, e.g. a FunctionType.
parseType(TypeType & result)314   template <typename TypeType> ParseResult parseType(TypeType &result) {
315     llvm::SMLoc loc = getCurrentLocation();
316 
317     // Parse any kind of type.
318     Type type;
319     if (parseType(type))
320       return failure();
321 
322     // Check for the right kind of attribute.
323     result = type.dyn_cast<TypeType>();
324     if (!result)
325       return emitError(loc, "invalid kind of type specified");
326     return success();
327   }
328 
329   /// Parse a type if present.
330   virtual OptionalParseResult parseOptionalType(Type &result) = 0;
331 
332   /// Parse a 'x' separated dimension list. This populates the dimension list,
333   /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
334   /// `?` otherwise.
335   ///
336   ///   dimension-list ::= (dimension `x`)*
337   ///   dimension ::= `?` | integer
338   ///
339   /// When `allowDynamic` is not set, this is used to parse:
340   ///
341   ///   static-dimension-list ::= (integer `x`)*
342   virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
343                                          bool allowDynamic = true) = 0;
344 };
345 
346 } // end namespace mlir
347 
348 #endif
349