• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Operator.h - Operator class ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TABLEGEN_OPERATOR_H_
14 #define MLIR_TABLEGEN_OPERATOR_H_
15 
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Dialect.h"
20 #include "mlir/TableGen/OpTrait.h"
21 #include "mlir/TableGen/Region.h"
22 #include "mlir/TableGen/Successor.h"
23 #include "mlir/TableGen/Type.h"
24 #include "llvm/ADT/PointerUnion.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringMap.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/SMLoc.h"
29 
30 namespace llvm {
31 class DefInit;
32 class Record;
33 class StringInit;
34 } // end namespace llvm
35 
36 namespace mlir {
37 namespace tblgen {
38 
39 // Wrapper class that contains a MLIR op's information (e.g., operands,
40 // attributes) defined in TableGen and provides helper methods for
41 // accessing them.
42 class Operator {
43 public:
44   explicit Operator(const llvm::Record &def);
Operator(const llvm::Record * def)45   explicit Operator(const llvm::Record *def) : Operator(*def) {}
46 
47   // Returns this op's dialect name.
48   StringRef getDialectName() const;
49 
50   // Returns the operation name. The name will follow the "<dialect>.<op-name>"
51   // format if its dialect name is not empty.
52   std::string getOperationName() const;
53 
54   // Returns this op's C++ class name.
55   StringRef getCppClassName() const;
56 
57   // Returns this op's C++ class name prefixed with namespaces.
58   std::string getQualCppClassName() const;
59 
60   // Returns the name of op's adaptor C++ class.
61   std::string getAdaptorName() const;
62 
63   /// A class used to represent the decorators of an operator variable, i.e.
64   /// argument or result.
65   struct VariableDecorator {
66   public:
VariableDecoratorVariableDecorator67     explicit VariableDecorator(const llvm::Record *def) : def(def) {}
getDefVariableDecorator68     const llvm::Record &getDef() const { return *def; }
69 
70   protected:
71     // The TableGen definition of this decorator.
72     const llvm::Record *def;
73   };
74 
75   // A utility iterator over a list of variable decorators.
76   struct VariableDecoratorIterator
77       : public llvm::mapped_iterator<llvm::Init *const *,
78                                      VariableDecorator (*)(llvm::Init *)> {
79     using reference = VariableDecorator;
80 
81     /// Initializes the iterator to the specified iterator.
VariableDecoratorIteratorVariableDecoratorIterator82     VariableDecoratorIterator(llvm::Init *const *it)
83         : llvm::mapped_iterator<llvm::Init *const *,
84                                 VariableDecorator (*)(llvm::Init *)>(it,
85                                                                      &unwrap) {}
86     static VariableDecorator unwrap(llvm::Init *init);
87   };
88   using var_decorator_iterator = VariableDecoratorIterator;
89   using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
90 
91   using value_iterator = NamedTypeConstraint *;
92   using value_range = llvm::iterator_range<value_iterator>;
93 
94   // Returns true if this op has variable length operands or results.
95   bool isVariadic() const;
96 
97   // Returns true if default builders should not be generated.
98   bool skipDefaultBuilders() const;
99 
100   // Op result iterators.
101   value_iterator result_begin();
102   value_iterator result_end();
103   value_range getResults();
104 
105   // Returns the number of results this op produces.
106   int getNumResults() const;
107 
108   // Returns the op result at the given `index`.
getResult(int index)109   NamedTypeConstraint &getResult(int index) { return results[index]; }
getResult(int index)110   const NamedTypeConstraint &getResult(int index) const {
111     return results[index];
112   }
113 
114   // Returns the `index`-th result's type constraint.
115   TypeConstraint getResultTypeConstraint(int index) const;
116   // Returns the `index`-th result's name.
117   StringRef getResultName(int index) const;
118   // Returns the `index`-th result's decorators.
119   var_decorator_range getResultDecorators(int index) const;
120 
121   // Returns the number of variable length results in this operation.
122   unsigned getNumVariableLengthResults() const;
123 
124   // Op attribute iterators.
125   using attribute_iterator = const NamedAttribute *;
126   attribute_iterator attribute_begin() const;
127   attribute_iterator attribute_end() const;
128   llvm::iterator_range<attribute_iterator> getAttributes() const;
129 
getNumAttributes()130   int getNumAttributes() const { return attributes.size(); }
getNumNativeAttributes()131   int getNumNativeAttributes() const { return numNativeAttributes; }
132 
133   // Op attribute accessors.
getAttribute(int index)134   NamedAttribute &getAttribute(int index) { return attributes[index]; }
135 
136   // Op operand iterators.
137   value_iterator operand_begin();
138   value_iterator operand_end();
139   value_range getOperands();
140 
getNumOperands()141   int getNumOperands() const { return operands.size(); }
getOperand(int index)142   NamedTypeConstraint &getOperand(int index) { return operands[index]; }
getOperand(int index)143   const NamedTypeConstraint &getOperand(int index) const {
144     return operands[index];
145   }
146 
147   // Returns the number of variadic operands in this operation.
148   unsigned getNumVariableLengthOperands() const;
149 
150   // Returns the total number of arguments.
getNumArgs()151   int getNumArgs() const { return arguments.size(); }
152 
153   // Returns true of the operation has a single variadic arg.
154   bool hasSingleVariadicArg() const;
155 
156   // Returns true if the operation has a single variadic result.
hasSingleVariadicResult()157   bool hasSingleVariadicResult() const {
158     return getNumResults() == 1 && getResult(0).isVariadic();
159   }
160 
161   // Returns true of the operation has no variadic regions.
hasNoVariadicRegions()162   bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
163 
164   using arg_iterator = const Argument *;
165   using arg_range = llvm::iterator_range<arg_iterator>;
166 
167   // Op argument (attribute or operand) iterators.
168   arg_iterator arg_begin() const;
169   arg_iterator arg_end() const;
170   arg_range getArgs() const;
171 
172   // Op argument (attribute or operand) accessors.
173   Argument getArg(int index) const;
174   StringRef getArgName(int index) const;
175   var_decorator_range getArgDecorators(int index) const;
176 
177   // Returns the trait wrapper for the given MLIR C++ `trait`.
178   // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
179   // requiring the raw MLIR trait here.
180   const OpTrait *getTrait(llvm::StringRef trait) const;
181 
182   // Regions.
183   using const_region_iterator = const NamedRegion *;
184   const_region_iterator region_begin() const;
185   const_region_iterator region_end() const;
186   llvm::iterator_range<const_region_iterator> getRegions() const;
187 
188   // Returns the number of regions.
189   unsigned getNumRegions() const;
190   // Returns the `index`-th region.
191   const NamedRegion &getRegion(unsigned index) const;
192 
193   // Returns the number of variadic regions in this operation.
194   unsigned getNumVariadicRegions() const;
195 
196   // Successors.
197   using const_successor_iterator = const NamedSuccessor *;
198   const_successor_iterator successor_begin() const;
199   const_successor_iterator successor_end() const;
200   llvm::iterator_range<const_successor_iterator> getSuccessors() const;
201 
202   // Returns the number of successors.
203   unsigned getNumSuccessors() const;
204   // Returns the `index`-th successor.
205   const NamedSuccessor &getSuccessor(unsigned index) const;
206 
207   // Returns the number of variadic successors in this operation.
208   unsigned getNumVariadicSuccessors() const;
209 
210   // Trait.
211   using const_trait_iterator = const OpTrait *;
212   const_trait_iterator trait_begin() const;
213   const_trait_iterator trait_end() const;
214   llvm::iterator_range<const_trait_iterator> getTraits() const;
215 
216   ArrayRef<llvm::SMLoc> getLoc() const;
217 
218   // Query functions for the documentation of the operator.
219   bool hasDescription() const;
220   StringRef getDescription() const;
221   bool hasSummary() const;
222   StringRef getSummary() const;
223 
224   // Query functions for the assembly format of the operator.
225   bool hasAssemblyFormat() const;
226   StringRef getAssemblyFormat() const;
227 
228   // Returns this op's extra class declaration code.
229   StringRef getExtraClassDeclaration() const;
230 
231   // Returns the Tablegen definition this operator was constructed from.
232   // TODO: do not expose the TableGen record, this is a temporary solution to
233   // OpEmitter requiring a Record because Operator does not provide enough
234   // methods.
235   const llvm::Record &getDef() const;
236 
237   // Returns the dialect of the op.
getDialect()238   const Dialect &getDialect() const { return dialect; }
239 
240   // Prints the contents in this operator to the given `os`. This is used for
241   // debugging purposes.
242   void print(llvm::raw_ostream &os) const;
243 
244   // Return whether all the result types are known.
allResultTypesKnown()245   bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
246 
247   // Pair representing either a index to an argument or a type constraint. Only
248   // one of these entries should have the non-default value.
249   struct ArgOrType {
ArgOrTypeArgOrType250     explicit ArgOrType(int index) : index(index), constraint(None) {}
ArgOrTypeArgOrType251     explicit ArgOrType(TypeConstraint constraint)
252         : index(None), constraint(constraint) {}
isArgArgOrType253     bool isArg() const {
254       assert(constraint.hasValue() ^ index.hasValue());
255       return index.hasValue();
256     }
isTypeArgOrType257     bool isType() const {
258       assert(constraint.hasValue() ^ index.hasValue());
259       return constraint.hasValue();
260     }
261 
getArgArgOrType262     int getArg() const { return *index; }
getTypeArgOrType263     TypeConstraint getType() const { return *constraint; }
264 
265   private:
266     Optional<int> index;
267     Optional<TypeConstraint> constraint;
268   };
269 
270   // Return all arguments or type constraints with same type as result[index].
271   // Requires: all result types are known.
272   ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
273 
274   // Pair consisting kind of argument and index into operands or attributes.
275   struct OperandOrAttribute {
276     enum class Kind { Operand, Attribute };
OperandOrAttributeOperandOrAttribute277     OperandOrAttribute(Kind kind, int index) {
278       packed = (index << 1) & (kind == Kind::Attribute);
279     }
operandOrAttributeIndexOperandOrAttribute280     int operandOrAttributeIndex() const { return (packed >> 1); }
kindOperandOrAttribute281     Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
282 
283   private:
284     int packed;
285   };
286 
287   // Returns the OperandOrAttribute corresponding to the index.
288   OperandOrAttribute getArgToOperandOrAttribute(int index) const;
289 
290 private:
291   // Populates the vectors containing operands, attributes, results and traits.
292   void populateOpStructure();
293 
294   // Populates type inference info (mostly equality) with input a mapping from
295   // names to indices for arguments and results.
296   void populateTypeInferenceInfo(
297       const llvm::StringMap<int> &argumentsAndResultsIndex);
298 
299   // The dialect of this op.
300   Dialect dialect;
301 
302   // The unqualified C++ class name of the op.
303   StringRef cppClassName;
304 
305   // The operands of the op.
306   SmallVector<NamedTypeConstraint, 4> operands;
307 
308   // The attributes of the op.  Contains native attributes (corresponding to the
309   // actual stored attributed of the operation) followed by derived attributes
310   // (corresponding to dynamic properties of the operation that are computed
311   // upon request).
312   SmallVector<NamedAttribute, 4> attributes;
313 
314   // The arguments of the op (operands and native attributes).
315   SmallVector<Argument, 4> arguments;
316 
317   // The results of the op.
318   SmallVector<NamedTypeConstraint, 4> results;
319 
320   // The successors of this op.
321   SmallVector<NamedSuccessor, 0> successors;
322 
323   // The traits of the op.
324   SmallVector<OpTrait, 4> traits;
325 
326   // The regions of this op.
327   SmallVector<NamedRegion, 1> regions;
328 
329   // The argument with the same type as the result.
330   SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
331 
332   // Map from argument to attribute or operand number.
333   SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
334 
335   // The number of native attributes stored in the leading positions of
336   // `attributes`.
337   int numNativeAttributes;
338 
339   // The TableGen definition of this op.
340   const llvm::Record &def;
341 
342   // Whether the type of all results are known.
343   bool allResultsHaveKnownTypes;
344 };
345 
346 } // end namespace tblgen
347 } // end namespace mlir
348 
349 #endif // MLIR_TABLEGEN_OPERATOR_H_
350