1 //===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_BUILDERS_H 10 #define MLIR_IR_BUILDERS_H 11 12 #include "mlir/IR/OpDefinition.h" 13 14 namespace mlir { 15 16 class AffineExpr; 17 class BlockAndValueMapping; 18 class UnknownLoc; 19 class FileLineColLoc; 20 class Type; 21 class PrimitiveType; 22 class IntegerType; 23 class FloatType; 24 class FunctionType; 25 class IndexType; 26 class MemRefType; 27 class VectorType; 28 class RankedTensorType; 29 class UnrankedTensorType; 30 class TupleType; 31 class NoneType; 32 class BoolAttr; 33 class IntegerAttr; 34 class FloatAttr; 35 class StringAttr; 36 class TypeAttr; 37 class ArrayAttr; 38 class SymbolRefAttr; 39 class ElementsAttr; 40 class DenseElementsAttr; 41 class DenseIntElementsAttr; 42 class AffineMapAttr; 43 class AffineMap; 44 class UnitAttr; 45 46 /// This class is a general helper class for creating context-global objects 47 /// like types, attributes, and affine expressions. 48 class Builder { 49 public: Builder(MLIRContext * context)50 explicit Builder(MLIRContext *context) : context(context) {} Builder(Operation * op)51 explicit Builder(Operation *op) : Builder(op->getContext()) {} 52 getContext()53 MLIRContext *getContext() const { return context; } 54 55 Identifier getIdentifier(StringRef str); 56 57 // Locations. 58 Location getUnknownLoc(); 59 Location getFileLineColLoc(Identifier filename, unsigned line, 60 unsigned column); 61 Location getFusedLoc(ArrayRef<Location> locs, 62 Attribute metadata = Attribute()); 63 64 // Types. 65 FloatType getBF16Type(); 66 FloatType getF16Type(); 67 FloatType getF32Type(); 68 FloatType getF64Type(); 69 70 IndexType getIndexType(); 71 72 IntegerType getI1Type(); 73 IntegerType getI32Type(); 74 IntegerType getI64Type(); 75 IntegerType getIntegerType(unsigned width); 76 IntegerType getIntegerType(unsigned width, bool isSigned); 77 FunctionType getFunctionType(TypeRange inputs, TypeRange results); 78 TupleType getTupleType(TypeRange elementTypes); 79 NoneType getNoneType(); 80 81 /// Get or construct an instance of the type 'ty' with provided arguments. getType(Args...args)82 template <typename Ty, typename... Args> Ty getType(Args... args) { 83 return Ty::get(context, args...); 84 } 85 86 // Attributes. 87 NamedAttribute getNamedAttr(StringRef name, Attribute val); 88 89 UnitAttr getUnitAttr(); 90 BoolAttr getBoolAttr(bool value); 91 DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value); 92 IntegerAttr getIntegerAttr(Type type, int64_t value); 93 IntegerAttr getIntegerAttr(Type type, const APInt &value); 94 FloatAttr getFloatAttr(Type type, double value); 95 FloatAttr getFloatAttr(Type type, const APFloat &value); 96 StringAttr getStringAttr(StringRef bytes); 97 ArrayAttr getArrayAttr(ArrayRef<Attribute> value); 98 FlatSymbolRefAttr getSymbolRefAttr(Operation *value); 99 FlatSymbolRefAttr getSymbolRefAttr(StringRef value); 100 SymbolRefAttr getSymbolRefAttr(StringRef value, 101 ArrayRef<FlatSymbolRefAttr> nestedReferences); 102 103 // Returns a 0-valued attribute of the given `type`. This function only 104 // supports boolean, integer, and 16-/32-/64-bit float types, and vector or 105 // ranked tensor of them. Returns null attribute otherwise. 106 Attribute getZeroAttr(Type type); 107 108 // Convenience methods for fixed types. 109 FloatAttr getF16FloatAttr(float value); 110 FloatAttr getF32FloatAttr(float value); 111 FloatAttr getF64FloatAttr(double value); 112 113 IntegerAttr getI8IntegerAttr(int8_t value); 114 IntegerAttr getI16IntegerAttr(int16_t value); 115 IntegerAttr getI32IntegerAttr(int32_t value); 116 IntegerAttr getI64IntegerAttr(int64_t value); 117 IntegerAttr getIndexAttr(int64_t value); 118 119 /// Signed and unsigned integer attribute getters. 120 IntegerAttr getSI32IntegerAttr(int32_t value); 121 IntegerAttr getUI32IntegerAttr(uint32_t value); 122 123 /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty. 124 DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values); 125 DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values); 126 DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values); 127 128 /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty. 129 /// These are generally preferable for representing general lists of integers 130 /// as attributes. 131 DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values); 132 DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values); 133 DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values); 134 135 ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values); 136 ArrayAttr getBoolArrayAttr(ArrayRef<bool> values); 137 ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values); 138 ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values); 139 ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values); 140 ArrayAttr getF32ArrayAttr(ArrayRef<float> values); 141 ArrayAttr getF64ArrayAttr(ArrayRef<double> values); 142 ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values); 143 ArrayAttr getTypeArrayAttr(TypeRange values); 144 145 // Affine expressions and affine maps. 146 AffineExpr getAffineDimExpr(unsigned position); 147 AffineExpr getAffineSymbolExpr(unsigned position); 148 AffineExpr getAffineConstantExpr(int64_t constant); 149 150 // Special cases of affine maps and integer sets 151 /// Returns a zero result affine map with no dimensions or symbols: () -> (). 152 AffineMap getEmptyAffineMap(); 153 /// Returns a single constant result affine map with 0 dimensions and 0 154 /// symbols. One constant result: () -> (val). 155 AffineMap getConstantAffineMap(int64_t val); 156 // One dimension id identity map: (i) -> (i). 157 AffineMap getDimIdentityMap(); 158 // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2). 159 AffineMap getMultiDimIdentityMap(unsigned rank); 160 // One symbol identity map: ()[s] -> (s). 161 AffineMap getSymbolIdentityMap(); 162 163 /// Returns a map that shifts its (single) input dimension by 'shift'. 164 /// (d0) -> (d0 + shift) 165 AffineMap getSingleDimShiftAffineMap(int64_t shift); 166 167 /// Returns an affine map that is a translation (shift) of all result 168 /// expressions in 'map' by 'shift'. 169 /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2 170 /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2) 171 AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); 172 173 protected: 174 MLIRContext *context; 175 }; 176 177 /// This class helps build Operations. Operations that are created are 178 /// automatically inserted at an insertion point. The builder is copyable. 179 class OpBuilder : public Builder { 180 public: 181 struct Listener; 182 183 /// Create a builder with the given context. 184 explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) Builder(ctx)185 : Builder(ctx), listener(listener) {} 186 187 /// Create a builder and set the insertion point to the start of the region. 188 explicit OpBuilder(Region *region, Listener *listener = nullptr) 189 : OpBuilder(region->getContext(), listener) { 190 if (!region->empty()) 191 setInsertionPoint(®ion->front(), region->front().begin()); 192 } 193 explicit OpBuilder(Region ®ion, Listener *listener = nullptr) 194 : OpBuilder(®ion, listener) {} 195 196 /// Create a builder and set insertion point to the given operation, which 197 /// will cause subsequent insertions to go right before it. 198 explicit OpBuilder(Operation *op, Listener *listener = nullptr) 199 : OpBuilder(op->getContext(), listener) { 200 setInsertionPoint(op); 201 } 202 203 OpBuilder(Block *block, Block::iterator insertPoint, 204 Listener *listener = nullptr) 205 : OpBuilder(block->getParent()->getContext(), listener) { 206 setInsertionPoint(block, insertPoint); 207 } 208 209 /// Create a builder and set the insertion point to before the first operation 210 /// in the block but still inside the block. 211 static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { 212 return OpBuilder(block, block->begin(), listener); 213 } 214 215 /// Create a builder and set the insertion point to after the last operation 216 /// in the block but still inside the block. 217 static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { 218 return OpBuilder(block, block->end(), listener); 219 } 220 221 /// Create a builder and set the insertion point to before the block 222 /// terminator. 223 static OpBuilder atBlockTerminator(Block *block, 224 Listener *listener = nullptr) { 225 auto *terminator = block->getTerminator(); 226 assert(terminator != nullptr && "the block has no terminator"); 227 return OpBuilder(block, Block::iterator(terminator), listener); 228 } 229 230 //===--------------------------------------------------------------------===// 231 // Listeners 232 //===--------------------------------------------------------------------===// 233 234 /// This class represents a listener that may be used to hook into various 235 /// actions within an OpBuilder. 236 struct Listener { 237 virtual ~Listener(); 238 239 /// Notification handler for when an operation is inserted into the builder. 240 /// `op` is the operation that was inserted. notifyOperationInsertedListener241 virtual void notifyOperationInserted(Operation *op) {} 242 243 /// Notification handler for when a block is created using the builder. 244 /// `block` is the block that was created. notifyBlockCreatedListener245 virtual void notifyBlockCreated(Block *block) {} 246 }; 247 248 /// Sets the listener of this builder to the one provided. setListener(Listener * newListener)249 void setListener(Listener *newListener) { listener = newListener; } 250 251 /// Returns the current listener of this builder, or nullptr if this builder 252 /// doesn't have a listener. getListener()253 Listener *getListener() const { return listener; } 254 255 //===--------------------------------------------------------------------===// 256 // Insertion Point Management 257 //===--------------------------------------------------------------------===// 258 259 /// This class represents a saved insertion point. 260 class InsertPoint { 261 public: 262 /// Creates a new insertion point which doesn't point to anything. 263 InsertPoint() = default; 264 265 /// Creates a new insertion point at the given location. InsertPoint(Block * insertBlock,Block::iterator insertPt)266 InsertPoint(Block *insertBlock, Block::iterator insertPt) 267 : block(insertBlock), point(insertPt) {} 268 269 /// Returns true if this insert point is set. isSet()270 bool isSet() const { return (block != nullptr); } 271 getBlock()272 Block *getBlock() const { return block; } getPoint()273 Block::iterator getPoint() const { return point; } 274 275 private: 276 Block *block = nullptr; 277 Block::iterator point; 278 }; 279 280 /// RAII guard to reset the insertion point of the builder when destroyed. 281 class InsertionGuard { 282 public: InsertionGuard(OpBuilder & builder)283 InsertionGuard(OpBuilder &builder) 284 : builder(builder), ip(builder.saveInsertionPoint()) {} ~InsertionGuard()285 ~InsertionGuard() { builder.restoreInsertionPoint(ip); } 286 287 private: 288 OpBuilder &builder; 289 OpBuilder::InsertPoint ip; 290 }; 291 292 /// Reset the insertion point to no location. Creating an operation without a 293 /// set insertion point is an error, but this can still be useful when the 294 /// current insertion point a builder refers to is being removed. clearInsertionPoint()295 void clearInsertionPoint() { 296 this->block = nullptr; 297 insertPoint = Block::iterator(); 298 } 299 300 /// Return a saved insertion point. saveInsertionPoint()301 InsertPoint saveInsertionPoint() const { 302 return InsertPoint(getInsertionBlock(), getInsertionPoint()); 303 } 304 305 /// Restore the insert point to a previously saved point. restoreInsertionPoint(InsertPoint ip)306 void restoreInsertionPoint(InsertPoint ip) { 307 if (ip.isSet()) 308 setInsertionPoint(ip.getBlock(), ip.getPoint()); 309 else 310 clearInsertionPoint(); 311 } 312 313 /// Set the insertion point to the specified location. setInsertionPoint(Block * block,Block::iterator insertPoint)314 void setInsertionPoint(Block *block, Block::iterator insertPoint) { 315 // TODO: check that insertPoint is in this rather than some other block. 316 this->block = block; 317 this->insertPoint = insertPoint; 318 } 319 320 /// Sets the insertion point to the specified operation, which will cause 321 /// subsequent insertions to go right before it. setInsertionPoint(Operation * op)322 void setInsertionPoint(Operation *op) { 323 setInsertionPoint(op->getBlock(), Block::iterator(op)); 324 } 325 326 /// Sets the insertion point to the node after the specified operation, which 327 /// will cause subsequent insertions to go right after it. setInsertionPointAfter(Operation * op)328 void setInsertionPointAfter(Operation *op) { 329 setInsertionPoint(op->getBlock(), ++Block::iterator(op)); 330 } 331 332 /// Sets the insertion point to the node after the specified value. If value 333 /// has a defining operation, sets the insertion point to the node after such 334 /// defining operation. This will cause subsequent insertions to go right 335 /// after it. Otherwise, value is a BlockArgumen. Sets the insertion point to 336 /// the start of its block. setInsertionPointAfterValue(Value val)337 void setInsertionPointAfterValue(Value val) { 338 if (Operation *op = val.getDefiningOp()) { 339 setInsertionPointAfter(op); 340 } else { 341 auto blockArg = val.cast<BlockArgument>(); 342 setInsertionPointToStart(blockArg.getOwner()); 343 } 344 } 345 346 /// Sets the insertion point to the start of the specified block. setInsertionPointToStart(Block * block)347 void setInsertionPointToStart(Block *block) { 348 setInsertionPoint(block, block->begin()); 349 } 350 351 /// Sets the insertion point to the end of the specified block. setInsertionPointToEnd(Block * block)352 void setInsertionPointToEnd(Block *block) { 353 setInsertionPoint(block, block->end()); 354 } 355 356 /// Return the block the current insertion point belongs to. Note that the 357 /// the insertion point is not necessarily the end of the block. getInsertionBlock()358 Block *getInsertionBlock() const { return block; } 359 360 /// Returns the current insertion point of the builder. getInsertionPoint()361 Block::iterator getInsertionPoint() const { return insertPoint; } 362 363 /// Returns the current block of the builder. getBlock()364 Block *getBlock() const { return block; } 365 366 //===--------------------------------------------------------------------===// 367 // Block Creation 368 //===--------------------------------------------------------------------===// 369 370 /// Add new block with 'argTypes' arguments and set the insertion point to the 371 /// end of it. The block is inserted at the provided insertion point of 372 /// 'parent'. 373 Block *createBlock(Region *parent, Region::iterator insertPt = {}, 374 TypeRange argTypes = llvm::None); 375 376 /// Add new block with 'argTypes' arguments and set the insertion point to the 377 /// end of it. The block is placed before 'insertBefore'. 378 Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None); 379 380 //===--------------------------------------------------------------------===// 381 // Operation Creation 382 //===--------------------------------------------------------------------===// 383 384 /// Insert the given operation at the current insertion point and return it. 385 Operation *insert(Operation *op); 386 387 /// Creates an operation given the fields represented as an OperationState. 388 Operation *createOperation(const OperationState &state); 389 390 /// Create an operation of specific op type at the current insertion point. 391 template <typename OpTy, typename... Args> create(Location location,Args &&...args)392 OpTy create(Location location, Args &&... args) { 393 OperationState state(location, OpTy::getOperationName()); 394 if (!state.name.getAbstractOperation()) 395 llvm::report_fatal_error("Building op `" + 396 state.name.getStringRef().str() + 397 "` but it isn't registered in this MLIRContext"); 398 OpTy::build(*this, state, std::forward<Args>(args)...); 399 auto *op = createOperation(state); 400 auto result = dyn_cast<OpTy>(op); 401 assert(result && "builder didn't return the right type"); 402 return result; 403 } 404 405 /// Create an operation of specific op type at the current insertion point, 406 /// and immediately try to fold it. This functions populates 'results' with 407 /// the results after folding the operation. 408 template <typename OpTy, typename... Args> createOrFold(SmallVectorImpl<Value> & results,Location location,Args &&...args)409 void createOrFold(SmallVectorImpl<Value> &results, Location location, 410 Args &&... args) { 411 // Create the operation without using 'createOperation' as we don't want to 412 // insert it yet. 413 OperationState state(location, OpTy::getOperationName()); 414 if (!state.name.getAbstractOperation()) 415 llvm::report_fatal_error("Building op `" + 416 state.name.getStringRef().str() + 417 "` but it isn't registered in this MLIRContext"); 418 OpTy::build(*this, state, std::forward<Args>(args)...); 419 Operation *op = Operation::create(state); 420 421 // Fold the operation. If successful destroy it, otherwise insert it. 422 if (succeeded(tryFold(op, results))) 423 op->destroy(); 424 else 425 insert(op); 426 } 427 428 /// Overload to create or fold a single result operation. 429 template <typename OpTy, typename... Args> 430 typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(), 431 Value>::type createOrFold(Location location,Args &&...args)432 createOrFold(Location location, Args &&... args) { 433 SmallVector<Value, 1> results; 434 createOrFold<OpTy>(results, location, std::forward<Args>(args)...); 435 return results.front(); 436 } 437 438 /// Overload to create or fold a zero result operation. 439 template <typename OpTy, typename... Args> 440 typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(), 441 OpTy>::type createOrFold(Location location,Args &&...args)442 createOrFold(Location location, Args &&... args) { 443 auto op = create<OpTy>(location, std::forward<Args>(args)...); 444 SmallVector<Value, 0> unused; 445 tryFold(op.getOperation(), unused); 446 447 // Folding cannot remove a zero-result operation, so for convenience we 448 // continue to return it. 449 return op; 450 } 451 452 /// Attempts to fold the given operation and places new results within 453 /// 'results'. Returns success if the operation was folded, failure otherwise. 454 /// Note: This function does not erase the operation on a successful fold. 455 LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results); 456 457 /// Creates a deep copy of the specified operation, remapping any operands 458 /// that use values outside of the operation using the map that is provided 459 /// ( leaving them alone if no entry is present). Replaces references to 460 /// cloned sub-operations to the corresponding operation that is copied, 461 /// and adds those mappings to the map. 462 Operation *clone(Operation &op, BlockAndValueMapping &mapper); 463 Operation *clone(Operation &op); 464 465 /// Creates a deep copy of this operation but keep the operation regions 466 /// empty. Operands are remapped using `mapper` (if present), and `mapper` is 467 /// updated to contain the results. cloneWithoutRegions(Operation & op,BlockAndValueMapping & mapper)468 Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) { 469 return insert(op.cloneWithoutRegions(mapper)); 470 } cloneWithoutRegions(Operation & op)471 Operation *cloneWithoutRegions(Operation &op) { 472 return insert(op.cloneWithoutRegions()); 473 } cloneWithoutRegions(OpT op)474 template <typename OpT> OpT cloneWithoutRegions(OpT op) { 475 return cast<OpT>(cloneWithoutRegions(*op.getOperation())); 476 } 477 478 private: 479 /// The current block this builder is inserting into. 480 Block *block = nullptr; 481 /// The insertion point within the block that this builder is inserting 482 /// before. 483 Block::iterator insertPoint; 484 /// The optional listener for events of this builder. 485 Listener *listener; 486 }; 487 488 } // namespace mlir 489 490 #endif 491