• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILDERS_H
10 #define MLIR_IR_BUILDERS_H
11 
12 #include "mlir/IR/OpDefinition.h"
13 
14 namespace mlir {
15 
16 class AffineExpr;
17 class BlockAndValueMapping;
18 class UnknownLoc;
19 class FileLineColLoc;
20 class Type;
21 class PrimitiveType;
22 class IntegerType;
23 class FloatType;
24 class FunctionType;
25 class IndexType;
26 class MemRefType;
27 class VectorType;
28 class RankedTensorType;
29 class UnrankedTensorType;
30 class TupleType;
31 class NoneType;
32 class BoolAttr;
33 class IntegerAttr;
34 class FloatAttr;
35 class StringAttr;
36 class TypeAttr;
37 class ArrayAttr;
38 class SymbolRefAttr;
39 class ElementsAttr;
40 class DenseElementsAttr;
41 class DenseIntElementsAttr;
42 class AffineMapAttr;
43 class AffineMap;
44 class UnitAttr;
45 
46 /// This class is a general helper class for creating context-global objects
47 /// like types, attributes, and affine expressions.
48 class Builder {
49 public:
Builder(MLIRContext * context)50   explicit Builder(MLIRContext *context) : context(context) {}
Builder(Operation * op)51   explicit Builder(Operation *op) : Builder(op->getContext()) {}
52 
getContext()53   MLIRContext *getContext() const { return context; }
54 
55   Identifier getIdentifier(StringRef str);
56 
57   // Locations.
58   Location getUnknownLoc();
59   Location getFileLineColLoc(Identifier filename, unsigned line,
60                              unsigned column);
61   Location getFusedLoc(ArrayRef<Location> locs,
62                        Attribute metadata = Attribute());
63 
64   // Types.
65   FloatType getBF16Type();
66   FloatType getF16Type();
67   FloatType getF32Type();
68   FloatType getF64Type();
69 
70   IndexType getIndexType();
71 
72   IntegerType getI1Type();
73   IntegerType getI32Type();
74   IntegerType getI64Type();
75   IntegerType getIntegerType(unsigned width);
76   IntegerType getIntegerType(unsigned width, bool isSigned);
77   FunctionType getFunctionType(TypeRange inputs, TypeRange results);
78   TupleType getTupleType(TypeRange elementTypes);
79   NoneType getNoneType();
80 
81   /// Get or construct an instance of the type 'ty' with provided arguments.
getType(Args...args)82   template <typename Ty, typename... Args> Ty getType(Args... args) {
83     return Ty::get(context, args...);
84   }
85 
86   // Attributes.
87   NamedAttribute getNamedAttr(StringRef name, Attribute val);
88 
89   UnitAttr getUnitAttr();
90   BoolAttr getBoolAttr(bool value);
91   DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value);
92   IntegerAttr getIntegerAttr(Type type, int64_t value);
93   IntegerAttr getIntegerAttr(Type type, const APInt &value);
94   FloatAttr getFloatAttr(Type type, double value);
95   FloatAttr getFloatAttr(Type type, const APFloat &value);
96   StringAttr getStringAttr(StringRef bytes);
97   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
98   FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
99   FlatSymbolRefAttr getSymbolRefAttr(StringRef value);
100   SymbolRefAttr getSymbolRefAttr(StringRef value,
101                                  ArrayRef<FlatSymbolRefAttr> nestedReferences);
102 
103   // Returns a 0-valued attribute of the given `type`. This function only
104   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
105   // ranked tensor of them. Returns null attribute otherwise.
106   Attribute getZeroAttr(Type type);
107 
108   // Convenience methods for fixed types.
109   FloatAttr getF16FloatAttr(float value);
110   FloatAttr getF32FloatAttr(float value);
111   FloatAttr getF64FloatAttr(double value);
112 
113   IntegerAttr getI8IntegerAttr(int8_t value);
114   IntegerAttr getI16IntegerAttr(int16_t value);
115   IntegerAttr getI32IntegerAttr(int32_t value);
116   IntegerAttr getI64IntegerAttr(int64_t value);
117   IntegerAttr getIndexAttr(int64_t value);
118 
119   /// Signed and unsigned integer attribute getters.
120   IntegerAttr getSI32IntegerAttr(int32_t value);
121   IntegerAttr getUI32IntegerAttr(uint32_t value);
122 
123   /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty.
124   DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values);
125   DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
126   DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
127 
128   /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty.
129   /// These are generally preferable for representing general lists of integers
130   /// as attributes.
131   DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values);
132   DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values);
133   DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values);
134 
135   ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
136   ArrayAttr getBoolArrayAttr(ArrayRef<bool> values);
137   ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
138   ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
139   ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values);
140   ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
141   ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
142   ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
143   ArrayAttr getTypeArrayAttr(TypeRange values);
144 
145   // Affine expressions and affine maps.
146   AffineExpr getAffineDimExpr(unsigned position);
147   AffineExpr getAffineSymbolExpr(unsigned position);
148   AffineExpr getAffineConstantExpr(int64_t constant);
149 
150   // Special cases of affine maps and integer sets
151   /// Returns a zero result affine map with no dimensions or symbols: () -> ().
152   AffineMap getEmptyAffineMap();
153   /// Returns a single constant result affine map with 0 dimensions and 0
154   /// symbols.  One constant result: () -> (val).
155   AffineMap getConstantAffineMap(int64_t val);
156   // One dimension id identity map: (i) -> (i).
157   AffineMap getDimIdentityMap();
158   // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
159   AffineMap getMultiDimIdentityMap(unsigned rank);
160   // One symbol identity map: ()[s] -> (s).
161   AffineMap getSymbolIdentityMap();
162 
163   /// Returns a map that shifts its (single) input dimension by 'shift'.
164   /// (d0) -> (d0 + shift)
165   AffineMap getSingleDimShiftAffineMap(int64_t shift);
166 
167   /// Returns an affine map that is a translation (shift) of all result
168   /// expressions in 'map' by 'shift'.
169   /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
170   ///   returns:    (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
171   AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
172 
173 protected:
174   MLIRContext *context;
175 };
176 
177 /// This class helps build Operations. Operations that are created are
178 /// automatically inserted at an insertion point. The builder is copyable.
179 class OpBuilder : public Builder {
180 public:
181   struct Listener;
182 
183   /// Create a builder with the given context.
184   explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr)
Builder(ctx)185       : Builder(ctx), listener(listener) {}
186 
187   /// Create a builder and set the insertion point to the start of the region.
188   explicit OpBuilder(Region *region, Listener *listener = nullptr)
189       : OpBuilder(region->getContext(), listener) {
190     if (!region->empty())
191       setInsertionPoint(&region->front(), region->front().begin());
192   }
193   explicit OpBuilder(Region &region, Listener *listener = nullptr)
194       : OpBuilder(&region, listener) {}
195 
196   /// Create a builder and set insertion point to the given operation, which
197   /// will cause subsequent insertions to go right before it.
198   explicit OpBuilder(Operation *op, Listener *listener = nullptr)
199       : OpBuilder(op->getContext(), listener) {
200     setInsertionPoint(op);
201   }
202 
203   OpBuilder(Block *block, Block::iterator insertPoint,
204             Listener *listener = nullptr)
205       : OpBuilder(block->getParent()->getContext(), listener) {
206     setInsertionPoint(block, insertPoint);
207   }
208 
209   /// Create a builder and set the insertion point to before the first operation
210   /// in the block but still inside the block.
211   static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) {
212     return OpBuilder(block, block->begin(), listener);
213   }
214 
215   /// Create a builder and set the insertion point to after the last operation
216   /// in the block but still inside the block.
217   static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) {
218     return OpBuilder(block, block->end(), listener);
219   }
220 
221   /// Create a builder and set the insertion point to before the block
222   /// terminator.
223   static OpBuilder atBlockTerminator(Block *block,
224                                      Listener *listener = nullptr) {
225     auto *terminator = block->getTerminator();
226     assert(terminator != nullptr && "the block has no terminator");
227     return OpBuilder(block, Block::iterator(terminator), listener);
228   }
229 
230   //===--------------------------------------------------------------------===//
231   // Listeners
232   //===--------------------------------------------------------------------===//
233 
234   /// This class represents a listener that may be used to hook into various
235   /// actions within an OpBuilder.
236   struct Listener {
237     virtual ~Listener();
238 
239     /// Notification handler for when an operation is inserted into the builder.
240     /// `op` is the operation that was inserted.
notifyOperationInsertedListener241     virtual void notifyOperationInserted(Operation *op) {}
242 
243     /// Notification handler for when a block is created using the builder.
244     /// `block` is the block that was created.
notifyBlockCreatedListener245     virtual void notifyBlockCreated(Block *block) {}
246   };
247 
248   /// Sets the listener of this builder to the one provided.
setListener(Listener * newListener)249   void setListener(Listener *newListener) { listener = newListener; }
250 
251   /// Returns the current listener of this builder, or nullptr if this builder
252   /// doesn't have a listener.
getListener()253   Listener *getListener() const { return listener; }
254 
255   //===--------------------------------------------------------------------===//
256   // Insertion Point Management
257   //===--------------------------------------------------------------------===//
258 
259   /// This class represents a saved insertion point.
260   class InsertPoint {
261   public:
262     /// Creates a new insertion point which doesn't point to anything.
263     InsertPoint() = default;
264 
265     /// Creates a new insertion point at the given location.
InsertPoint(Block * insertBlock,Block::iterator insertPt)266     InsertPoint(Block *insertBlock, Block::iterator insertPt)
267         : block(insertBlock), point(insertPt) {}
268 
269     /// Returns true if this insert point is set.
isSet()270     bool isSet() const { return (block != nullptr); }
271 
getBlock()272     Block *getBlock() const { return block; }
getPoint()273     Block::iterator getPoint() const { return point; }
274 
275   private:
276     Block *block = nullptr;
277     Block::iterator point;
278   };
279 
280   /// RAII guard to reset the insertion point of the builder when destroyed.
281   class InsertionGuard {
282   public:
InsertionGuard(OpBuilder & builder)283     InsertionGuard(OpBuilder &builder)
284         : builder(builder), ip(builder.saveInsertionPoint()) {}
~InsertionGuard()285     ~InsertionGuard() { builder.restoreInsertionPoint(ip); }
286 
287   private:
288     OpBuilder &builder;
289     OpBuilder::InsertPoint ip;
290   };
291 
292   /// Reset the insertion point to no location.  Creating an operation without a
293   /// set insertion point is an error, but this can still be useful when the
294   /// current insertion point a builder refers to is being removed.
clearInsertionPoint()295   void clearInsertionPoint() {
296     this->block = nullptr;
297     insertPoint = Block::iterator();
298   }
299 
300   /// Return a saved insertion point.
saveInsertionPoint()301   InsertPoint saveInsertionPoint() const {
302     return InsertPoint(getInsertionBlock(), getInsertionPoint());
303   }
304 
305   /// Restore the insert point to a previously saved point.
restoreInsertionPoint(InsertPoint ip)306   void restoreInsertionPoint(InsertPoint ip) {
307     if (ip.isSet())
308       setInsertionPoint(ip.getBlock(), ip.getPoint());
309     else
310       clearInsertionPoint();
311   }
312 
313   /// Set the insertion point to the specified location.
setInsertionPoint(Block * block,Block::iterator insertPoint)314   void setInsertionPoint(Block *block, Block::iterator insertPoint) {
315     // TODO: check that insertPoint is in this rather than some other block.
316     this->block = block;
317     this->insertPoint = insertPoint;
318   }
319 
320   /// Sets the insertion point to the specified operation, which will cause
321   /// subsequent insertions to go right before it.
setInsertionPoint(Operation * op)322   void setInsertionPoint(Operation *op) {
323     setInsertionPoint(op->getBlock(), Block::iterator(op));
324   }
325 
326   /// Sets the insertion point to the node after the specified operation, which
327   /// will cause subsequent insertions to go right after it.
setInsertionPointAfter(Operation * op)328   void setInsertionPointAfter(Operation *op) {
329     setInsertionPoint(op->getBlock(), ++Block::iterator(op));
330   }
331 
332   /// Sets the insertion point to the node after the specified value. If value
333   /// has a defining operation, sets the insertion point to the node after such
334   /// defining operation. This will cause subsequent insertions to go right
335   /// after it. Otherwise, value is a BlockArgumen. Sets the insertion point to
336   /// the start of its block.
setInsertionPointAfterValue(Value val)337   void setInsertionPointAfterValue(Value val) {
338     if (Operation *op = val.getDefiningOp()) {
339       setInsertionPointAfter(op);
340     } else {
341       auto blockArg = val.cast<BlockArgument>();
342       setInsertionPointToStart(blockArg.getOwner());
343     }
344   }
345 
346   /// Sets the insertion point to the start of the specified block.
setInsertionPointToStart(Block * block)347   void setInsertionPointToStart(Block *block) {
348     setInsertionPoint(block, block->begin());
349   }
350 
351   /// Sets the insertion point to the end of the specified block.
setInsertionPointToEnd(Block * block)352   void setInsertionPointToEnd(Block *block) {
353     setInsertionPoint(block, block->end());
354   }
355 
356   /// Return the block the current insertion point belongs to.  Note that the
357   /// the insertion point is not necessarily the end of the block.
getInsertionBlock()358   Block *getInsertionBlock() const { return block; }
359 
360   /// Returns the current insertion point of the builder.
getInsertionPoint()361   Block::iterator getInsertionPoint() const { return insertPoint; }
362 
363   /// Returns the current block of the builder.
getBlock()364   Block *getBlock() const { return block; }
365 
366   //===--------------------------------------------------------------------===//
367   // Block Creation
368   //===--------------------------------------------------------------------===//
369 
370   /// Add new block with 'argTypes' arguments and set the insertion point to the
371   /// end of it. The block is inserted at the provided insertion point of
372   /// 'parent'.
373   Block *createBlock(Region *parent, Region::iterator insertPt = {},
374                      TypeRange argTypes = llvm::None);
375 
376   /// Add new block with 'argTypes' arguments and set the insertion point to the
377   /// end of it. The block is placed before 'insertBefore'.
378   Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
379 
380   //===--------------------------------------------------------------------===//
381   // Operation Creation
382   //===--------------------------------------------------------------------===//
383 
384   /// Insert the given operation at the current insertion point and return it.
385   Operation *insert(Operation *op);
386 
387   /// Creates an operation given the fields represented as an OperationState.
388   Operation *createOperation(const OperationState &state);
389 
390   /// Create an operation of specific op type at the current insertion point.
391   template <typename OpTy, typename... Args>
create(Location location,Args &&...args)392   OpTy create(Location location, Args &&... args) {
393     OperationState state(location, OpTy::getOperationName());
394     if (!state.name.getAbstractOperation())
395       llvm::report_fatal_error("Building op `" +
396                                state.name.getStringRef().str() +
397                                "` but it isn't registered in this MLIRContext");
398     OpTy::build(*this, state, std::forward<Args>(args)...);
399     auto *op = createOperation(state);
400     auto result = dyn_cast<OpTy>(op);
401     assert(result && "builder didn't return the right type");
402     return result;
403   }
404 
405   /// Create an operation of specific op type at the current insertion point,
406   /// and immediately try to fold it. This functions populates 'results' with
407   /// the results after folding the operation.
408   template <typename OpTy, typename... Args>
createOrFold(SmallVectorImpl<Value> & results,Location location,Args &&...args)409   void createOrFold(SmallVectorImpl<Value> &results, Location location,
410                     Args &&... args) {
411     // Create the operation without using 'createOperation' as we don't want to
412     // insert it yet.
413     OperationState state(location, OpTy::getOperationName());
414     if (!state.name.getAbstractOperation())
415       llvm::report_fatal_error("Building op `" +
416                                state.name.getStringRef().str() +
417                                "` but it isn't registered in this MLIRContext");
418     OpTy::build(*this, state, std::forward<Args>(args)...);
419     Operation *op = Operation::create(state);
420 
421     // Fold the operation. If successful destroy it, otherwise insert it.
422     if (succeeded(tryFold(op, results)))
423       op->destroy();
424     else
425       insert(op);
426   }
427 
428   /// Overload to create or fold a single result operation.
429   template <typename OpTy, typename... Args>
430   typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
431                           Value>::type
createOrFold(Location location,Args &&...args)432   createOrFold(Location location, Args &&... args) {
433     SmallVector<Value, 1> results;
434     createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
435     return results.front();
436   }
437 
438   /// Overload to create or fold a zero result operation.
439   template <typename OpTy, typename... Args>
440   typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
441                           OpTy>::type
createOrFold(Location location,Args &&...args)442   createOrFold(Location location, Args &&... args) {
443     auto op = create<OpTy>(location, std::forward<Args>(args)...);
444     SmallVector<Value, 0> unused;
445     tryFold(op.getOperation(), unused);
446 
447     // Folding cannot remove a zero-result operation, so for convenience we
448     // continue to return it.
449     return op;
450   }
451 
452   /// Attempts to fold the given operation and places new results within
453   /// 'results'. Returns success if the operation was folded, failure otherwise.
454   /// Note: This function does not erase the operation on a successful fold.
455   LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
456 
457   /// Creates a deep copy of the specified operation, remapping any operands
458   /// that use values outside of the operation using the map that is provided
459   /// ( leaving them alone if no entry is present).  Replaces references to
460   /// cloned sub-operations to the corresponding operation that is copied,
461   /// and adds those mappings to the map.
462   Operation *clone(Operation &op, BlockAndValueMapping &mapper);
463   Operation *clone(Operation &op);
464 
465   /// Creates a deep copy of this operation but keep the operation regions
466   /// empty. Operands are remapped using `mapper` (if present), and `mapper` is
467   /// updated to contain the results.
cloneWithoutRegions(Operation & op,BlockAndValueMapping & mapper)468   Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
469     return insert(op.cloneWithoutRegions(mapper));
470   }
cloneWithoutRegions(Operation & op)471   Operation *cloneWithoutRegions(Operation &op) {
472     return insert(op.cloneWithoutRegions());
473   }
cloneWithoutRegions(OpT op)474   template <typename OpT> OpT cloneWithoutRegions(OpT op) {
475     return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
476   }
477 
478 private:
479   /// The current block this builder is inserting into.
480   Block *block = nullptr;
481   /// The insertion point within the block that this builder is inserting
482   /// before.
483   Block::iterator insertPoint;
484   /// The optional listener for events of this builder.
485   Listener *listener;
486 };
487 
488 } // namespace mlir
489 
490 #endif
491