• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines several classes for Op C++ code emission. They are only
10 // expected to be used by MLIR TableGen backends.
11 //
12 // We emit the op declaration and definition into separate files: *Ops.h.inc
13 // and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
14 // the latter for dialect *Ops.cpp. This way provides a cleaner interface.
15 //
16 // In order to do this split, we need to track method signature and
17 // implementation logic separately. Signature information is used for both
18 // declaration and definition, while implementation logic is only for
19 // definition. So we have the following classes for C++ code emission.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #ifndef MLIR_TABLEGEN_OPCLASS_H_
24 #define MLIR_TABLEGEN_OPCLASS_H_
25 
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #include <set>
34 #include <string>
35 
36 namespace mlir {
37 namespace tblgen {
38 class FmtObjectBase;
39 
40 // Class for holding a single parameter of an op's method for C++ code emission.
41 class OpMethodParameter {
42 public:
43   // Properties (qualifiers) for the parameter.
44   enum Property {
45     PP_None = 0x0,
46     PP_Optional = 0x1,
47   };
48 
49   OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "",
50                     Property properties = PP_None)
type(type)51       : type(type), name(name), defaultValue(defaultValue),
52         properties(properties) {}
53 
OpMethodParameter(StringRef type,StringRef name,Property property)54   OpMethodParameter(StringRef type, StringRef name, Property property)
55       : OpMethodParameter(type, name, "", property) {}
56 
57   // Writes the parameter as a part of a method declaration to `os`.
writeDeclTo(raw_ostream & os)58   void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
59 
60   // Writes the parameter as a part of a method definition to `os`
writeDefTo(raw_ostream & os)61   void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
62 
getType()63   const std::string &getType() const { return type; }
hasDefaultValue()64   bool hasDefaultValue() const { return !defaultValue.empty(); }
65 
66 private:
67   void writeTo(raw_ostream &os, bool emitDefault) const;
68 
69   std::string type;
70   std::string name;
71   std::string defaultValue;
72   Property properties;
73 };
74 
75 // Base class for holding parameters of an op's method for C++ code emission.
76 class OpMethodParameters {
77 public:
78   // Discriminator for LLVM-style RTTI.
79   enum ParamsKind {
80     // Separate type and name for each parameter is not known.
81     PK_Unresolved,
82     // Each parameter is resolved to a type and name.
83     PK_Resolved,
84   };
85 
OpMethodParameters(ParamsKind kind)86   OpMethodParameters(ParamsKind kind) : kind(kind) {}
~OpMethodParameters()87   virtual ~OpMethodParameters() {}
88 
89   // LLVM-style RTTI support.
getKind()90   ParamsKind getKind() const { return kind; }
91 
92   // Writes the parameters as a part of a method declaration to `os`.
93   virtual void writeDeclTo(raw_ostream &os) const = 0;
94 
95   // Writes the parameters as a part of a method definition to `os`
96   virtual void writeDefTo(raw_ostream &os) const = 0;
97 
98   // Factory methods to create the correct type of `OpMethodParameters`
99   // object based on the arguments.
100   static std::unique_ptr<OpMethodParameters> create();
101 
102   static std::unique_ptr<OpMethodParameters> create(StringRef params);
103 
104   static std::unique_ptr<OpMethodParameters>
105   create(llvm::SmallVectorImpl<OpMethodParameter> &&params);
106 
107   static std::unique_ptr<OpMethodParameters>
108   create(StringRef type, StringRef name, StringRef defaultValue = "");
109 
110 private:
111   const ParamsKind kind;
112 };
113 
114 // Class for holding unresolved parameters.
115 class OpMethodUnresolvedParameters : public OpMethodParameters {
116 public:
OpMethodUnresolvedParameters(StringRef params)117   OpMethodUnresolvedParameters(StringRef params)
118       : OpMethodParameters(PK_Unresolved), parameters(params) {}
119 
120   // write the parameters as a part of a method declaration to the given `os`.
121   void writeDeclTo(raw_ostream &os) const override;
122 
123   // write the parameters as a part of a method definition to the given `os`
124   void writeDefTo(raw_ostream &os) const override;
125 
126   // LLVM-style RTTI support.
classof(const OpMethodParameters * params)127   static bool classof(const OpMethodParameters *params) {
128     return params->getKind() == PK_Unresolved;
129   }
130 
131 private:
132   std::string parameters;
133 };
134 
135 // Class for holding resolved parameters.
136 class OpMethodResolvedParameters : public OpMethodParameters {
137 public:
OpMethodResolvedParameters()138   OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {}
139 
OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> && params)140   OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &&params)
141       : OpMethodParameters(PK_Resolved) {
142     for (OpMethodParameter &param : params)
143       parameters.emplace_back(std::move(param));
144   }
145 
OpMethodResolvedParameters(StringRef type,StringRef name,StringRef defaultValue)146   OpMethodResolvedParameters(StringRef type, StringRef name,
147                              StringRef defaultValue)
148       : OpMethodParameters(PK_Resolved) {
149     parameters.emplace_back(type, name, defaultValue);
150   }
151 
152   // Returns the number of parameters.
getNumParameters()153   size_t getNumParameters() const { return parameters.size(); }
154 
155   // Returns if this method makes the `other` method redundant. Note that this
156   // is more than just finding conflicting methods. This method determines if
157   // the 2 set of parameters are conflicting and if so, returns true if this
158   // method has a more general set of parameters that can replace all possible
159   // calls to the `other` method.
160   bool makesRedundant(const OpMethodResolvedParameters &other) const;
161 
162   // write the parameters as a part of a method declaration to the given `os`.
163   void writeDeclTo(raw_ostream &os) const override;
164 
165   // write the parameters as a part of a method definition to the given `os`
166   void writeDefTo(raw_ostream &os) const override;
167 
168   // LLVM-style RTTI support.
classof(const OpMethodParameters * params)169   static bool classof(const OpMethodParameters *params) {
170     return params->getKind() == PK_Resolved;
171   }
172 
173 private:
174   llvm::SmallVector<OpMethodParameter, 4> parameters;
175 };
176 
177 // Class for holding the signature of an op's method for C++ code emission
178 class OpMethodSignature {
179 public:
180   template <typename... Args>
OpMethodSignature(StringRef retType,StringRef name,Args &&...args)181   OpMethodSignature(StringRef retType, StringRef name, Args &&...args)
182       : returnType(retType), methodName(name),
183         parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {}
184   OpMethodSignature(OpMethodSignature &&) = default;
185 
186   // Returns if a method with this signature makes a method with `other`
187   // signature redundant. Only supports resolved parameters.
188   bool makesRedundant(const OpMethodSignature &other) const;
189 
190   // Returns the number of parameters (for resolved parameters).
getNumParameters()191   size_t getNumParameters() const {
192     return cast<OpMethodResolvedParameters>(parameters.get())
193         ->getNumParameters();
194   }
195 
196   // Returns the name of the method.
getName()197   StringRef getName() const { return methodName; }
198 
199   // Writes the signature as a method declaration to the given `os`.
200   void writeDeclTo(raw_ostream &os) const;
201 
202   // Writes the signature as the start of a method definition to the given `os`.
203   // `namePrefix` is the prefix to be prepended to the method name (typically
204   // namespaces for qualifying the method definition).
205   void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
206 
207 private:
208   std::string returnType;
209   std::string methodName;
210   std::unique_ptr<OpMethodParameters> parameters;
211 };
212 
213 // Class for holding the body of an op's method for C++ code emission
214 class OpMethodBody {
215 public:
216   explicit OpMethodBody(bool declOnly);
217 
218   OpMethodBody &operator<<(Twine content);
219   OpMethodBody &operator<<(int content);
220   OpMethodBody &operator<<(const FmtObjectBase &content);
221 
222   void writeTo(raw_ostream &os) const;
223 
224 private:
225   // Whether this class should record method body.
226   bool isEffective;
227   std::string body;
228 };
229 
230 // Class for holding an op's method for C++ code emission
231 class OpMethod {
232 public:
233   // Properties (qualifiers) of class methods. Bitfield is used here to help
234   // querying properties.
235   enum Property {
236     MP_None = 0x0,
237     MP_Static = 0x1,
238     MP_Constructor = 0x2,
239     MP_Private = 0x4,
240     MP_Declaration = 0x8,
241     MP_StaticDeclaration = MP_Static | MP_Declaration,
242   };
243 
244   template <typename... Args>
OpMethod(StringRef retType,StringRef name,Property property,unsigned id,Args &&...args)245   OpMethod(StringRef retType, StringRef name, Property property, unsigned id,
246            Args &&...args)
247       : properties(property),
248         methodSignature(retType, name, std::forward<Args>(args)...),
249         methodBody(properties & MP_Declaration), id(id) {}
250 
251   OpMethod(OpMethod &&) = default;
252 
253   virtual ~OpMethod() = default;
254 
body()255   OpMethodBody &body() { return methodBody; }
256 
257   // Returns true if this is a static method.
isStatic()258   bool isStatic() const { return properties & MP_Static; }
259 
260   // Returns true if this is a private method.
isPrivate()261   bool isPrivate() const { return properties & MP_Private; }
262 
263   // Returns the name of this method.
getName()264   StringRef getName() const { return methodSignature.getName(); }
265 
266   // Returns the ID for this method
getID()267   unsigned getID() const { return id; }
268 
269   // Returns if this method makes the `other` method redundant.
makesRedundant(const OpMethod & other)270   bool makesRedundant(const OpMethod &other) const {
271     return methodSignature.makesRedundant(other.methodSignature);
272   }
273 
274   // Writes the method as a declaration to the given `os`.
275   virtual void writeDeclTo(raw_ostream &os) const;
276 
277   // Writes the method as a definition to the given `os`. `namePrefix` is the
278   // prefix to be prepended to the method name (typically namespaces for
279   // qualifying the method definition).
280   virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
281 
282 protected:
283   Property properties;
284   OpMethodSignature methodSignature;
285   OpMethodBody methodBody;
286   const unsigned id;
287 };
288 
289 // Class for holding an op's constructor method for C++ code emission.
290 class OpConstructor : public OpMethod {
291 public:
292   template <typename... Args>
OpConstructor(StringRef className,Property property,unsigned id,Args &&...args)293   OpConstructor(StringRef className, Property property, unsigned id,
294                 Args &&...args)
295       : OpMethod("", className, property, id, std::forward<Args>(args)...) {}
296 
297   // Add member initializer to constructor initializing `name` with `value`.
298   void addMemberInitializer(StringRef name, StringRef value);
299 
300   // Writes the method as a definition to the given `os`. `namePrefix` is the
301   // prefix to be prepended to the method name (typically namespaces for
302   // qualifying the method definition).
303   void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
304 
305 private:
306   // Member initializers.
307   std::string memberInitializers;
308 };
309 
310 // A class used to emit C++ classes from Tablegen.  Contains a list of public
311 // methods and a list of private fields to be emitted.
312 class Class {
313 public:
314   explicit Class(StringRef name);
315 
316   // Adds a new method to this class and prune redundant methods. Returns null
317   // if the method was not added (because an existing method would make it
318   // redundant), else returns a pointer to the added method. Note that this call
319   // may also delete existing methods that are made redundant by a method to the
320   // class.
321   template <typename... Args>
addMethodAndPrune(StringRef retType,StringRef name,OpMethod::Property properties,Args &&...args)322   OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
323                               OpMethod::Property properties, Args &&...args) {
324     auto newMethod = std::make_unique<OpMethod>(
325         retType, name, properties, nextMethodID++, std::forward<Args>(args)...);
326     return addMethodAndPrune(methods, std::move(newMethod));
327   }
328 
329   template <typename... Args>
addMethodAndPrune(StringRef retType,StringRef name,Args &&...args)330   OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
331                               Args &&...args) {
332     return addMethodAndPrune(retType, name, OpMethod::MP_None,
333                              std::forward<Args>(args)...);
334   }
335 
336   template <typename... Args>
addConstructorAndPrune(Args &&...args)337   OpConstructor *addConstructorAndPrune(Args &&...args) {
338     auto newConstructor = std::make_unique<OpConstructor>(
339         getClassName(), OpMethod::MP_Constructor, nextMethodID++,
340         std::forward<Args>(args)...);
341     return addMethodAndPrune(constructors, std::move(newConstructor));
342   }
343 
344   // Creates a new field in this class.
345   void newField(StringRef type, StringRef name, StringRef defaultValue = "");
346 
347   // Writes this op's class as a declaration to the given `os`.
348   void writeDeclTo(raw_ostream &os) const;
349   // Writes the method definitions in this op's class to the given `os`.
350   void writeDefTo(raw_ostream &os) const;
351 
352   // Returns the C++ class name of the op.
getClassName()353   StringRef getClassName() const { return className; }
354 
355 protected:
356   // Get a list of all the methods to emit, filtering out hidden ones.
forAllMethods(llvm::function_ref<void (const OpMethod &)> func)357   void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const {
358     using ConsRef = const std::unique_ptr<OpConstructor> &;
359     using MethodRef = const std::unique_ptr<OpMethod> &;
360     llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); });
361     llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); });
362   }
363 
364   // For deterministic code generation, keep methods sorted in the order in
365   // which they were generated.
366   template <typename MethodTy>
367   struct MethodCompare {
operatorMethodCompare368     bool operator()(const std::unique_ptr<MethodTy> &x,
369                     const std::unique_ptr<MethodTy> &y) const {
370       return x->getID() < y->getID();
371     }
372   };
373 
374   template <typename MethodTy>
375   using MethodSet =
376       std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>;
377 
378   template <typename MethodTy>
addMethodAndPrune(MethodSet<MethodTy> & set,std::unique_ptr<MethodTy> && newMethod)379   MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set,
380                               std::unique_ptr<MethodTy> &&newMethod) {
381     // Check if the new method will be made redundant by existing methods.
382     for (auto &method : set)
383       if (method->makesRedundant(*newMethod))
384         return nullptr;
385 
386     // We can add this a method to the set. Prune any existing methods that will
387     // be made redundant by adding this new method. Note that the redundant
388     // check between two methods is more than a conflict check. makesRedundant()
389     // below will check if the new method conflicts with an existing method and
390     // if so, returns true if the new method makes the existing method redundant
391     // because all calls to the existing method can be subsumed by the new
392     // method. So makesRedundant() does a combined job of finding conflicts and
393     // deciding which of the 2 conflicting methods survive.
394     //
395     // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing
396     // it manually here.
397     for (auto it = set.begin(), end = set.end(); it != end;) {
398       if (newMethod->makesRedundant(*(it->get())))
399         it = set.erase(it);
400       else
401         ++it;
402     }
403 
404     MethodTy *ret = newMethod.get();
405     set.insert(std::move(newMethod));
406     return ret;
407   }
408 
409   std::string className;
410   MethodSet<OpConstructor> constructors;
411   MethodSet<OpMethod> methods;
412   unsigned nextMethodID = 0;
413   SmallVector<std::string, 4> fields;
414 };
415 
416 // Class for holding an op for C++ code emission
417 class OpClass : public Class {
418 public:
419   explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
420 
421   // Adds an op trait.
422   void addTrait(Twine trait);
423 
424   // Writes this op's class as a declaration to the given `os`.  Redefines
425   // Class::writeDeclTo to also emit traits and extra class declarations.
426   void writeDeclTo(raw_ostream &os) const;
427 
428 private:
429   StringRef extraClassDeclaration;
430   SmallVector<std::string, 4> traitsVec;
431   StringSet<> traitsSet;
432 };
433 
434 } // namespace tblgen
435 } // namespace mlir
436 
437 #endif // MLIR_TABLEGEN_OPCLASS_H_
438