1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous inlining utilities.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Transforms/InliningUtils.h"
14
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/raw_ostream.h"
22
23 #define DEBUG_TYPE "inlining"
24
25 using namespace mlir;
26
27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the
28 /// provided caller location.
29 static void
remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,Location callerLoc)30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
31 Location callerLoc) {
32 DenseMap<Location, Location> mappedLocations;
33 auto remapOpLoc = [&](Operation *op) {
34 auto it = mappedLocations.find(op->getLoc());
35 if (it == mappedLocations.end()) {
36 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
37 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
38 }
39 op->setLoc(it->second);
40 };
41 for (auto &block : inlinedBlocks)
42 block.walk(remapOpLoc);
43 }
44
remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,BlockAndValueMapping & mapper)45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
46 BlockAndValueMapping &mapper) {
47 auto remapOperands = [&](Operation *op) {
48 for (auto &operand : op->getOpOperands())
49 if (auto mappedOp = mapper.lookupOrNull(operand.get()))
50 operand.set(mappedOp);
51 };
52 for (auto &block : inlinedBlocks)
53 block.walk(remapOperands);
54 }
55
56 //===----------------------------------------------------------------------===//
57 // InlinerInterface
58 //===----------------------------------------------------------------------===//
59
isLegalToInline(Operation * call,Operation * callable,bool wouldBeCloned) const60 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
61 bool wouldBeCloned) const {
62 if (auto *handler = getInterfaceFor(call))
63 return handler->isLegalToInline(call, callable, wouldBeCloned);
64 return false;
65 }
66
isLegalToInline(Region * dest,Region * src,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const67 bool InlinerInterface::isLegalToInline(
68 Region *dest, Region *src, bool wouldBeCloned,
69 BlockAndValueMapping &valueMapping) const {
70 // Regions can always be inlined into functions.
71 if (isa<FuncOp>(dest->getParentOp()))
72 return true;
73
74 if (auto *handler = getInterfaceFor(dest->getParentOp()))
75 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
76 return false;
77 }
78
isLegalToInline(Operation * op,Region * dest,bool wouldBeCloned,BlockAndValueMapping & valueMapping) const79 bool InlinerInterface::isLegalToInline(
80 Operation *op, Region *dest, bool wouldBeCloned,
81 BlockAndValueMapping &valueMapping) const {
82 if (auto *handler = getInterfaceFor(op))
83 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
84 return false;
85 }
86
shouldAnalyzeRecursively(Operation * op) const87 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
88 auto *handler = getInterfaceFor(op);
89 return handler ? handler->shouldAnalyzeRecursively(op) : true;
90 }
91
92 /// Handle the given inlined terminator by replacing it with a new operation
93 /// as necessary.
handleTerminator(Operation * op,Block * newDest) const94 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
95 auto *handler = getInterfaceFor(op);
96 assert(handler && "expected valid dialect handler");
97 handler->handleTerminator(op, newDest);
98 }
99
100 /// Handle the given inlined terminator by replacing it with a new operation
101 /// as necessary.
handleTerminator(Operation * op,ArrayRef<Value> valuesToRepl) const102 void InlinerInterface::handleTerminator(Operation *op,
103 ArrayRef<Value> valuesToRepl) const {
104 auto *handler = getInterfaceFor(op);
105 assert(handler && "expected valid dialect handler");
106 handler->handleTerminator(op, valuesToRepl);
107 }
108
109 /// Utility to check that all of the operations within 'src' can be inlined.
isLegalToInline(InlinerInterface & interface,Region * src,Region * insertRegion,bool shouldCloneInlinedRegion,BlockAndValueMapping & valueMapping)110 static bool isLegalToInline(InlinerInterface &interface, Region *src,
111 Region *insertRegion, bool shouldCloneInlinedRegion,
112 BlockAndValueMapping &valueMapping) {
113 for (auto &block : *src) {
114 for (auto &op : block) {
115 // Check this operation.
116 if (!interface.isLegalToInline(&op, insertRegion,
117 shouldCloneInlinedRegion, valueMapping)) {
118 LLVM_DEBUG({
119 llvm::dbgs() << "* Illegal to inline because of op: ";
120 op.dump();
121 });
122 return false;
123 }
124 // Check any nested regions.
125 if (interface.shouldAnalyzeRecursively(&op) &&
126 llvm::any_of(op.getRegions(), [&](Region ®ion) {
127 return !isLegalToInline(interface, ®ion, insertRegion,
128 shouldCloneInlinedRegion, valueMapping);
129 }))
130 return false;
131 }
132 }
133 return true;
134 }
135
136 //===----------------------------------------------------------------------===//
137 // Inline Methods
138 //===----------------------------------------------------------------------===//
139
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,BlockAndValueMapping & mapper,ValueRange resultsToReplace,TypeRange regionResultTypes,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)140 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
141 Operation *inlinePoint,
142 BlockAndValueMapping &mapper,
143 ValueRange resultsToReplace,
144 TypeRange regionResultTypes,
145 Optional<Location> inlineLoc,
146 bool shouldCloneInlinedRegion) {
147 assert(resultsToReplace.size() == regionResultTypes.size());
148 // We expect the region to have at least one block.
149 if (src->empty())
150 return failure();
151
152 // Check that all of the region arguments have been mapped.
153 auto *srcEntryBlock = &src->front();
154 if (llvm::any_of(srcEntryBlock->getArguments(),
155 [&](BlockArgument arg) { return !mapper.contains(arg); }))
156 return failure();
157
158 // The insertion point must be within a block.
159 Block *insertBlock = inlinePoint->getBlock();
160 if (!insertBlock)
161 return failure();
162 Region *insertRegion = insertBlock->getParent();
163
164 // Check that the operations within the source region are valid to inline.
165 if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
166 mapper) ||
167 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
168 mapper))
169 return failure();
170
171 // Split the insertion block.
172 Block *postInsertBlock =
173 insertBlock->splitBlock(++inlinePoint->getIterator());
174
175 // Check to see if the region is being cloned, or moved inline. In either
176 // case, move the new blocks after the 'insertBlock' to improve IR
177 // readability.
178 if (shouldCloneInlinedRegion)
179 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
180 else
181 insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
182 src->getBlocks(), src->begin(),
183 src->end());
184
185 // Get the range of newly inserted blocks.
186 auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()),
187 postInsertBlock->getIterator());
188 Block *firstNewBlock = &*newBlocks.begin();
189
190 // Remap the locations of the inlined operations if a valid source location
191 // was provided.
192 if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
193 remapInlinedLocations(newBlocks, *inlineLoc);
194
195 // If the blocks were moved in-place, make sure to remap any necessary
196 // operands.
197 if (!shouldCloneInlinedRegion)
198 remapInlinedOperands(newBlocks, mapper);
199
200 // Process the newly inlined blocks.
201 interface.processInlinedBlocks(newBlocks);
202
203 // Handle the case where only a single block was inlined.
204 if (std::next(newBlocks.begin()) == newBlocks.end()) {
205 // Have the interface handle the terminator of this block.
206 auto *firstBlockTerminator = firstNewBlock->getTerminator();
207 interface.handleTerminator(firstBlockTerminator,
208 llvm::to_vector<6>(resultsToReplace));
209 firstBlockTerminator->erase();
210
211 // Merge the post insert block into the cloned entry block.
212 firstNewBlock->getOperations().splice(firstNewBlock->end(),
213 postInsertBlock->getOperations());
214 postInsertBlock->erase();
215 } else {
216 // Otherwise, there were multiple blocks inlined. Add arguments to the post
217 // insertion block to represent the results to replace.
218 for (auto resultToRepl : llvm::enumerate(resultsToReplace)) {
219 resultToRepl.value().replaceAllUsesWith(postInsertBlock->addArgument(
220 regionResultTypes[resultToRepl.index()]));
221 }
222
223 /// Handle the terminators for each of the new blocks.
224 for (auto &newBlock : newBlocks)
225 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
226 }
227
228 // Splice the instructions of the inlined entry block into the insert block.
229 insertBlock->getOperations().splice(insertBlock->end(),
230 firstNewBlock->getOperations());
231 firstNewBlock->erase();
232 return success();
233 }
234
235 /// This function is an overload of the above 'inlineRegion' that allows for
236 /// providing the set of operands ('inlinedOperands') that should be used
237 /// in-favor of the region arguments when inlining.
inlineRegion(InlinerInterface & interface,Region * src,Operation * inlinePoint,ValueRange inlinedOperands,ValueRange resultsToReplace,Optional<Location> inlineLoc,bool shouldCloneInlinedRegion)238 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
239 Operation *inlinePoint,
240 ValueRange inlinedOperands,
241 ValueRange resultsToReplace,
242 Optional<Location> inlineLoc,
243 bool shouldCloneInlinedRegion) {
244 // We expect the region to have at least one block.
245 if (src->empty())
246 return failure();
247
248 auto *entryBlock = &src->front();
249 if (inlinedOperands.size() != entryBlock->getNumArguments())
250 return failure();
251
252 // Map the provided call operands to the arguments of the region.
253 BlockAndValueMapping mapper;
254 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
255 // Verify that the types of the provided values match the function argument
256 // types.
257 BlockArgument regionArg = entryBlock->getArgument(i);
258 if (inlinedOperands[i].getType() != regionArg.getType())
259 return failure();
260 mapper.map(regionArg, inlinedOperands[i]);
261 }
262
263 // Call into the main region inliner function.
264 return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
265 resultsToReplace.getTypes(), inlineLoc,
266 shouldCloneInlinedRegion);
267 }
268
269 /// Utility function used to generate a cast operation from the given interface,
270 /// or return nullptr if a cast could not be generated.
materializeConversion(const DialectInlinerInterface * interface,SmallVectorImpl<Operation * > & castOps,OpBuilder & castBuilder,Value arg,Type type,Location conversionLoc)271 static Value materializeConversion(const DialectInlinerInterface *interface,
272 SmallVectorImpl<Operation *> &castOps,
273 OpBuilder &castBuilder, Value arg, Type type,
274 Location conversionLoc) {
275 if (!interface)
276 return nullptr;
277
278 // Check to see if the interface for the call can materialize a conversion.
279 Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
280 type, conversionLoc);
281 if (!castOp)
282 return nullptr;
283 castOps.push_back(castOp);
284
285 // Ensure that the generated cast is correct.
286 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
287 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
288 return castOp->getResult(0);
289 }
290
291 /// This function inlines a given region, 'src', of a callable operation,
292 /// 'callable', into the location defined by the given call operation. This
293 /// function returns failure if inlining is not possible, success otherwise. On
294 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
295 /// corresponds to whether the source region should be cloned into the 'call' or
296 /// spliced directly.
inlineCall(InlinerInterface & interface,CallOpInterface call,CallableOpInterface callable,Region * src,bool shouldCloneInlinedRegion)297 LogicalResult mlir::inlineCall(InlinerInterface &interface,
298 CallOpInterface call,
299 CallableOpInterface callable, Region *src,
300 bool shouldCloneInlinedRegion) {
301 // We expect the region to have at least one block.
302 if (src->empty())
303 return failure();
304 auto *entryBlock = &src->front();
305 ArrayRef<Type> callableResultTypes = callable.getCallableResults();
306
307 // Make sure that the number of arguments and results matchup between the call
308 // and the region.
309 SmallVector<Value, 8> callOperands(call.getArgOperands());
310 SmallVector<Value, 8> callResults(call->getResults());
311 if (callOperands.size() != entryBlock->getNumArguments() ||
312 callResults.size() != callableResultTypes.size())
313 return failure();
314
315 // A set of cast operations generated to matchup the signature of the region
316 // with the signature of the call.
317 SmallVector<Operation *, 4> castOps;
318 castOps.reserve(callOperands.size() + callResults.size());
319
320 // Functor used to cleanup generated state on failure.
321 auto cleanupState = [&] {
322 for (auto *op : castOps) {
323 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
324 op->erase();
325 }
326 return failure();
327 };
328
329 // Builder used for any conversion operations that need to be materialized.
330 OpBuilder castBuilder(call);
331 Location castLoc = call.getLoc();
332 const auto *callInterface = interface.getInterfaceFor(call->getDialect());
333
334 // Map the provided call operands to the arguments of the region.
335 BlockAndValueMapping mapper;
336 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
337 BlockArgument regionArg = entryBlock->getArgument(i);
338 Value operand = callOperands[i];
339
340 // If the call operand doesn't match the expected region argument, try to
341 // generate a cast.
342 Type regionArgType = regionArg.getType();
343 if (operand.getType() != regionArgType) {
344 if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
345 operand, regionArgType, castLoc)))
346 return cleanupState();
347 }
348 mapper.map(regionArg, operand);
349 }
350
351 // Ensure that the resultant values of the call match the callable.
352 castBuilder.setInsertionPointAfter(call);
353 for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
354 Value callResult = callResults[i];
355 if (callResult.getType() == callableResultTypes[i])
356 continue;
357
358 // Generate a conversion that will produce the original type, so that the IR
359 // is still valid after the original call gets replaced.
360 Value castResult =
361 materializeConversion(callInterface, castOps, castBuilder, callResult,
362 callResult.getType(), castLoc);
363 if (!castResult)
364 return cleanupState();
365 callResult.replaceAllUsesWith(castResult);
366 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
367 }
368
369 // Check that it is legal to inline the callable into the call.
370 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
371 return cleanupState();
372
373 // Attempt to inline the call.
374 if (failed(inlineRegion(interface, src, call, mapper, callResults,
375 callableResultTypes, call.getLoc(),
376 shouldCloneInlinedRegion)))
377 return cleanupState();
378 return success();
379 }
380