1 //===- NormalizeMemRefs.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an interprocedural pass to normalize memrefs to have
10 // identity layout maps.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Transforms/Passes.h"
17 #include "mlir/Transforms/Utils.h"
18 #include "llvm/ADT/SmallSet.h"
19
20 #define DEBUG_TYPE "normalize-memrefs"
21
22 using namespace mlir;
23
24 namespace {
25
26 /// All memrefs passed across functions with non-trivial layout maps are
27 /// converted to ones with trivial identity layout ones.
28 /// If all the memref types/uses in a function are normalizable, we treat
29 /// such functions as normalizable. Also, if a normalizable function is known
30 /// to call a non-normalizable function, we treat that function as
31 /// non-normalizable as well. We assume external functions to be normalizable.
32 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
33 void runOnOperation() override;
34 void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
35 bool areMemRefsNormalizable(FuncOp funcOp);
36 void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
37 void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
38 DenseSet<FuncOp> &normalizableFuncs);
39 Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp);
40 };
41
42 } // end anonymous namespace
43
createNormalizeMemRefsPass()44 std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
45 return std::make_unique<NormalizeMemRefs>();
46 }
47
runOnOperation()48 void NormalizeMemRefs::runOnOperation() {
49 LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
50 ModuleOp moduleOp = getOperation();
51 // We maintain all normalizable FuncOps in a DenseSet. It is initialized
52 // with all the functions within a module and then functions which are not
53 // normalizable are removed from this set.
54 // TODO: Change this to work on FuncLikeOp once there is an operation
55 // interface for it.
56 DenseSet<FuncOp> normalizableFuncs;
57 // Initialize `normalizableFuncs` with all the functions within a module.
58 moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
59
60 // Traverse through all the functions applying a filter which determines
61 // whether that function is normalizable or not. All callers/callees of
62 // a non-normalizable function will also become non-normalizable even if
63 // they aren't passing any or specific non-normalizable memrefs. So,
64 // functions which calls or get called by a non-normalizable becomes non-
65 // normalizable functions themselves.
66 moduleOp.walk([&](FuncOp funcOp) {
67 if (normalizableFuncs.contains(funcOp)) {
68 if (!areMemRefsNormalizable(funcOp)) {
69 LLVM_DEBUG(llvm::dbgs()
70 << "@" << funcOp.getName()
71 << " contains ops that cannot normalize MemRefs\n");
72 // Since this function is not normalizable, we set all the caller
73 // functions and the callees of this function as not normalizable.
74 // TODO: Drop this conservative assumption in the future.
75 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
76 normalizableFuncs);
77 }
78 }
79 });
80
81 LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
82 << " functions\n");
83 // Those functions which can be normalized are subjected to normalization.
84 for (FuncOp &funcOp : normalizableFuncs)
85 normalizeFuncOpMemRefs(funcOp, moduleOp);
86 }
87
88 /// Check whether all the uses of oldMemRef are either dereferencing uses or the
89 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
90 /// are satisfied will the value become a candidate for replacement.
91 /// TODO: Extend this for DimOps.
isMemRefNormalizable(Value::user_range opUsers)92 static bool isMemRefNormalizable(Value::user_range opUsers) {
93 if (llvm::any_of(opUsers, [](Operation *op) {
94 if (op->hasTrait<OpTrait::MemRefsNormalizable>())
95 return false;
96 return true;
97 }))
98 return false;
99 return true;
100 }
101
102 /// Set all the calling functions and the callees of the function as not
103 /// normalizable.
setCalleesAndCallersNonNormalizable(FuncOp funcOp,ModuleOp moduleOp,DenseSet<FuncOp> & normalizableFuncs)104 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
105 FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
106 if (!normalizableFuncs.contains(funcOp))
107 return;
108
109 LLVM_DEBUG(
110 llvm::dbgs() << "@" << funcOp.getName()
111 << " calls or is called by non-normalizable function\n");
112 normalizableFuncs.erase(funcOp);
113 // Caller of the function.
114 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
115 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
116 // TODO: Extend this for ops that are FunctionLike. This would require
117 // creating an OpInterface for FunctionLike ops.
118 FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
119 for (FuncOp &funcOp : normalizableFuncs) {
120 if (parentFuncOp == funcOp) {
121 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
122 normalizableFuncs);
123 break;
124 }
125 }
126 }
127
128 // Functions called by this function.
129 funcOp.walk([&](CallOp callOp) {
130 StringRef callee = callOp.getCallee();
131 for (FuncOp &funcOp : normalizableFuncs) {
132 // We compare FuncOp and callee's name.
133 if (callee == funcOp.getName()) {
134 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
135 normalizableFuncs);
136 break;
137 }
138 }
139 });
140 }
141
142 /// Check whether all the uses of AllocOps, CallOps and function arguments of a
143 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp
144 /// or ReturnOp. Only if these constraints are satisfied will the function
145 /// become a candidate for normalization. We follow a conservative approach here
146 /// wherein even if the non-normalizable memref is not a part of the function's
147 /// argument or return type, we still label the entire function as
148 /// non-normalizable. We assume external functions to be normalizable.
areMemRefsNormalizable(FuncOp funcOp)149 bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
150 // We assume external functions to be normalizable.
151 if (funcOp.isExternal())
152 return true;
153
154 if (funcOp
155 .walk([&](AllocOp allocOp) -> WalkResult {
156 Value oldMemRef = allocOp.getResult();
157 if (!isMemRefNormalizable(oldMemRef.getUsers()))
158 return WalkResult::interrupt();
159 return WalkResult::advance();
160 })
161 .wasInterrupted())
162 return false;
163
164 if (funcOp
165 .walk([&](CallOp callOp) -> WalkResult {
166 for (unsigned resIndex :
167 llvm::seq<unsigned>(0, callOp.getNumResults())) {
168 Value oldMemRef = callOp.getResult(resIndex);
169 if (oldMemRef.getType().isa<MemRefType>())
170 if (!isMemRefNormalizable(oldMemRef.getUsers()))
171 return WalkResult::interrupt();
172 }
173 return WalkResult::advance();
174 })
175 .wasInterrupted())
176 return false;
177
178 for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
179 BlockArgument oldMemRef = funcOp.getArgument(argIndex);
180 if (oldMemRef.getType().isa<MemRefType>())
181 if (!isMemRefNormalizable(oldMemRef.getUsers()))
182 return false;
183 }
184
185 return true;
186 }
187
188 /// Fetch the updated argument list and result of the function and update the
189 /// function signature. This updates the function's return type at the caller
190 /// site and in case the return type is a normalized memref then it updates
191 /// the calling function's signature.
192 /// TODO: An update to the calling function signature is required only if the
193 /// returned value is in turn used in ReturnOp of the calling function.
updateFunctionSignature(FuncOp funcOp,ModuleOp moduleOp)194 void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
195 ModuleOp moduleOp) {
196 FunctionType functionType = funcOp.getType();
197 SmallVector<Type, 4> resultTypes;
198 FunctionType newFuncType;
199 resultTypes = llvm::to_vector<4>(functionType.getResults());
200
201 // External function's signature was already updated in
202 // 'normalizeFuncOpMemRefs()'.
203 if (!funcOp.isExternal()) {
204 SmallVector<Type, 8> argTypes;
205 for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
206 argTypes.push_back(argEn.value().getType());
207
208 // Traverse ReturnOps to check if an update to the return type in the
209 // function signature is required.
210 funcOp.walk([&](ReturnOp returnOp) {
211 for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
212 Type opType = operandEn.value().getType();
213 MemRefType memrefType = opType.dyn_cast<MemRefType>();
214 // If type is not memref or if the memref type is same as that in
215 // function's return signature then no update is required.
216 if (!memrefType || memrefType == resultTypes[operandEn.index()])
217 continue;
218 // Update function's return type signature.
219 // Return type gets normalized either as a result of function argument
220 // normalization, AllocOp normalization or an update made at CallOp.
221 // There can be many call flows inside a function and an update to a
222 // specific ReturnOp has not yet been made. So we check that the result
223 // memref type is normalized.
224 // TODO: When selective normalization is implemented, handle multiple
225 // results case where some are normalized, some aren't.
226 if (memrefType.getAffineMaps().empty())
227 resultTypes[operandEn.index()] = memrefType;
228 }
229 });
230
231 // We create a new function type and modify the function signature with this
232 // new type.
233 newFuncType = FunctionType::get(/*inputs=*/argTypes,
234 /*results=*/resultTypes,
235 /*context=*/&getContext());
236 }
237
238 // Since we update the function signature, it might affect the result types at
239 // the caller site. Since this result might even be used by the caller
240 // function in ReturnOps, the caller function's signature will also change.
241 // Hence we record the caller function in 'funcOpsToUpdate' to update their
242 // signature as well.
243 llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
244 // We iterate over all symbolic uses of the function and update the return
245 // type at the caller site.
246 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
247 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
248 Operation *userOp = symbolUse.getUser();
249 OpBuilder builder(userOp);
250 // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
251 // that the non-CallOp has no memrefs to be replaced.
252 // TODO: Handle cases where a non-CallOp symbol use of a function deals with
253 // memrefs.
254 auto callOp = dyn_cast<CallOp>(userOp);
255 if (!callOp)
256 continue;
257 StringRef callee = callOp.getCallee();
258 Operation *newCallOp = builder.create<CallOp>(
259 userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
260 userOp->getOperands());
261 bool replacingMemRefUsesFailed = false;
262 bool returnTypeChanged = false;
263 for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
264 OpResult oldResult = userOp->getResult(resIndex);
265 OpResult newResult = newCallOp->getResult(resIndex);
266 // This condition ensures that if the result is not of type memref or if
267 // the resulting memref was already having a trivial map layout then we
268 // need not perform any use replacement here.
269 if (oldResult.getType() == newResult.getType())
270 continue;
271 AffineMap layoutMap =
272 oldResult.getType().dyn_cast<MemRefType>().getAffineMaps().front();
273 if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
274 /*extraIndices=*/{},
275 /*indexRemap=*/layoutMap,
276 /*extraOperands=*/{},
277 /*symbolOperands=*/{},
278 /*domInstFilter=*/nullptr,
279 /*postDomInstFilter=*/nullptr,
280 /*allowDereferencingOps=*/true,
281 /*replaceInDeallocOp=*/true))) {
282 // If it failed (due to escapes for example), bail out.
283 // It should never hit this part of the code because it is called by
284 // only those functions which are normalizable.
285 newCallOp->erase();
286 replacingMemRefUsesFailed = true;
287 break;
288 }
289 returnTypeChanged = true;
290 }
291 if (replacingMemRefUsesFailed)
292 continue;
293 // Replace all uses for other non-memref result types.
294 userOp->replaceAllUsesWith(newCallOp);
295 userOp->erase();
296 if (returnTypeChanged) {
297 // Since the return type changed it might lead to a change in function's
298 // signature.
299 // TODO: If funcOp doesn't return any memref type then no need to update
300 // signature.
301 // TODO: Further optimization - Check if the memref is indeed part of
302 // ReturnOp at the parentFuncOp and only then updation of signature is
303 // required.
304 // TODO: Extend this for ops that are FunctionLike. This would require
305 // creating an OpInterface for FunctionLike ops.
306 FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
307 funcOpsToUpdate.insert(parentFuncOp);
308 }
309 }
310 // Because external function's signature is already updated in
311 // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
312 if (!funcOp.isExternal())
313 funcOp.setType(newFuncType);
314
315 // Updating the signature type of those functions which call the current
316 // function. Only if the return type of the current function has a normalized
317 // memref will the caller function become a candidate for signature update.
318 for (FuncOp parentFuncOp : funcOpsToUpdate)
319 updateFunctionSignature(parentFuncOp, moduleOp);
320 }
321
322 /// Normalizes the memrefs within a function which includes those arising as a
323 /// result of AllocOps, CallOps and function's argument. The ModuleOp argument
324 /// is used to help update function's signature after normalization.
normalizeFuncOpMemRefs(FuncOp funcOp,ModuleOp moduleOp)325 void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
326 ModuleOp moduleOp) {
327 // Turn memrefs' non-identity layouts maps into ones with identity. Collect
328 // alloc ops first and then process since normalizeMemRef replaces/erases ops
329 // during memref rewriting.
330 SmallVector<AllocOp, 4> allocOps;
331 funcOp.walk([&](AllocOp op) { allocOps.push_back(op); });
332 for (AllocOp allocOp : allocOps)
333 normalizeMemRef(allocOp);
334
335 // We use this OpBuilder to create new memref layout later.
336 OpBuilder b(funcOp);
337
338 FunctionType functionType = funcOp.getType();
339 SmallVector<Type, 8> inputTypes;
340 // Walk over each argument of a function to perform memref normalization (if
341 for (unsigned argIndex :
342 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
343 Type argType = functionType.getInput(argIndex);
344 MemRefType memrefType = argType.dyn_cast<MemRefType>();
345 // Check whether argument is of MemRef type. Any other argument type can
346 // simply be part of the final function signature.
347 if (!memrefType) {
348 inputTypes.push_back(argType);
349 continue;
350 }
351 // Fetch a new memref type after normalizing the old memref to have an
352 // identity map layout.
353 MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
354 /*numSymbolicOperands=*/0);
355 if (newMemRefType == memrefType || funcOp.isExternal()) {
356 // Either memrefType already had an identity map or the map couldn't be
357 // transformed to an identity map.
358 inputTypes.push_back(newMemRefType);
359 continue;
360 }
361
362 // Insert a new temporary argument with the new memref type.
363 BlockArgument newMemRef =
364 funcOp.front().insertArgument(argIndex, newMemRefType);
365 BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
366 AffineMap layoutMap = memrefType.getAffineMaps().front();
367 // Replace all uses of the old memref.
368 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
369 /*extraIndices=*/{},
370 /*indexRemap=*/layoutMap,
371 /*extraOperands=*/{},
372 /*symbolOperands=*/{},
373 /*domInstFilter=*/nullptr,
374 /*postDomInstFilter=*/nullptr,
375 /*allowNonDereferencingOps=*/true,
376 /*replaceInDeallocOp=*/true))) {
377 // If it failed (due to escapes for example), bail out. Removing the
378 // temporary argument inserted previously.
379 funcOp.front().eraseArgument(argIndex);
380 continue;
381 }
382
383 // All uses for the argument with old memref type were replaced
384 // successfully. So we remove the old argument now.
385 funcOp.front().eraseArgument(argIndex + 1);
386 }
387
388 // Walk over normalizable operations to normalize memrefs of the operation
389 // results. When `op` has memrefs with affine map in the operation results,
390 // new operation containin normalized memrefs is created. Then, the memrefs
391 // are replaced. `CallOp` is skipped here because it is handled in
392 // `updateFunctionSignature()`.
393 funcOp.walk([&](Operation *op) {
394 if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
395 op->getNumResults() > 0 && !isa<CallOp>(op) && !funcOp.isExternal()) {
396 // Create newOp containing normalized memref in the operation result.
397 Operation *newOp = createOpResultsNormalized(funcOp, op);
398 // When all of the operation results have no memrefs or memrefs without
399 // affine map, `newOp` is the same with `op` and following process is
400 // skipped.
401 if (op != newOp) {
402 bool replacingMemRefUsesFailed = false;
403 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
404 // Replace all uses of the old memrefs.
405 Value oldMemRef = op->getResult(resIndex);
406 Value newMemRef = newOp->getResult(resIndex);
407 MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
408 // Check whether the operation result is MemRef type.
409 if (!oldMemRefType)
410 continue;
411 MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
412 if (oldMemRefType == newMemRefType)
413 continue;
414 // TODO: Assume single layout map. Multiple maps not supported.
415 AffineMap layoutMap = oldMemRefType.getAffineMaps().front();
416 if (failed(replaceAllMemRefUsesWith(oldMemRef,
417 /*newMemRef=*/newMemRef,
418 /*extraIndices=*/{},
419 /*indexRemap=*/layoutMap,
420 /*extraOperands=*/{},
421 /*symbolOperands=*/{},
422 /*domInstFilter=*/nullptr,
423 /*postDomInstFilter=*/nullptr,
424 /*allowDereferencingOps=*/true,
425 /*replaceInDeallocOp=*/true))) {
426 newOp->erase();
427 replacingMemRefUsesFailed = true;
428 continue;
429 }
430 }
431 if (!replacingMemRefUsesFailed) {
432 // Replace other ops with new op and delete the old op when the
433 // replacement succeeded.
434 op->replaceAllUsesWith(newOp);
435 op->erase();
436 }
437 }
438 }
439 });
440
441 // In a normal function, memrefs in the return type signature gets normalized
442 // as a result of normalization of functions arguments, AllocOps or CallOps'
443 // result types. Since an external function doesn't have a body, memrefs in
444 // the return type signature can only get normalized by iterating over the
445 // individual return types.
446 if (funcOp.isExternal()) {
447 SmallVector<Type, 4> resultTypes;
448 for (unsigned resIndex :
449 llvm::seq<unsigned>(0, functionType.getNumResults())) {
450 Type resType = functionType.getResult(resIndex);
451 MemRefType memrefType = resType.dyn_cast<MemRefType>();
452 // Check whether result is of MemRef type. Any other argument type can
453 // simply be part of the final function signature.
454 if (!memrefType) {
455 resultTypes.push_back(resType);
456 continue;
457 }
458 // Computing a new memref type after normalizing the old memref to have an
459 // identity map layout.
460 MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
461 /*numSymbolicOperands=*/0);
462 resultTypes.push_back(newMemRefType);
463 continue;
464 }
465
466 FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
467 /*results=*/resultTypes,
468 /*context=*/&getContext());
469 // Setting the new function signature for this external function.
470 funcOp.setType(newFuncType);
471 }
472 updateFunctionSignature(funcOp, moduleOp);
473 }
474
475 /// Create an operation containing normalized memrefs in the operation results.
476 /// When the results of `oldOp` have memrefs with affine map, the memrefs are
477 /// normalized, and new operation containing them in the operation results is
478 /// returned. If all of the results of `oldOp` have no memrefs or memrefs
479 /// without affine map, `oldOp` is returned without modification.
createOpResultsNormalized(FuncOp funcOp,Operation * oldOp)480 Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
481 Operation *oldOp) {
482 // Prepare OperationState to create newOp containing normalized memref in
483 // the operation results.
484 OperationState result(oldOp->getLoc(), oldOp->getName());
485 result.addOperands(oldOp->getOperands());
486 result.addAttributes(oldOp->getAttrs());
487 // Add normalized MemRefType to the OperationState.
488 SmallVector<Type, 4> resultTypes;
489 OpBuilder b(funcOp);
490 bool resultTypeNormalized = false;
491 for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
492 auto resultType = oldOp->getResult(resIndex).getType();
493 MemRefType memrefType = resultType.dyn_cast<MemRefType>();
494 // Check whether the operation result is MemRef type.
495 if (!memrefType) {
496 resultTypes.push_back(resultType);
497 continue;
498 }
499 // Fetch a new memref type after normalizing the old memref.
500 MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
501 /*numSymbolicOperands=*/0);
502 if (newMemRefType == memrefType) {
503 // Either memrefType already had an identity map or the map couldn't
504 // be transformed to an identity map.
505 resultTypes.push_back(memrefType);
506 continue;
507 }
508 resultTypes.push_back(newMemRefType);
509 resultTypeNormalized = true;
510 }
511 result.addTypes(resultTypes);
512 // When all of the results of `oldOp` have no memrefs or memrefs without
513 // affine map, `oldOp` is returned without modification.
514 if (resultTypeNormalized) {
515 OpBuilder bb(oldOp);
516 return bb.createOperation(result);
517 } else
518 return oldOp;
519 }
520