1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/Support/raw_ostream.h"
19
20 using namespace mlir;
21
getIdentifier(StringRef str)22 Identifier Builder::getIdentifier(StringRef str) {
23 return Identifier::get(str, context);
24 }
25
26 //===----------------------------------------------------------------------===//
27 // Locations.
28 //===----------------------------------------------------------------------===//
29
getUnknownLoc()30 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31
getFileLineColLoc(Identifier filename,unsigned line,unsigned column)32 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33 unsigned column) {
34 return FileLineColLoc::get(filename, line, column, context);
35 }
36
getFusedLoc(ArrayRef<Location> locs,Attribute metadata)37 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
38 return FusedLoc::get(locs, metadata, context);
39 }
40
41 //===----------------------------------------------------------------------===//
42 // Types.
43 //===----------------------------------------------------------------------===//
44
getBF16Type()45 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
46
getF16Type()47 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
48
getF32Type()49 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
50
getF64Type()51 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
52
getIndexType()53 IndexType Builder::getIndexType() { return IndexType::get(context); }
54
getI1Type()55 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
56
getI32Type()57 IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
58
getI64Type()59 IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
60
getIntegerType(unsigned width)61 IntegerType Builder::getIntegerType(unsigned width) {
62 return IntegerType::get(width, context);
63 }
64
getIntegerType(unsigned width,bool isSigned)65 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
66 return IntegerType::get(
67 width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
68 }
69
getFunctionType(TypeRange inputs,TypeRange results)70 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
71 return FunctionType::get(inputs, results, context);
72 }
73
getTupleType(TypeRange elementTypes)74 TupleType Builder::getTupleType(TypeRange elementTypes) {
75 return TupleType::get(elementTypes, context);
76 }
77
getNoneType()78 NoneType Builder::getNoneType() { return NoneType::get(context); }
79
80 //===----------------------------------------------------------------------===//
81 // Attributes.
82 //===----------------------------------------------------------------------===//
83
getNamedAttr(StringRef name,Attribute val)84 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
85 return NamedAttribute(getIdentifier(name), val);
86 }
87
getUnitAttr()88 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
89
getBoolAttr(bool value)90 BoolAttr Builder::getBoolAttr(bool value) {
91 return BoolAttr::get(value, context);
92 }
93
getDictionaryAttr(ArrayRef<NamedAttribute> value)94 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
95 return DictionaryAttr::get(value, context);
96 }
97
getIndexAttr(int64_t value)98 IntegerAttr Builder::getIndexAttr(int64_t value) {
99 return IntegerAttr::get(getIndexType(), APInt(64, value));
100 }
101
getI64IntegerAttr(int64_t value)102 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
103 return IntegerAttr::get(getIntegerType(64), APInt(64, value));
104 }
105
getBoolVectorAttr(ArrayRef<bool> values)106 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
107 return DenseIntElementsAttr::get(
108 VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
109 values);
110 }
111
getI32VectorAttr(ArrayRef<int32_t> values)112 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
113 return DenseIntElementsAttr::get(
114 VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
115 values);
116 }
117
getI64VectorAttr(ArrayRef<int64_t> values)118 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
119 return DenseIntElementsAttr::get(
120 VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
121 values);
122 }
123
getI32TensorAttr(ArrayRef<int32_t> values)124 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
125 return DenseIntElementsAttr::get(
126 RankedTensorType::get(static_cast<int64_t>(values.size()),
127 getIntegerType(32)),
128 values);
129 }
130
getI64TensorAttr(ArrayRef<int64_t> values)131 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
132 return DenseIntElementsAttr::get(
133 RankedTensorType::get(static_cast<int64_t>(values.size()),
134 getIntegerType(64)),
135 values);
136 }
137
getIndexTensorAttr(ArrayRef<int64_t> values)138 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
139 return DenseIntElementsAttr::get(
140 RankedTensorType::get(static_cast<int64_t>(values.size()),
141 getIndexType()),
142 values);
143 }
144
getI32IntegerAttr(int32_t value)145 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
146 return IntegerAttr::get(getIntegerType(32), APInt(32, value));
147 }
148
getSI32IntegerAttr(int32_t value)149 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
150 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
151 APInt(32, value, /*isSigned=*/true));
152 }
153
getUI32IntegerAttr(uint32_t value)154 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
155 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
156 APInt(32, (uint64_t)value, /*isSigned=*/false));
157 }
158
getI16IntegerAttr(int16_t value)159 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
160 return IntegerAttr::get(getIntegerType(16), APInt(16, value));
161 }
162
getI8IntegerAttr(int8_t value)163 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
164 return IntegerAttr::get(getIntegerType(8), APInt(8, value));
165 }
166
getIntegerAttr(Type type,int64_t value)167 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
168 if (type.isIndex())
169 return IntegerAttr::get(type, APInt(64, value));
170 return IntegerAttr::get(
171 type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
172 }
173
getIntegerAttr(Type type,const APInt & value)174 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
175 return IntegerAttr::get(type, value);
176 }
177
getF64FloatAttr(double value)178 FloatAttr Builder::getF64FloatAttr(double value) {
179 return FloatAttr::get(getF64Type(), APFloat(value));
180 }
181
getF32FloatAttr(float value)182 FloatAttr Builder::getF32FloatAttr(float value) {
183 return FloatAttr::get(getF32Type(), APFloat(value));
184 }
185
getF16FloatAttr(float value)186 FloatAttr Builder::getF16FloatAttr(float value) {
187 return FloatAttr::get(getF16Type(), value);
188 }
189
getFloatAttr(Type type,double value)190 FloatAttr Builder::getFloatAttr(Type type, double value) {
191 return FloatAttr::get(type, value);
192 }
193
getFloatAttr(Type type,const APFloat & value)194 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
195 return FloatAttr::get(type, value);
196 }
197
getStringAttr(StringRef bytes)198 StringAttr Builder::getStringAttr(StringRef bytes) {
199 return StringAttr::get(bytes, context);
200 }
201
getArrayAttr(ArrayRef<Attribute> value)202 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
203 return ArrayAttr::get(value, context);
204 }
205
getSymbolRefAttr(Operation * value)206 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
207 auto symName =
208 value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
209 assert(symName && "value does not have a valid symbol name");
210 return getSymbolRefAttr(symName.getValue());
211 }
getSymbolRefAttr(StringRef value)212 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
213 return SymbolRefAttr::get(value, getContext());
214 }
215 SymbolRefAttr
getSymbolRefAttr(StringRef value,ArrayRef<FlatSymbolRefAttr> nestedReferences)216 Builder::getSymbolRefAttr(StringRef value,
217 ArrayRef<FlatSymbolRefAttr> nestedReferences) {
218 return SymbolRefAttr::get(value, nestedReferences, getContext());
219 }
220
getBoolArrayAttr(ArrayRef<bool> values)221 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
222 auto attrs = llvm::to_vector<8>(llvm::map_range(
223 values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
224 return getArrayAttr(attrs);
225 }
226
getI32ArrayAttr(ArrayRef<int32_t> values)227 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
228 auto attrs = llvm::to_vector<8>(llvm::map_range(
229 values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
230 return getArrayAttr(attrs);
231 }
getI64ArrayAttr(ArrayRef<int64_t> values)232 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
233 auto attrs = llvm::to_vector<8>(llvm::map_range(
234 values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
235 return getArrayAttr(attrs);
236 }
237
getIndexArrayAttr(ArrayRef<int64_t> values)238 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
239 auto attrs = llvm::to_vector<8>(
240 llvm::map_range(values, [this](int64_t v) -> Attribute {
241 return getIntegerAttr(IndexType::get(getContext()), v);
242 }));
243 return getArrayAttr(attrs);
244 }
245
getF32ArrayAttr(ArrayRef<float> values)246 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
247 auto attrs = llvm::to_vector<8>(llvm::map_range(
248 values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
249 return getArrayAttr(attrs);
250 }
251
getF64ArrayAttr(ArrayRef<double> values)252 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
253 auto attrs = llvm::to_vector<8>(llvm::map_range(
254 values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
255 return getArrayAttr(attrs);
256 }
257
getStrArrayAttr(ArrayRef<StringRef> values)258 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
259 auto attrs = llvm::to_vector<8>(llvm::map_range(
260 values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
261 return getArrayAttr(attrs);
262 }
263
getTypeArrayAttr(TypeRange values)264 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
265 auto attrs = llvm::to_vector<8>(llvm::map_range(
266 values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
267 return getArrayAttr(attrs);
268 }
269
getAffineMapArrayAttr(ArrayRef<AffineMap> values)270 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
271 auto attrs = llvm::to_vector<8>(llvm::map_range(
272 values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
273 return getArrayAttr(attrs);
274 }
275
getZeroAttr(Type type)276 Attribute Builder::getZeroAttr(Type type) {
277 if (type.isa<FloatType>())
278 return getFloatAttr(type, 0.0);
279 if (type.isa<IndexType>())
280 return getIndexAttr(0);
281 if (auto integerType = type.dyn_cast<IntegerType>())
282 return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
283 if (type.isa<RankedTensorType, VectorType>()) {
284 auto vtType = type.cast<ShapedType>();
285 auto element = getZeroAttr(vtType.getElementType());
286 if (!element)
287 return {};
288 return DenseElementsAttr::get(vtType, element);
289 }
290 return {};
291 }
292
293 //===----------------------------------------------------------------------===//
294 // Affine Expressions, Affine Maps, and Integer Sets.
295 //===----------------------------------------------------------------------===//
296
getAffineDimExpr(unsigned position)297 AffineExpr Builder::getAffineDimExpr(unsigned position) {
298 return mlir::getAffineDimExpr(position, context);
299 }
300
getAffineSymbolExpr(unsigned position)301 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
302 return mlir::getAffineSymbolExpr(position, context);
303 }
304
getAffineConstantExpr(int64_t constant)305 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
306 return mlir::getAffineConstantExpr(constant, context);
307 }
308
getEmptyAffineMap()309 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
310
getConstantAffineMap(int64_t val)311 AffineMap Builder::getConstantAffineMap(int64_t val) {
312 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
313 getAffineConstantExpr(val));
314 }
315
getDimIdentityMap()316 AffineMap Builder::getDimIdentityMap() {
317 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
318 }
319
getMultiDimIdentityMap(unsigned rank)320 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
321 SmallVector<AffineExpr, 4> dimExprs;
322 dimExprs.reserve(rank);
323 for (unsigned i = 0; i < rank; ++i)
324 dimExprs.push_back(getAffineDimExpr(i));
325 return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
326 context);
327 }
328
getSymbolIdentityMap()329 AffineMap Builder::getSymbolIdentityMap() {
330 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
331 getAffineSymbolExpr(0));
332 }
333
getSingleDimShiftAffineMap(int64_t shift)334 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
335 // expr = d0 + shift.
336 auto expr = getAffineDimExpr(0) + shift;
337 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
338 }
339
getShiftedAffineMap(AffineMap map,int64_t shift)340 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
341 SmallVector<AffineExpr, 4> shiftedResults;
342 shiftedResults.reserve(map.getNumResults());
343 for (auto resultExpr : map.getResults())
344 shiftedResults.push_back(resultExpr + shift);
345 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
346 context);
347 }
348
349 //===----------------------------------------------------------------------===//
350 // OpBuilder
351 //===----------------------------------------------------------------------===//
352
~Listener()353 OpBuilder::Listener::~Listener() {}
354
355 /// Insert the given operation at the current insertion point and return it.
insert(Operation * op)356 Operation *OpBuilder::insert(Operation *op) {
357 if (block)
358 block->getOperations().insert(insertPoint, op);
359
360 if (listener)
361 listener->notifyOperationInserted(op);
362 return op;
363 }
364
365 /// Add new block with 'argTypes' arguments and set the insertion point to the
366 /// end of it. The block is inserted at the provided insertion point of
367 /// 'parent'.
createBlock(Region * parent,Region::iterator insertPt,TypeRange argTypes)368 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
369 TypeRange argTypes) {
370 assert(parent && "expected valid parent region");
371 if (insertPt == Region::iterator())
372 insertPt = parent->end();
373
374 Block *b = new Block();
375 b->addArguments(argTypes);
376 parent->getBlocks().insert(insertPt, b);
377 setInsertionPointToEnd(b);
378
379 if (listener)
380 listener->notifyBlockCreated(b);
381 return b;
382 }
383
384 /// Add new block with 'argTypes' arguments and set the insertion point to the
385 /// end of it. The block is placed before 'insertBefore'.
createBlock(Block * insertBefore,TypeRange argTypes)386 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
387 assert(insertBefore && "expected valid insertion block");
388 return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
389 argTypes);
390 }
391
392 /// Create an operation given the fields represented as an OperationState.
createOperation(const OperationState & state)393 Operation *OpBuilder::createOperation(const OperationState &state) {
394 return insert(Operation::create(state));
395 }
396
397 /// Attempts to fold the given operation and places new results within
398 /// 'results'. Returns success if the operation was folded, failure otherwise.
399 /// Note: This function does not erase the operation on a successful fold.
tryFold(Operation * op,SmallVectorImpl<Value> & results)400 LogicalResult OpBuilder::tryFold(Operation *op,
401 SmallVectorImpl<Value> &results) {
402 results.reserve(op->getNumResults());
403 auto cleanupFailure = [&] {
404 results.assign(op->result_begin(), op->result_end());
405 return failure();
406 };
407
408 // If this operation is already a constant, there is nothing to do.
409 if (matchPattern(op, m_Constant()))
410 return cleanupFailure();
411
412 // Check to see if any operands to the operation is constant and whether
413 // the operation knows how to constant fold itself.
414 SmallVector<Attribute, 4> constOperands(op->getNumOperands());
415 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
416 matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
417
418 // Try to fold the operation.
419 SmallVector<OpFoldResult, 4> foldResults;
420 if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
421 return cleanupFailure();
422
423 // A temporary builder used for creating constants during folding.
424 OpBuilder cstBuilder(context);
425 SmallVector<Operation *, 1> generatedConstants;
426
427 // Populate the results with the folded results.
428 Dialect *dialect = op->getDialect();
429 for (auto &it : llvm::enumerate(foldResults)) {
430 // Normal values get pushed back directly.
431 if (auto value = it.value().dyn_cast<Value>()) {
432 results.push_back(value);
433 continue;
434 }
435
436 // Otherwise, try to materialize a constant operation.
437 if (!dialect)
438 return cleanupFailure();
439
440 // Ask the dialect to materialize a constant operation for this value.
441 Attribute attr = it.value().get<Attribute>();
442 auto *constOp = dialect->materializeConstant(
443 cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
444 if (!constOp) {
445 // Erase any generated constants.
446 for (Operation *cst : generatedConstants)
447 cst->erase();
448 return cleanupFailure();
449 }
450 assert(matchPattern(constOp, m_Constant()));
451
452 generatedConstants.push_back(constOp);
453 results.push_back(constOp->getResult(0));
454 }
455
456 // If we were successful, insert any generated constants.
457 for (Operation *cst : generatedConstants)
458 insert(cst);
459
460 return success();
461 }
462
clone(Operation & op,BlockAndValueMapping & mapper)463 Operation *OpBuilder::clone(Operation &op, BlockAndValueMapping &mapper) {
464 Operation *newOp = op.clone(mapper);
465 // The `insert` call below handles the notification for inserting `newOp`
466 // itself. But if `newOp` has any regions, we need to notify the listener
467 // about any ops that got inserted inside those regions as part of cloning.
468 if (listener) {
469 auto walkFn = [&](Operation *walkedOp) {
470 listener->notifyOperationInserted(walkedOp);
471 };
472 for (Region ®ion : newOp->getRegions())
473 region.walk(walkFn);
474 }
475 return insert(newOp);
476 }
477
clone(Operation & op)478 Operation *OpBuilder::clone(Operation &op) {
479 BlockAndValueMapping mapper;
480 return clone(op, mapper);
481 }
482