1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12
13 using namespace mlir;
14
15 //===----------------------------------------------------------------------===//
16 // ControlFlowInterfaces
17 //===----------------------------------------------------------------------===//
18
19 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
20
21 //===----------------------------------------------------------------------===//
22 // BranchOpInterface
23 //===----------------------------------------------------------------------===//
24
25 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
26 /// successor if 'operandIndex' is within the range of 'operands', or None if
27 /// `operandIndex` isn't a successor operand index.
28 Optional<BlockArgument>
getBranchSuccessorArgument(Optional<OperandRange> operands,unsigned operandIndex,Block * successor)29 detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
30 unsigned operandIndex, Block *successor) {
31 // Check that the operands are valid.
32 if (!operands || operands->empty())
33 return llvm::None;
34
35 // Check to ensure that this operand is within the range.
36 unsigned operandsStart = operands->getBeginOperandIndex();
37 if (operandIndex < operandsStart ||
38 operandIndex >= (operandsStart + operands->size()))
39 return llvm::None;
40
41 // Index the successor.
42 unsigned argIndex = operandIndex - operandsStart;
43 return successor->getArgument(argIndex);
44 }
45
46 /// Verify that the given operands match those of the given successor block.
47 LogicalResult
verifyBranchSuccessorOperands(Operation * op,unsigned succNo,Optional<OperandRange> operands)48 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
49 Optional<OperandRange> operands) {
50 if (!operands)
51 return success();
52
53 // Check the count.
54 unsigned operandCount = operands->size();
55 Block *destBB = op->getSuccessor(succNo);
56 if (operandCount != destBB->getNumArguments())
57 return op->emitError() << "branch has " << operandCount
58 << " operands for successor #" << succNo
59 << ", but target block has "
60 << destBB->getNumArguments();
61
62 // Check the types.
63 auto operandIt = operands->begin();
64 for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
65 if ((*operandIt).getType() != destBB->getArgument(i).getType())
66 return op->emitError() << "type mismatch for bb argument #" << i
67 << " of successor #" << succNo;
68 }
69 return success();
70 }
71
72 //===----------------------------------------------------------------------===//
73 // RegionBranchOpInterface
74 //===----------------------------------------------------------------------===//
75
76 // A constant value to represent unknown number of region invocations.
77 const int64_t mlir::kUnknownNumRegionInvocations = -1;
78
79 /// Verify that types match along all region control flow edges originating from
80 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
81 /// op). `getInputsTypesForRegion` is a function that returns the types of the
82 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
83 /// the exact type match verification is not necessary (e.g., if the Op verifies
84 /// the match itself).
85 static LogicalResult
verifyTypesAlongAllEdges(Operation * op,Optional<unsigned> sourceNo,function_ref<Optional<TypeRange> (Optional<unsigned>)> getInputsTypesForRegion)86 verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
87 function_ref<Optional<TypeRange>(Optional<unsigned>)>
88 getInputsTypesForRegion) {
89 auto regionInterface = cast<RegionBranchOpInterface>(op);
90
91 SmallVector<RegionSuccessor, 2> successors;
92 unsigned numInputs;
93 if (sourceNo) {
94 Region &srcRegion = op->getRegion(sourceNo.getValue());
95 numInputs = srcRegion.getNumArguments();
96 } else {
97 numInputs = op->getNumOperands();
98 }
99 SmallVector<Attribute, 2> operands(numInputs, nullptr);
100 regionInterface.getSuccessorRegions(sourceNo, operands, successors);
101
102 for (RegionSuccessor &succ : successors) {
103 Optional<unsigned> succRegionNo;
104 if (!succ.isParent())
105 succRegionNo = succ.getSuccessor()->getRegionNumber();
106
107 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108 diag << "from ";
109 if (sourceNo)
110 diag << "Region #" << sourceNo.getValue();
111 else
112 diag << "parent operands";
113
114 diag << " to ";
115 if (succRegionNo)
116 diag << "Region #" << succRegionNo.getValue();
117 else
118 diag << "parent results";
119 return diag;
120 };
121
122 Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
123 if (!sourceTypes.hasValue())
124 continue;
125
126 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
127 if (sourceTypes->size() != succInputsTypes.size()) {
128 InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
129 return printEdgeName(diag) << ": source has " << sourceTypes->size()
130 << " operands, but target successor needs "
131 << succInputsTypes.size();
132 }
133
134 for (auto typesIdx :
135 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
136 Type sourceType = std::get<0>(typesIdx.value());
137 Type inputType = std::get<1>(typesIdx.value());
138 if (sourceType != inputType) {
139 InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
140 return printEdgeName(diag)
141 << ": source type #" << typesIdx.index() << " " << sourceType
142 << " should match input type #" << typesIdx.index() << " "
143 << inputType;
144 }
145 }
146 }
147 return success();
148 }
149
150 /// Verify that types match along control flow edges described the given op.
verifyTypesAlongControlFlowEdges(Operation * op)151 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
152 auto regionInterface = cast<RegionBranchOpInterface>(op);
153
154 auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
155 if (regionNo.hasValue()) {
156 return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
157 .getTypes();
158 }
159
160 // If the successor of a parent op is the parent itself
161 // RegionBranchOpInterface does not have an API to query what the entry
162 // operands will be in that case. Vend out the result types of the op in
163 // that case so that type checking succeeds for this case.
164 return op->getResultTypes();
165 };
166
167 // Verify types along control flow edges originating from the parent.
168 if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
169 return failure();
170
171 // RegionBranchOpInterface should not be implemented by Ops that do not have
172 // attached regions.
173 assert(op->getNumRegions() != 0);
174
175 // Verify types along control flow edges originating from each region.
176 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
177 Region ®ion = op->getRegion(regionNo);
178
179 // Since the interface cannot distinguish between different ReturnLike
180 // ops within the region branching to different successors, all ReturnLike
181 // ops in this region should have the same operand types. We will then use
182 // one of them as the representative for type matching.
183
184 Operation *regionReturn = nullptr;
185 for (Block &block : region) {
186 Operation *terminator = block.getTerminator();
187 if (!terminator->hasTrait<OpTrait::ReturnLike>())
188 continue;
189
190 if (!regionReturn) {
191 regionReturn = terminator;
192 continue;
193 }
194
195 // Found more than one ReturnLike terminator. Make sure the operand types
196 // match with the first one.
197 if (regionReturn->getOperandTypes() != terminator->getOperandTypes())
198 return op->emitOpError("Region #")
199 << regionNo
200 << " operands mismatch between return-like terminators";
201 }
202
203 auto inputTypesFromRegion =
204 [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
205 // If there is no return-like terminator, the op itself should verify
206 // type consistency.
207 if (!regionReturn)
208 return llvm::None;
209
210 // All successors get the same set of operands.
211 return TypeRange(regionReturn->getOperands().getTypes());
212 };
213
214 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
215 return failure();
216 }
217
218 return success();
219 }
220