1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SCF/SCF.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/Support/MathExtras.h"
14 #include "mlir/Transforms/InliningUtils.h"
15
16 using namespace mlir;
17 using namespace mlir::scf;
18
19 //===----------------------------------------------------------------------===//
20 // SCFDialect Dialect Interfaces
21 //===----------------------------------------------------------------------===//
22
23 namespace {
24 struct SCFInlinerInterface : public DialectInlinerInterface {
25 using DialectInlinerInterface::DialectInlinerInterface;
26 // We don't have any special restrictions on what can be inlined into
27 // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInline__anon89baa06c0111::SCFInlinerInterface28 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
29 BlockAndValueMapping &valueMapping) const final {
30 return true;
31 }
32 // Operations in scf dialect are always legal to inline since they are
33 // pure.
isLegalToInline__anon89baa06c0111::SCFInlinerInterface34 bool isLegalToInline(Operation *, Region *, bool,
35 BlockAndValueMapping &) const final {
36 return true;
37 }
38 // Handle the given inlined terminator by replacing it with a new operation
39 // as necessary. Required when the region has only one block.
handleTerminator__anon89baa06c0111::SCFInlinerInterface40 void handleTerminator(Operation *op,
41 ArrayRef<Value> valuesToRepl) const final {
42 auto retValOp = dyn_cast<scf::YieldOp>(op);
43 if (!retValOp)
44 return;
45
46 for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
47 std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
48 }
49 }
50 };
51 } // end anonymous namespace
52
53 //===----------------------------------------------------------------------===//
54 // SCFDialect
55 //===----------------------------------------------------------------------===//
56
initialize()57 void SCFDialect::initialize() {
58 addOperations<
59 #define GET_OP_LIST
60 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
61 >();
62 addInterfaces<SCFInlinerInterface>();
63 }
64
65 /// Default callback for IfOp builders. Inserts a yield without arguments.
buildTerminatedBody(OpBuilder & builder,Location loc)66 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
67 builder.create<scf::YieldOp>(loc);
68 }
69
70 //===----------------------------------------------------------------------===//
71 // ForOp
72 //===----------------------------------------------------------------------===//
73
build(OpBuilder & builder,OperationState & result,Value lb,Value ub,Value step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)74 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
75 Value ub, Value step, ValueRange iterArgs,
76 BodyBuilderFn bodyBuilder) {
77 result.addOperands({lb, ub, step});
78 result.addOperands(iterArgs);
79 for (Value v : iterArgs)
80 result.addTypes(v.getType());
81 Region *bodyRegion = result.addRegion();
82 bodyRegion->push_back(new Block);
83 Block &bodyBlock = bodyRegion->front();
84 bodyBlock.addArgument(builder.getIndexType());
85 for (Value v : iterArgs)
86 bodyBlock.addArgument(v.getType());
87
88 // Create the default terminator if the builder is not provided and if the
89 // iteration arguments are not provided. Otherwise, leave this to the caller
90 // because we don't know which values to return from the loop.
91 if (iterArgs.empty() && !bodyBuilder) {
92 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
93 } else if (bodyBuilder) {
94 OpBuilder::InsertionGuard guard(builder);
95 builder.setInsertionPointToStart(&bodyBlock);
96 bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
97 bodyBlock.getArguments().drop_front());
98 }
99 }
100
verify(ForOp op)101 static LogicalResult verify(ForOp op) {
102 if (auto cst = op.step().getDefiningOp<ConstantIndexOp>())
103 if (cst.getValue() <= 0)
104 return op.emitOpError("constant step operand must be positive");
105
106 // Check that the body defines as single block argument for the induction
107 // variable.
108 auto *body = op.getBody();
109 if (!body->getArgument(0).getType().isIndex())
110 return op.emitOpError(
111 "expected body first argument to be an index argument for "
112 "the induction variable");
113
114 auto opNumResults = op.getNumResults();
115 if (opNumResults == 0)
116 return success();
117 // If ForOp defines values, check that the number and types of
118 // the defined values match ForOp initial iter operands and backedge
119 // basic block arguments.
120 if (op.getNumIterOperands() != opNumResults)
121 return op.emitOpError(
122 "mismatch in number of loop-carried values and defined values");
123 if (op.getNumRegionIterArgs() != opNumResults)
124 return op.emitOpError(
125 "mismatch in number of basic block args and defined values");
126 auto iterOperands = op.getIterOperands();
127 auto iterArgs = op.getRegionIterArgs();
128 auto opResults = op.getResults();
129 unsigned i = 0;
130 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
131 if (std::get<0>(e).getType() != std::get<2>(e).getType())
132 return op.emitOpError() << "types mismatch between " << i
133 << "th iter operand and defined value";
134 if (std::get<1>(e).getType() != std::get<2>(e).getType())
135 return op.emitOpError() << "types mismatch between " << i
136 << "th iter region arg and defined value";
137
138 i++;
139 }
140
141 return RegionBranchOpInterface::verifyTypes(op);
142 }
143
144 /// Prints the initialization list in the form of
145 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
146 /// where 'inner' values are assumed to be region arguments and 'outer' values
147 /// are regular SSA values.
printInitializationList(OpAsmPrinter & p,Block::BlockArgListType blocksArgs,ValueRange initializers,StringRef prefix="")148 static void printInitializationList(OpAsmPrinter &p,
149 Block::BlockArgListType blocksArgs,
150 ValueRange initializers,
151 StringRef prefix = "") {
152 assert(blocksArgs.size() == initializers.size() &&
153 "expected same length of arguments and initializers");
154 if (initializers.empty())
155 return;
156
157 p << prefix << '(';
158 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
159 p << std::get<0>(it) << " = " << std::get<1>(it);
160 });
161 p << ")";
162 }
163
print(OpAsmPrinter & p,ForOp op)164 static void print(OpAsmPrinter &p, ForOp op) {
165 p << op.getOperationName() << " " << op.getInductionVar() << " = "
166 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
167
168 printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(),
169 " iter_args");
170 if (!op.getIterOperands().empty())
171 p << " -> (" << op.getIterOperands().getTypes() << ')';
172 p.printRegion(op.region(),
173 /*printEntryBlockArgs=*/false,
174 /*printBlockTerminators=*/op.hasIterOperands());
175 p.printOptionalAttrDict(op.getAttrs());
176 }
177
parseForOp(OpAsmParser & parser,OperationState & result)178 static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
179 auto &builder = parser.getBuilder();
180 OpAsmParser::OperandType inductionVariable, lb, ub, step;
181 // Parse the induction variable followed by '='.
182 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
183 return failure();
184
185 // Parse loop bounds.
186 Type indexType = builder.getIndexType();
187 if (parser.parseOperand(lb) ||
188 parser.resolveOperand(lb, indexType, result.operands) ||
189 parser.parseKeyword("to") || parser.parseOperand(ub) ||
190 parser.resolveOperand(ub, indexType, result.operands) ||
191 parser.parseKeyword("step") || parser.parseOperand(step) ||
192 parser.resolveOperand(step, indexType, result.operands))
193 return failure();
194
195 // Parse the optional initial iteration arguments.
196 SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
197 SmallVector<Type, 4> argTypes;
198 regionArgs.push_back(inductionVariable);
199
200 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
201 // Parse assignment list and results type list.
202 if (parser.parseAssignmentList(regionArgs, operands) ||
203 parser.parseArrowTypeList(result.types))
204 return failure();
205 // Resolve input operands.
206 for (auto operand_type : llvm::zip(operands, result.types))
207 if (parser.resolveOperand(std::get<0>(operand_type),
208 std::get<1>(operand_type), result.operands))
209 return failure();
210 }
211 // Induction variable.
212 argTypes.push_back(indexType);
213 // Loop carried variables
214 argTypes.append(result.types.begin(), result.types.end());
215 // Parse the body region.
216 Region *body = result.addRegion();
217 if (regionArgs.size() != argTypes.size())
218 return parser.emitError(
219 parser.getNameLoc(),
220 "mismatch in number of loop-carried values and defined values");
221
222 if (parser.parseRegion(*body, regionArgs, argTypes))
223 return failure();
224
225 ForOp::ensureTerminator(*body, builder, result.location);
226
227 // Parse the optional attribute list.
228 if (parser.parseOptionalAttrDict(result.attributes))
229 return failure();
230
231 return success();
232 }
233
getLoopBody()234 Region &ForOp::getLoopBody() { return region(); }
235
isDefinedOutsideOfLoop(Value value)236 bool ForOp::isDefinedOutsideOfLoop(Value value) {
237 return !region().isAncestor(value.getParentRegion());
238 }
239
moveOutOfLoop(ArrayRef<Operation * > ops)240 LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
241 for (auto op : ops)
242 op->moveBefore(*this);
243 return success();
244 }
245
getForInductionVarOwner(Value val)246 ForOp mlir::scf::getForInductionVarOwner(Value val) {
247 auto ivArg = val.dyn_cast<BlockArgument>();
248 if (!ivArg)
249 return ForOp();
250 assert(ivArg.getOwner() && "unlinked block argument");
251 auto *containingOp = ivArg.getOwner()->getParentOp();
252 return dyn_cast_or_null<ForOp>(containingOp);
253 }
254
255 /// Return operands used when entering the region at 'index'. These operands
256 /// correspond to the loop iterator operands, i.e., those exclusing the
257 /// induction variable. LoopOp only has one region, so 0 is the only valid value
258 /// for `index`.
getSuccessorEntryOperands(unsigned index)259 OperandRange ForOp::getSuccessorEntryOperands(unsigned index) {
260 assert(index == 0 && "invalid region index");
261
262 // The initial operands map to the loop arguments after the induction
263 // variable.
264 return initArgs();
265 }
266
267 /// Given the region at `index`, or the parent operation if `index` is None,
268 /// return the successor regions. These are the regions that may be selected
269 /// during the flow of control. `operands` is a set of optional attributes that
270 /// correspond to a constant value for each operand, or null if that operand is
271 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)272 void ForOp::getSuccessorRegions(Optional<unsigned> index,
273 ArrayRef<Attribute> operands,
274 SmallVectorImpl<RegionSuccessor> ®ions) {
275 // If the predecessor is the ForOp, branch into the body using the iterator
276 // arguments.
277 if (!index.hasValue()) {
278 regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
279 return;
280 }
281
282 // Otherwise, the loop may branch back to itself or the parent operation.
283 assert(index.getValue() == 0 && "expected loop region");
284 regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
285 regions.push_back(RegionSuccessor(getResults()));
286 }
287
getNumRegionInvocations(ArrayRef<Attribute> operands,SmallVectorImpl<int64_t> & countPerRegion)288 void ForOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
289 SmallVectorImpl<int64_t> &countPerRegion) {
290 assert(countPerRegion.empty());
291 countPerRegion.resize(1);
292
293 auto lb = operands[0].dyn_cast_or_null<IntegerAttr>();
294 auto ub = operands[1].dyn_cast_or_null<IntegerAttr>();
295 auto step = operands[2].dyn_cast_or_null<IntegerAttr>();
296
297 // Loop bounds are not known statically.
298 if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) {
299 countPerRegion[0] = -1;
300 return;
301 }
302
303 countPerRegion[0] =
304 ceilDiv(ub.getValue().getSExtValue() - lb.getValue().getSExtValue(),
305 step.getValue().getSExtValue());
306 }
307
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,ValueRange iterArgs,function_ref<ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilder)308 LoopNest mlir::scf::buildLoopNest(
309 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
310 ValueRange steps, ValueRange iterArgs,
311 function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
312 bodyBuilder) {
313 assert(lbs.size() == ubs.size() &&
314 "expected the same number of lower and upper bounds");
315 assert(lbs.size() == steps.size() &&
316 "expected the same number of lower bounds and steps");
317
318 // If there are no bounds, call the body-building function and return early.
319 if (lbs.empty()) {
320 ValueVector results =
321 bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
322 : ValueVector();
323 assert(results.size() == iterArgs.size() &&
324 "loop nest body must return as many values as loop has iteration "
325 "arguments");
326 return LoopNest();
327 }
328
329 // First, create the loop structure iteratively using the body-builder
330 // callback of `ForOp::build`. Do not create `YieldOp`s yet.
331 OpBuilder::InsertionGuard guard(builder);
332 SmallVector<scf::ForOp, 4> loops;
333 SmallVector<Value, 4> ivs;
334 loops.reserve(lbs.size());
335 ivs.reserve(lbs.size());
336 ValueRange currentIterArgs = iterArgs;
337 Location currentLoc = loc;
338 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
339 auto loop = builder.create<scf::ForOp>(
340 currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
341 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
342 ValueRange args) {
343 ivs.push_back(iv);
344 // It is safe to store ValueRange args because it points to block
345 // arguments of a loop operation that we also own.
346 currentIterArgs = args;
347 currentLoc = nestedLoc;
348 });
349 // Set the builder to point to the body of the newly created loop. We don't
350 // do this in the callback because the builder is reset when the callback
351 // returns.
352 builder.setInsertionPointToStart(loop.getBody());
353 loops.push_back(loop);
354 }
355
356 // For all loops but the innermost, yield the results of the nested loop.
357 for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
358 builder.setInsertionPointToEnd(loops[i].getBody());
359 builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
360 }
361
362 // In the body of the innermost loop, call the body building function if any
363 // and yield its results.
364 builder.setInsertionPointToStart(loops.back().getBody());
365 ValueVector results = bodyBuilder
366 ? bodyBuilder(builder, currentLoc, ivs,
367 loops.back().getRegionIterArgs())
368 : ValueVector();
369 assert(results.size() == iterArgs.size() &&
370 "loop nest body must return as many values as loop has iteration "
371 "arguments");
372 builder.setInsertionPointToEnd(loops.back().getBody());
373 builder.create<scf::YieldOp>(loc, results);
374
375 // Return the loops.
376 LoopNest res;
377 res.loops.assign(loops.begin(), loops.end());
378 return res;
379 }
380
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)381 LoopNest mlir::scf::buildLoopNest(
382 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
383 ValueRange steps,
384 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
385 // Delegate to the main function by wrapping the body builder.
386 return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
387 [&bodyBuilder](OpBuilder &nestedBuilder,
388 Location nestedLoc, ValueRange ivs,
389 ValueRange) -> ValueVector {
390 if (bodyBuilder)
391 bodyBuilder(nestedBuilder, nestedLoc, ivs);
392 return {};
393 });
394 }
395
396 /// Replaces the given op with the contents of the given single-block region,
397 /// using the operands of the block terminator to replace operation results.
replaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})398 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
399 Region ®ion, ValueRange blockArgs = {}) {
400 assert(llvm::hasSingleElement(region) && "expected single-region block");
401 Block *block = ®ion.front();
402 Operation *terminator = block->getTerminator();
403 ValueRange results = terminator->getOperands();
404 rewriter.mergeBlockBefore(block, op, blockArgs);
405 rewriter.replaceOp(op, results);
406 rewriter.eraseOp(terminator);
407 }
408
409 namespace {
410 // Fold away ForOp iter arguments that are also yielded by the op.
411 // These arguments must be defined outside of the ForOp region and can just be
412 // forwarded after simplifying the op inits, yields and returns.
413 //
414 // The implementation uses `mergeBlockBefore` to steal the content of the
415 // original ForOp and avoid cloning.
416 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
417 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
418
matchAndRewrite__anon89baa06c0511::ForOpIterArgsFolder419 LogicalResult matchAndRewrite(scf::ForOp forOp,
420 PatternRewriter &rewriter) const final {
421 bool canonicalize = false;
422 Block &block = forOp.region().front();
423 auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
424
425 // An internal flat vector of block transfer
426 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
427 // transformed block argument mappings. This plays the role of a
428 // BlockAndValueMapping for the particular use case of calling into
429 // `mergeBlockBefore`.
430 SmallVector<bool, 4> keepMask;
431 keepMask.reserve(yieldOp.getNumOperands());
432 SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
433 newResultValues;
434 newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
435 newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
436 newIterArgs.reserve(forOp.getNumIterOperands());
437 newYieldValues.reserve(yieldOp.getNumOperands());
438 newResultValues.reserve(forOp.getNumResults());
439 for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
440 forOp.getRegionIterArgs(), // iter inside region
441 yieldOp.getOperands() // iter yield
442 )) {
443 // Forwarded is `true` when the region `iter` argument is yielded.
444 bool forwarded = (std::get<1>(it) == std::get<2>(it));
445 keepMask.push_back(!forwarded);
446 canonicalize |= forwarded;
447 if (forwarded) {
448 newBlockTransferArgs.push_back(std::get<0>(it));
449 newResultValues.push_back(std::get<0>(it));
450 continue;
451 }
452 newIterArgs.push_back(std::get<0>(it));
453 newYieldValues.push_back(std::get<2>(it));
454 newBlockTransferArgs.push_back(Value()); // placeholder with null value
455 newResultValues.push_back(Value()); // placeholder with null value
456 }
457
458 if (!canonicalize)
459 return failure();
460
461 scf::ForOp newForOp = rewriter.create<scf::ForOp>(
462 forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
463 newIterArgs);
464 Block &newBlock = newForOp.region().front();
465
466 // Replace the null placeholders with newly constructed values.
467 newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
468 for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
469 idx != e; ++idx) {
470 Value &blockTransferArg = newBlockTransferArgs[1 + idx];
471 Value &newResultVal = newResultValues[idx];
472 assert((blockTransferArg && newResultVal) ||
473 (!blockTransferArg && !newResultVal));
474 if (!blockTransferArg) {
475 blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
476 newResultVal = newForOp.getResult(collapsedIdx++);
477 }
478 }
479
480 Block &oldBlock = forOp.region().front();
481 assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
482 "unexpected argument size mismatch");
483
484 // No results case: the scf::ForOp builder already created a zero
485 // reult terminator. Merge before this terminator and just get rid of the
486 // original terminator that has been merged in.
487 if (newIterArgs.empty()) {
488 auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
489 rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
490 rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
491 rewriter.replaceOp(forOp, newResultValues);
492 return success();
493 }
494
495 // No terminator case: merge and rewrite the merged terminator.
496 auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
497 OpBuilder::InsertionGuard g(rewriter);
498 rewriter.setInsertionPoint(mergedTerminator);
499 SmallVector<Value, 4> filteredOperands;
500 filteredOperands.reserve(newResultValues.size());
501 for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
502 if (keepMask[idx])
503 filteredOperands.push_back(mergedTerminator.getOperand(idx));
504 rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
505 filteredOperands);
506 };
507
508 rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
509 auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
510 cloneFilteredTerminator(mergedYieldOp);
511 rewriter.eraseOp(mergedYieldOp);
512 rewriter.replaceOp(forOp, newResultValues);
513 return success();
514 }
515 };
516
517 /// Rewriting pattern that erases loops that are known not to iterate and
518 /// replaces single-iteration loops with their bodies.
519 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
520 using OpRewritePattern<ForOp>::OpRewritePattern;
521
matchAndRewrite__anon89baa06c0511::SimplifyTrivialLoops522 LogicalResult matchAndRewrite(ForOp op,
523 PatternRewriter &rewriter) const override {
524 // If the upper bound is the same as the lower bound, the loop does not
525 // iterate, just remove it.
526 if (op.lowerBound() == op.upperBound()) {
527 rewriter.replaceOp(op, op.getIterOperands());
528 return success();
529 }
530
531 auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
532 auto ub = op.upperBound().getDefiningOp<ConstantOp>();
533 if (!lb || !ub)
534 return failure();
535
536 // If the loop is known to have 0 iterations, remove it.
537 llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
538 llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
539 if (lbValue.sge(ubValue)) {
540 rewriter.replaceOp(op, op.getIterOperands());
541 return success();
542 }
543
544 auto step = op.step().getDefiningOp<ConstantOp>();
545 if (!step)
546 return failure();
547
548 // If the loop is known to have 1 iteration, inline its body and remove the
549 // loop.
550 llvm::APInt stepValue = lb.getValue().cast<IntegerAttr>().getValue();
551 if ((lbValue + stepValue).sge(ubValue)) {
552 SmallVector<Value, 4> blockArgs;
553 blockArgs.reserve(op.getNumIterOperands() + 1);
554 blockArgs.push_back(op.lowerBound());
555 llvm::append_range(blockArgs, op.getIterOperands());
556 replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
557 return success();
558 }
559
560 return failure();
561 }
562 };
563 } // namespace
564
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)565 void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
566 MLIRContext *context) {
567 results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops>(context);
568 }
569
570 //===----------------------------------------------------------------------===//
571 // IfOp
572 //===----------------------------------------------------------------------===//
573
build(OpBuilder & builder,OperationState & result,Value cond,bool withElseRegion)574 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
575 bool withElseRegion) {
576 build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
577 }
578
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,bool withElseRegion)579 void IfOp::build(OpBuilder &builder, OperationState &result,
580 TypeRange resultTypes, Value cond, bool withElseRegion) {
581 auto addTerminator = [&](OpBuilder &nested, Location loc) {
582 if (resultTypes.empty())
583 IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
584 loc);
585 };
586
587 build(builder, result, resultTypes, cond, addTerminator,
588 withElseRegion ? addTerminator
589 : function_ref<void(OpBuilder &, Location)>());
590 }
591
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)592 void IfOp::build(OpBuilder &builder, OperationState &result,
593 TypeRange resultTypes, Value cond,
594 function_ref<void(OpBuilder &, Location)> thenBuilder,
595 function_ref<void(OpBuilder &, Location)> elseBuilder) {
596 assert(thenBuilder && "the builder callback for 'then' must be present");
597
598 result.addOperands(cond);
599 result.addTypes(resultTypes);
600
601 OpBuilder::InsertionGuard guard(builder);
602 Region *thenRegion = result.addRegion();
603 builder.createBlock(thenRegion);
604 thenBuilder(builder, result.location);
605
606 Region *elseRegion = result.addRegion();
607 if (!elseBuilder)
608 return;
609
610 builder.createBlock(elseRegion);
611 elseBuilder(builder, result.location);
612 }
613
build(OpBuilder & builder,OperationState & result,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)614 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
615 function_ref<void(OpBuilder &, Location)> thenBuilder,
616 function_ref<void(OpBuilder &, Location)> elseBuilder) {
617 build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
618 }
619
verify(IfOp op)620 static LogicalResult verify(IfOp op) {
621 if (op.getNumResults() != 0 && op.elseRegion().empty())
622 return op.emitOpError("must have an else block if defining values");
623
624 return RegionBranchOpInterface::verifyTypes(op);
625 }
626
parseIfOp(OpAsmParser & parser,OperationState & result)627 static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
628 // Create the regions for 'then'.
629 result.regions.reserve(2);
630 Region *thenRegion = result.addRegion();
631 Region *elseRegion = result.addRegion();
632
633 auto &builder = parser.getBuilder();
634 OpAsmParser::OperandType cond;
635 Type i1Type = builder.getIntegerType(1);
636 if (parser.parseOperand(cond) ||
637 parser.resolveOperand(cond, i1Type, result.operands))
638 return failure();
639 // Parse optional results type list.
640 if (parser.parseOptionalArrowTypeList(result.types))
641 return failure();
642 // Parse the 'then' region.
643 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
644 return failure();
645 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
646
647 // If we find an 'else' keyword then parse the 'else' region.
648 if (!parser.parseOptionalKeyword("else")) {
649 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
650 return failure();
651 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
652 }
653
654 // Parse the optional attribute list.
655 if (parser.parseOptionalAttrDict(result.attributes))
656 return failure();
657 return success();
658 }
659
print(OpAsmPrinter & p,IfOp op)660 static void print(OpAsmPrinter &p, IfOp op) {
661 bool printBlockTerminators = false;
662
663 p << IfOp::getOperationName() << " " << op.condition();
664 if (!op.results().empty()) {
665 p << " -> (" << op.getResultTypes() << ")";
666 // Print yield explicitly if the op defines values.
667 printBlockTerminators = true;
668 }
669 p.printRegion(op.thenRegion(),
670 /*printEntryBlockArgs=*/false,
671 /*printBlockTerminators=*/printBlockTerminators);
672
673 // Print the 'else' regions if it exists and has a block.
674 auto &elseRegion = op.elseRegion();
675 if (!elseRegion.empty()) {
676 p << " else";
677 p.printRegion(elseRegion,
678 /*printEntryBlockArgs=*/false,
679 /*printBlockTerminators=*/printBlockTerminators);
680 }
681
682 p.printOptionalAttrDict(op.getAttrs());
683 }
684
685 /// Given the region at `index`, or the parent operation if `index` is None,
686 /// return the successor regions. These are the regions that may be selected
687 /// during the flow of control. `operands` is a set of optional attributes that
688 /// correspond to a constant value for each operand, or null if that operand is
689 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)690 void IfOp::getSuccessorRegions(Optional<unsigned> index,
691 ArrayRef<Attribute> operands,
692 SmallVectorImpl<RegionSuccessor> ®ions) {
693 // The `then` and the `else` region branch back to the parent operation.
694 if (index.hasValue()) {
695 regions.push_back(RegionSuccessor(getResults()));
696 return;
697 }
698
699 // Don't consider the else region if it is empty.
700 Region *elseRegion = &this->elseRegion();
701 if (elseRegion->empty())
702 elseRegion = nullptr;
703
704 // Otherwise, the successor is dependent on the condition.
705 bool condition;
706 if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
707 condition = condAttr.getValue().isOneValue();
708 } else {
709 // If the condition isn't constant, both regions may be executed.
710 regions.push_back(RegionSuccessor(&thenRegion()));
711 regions.push_back(RegionSuccessor(elseRegion));
712 return;
713 }
714
715 // Add the successor regions using the condition.
716 regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
717 }
718
719 namespace {
720 // Pattern to remove unused IfOp results.
721 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
722 using OpRewritePattern<IfOp>::OpRewritePattern;
723
transferBody__anon89baa06c0811::RemoveUnusedResults724 void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
725 PatternRewriter &rewriter) const {
726 // Move all operations to the destination block.
727 rewriter.mergeBlocks(source, dest);
728 // Replace the yield op by one that returns only the used values.
729 auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
730 SmallVector<Value, 4> usedOperands;
731 llvm::transform(usedResults, std::back_inserter(usedOperands),
732 [&](OpResult result) {
733 return yieldOp.getOperand(result.getResultNumber());
734 });
735 rewriter.updateRootInPlace(yieldOp,
736 [&]() { yieldOp->setOperands(usedOperands); });
737 }
738
matchAndRewrite__anon89baa06c0811::RemoveUnusedResults739 LogicalResult matchAndRewrite(IfOp op,
740 PatternRewriter &rewriter) const override {
741 // Compute the list of used results.
742 SmallVector<OpResult, 4> usedResults;
743 llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
744 [](OpResult result) { return !result.use_empty(); });
745
746 // Replace the operation if only a subset of its results have uses.
747 if (usedResults.size() == op.getNumResults())
748 return failure();
749
750 // Compute the result types of the replacement operation.
751 SmallVector<Type, 4> newTypes;
752 llvm::transform(usedResults, std::back_inserter(newTypes),
753 [](OpResult result) { return result.getType(); });
754
755 // Create a replacement operation with empty then and else regions.
756 auto emptyBuilder = [](OpBuilder &, Location) {};
757 auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.condition(),
758 emptyBuilder, emptyBuilder);
759
760 // Move the bodies and replace the terminators (note there is a then and
761 // an else region since the operation returns results).
762 transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
763 transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
764
765 // Replace the operation by the new one.
766 SmallVector<Value, 4> repResults(op.getNumResults());
767 for (auto en : llvm::enumerate(usedResults))
768 repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
769 rewriter.replaceOp(op, repResults);
770 return success();
771 }
772 };
773
774 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
775 using OpRewritePattern<IfOp>::OpRewritePattern;
776
matchAndRewrite__anon89baa06c0811::RemoveStaticCondition777 LogicalResult matchAndRewrite(IfOp op,
778 PatternRewriter &rewriter) const override {
779 auto constant = op.condition().getDefiningOp<ConstantOp>();
780 if (!constant)
781 return failure();
782
783 if (constant.getValue().cast<BoolAttr>().getValue())
784 replaceOpWithRegion(rewriter, op, op.thenRegion());
785 else if (!op.elseRegion().empty())
786 replaceOpWithRegion(rewriter, op, op.elseRegion());
787 else
788 rewriter.eraseOp(op);
789
790 return success();
791 }
792 };
793 } // namespace
794
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)795 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
796 MLIRContext *context) {
797 results.insert<RemoveUnusedResults, RemoveStaticCondition>(context);
798 }
799
800 //===----------------------------------------------------------------------===//
801 // ParallelOp
802 //===----------------------------------------------------------------------===//
803
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange initVals,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn)804 void ParallelOp::build(
805 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
806 ValueRange upperBounds, ValueRange steps, ValueRange initVals,
807 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
808 bodyBuilderFn) {
809 result.addOperands(lowerBounds);
810 result.addOperands(upperBounds);
811 result.addOperands(steps);
812 result.addOperands(initVals);
813 result.addAttribute(
814 ParallelOp::getOperandSegmentSizeAttr(),
815 builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
816 static_cast<int32_t>(upperBounds.size()),
817 static_cast<int32_t>(steps.size()),
818 static_cast<int32_t>(initVals.size())}));
819 result.addTypes(initVals.getTypes());
820
821 OpBuilder::InsertionGuard guard(builder);
822 unsigned numIVs = steps.size();
823 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
824 Region *bodyRegion = result.addRegion();
825 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
826
827 if (bodyBuilderFn) {
828 builder.setInsertionPointToStart(bodyBlock);
829 bodyBuilderFn(builder, result.location,
830 bodyBlock->getArguments().take_front(numIVs),
831 bodyBlock->getArguments().drop_front(numIVs));
832 }
833 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
834 }
835
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)836 void ParallelOp::build(
837 OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
838 ValueRange upperBounds, ValueRange steps,
839 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
840 // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
841 // we don't capture a reference to a temporary by constructing the lambda at
842 // function level.
843 auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
844 Location nestedLoc, ValueRange ivs,
845 ValueRange) {
846 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
847 };
848 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
849 if (bodyBuilderFn)
850 wrapper = wrappedBuilderFn;
851
852 build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
853 wrapper);
854 }
855
verify(ParallelOp op)856 static LogicalResult verify(ParallelOp op) {
857 // Check that there is at least one value in lowerBound, upperBound and step.
858 // It is sufficient to test only step, because it is ensured already that the
859 // number of elements in lowerBound, upperBound and step are the same.
860 Operation::operand_range stepValues = op.step();
861 if (stepValues.empty())
862 return op.emitOpError(
863 "needs at least one tuple element for lowerBound, upperBound and step");
864
865 // Check whether all constant step values are positive.
866 for (Value stepValue : stepValues)
867 if (auto cst = stepValue.getDefiningOp<ConstantIndexOp>())
868 if (cst.getValue() <= 0)
869 return op.emitOpError("constant step operand must be positive");
870
871 // Check that the body defines the same number of block arguments as the
872 // number of tuple elements in step.
873 Block *body = op.getBody();
874 if (body->getNumArguments() != stepValues.size())
875 return op.emitOpError()
876 << "expects the same number of induction variables: "
877 << body->getNumArguments()
878 << " as bound and step values: " << stepValues.size();
879 for (auto arg : body->getArguments())
880 if (!arg.getType().isIndex())
881 return op.emitOpError(
882 "expects arguments for the induction variable to be of index type");
883
884 // Check that the yield has no results
885 Operation *yield = body->getTerminator();
886 if (yield->getNumOperands() != 0)
887 return yield->emitOpError() << "not allowed to have operands inside '"
888 << ParallelOp::getOperationName() << "'";
889
890 // Check that the number of results is the same as the number of ReduceOps.
891 SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
892 auto resultsSize = op.results().size();
893 auto reductionsSize = reductions.size();
894 auto initValsSize = op.initVals().size();
895 if (resultsSize != reductionsSize)
896 return op.emitOpError()
897 << "expects number of results: " << resultsSize
898 << " to be the same as number of reductions: " << reductionsSize;
899 if (resultsSize != initValsSize)
900 return op.emitOpError()
901 << "expects number of results: " << resultsSize
902 << " to be the same as number of initial values: " << initValsSize;
903
904 // Check that the types of the results and reductions are the same.
905 for (auto resultAndReduce : llvm::zip(op.results(), reductions)) {
906 auto resultType = std::get<0>(resultAndReduce).getType();
907 auto reduceOp = std::get<1>(resultAndReduce);
908 auto reduceType = reduceOp.operand().getType();
909 if (resultType != reduceType)
910 return reduceOp.emitOpError()
911 << "expects type of reduce: " << reduceType
912 << " to be the same as result type: " << resultType;
913 }
914 return success();
915 }
916
parseParallelOp(OpAsmParser & parser,OperationState & result)917 static ParseResult parseParallelOp(OpAsmParser &parser,
918 OperationState &result) {
919 auto &builder = parser.getBuilder();
920 // Parse an opening `(` followed by induction variables followed by `)`
921 SmallVector<OpAsmParser::OperandType, 4> ivs;
922 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
923 OpAsmParser::Delimiter::Paren))
924 return failure();
925
926 // Parse loop bounds.
927 SmallVector<OpAsmParser::OperandType, 4> lower;
928 if (parser.parseEqual() ||
929 parser.parseOperandList(lower, ivs.size(),
930 OpAsmParser::Delimiter::Paren) ||
931 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
932 return failure();
933
934 SmallVector<OpAsmParser::OperandType, 4> upper;
935 if (parser.parseKeyword("to") ||
936 parser.parseOperandList(upper, ivs.size(),
937 OpAsmParser::Delimiter::Paren) ||
938 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
939 return failure();
940
941 // Parse step values.
942 SmallVector<OpAsmParser::OperandType, 4> steps;
943 if (parser.parseKeyword("step") ||
944 parser.parseOperandList(steps, ivs.size(),
945 OpAsmParser::Delimiter::Paren) ||
946 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
947 return failure();
948
949 // Parse init values.
950 SmallVector<OpAsmParser::OperandType, 4> initVals;
951 if (succeeded(parser.parseOptionalKeyword("init"))) {
952 if (parser.parseOperandList(initVals, /*requiredOperandCount=*/-1,
953 OpAsmParser::Delimiter::Paren))
954 return failure();
955 }
956
957 // Parse optional results in case there is a reduce.
958 if (parser.parseOptionalArrowTypeList(result.types))
959 return failure();
960
961 // Now parse the body.
962 Region *body = result.addRegion();
963 SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
964 if (parser.parseRegion(*body, ivs, types))
965 return failure();
966
967 // Set `operand_segment_sizes` attribute.
968 result.addAttribute(
969 ParallelOp::getOperandSegmentSizeAttr(),
970 builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
971 static_cast<int32_t>(upper.size()),
972 static_cast<int32_t>(steps.size()),
973 static_cast<int32_t>(initVals.size())}));
974
975 // Parse attributes.
976 if (parser.parseOptionalAttrDict(result.attributes))
977 return failure();
978
979 if (!initVals.empty())
980 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
981 result.operands);
982 // Add a terminator if none was parsed.
983 ForOp::ensureTerminator(*body, builder, result.location);
984
985 return success();
986 }
987
print(OpAsmPrinter & p,ParallelOp op)988 static void print(OpAsmPrinter &p, ParallelOp op) {
989 p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
990 << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
991 << ")";
992 if (!op.initVals().empty())
993 p << " init (" << op.initVals() << ")";
994 p.printOptionalArrowTypeList(op.getResultTypes());
995 p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
996 p.printOptionalAttrDict(
997 op.getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
998 }
999
getLoopBody()1000 Region &ParallelOp::getLoopBody() { return region(); }
1001
isDefinedOutsideOfLoop(Value value)1002 bool ParallelOp::isDefinedOutsideOfLoop(Value value) {
1003 return !region().isAncestor(value.getParentRegion());
1004 }
1005
moveOutOfLoop(ArrayRef<Operation * > ops)1006 LogicalResult ParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1007 for (auto op : ops)
1008 op->moveBefore(*this);
1009 return success();
1010 }
1011
getParallelForInductionVarOwner(Value val)1012 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
1013 auto ivArg = val.dyn_cast<BlockArgument>();
1014 if (!ivArg)
1015 return ParallelOp();
1016 assert(ivArg.getOwner() && "unlinked block argument");
1017 auto *containingOp = ivArg.getOwner()->getParentOp();
1018 return dyn_cast<ParallelOp>(containingOp);
1019 }
1020
1021 namespace {
1022 // Collapse loop dimensions that perform a single iteration.
1023 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
1024 using OpRewritePattern<ParallelOp>::OpRewritePattern;
1025
matchAndRewrite__anon89baa06c0f11::CollapseSingleIterationLoops1026 LogicalResult matchAndRewrite(ParallelOp op,
1027 PatternRewriter &rewriter) const override {
1028 BlockAndValueMapping mapping;
1029 // Compute new loop bounds that omit all single-iteration loop dimensions.
1030 SmallVector<Value, 2> newLowerBounds;
1031 SmallVector<Value, 2> newUpperBounds;
1032 SmallVector<Value, 2> newSteps;
1033 newLowerBounds.reserve(op.lowerBound().size());
1034 newUpperBounds.reserve(op.upperBound().size());
1035 newSteps.reserve(op.step().size());
1036 for (auto dim : llvm::zip(op.lowerBound(), op.upperBound(), op.step(),
1037 op.getInductionVars())) {
1038 Value lowerBound, upperBound, step, iv;
1039 std::tie(lowerBound, upperBound, step, iv) = dim;
1040 // Collect the statically known loop bounds.
1041 auto lowerBoundConstant =
1042 dyn_cast_or_null<ConstantIndexOp>(lowerBound.getDefiningOp());
1043 auto upperBoundConstant =
1044 dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp());
1045 auto stepConstant =
1046 dyn_cast_or_null<ConstantIndexOp>(step.getDefiningOp());
1047 // Replace the loop induction variable by the lower bound if the loop
1048 // performs a single iteration. Otherwise, copy the loop bounds.
1049 if (lowerBoundConstant && upperBoundConstant && stepConstant &&
1050 (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) > 0 &&
1051 (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) <=
1052 stepConstant.getValue()) {
1053 mapping.map(iv, lowerBound);
1054 } else {
1055 newLowerBounds.push_back(lowerBound);
1056 newUpperBounds.push_back(upperBound);
1057 newSteps.push_back(step);
1058 }
1059 }
1060 // Exit if all or none of the loop dimensions perform a single iteration.
1061 if (newLowerBounds.size() == 0 ||
1062 newLowerBounds.size() == op.lowerBound().size())
1063 return failure();
1064 // Replace the parallel loop by lower-dimensional parallel loop.
1065 auto newOp =
1066 rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
1067 newSteps, op.initVals(), nullptr);
1068 // Clone the loop body and remap the block arguments of the collapsed loops
1069 // (inlining does not support a cancellable block argument mapping).
1070 rewriter.cloneRegionBefore(op.region(), newOp.region(),
1071 newOp.region().begin(), mapping);
1072 rewriter.replaceOp(op, newOp.getResults());
1073 return success();
1074 }
1075 };
1076
1077 /// Removes parallel loops in which at least one lower/upper bound pair consists
1078 /// of the same values - such loops have an empty iteration domain.
1079 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
1080 using OpRewritePattern<ParallelOp>::OpRewritePattern;
1081
matchAndRewrite__anon89baa06c0f11::RemoveEmptyParallelLoops1082 LogicalResult matchAndRewrite(ParallelOp op,
1083 PatternRewriter &rewriter) const override {
1084 for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
1085 if (std::get<0>(dim) == std::get<1>(dim)) {
1086 rewriter.replaceOp(op, op.initVals());
1087 return success();
1088 }
1089 }
1090 return failure();
1091 }
1092 };
1093
1094 } // namespace
1095
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1096 void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1097 MLIRContext *context) {
1098 results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(
1099 context);
1100 }
1101
1102 //===----------------------------------------------------------------------===//
1103 // ReduceOp
1104 //===----------------------------------------------------------------------===//
1105
build(OpBuilder & builder,OperationState & result,Value operand,function_ref<void (OpBuilder &,Location,Value,Value)> bodyBuilderFn)1106 void ReduceOp::build(
1107 OpBuilder &builder, OperationState &result, Value operand,
1108 function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
1109 auto type = operand.getType();
1110 result.addOperands(operand);
1111
1112 OpBuilder::InsertionGuard guard(builder);
1113 Region *bodyRegion = result.addRegion();
1114 Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type});
1115 if (bodyBuilderFn)
1116 bodyBuilderFn(builder, result.location, body->getArgument(0),
1117 body->getArgument(1));
1118 }
1119
verify(ReduceOp op)1120 static LogicalResult verify(ReduceOp op) {
1121 // The region of a ReduceOp has two arguments of the same type as its operand.
1122 auto type = op.operand().getType();
1123 Block &block = op.reductionOperator().front();
1124 if (block.empty())
1125 return op.emitOpError("the block inside reduce should not be empty");
1126 if (block.getNumArguments() != 2 ||
1127 llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
1128 return arg.getType() != type;
1129 }))
1130 return op.emitOpError()
1131 << "expects two arguments to reduce block of type " << type;
1132
1133 // Check that the block is terminated by a ReduceReturnOp.
1134 if (!isa<ReduceReturnOp>(block.getTerminator()))
1135 return op.emitOpError("the block inside reduce should be terminated with a "
1136 "'scf.reduce.return' op");
1137
1138 return success();
1139 }
1140
parseReduceOp(OpAsmParser & parser,OperationState & result)1141 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1142 // Parse an opening `(` followed by the reduced value followed by `)`
1143 OpAsmParser::OperandType operand;
1144 if (parser.parseLParen() || parser.parseOperand(operand) ||
1145 parser.parseRParen())
1146 return failure();
1147
1148 Type resultType;
1149 // Parse the type of the operand (and also what reduce computes on).
1150 if (parser.parseColonType(resultType) ||
1151 parser.resolveOperand(operand, resultType, result.operands))
1152 return failure();
1153
1154 // Now parse the body.
1155 Region *body = result.addRegion();
1156 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1157 return failure();
1158
1159 return success();
1160 }
1161
print(OpAsmPrinter & p,ReduceOp op)1162 static void print(OpAsmPrinter &p, ReduceOp op) {
1163 p << op.getOperationName() << "(" << op.operand() << ") ";
1164 p << " : " << op.operand().getType();
1165 p.printRegion(op.reductionOperator());
1166 }
1167
1168 //===----------------------------------------------------------------------===//
1169 // ReduceReturnOp
1170 //===----------------------------------------------------------------------===//
1171
verify(ReduceReturnOp op)1172 static LogicalResult verify(ReduceReturnOp op) {
1173 // The type of the return value should be the same type as the type of the
1174 // operand of the enclosing ReduceOp.
1175 auto reduceOp = cast<ReduceOp>(op->getParentOp());
1176 Type reduceType = reduceOp.operand().getType();
1177 if (reduceType != op.result().getType())
1178 return op.emitOpError() << "needs to have type " << reduceType
1179 << " (the type of the enclosing ReduceOp)";
1180 return success();
1181 }
1182
1183 //===----------------------------------------------------------------------===//
1184 // WhileOp
1185 //===----------------------------------------------------------------------===//
1186
getSuccessorEntryOperands(unsigned index)1187 OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
1188 assert(index == 0 &&
1189 "WhileOp is expected to branch only to the first region");
1190
1191 return inits();
1192 }
1193
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1194 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
1195 ArrayRef<Attribute> operands,
1196 SmallVectorImpl<RegionSuccessor> ®ions) {
1197 (void)operands;
1198
1199 if (!index.hasValue()) {
1200 regions.emplace_back(&before(), before().getArguments());
1201 return;
1202 }
1203
1204 assert(*index < 2 && "there are only two regions in a WhileOp");
1205 if (*index == 0) {
1206 regions.emplace_back(&after(), after().getArguments());
1207 regions.emplace_back(getResults());
1208 return;
1209 }
1210
1211 regions.emplace_back(&before(), before().getArguments());
1212 }
1213
1214 /// Parses a `while` op.
1215 ///
1216 /// op ::= `scf.while` assignments `:` function-type region `do` region
1217 /// `attributes` attribute-dict
1218 /// initializer ::= /* empty */ | `(` assignment-list `)`
1219 /// assignment-list ::= assignment | assignment `,` assignment-list
1220 /// assignment ::= ssa-value `=` ssa-value
parseWhileOp(OpAsmParser & parser,OperationState & result)1221 static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) {
1222 SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1223 Region *before = result.addRegion();
1224 Region *after = result.addRegion();
1225
1226 OptionalParseResult listResult =
1227 parser.parseOptionalAssignmentList(regionArgs, operands);
1228 if (listResult.hasValue() && failed(listResult.getValue()))
1229 return failure();
1230
1231 FunctionType functionType;
1232 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1233 if (failed(parser.parseColonType(functionType)))
1234 return failure();
1235
1236 result.addTypes(functionType.getResults());
1237
1238 if (functionType.getNumInputs() != operands.size()) {
1239 return parser.emitError(typeLoc)
1240 << "expected as many input types as operands "
1241 << "(expected " << operands.size() << " got "
1242 << functionType.getNumInputs() << ")";
1243 }
1244
1245 // Resolve input operands.
1246 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
1247 parser.getCurrentLocation(),
1248 result.operands)))
1249 return failure();
1250
1251 return failure(
1252 parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
1253 parser.parseKeyword("do") || parser.parseRegion(*after) ||
1254 parser.parseOptionalAttrDictWithKeyword(result.attributes));
1255 }
1256
1257 /// Prints a `while` op.
print(OpAsmPrinter & p,scf::WhileOp op)1258 static void print(OpAsmPrinter &p, scf::WhileOp op) {
1259 p << op.getOperationName();
1260 printInitializationList(p, op.before().front().getArguments(), op.inits(),
1261 " ");
1262 p << " : ";
1263 p.printFunctionalType(op.inits().getTypes(), op.results().getTypes());
1264 p.printRegion(op.before(), /*printEntryBlockArgs=*/false);
1265 p << " do";
1266 p.printRegion(op.after());
1267 p.printOptionalAttrDictWithKeyword(op.getAttrs());
1268 }
1269
1270 /// Verifies that two ranges of types match, i.e. have the same number of
1271 /// entries and that types are pairwise equals. Reports errors on the given
1272 /// operation in case of mismatch.
1273 template <typename OpTy>
verifyTypeRangesMatch(OpTy op,TypeRange left,TypeRange right,StringRef message)1274 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
1275 TypeRange right, StringRef message) {
1276 if (left.size() != right.size())
1277 return op.emitOpError("expects the same number of ") << message;
1278
1279 for (unsigned i = 0, e = left.size(); i < e; ++i) {
1280 if (left[i] != right[i]) {
1281 InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
1282 << message;
1283 diag.attachNote() << "for argument " << i << ", found " << left[i]
1284 << " and " << right[i];
1285 return diag;
1286 }
1287 }
1288
1289 return success();
1290 }
1291
1292 /// Verifies that the first block of the given `region` is terminated by a
1293 /// YieldOp. Reports errors on the given operation if it is not the case.
1294 template <typename TerminatorTy>
verifyAndGetTerminator(scf::WhileOp op,Region & region,StringRef errorMessage)1295 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region ®ion,
1296 StringRef errorMessage) {
1297 Operation *terminatorOperation = region.front().getTerminator();
1298 if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
1299 return yield;
1300
1301 auto diag = op.emitOpError(errorMessage);
1302 if (terminatorOperation)
1303 diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
1304 return nullptr;
1305 }
1306
verify(scf::WhileOp op)1307 static LogicalResult verify(scf::WhileOp op) {
1308 if (failed(RegionBranchOpInterface::verifyTypes(op)))
1309 return failure();
1310
1311 auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
1312 op, op.before(),
1313 "expects the 'before' region to terminate with 'scf.condition'");
1314 if (!beforeTerminator)
1315 return failure();
1316
1317 TypeRange trailingTerminatorOperands = beforeTerminator.args().getTypes();
1318 if (failed(verifyTypeRangesMatch(op, trailingTerminatorOperands,
1319 op.after().getArgumentTypes(),
1320 "trailing operands of the 'before' block "
1321 "terminator and 'after' region arguments")))
1322 return failure();
1323
1324 if (failed(verifyTypeRangesMatch(
1325 op, trailingTerminatorOperands, op.getResultTypes(),
1326 "trailing operands of the 'before' block terminator and op results")))
1327 return failure();
1328
1329 auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
1330 op, op.after(),
1331 "expects the 'after' region to terminate with 'scf.yield'");
1332 return success(afterTerminator != nullptr);
1333 }
1334
1335 //===----------------------------------------------------------------------===//
1336 // YieldOp
1337 //===----------------------------------------------------------------------===//
1338
parseYieldOp(OpAsmParser & parser,OperationState & result)1339 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
1340 SmallVector<OpAsmParser::OperandType, 4> operands;
1341 SmallVector<Type, 4> types;
1342 llvm::SMLoc loc = parser.getCurrentLocation();
1343 // Parse variadic operands list, their types, and resolve operands to SSA
1344 // values.
1345 if (parser.parseOperandList(operands) ||
1346 parser.parseOptionalColonTypeList(types) ||
1347 parser.resolveOperands(operands, types, loc, result.operands))
1348 return failure();
1349 return success();
1350 }
1351
print(OpAsmPrinter & p,scf::YieldOp op)1352 static void print(OpAsmPrinter &p, scf::YieldOp op) {
1353 p << op.getOperationName();
1354 if (op.getNumOperands() != 0)
1355 p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
1356 }
1357
1358 //===----------------------------------------------------------------------===//
1359 // TableGen'd op method definitions
1360 //===----------------------------------------------------------------------===//
1361
1362 #define GET_OP_CLASSES
1363 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
1364