• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- NormalizeMemRefs.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an interprocedural pass to normalize memrefs to have
10 // identity layout maps.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Transforms/Passes.h"
17 #include "mlir/Transforms/Utils.h"
18 #include "llvm/ADT/SmallSet.h"
19 
20 #define DEBUG_TYPE "normalize-memrefs"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 /// All memrefs passed across functions with non-trivial layout maps are
27 /// converted to ones with trivial identity layout ones.
28 /// If all the memref types/uses in a function are normalizable, we treat
29 /// such functions as normalizable. Also, if a normalizable function is known
30 /// to call a non-normalizable function, we treat that function as
31 /// non-normalizable as well. We assume external functions to be normalizable.
32 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
33   void runOnOperation() override;
34   void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
35   bool areMemRefsNormalizable(FuncOp funcOp);
36   void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
37   void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
38                                            DenseSet<FuncOp> &normalizableFuncs);
39   Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp);
40 };
41 
42 } // end anonymous namespace
43 
createNormalizeMemRefsPass()44 std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
45   return std::make_unique<NormalizeMemRefs>();
46 }
47 
runOnOperation()48 void NormalizeMemRefs::runOnOperation() {
49   LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
50   ModuleOp moduleOp = getOperation();
51   // We maintain all normalizable FuncOps in a DenseSet. It is initialized
52   // with all the functions within a module and then functions which are not
53   // normalizable are removed from this set.
54   // TODO: Change this to work on FuncLikeOp once there is an operation
55   // interface for it.
56   DenseSet<FuncOp> normalizableFuncs;
57   // Initialize `normalizableFuncs` with all the functions within a module.
58   moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
59 
60   // Traverse through all the functions applying a filter which determines
61   // whether that function is normalizable or not. All callers/callees of
62   // a non-normalizable function will also become non-normalizable even if
63   // they aren't passing any or specific non-normalizable memrefs. So,
64   // functions which calls or get called by a non-normalizable becomes non-
65   // normalizable functions themselves.
66   moduleOp.walk([&](FuncOp funcOp) {
67     if (normalizableFuncs.contains(funcOp)) {
68       if (!areMemRefsNormalizable(funcOp)) {
69         LLVM_DEBUG(llvm::dbgs()
70                    << "@" << funcOp.getName()
71                    << " contains ops that cannot normalize MemRefs\n");
72         // Since this function is not normalizable, we set all the caller
73         // functions and the callees of this function as not normalizable.
74         // TODO: Drop this conservative assumption in the future.
75         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
76                                             normalizableFuncs);
77       }
78     }
79   });
80 
81   LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
82                           << " functions\n");
83   // Those functions which can be normalized are subjected to normalization.
84   for (FuncOp &funcOp : normalizableFuncs)
85     normalizeFuncOpMemRefs(funcOp, moduleOp);
86 }
87 
88 /// Check whether all the uses of oldMemRef are either dereferencing uses or the
89 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
90 /// are satisfied will the value become a candidate for replacement.
91 /// TODO: Extend this for DimOps.
isMemRefNormalizable(Value::user_range opUsers)92 static bool isMemRefNormalizable(Value::user_range opUsers) {
93   if (llvm::any_of(opUsers, [](Operation *op) {
94         if (op->hasTrait<OpTrait::MemRefsNormalizable>())
95           return false;
96         return true;
97       }))
98     return false;
99   return true;
100 }
101 
102 /// Set all the calling functions and the callees of the function as not
103 /// normalizable.
setCalleesAndCallersNonNormalizable(FuncOp funcOp,ModuleOp moduleOp,DenseSet<FuncOp> & normalizableFuncs)104 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
105     FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
106   if (!normalizableFuncs.contains(funcOp))
107     return;
108 
109   LLVM_DEBUG(
110       llvm::dbgs() << "@" << funcOp.getName()
111                    << " calls or is called by non-normalizable function\n");
112   normalizableFuncs.erase(funcOp);
113   // Caller of the function.
114   Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
115   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
116     // TODO: Extend this for ops that are FunctionLike. This would require
117     // creating an OpInterface for FunctionLike ops.
118     FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
119     for (FuncOp &funcOp : normalizableFuncs) {
120       if (parentFuncOp == funcOp) {
121         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
122                                             normalizableFuncs);
123         break;
124       }
125     }
126   }
127 
128   // Functions called by this function.
129   funcOp.walk([&](CallOp callOp) {
130     StringRef callee = callOp.getCallee();
131     for (FuncOp &funcOp : normalizableFuncs) {
132       // We compare FuncOp and callee's name.
133       if (callee == funcOp.getName()) {
134         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
135                                             normalizableFuncs);
136         break;
137       }
138     }
139   });
140 }
141 
142 /// Check whether all the uses of AllocOps, CallOps and function arguments of a
143 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp
144 /// or ReturnOp. Only if these constraints are satisfied will the function
145 /// become a candidate for normalization. We follow a conservative approach here
146 /// wherein even if the non-normalizable memref is not a part of the function's
147 /// argument or return type, we still label the entire function as
148 /// non-normalizable. We assume external functions to be normalizable.
areMemRefsNormalizable(FuncOp funcOp)149 bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
150   // We assume external functions to be normalizable.
151   if (funcOp.isExternal())
152     return true;
153 
154   if (funcOp
155           .walk([&](AllocOp allocOp) -> WalkResult {
156             Value oldMemRef = allocOp.getResult();
157             if (!isMemRefNormalizable(oldMemRef.getUsers()))
158               return WalkResult::interrupt();
159             return WalkResult::advance();
160           })
161           .wasInterrupted())
162     return false;
163 
164   if (funcOp
165           .walk([&](CallOp callOp) -> WalkResult {
166             for (unsigned resIndex :
167                  llvm::seq<unsigned>(0, callOp.getNumResults())) {
168               Value oldMemRef = callOp.getResult(resIndex);
169               if (oldMemRef.getType().isa<MemRefType>())
170                 if (!isMemRefNormalizable(oldMemRef.getUsers()))
171                   return WalkResult::interrupt();
172             }
173             return WalkResult::advance();
174           })
175           .wasInterrupted())
176     return false;
177 
178   for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
179     BlockArgument oldMemRef = funcOp.getArgument(argIndex);
180     if (oldMemRef.getType().isa<MemRefType>())
181       if (!isMemRefNormalizable(oldMemRef.getUsers()))
182         return false;
183   }
184 
185   return true;
186 }
187 
188 /// Fetch the updated argument list and result of the function and update the
189 /// function signature. This updates the function's return type at the caller
190 /// site and in case the return type is a normalized memref then it updates
191 /// the calling function's signature.
192 /// TODO: An update to the calling function signature is required only if the
193 /// returned value is in turn used in ReturnOp of the calling function.
updateFunctionSignature(FuncOp funcOp,ModuleOp moduleOp)194 void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
195                                                ModuleOp moduleOp) {
196   FunctionType functionType = funcOp.getType();
197   SmallVector<Type, 4> resultTypes;
198   FunctionType newFuncType;
199   resultTypes = llvm::to_vector<4>(functionType.getResults());
200 
201   // External function's signature was already updated in
202   // 'normalizeFuncOpMemRefs()'.
203   if (!funcOp.isExternal()) {
204     SmallVector<Type, 8> argTypes;
205     for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
206       argTypes.push_back(argEn.value().getType());
207 
208     // Traverse ReturnOps to check if an update to the return type in the
209     // function signature is required.
210     funcOp.walk([&](ReturnOp returnOp) {
211       for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
212         Type opType = operandEn.value().getType();
213         MemRefType memrefType = opType.dyn_cast<MemRefType>();
214         // If type is not memref or if the memref type is same as that in
215         // function's return signature then no update is required.
216         if (!memrefType || memrefType == resultTypes[operandEn.index()])
217           continue;
218         // Update function's return type signature.
219         // Return type gets normalized either as a result of function argument
220         // normalization, AllocOp normalization or an update made at CallOp.
221         // There can be many call flows inside a function and an update to a
222         // specific ReturnOp has not yet been made. So we check that the result
223         // memref type is normalized.
224         // TODO: When selective normalization is implemented, handle multiple
225         // results case where some are normalized, some aren't.
226         if (memrefType.getAffineMaps().empty())
227           resultTypes[operandEn.index()] = memrefType;
228       }
229     });
230 
231     // We create a new function type and modify the function signature with this
232     // new type.
233     newFuncType = FunctionType::get(/*inputs=*/argTypes,
234                                     /*results=*/resultTypes,
235                                     /*context=*/&getContext());
236   }
237 
238   // Since we update the function signature, it might affect the result types at
239   // the caller site. Since this result might even be used by the caller
240   // function in ReturnOps, the caller function's signature will also change.
241   // Hence we record the caller function in 'funcOpsToUpdate' to update their
242   // signature as well.
243   llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
244   // We iterate over all symbolic uses of the function and update the return
245   // type at the caller site.
246   Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
247   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
248     Operation *userOp = symbolUse.getUser();
249     OpBuilder builder(userOp);
250     // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
251     // that the non-CallOp has no memrefs to be replaced.
252     // TODO: Handle cases where a non-CallOp symbol use of a function deals with
253     // memrefs.
254     auto callOp = dyn_cast<CallOp>(userOp);
255     if (!callOp)
256       continue;
257     StringRef callee = callOp.getCallee();
258     Operation *newCallOp = builder.create<CallOp>(
259         userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
260         userOp->getOperands());
261     bool replacingMemRefUsesFailed = false;
262     bool returnTypeChanged = false;
263     for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
264       OpResult oldResult = userOp->getResult(resIndex);
265       OpResult newResult = newCallOp->getResult(resIndex);
266       // This condition ensures that if the result is not of type memref or if
267       // the resulting memref was already having a trivial map layout then we
268       // need not perform any use replacement here.
269       if (oldResult.getType() == newResult.getType())
270         continue;
271       AffineMap layoutMap =
272           oldResult.getType().dyn_cast<MemRefType>().getAffineMaps().front();
273       if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
274                                           /*extraIndices=*/{},
275                                           /*indexRemap=*/layoutMap,
276                                           /*extraOperands=*/{},
277                                           /*symbolOperands=*/{},
278                                           /*domInstFilter=*/nullptr,
279                                           /*postDomInstFilter=*/nullptr,
280                                           /*allowDereferencingOps=*/true,
281                                           /*replaceInDeallocOp=*/true))) {
282         // If it failed (due to escapes for example), bail out.
283         // It should never hit this part of the code because it is called by
284         // only those functions which are normalizable.
285         newCallOp->erase();
286         replacingMemRefUsesFailed = true;
287         break;
288       }
289       returnTypeChanged = true;
290     }
291     if (replacingMemRefUsesFailed)
292       continue;
293     // Replace all uses for other non-memref result types.
294     userOp->replaceAllUsesWith(newCallOp);
295     userOp->erase();
296     if (returnTypeChanged) {
297       // Since the return type changed it might lead to a change in function's
298       // signature.
299       // TODO: If funcOp doesn't return any memref type then no need to update
300       // signature.
301       // TODO: Further optimization - Check if the memref is indeed part of
302       // ReturnOp at the parentFuncOp and only then updation of signature is
303       // required.
304       // TODO: Extend this for ops that are FunctionLike. This would require
305       // creating an OpInterface for FunctionLike ops.
306       FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
307       funcOpsToUpdate.insert(parentFuncOp);
308     }
309   }
310   // Because external function's signature is already updated in
311   // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
312   if (!funcOp.isExternal())
313     funcOp.setType(newFuncType);
314 
315   // Updating the signature type of those functions which call the current
316   // function. Only if the return type of the current function has a normalized
317   // memref will the caller function become a candidate for signature update.
318   for (FuncOp parentFuncOp : funcOpsToUpdate)
319     updateFunctionSignature(parentFuncOp, moduleOp);
320 }
321 
322 /// Normalizes the memrefs within a function which includes those arising as a
323 /// result of AllocOps, CallOps and function's argument. The ModuleOp argument
324 /// is used to help update function's signature after normalization.
normalizeFuncOpMemRefs(FuncOp funcOp,ModuleOp moduleOp)325 void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
326                                               ModuleOp moduleOp) {
327   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
328   // alloc ops first and then process since normalizeMemRef replaces/erases ops
329   // during memref rewriting.
330   SmallVector<AllocOp, 4> allocOps;
331   funcOp.walk([&](AllocOp op) { allocOps.push_back(op); });
332   for (AllocOp allocOp : allocOps)
333     normalizeMemRef(allocOp);
334 
335   // We use this OpBuilder to create new memref layout later.
336   OpBuilder b(funcOp);
337 
338   FunctionType functionType = funcOp.getType();
339   SmallVector<Type, 8> inputTypes;
340   // Walk over each argument of a function to perform memref normalization (if
341   for (unsigned argIndex :
342        llvm::seq<unsigned>(0, functionType.getNumInputs())) {
343     Type argType = functionType.getInput(argIndex);
344     MemRefType memrefType = argType.dyn_cast<MemRefType>();
345     // Check whether argument is of MemRef type. Any other argument type can
346     // simply be part of the final function signature.
347     if (!memrefType) {
348       inputTypes.push_back(argType);
349       continue;
350     }
351     // Fetch a new memref type after normalizing the old memref to have an
352     // identity map layout.
353     MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
354                                                    /*numSymbolicOperands=*/0);
355     if (newMemRefType == memrefType || funcOp.isExternal()) {
356       // Either memrefType already had an identity map or the map couldn't be
357       // transformed to an identity map.
358       inputTypes.push_back(newMemRefType);
359       continue;
360     }
361 
362     // Insert a new temporary argument with the new memref type.
363     BlockArgument newMemRef =
364         funcOp.front().insertArgument(argIndex, newMemRefType);
365     BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
366     AffineMap layoutMap = memrefType.getAffineMaps().front();
367     // Replace all uses of the old memref.
368     if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
369                                         /*extraIndices=*/{},
370                                         /*indexRemap=*/layoutMap,
371                                         /*extraOperands=*/{},
372                                         /*symbolOperands=*/{},
373                                         /*domInstFilter=*/nullptr,
374                                         /*postDomInstFilter=*/nullptr,
375                                         /*allowNonDereferencingOps=*/true,
376                                         /*replaceInDeallocOp=*/true))) {
377       // If it failed (due to escapes for example), bail out. Removing the
378       // temporary argument inserted previously.
379       funcOp.front().eraseArgument(argIndex);
380       continue;
381     }
382 
383     // All uses for the argument with old memref type were replaced
384     // successfully. So we remove the old argument now.
385     funcOp.front().eraseArgument(argIndex + 1);
386   }
387 
388   // Walk over normalizable operations to normalize memrefs of the operation
389   // results. When `op` has memrefs with affine map in the operation results,
390   // new operation containin normalized memrefs is created. Then, the memrefs
391   // are replaced. `CallOp` is skipped here because it is handled in
392   // `updateFunctionSignature()`.
393   funcOp.walk([&](Operation *op) {
394     if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
395         op->getNumResults() > 0 && !isa<CallOp>(op) && !funcOp.isExternal()) {
396       // Create newOp containing normalized memref in the operation result.
397       Operation *newOp = createOpResultsNormalized(funcOp, op);
398       // When all of the operation results have no memrefs or memrefs without
399       // affine map, `newOp` is the same with `op` and following process is
400       // skipped.
401       if (op != newOp) {
402         bool replacingMemRefUsesFailed = false;
403         for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
404           // Replace all uses of the old memrefs.
405           Value oldMemRef = op->getResult(resIndex);
406           Value newMemRef = newOp->getResult(resIndex);
407           MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
408           // Check whether the operation result is MemRef type.
409           if (!oldMemRefType)
410             continue;
411           MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
412           if (oldMemRefType == newMemRefType)
413             continue;
414           // TODO: Assume single layout map. Multiple maps not supported.
415           AffineMap layoutMap = oldMemRefType.getAffineMaps().front();
416           if (failed(replaceAllMemRefUsesWith(oldMemRef,
417                                               /*newMemRef=*/newMemRef,
418                                               /*extraIndices=*/{},
419                                               /*indexRemap=*/layoutMap,
420                                               /*extraOperands=*/{},
421                                               /*symbolOperands=*/{},
422                                               /*domInstFilter=*/nullptr,
423                                               /*postDomInstFilter=*/nullptr,
424                                               /*allowDereferencingOps=*/true,
425                                               /*replaceInDeallocOp=*/true))) {
426             newOp->erase();
427             replacingMemRefUsesFailed = true;
428             continue;
429           }
430         }
431         if (!replacingMemRefUsesFailed) {
432           // Replace other ops with new op and delete the old op when the
433           // replacement succeeded.
434           op->replaceAllUsesWith(newOp);
435           op->erase();
436         }
437       }
438     }
439   });
440 
441   // In a normal function, memrefs in the return type signature gets normalized
442   // as a result of normalization of functions arguments, AllocOps or CallOps'
443   // result types. Since an external function doesn't have a body, memrefs in
444   // the return type signature can only get normalized by iterating over the
445   // individual return types.
446   if (funcOp.isExternal()) {
447     SmallVector<Type, 4> resultTypes;
448     for (unsigned resIndex :
449          llvm::seq<unsigned>(0, functionType.getNumResults())) {
450       Type resType = functionType.getResult(resIndex);
451       MemRefType memrefType = resType.dyn_cast<MemRefType>();
452       // Check whether result is of MemRef type. Any other argument type can
453       // simply be part of the final function signature.
454       if (!memrefType) {
455         resultTypes.push_back(resType);
456         continue;
457       }
458       // Computing a new memref type after normalizing the old memref to have an
459       // identity map layout.
460       MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
461                                                      /*numSymbolicOperands=*/0);
462       resultTypes.push_back(newMemRefType);
463       continue;
464     }
465 
466     FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
467                                                  /*results=*/resultTypes,
468                                                  /*context=*/&getContext());
469     // Setting the new function signature for this external function.
470     funcOp.setType(newFuncType);
471   }
472   updateFunctionSignature(funcOp, moduleOp);
473 }
474 
475 /// Create an operation containing normalized memrefs in the operation results.
476 /// When the results of `oldOp` have memrefs with affine map, the memrefs are
477 /// normalized, and new operation containing them in the operation results is
478 /// returned. If all of the results of `oldOp` have no memrefs or memrefs
479 /// without affine map, `oldOp` is returned without modification.
createOpResultsNormalized(FuncOp funcOp,Operation * oldOp)480 Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
481                                                        Operation *oldOp) {
482   // Prepare OperationState to create newOp containing normalized memref in
483   // the operation results.
484   OperationState result(oldOp->getLoc(), oldOp->getName());
485   result.addOperands(oldOp->getOperands());
486   result.addAttributes(oldOp->getAttrs());
487   // Add normalized MemRefType to the OperationState.
488   SmallVector<Type, 4> resultTypes;
489   OpBuilder b(funcOp);
490   bool resultTypeNormalized = false;
491   for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
492     auto resultType = oldOp->getResult(resIndex).getType();
493     MemRefType memrefType = resultType.dyn_cast<MemRefType>();
494     // Check whether the operation result is MemRef type.
495     if (!memrefType) {
496       resultTypes.push_back(resultType);
497       continue;
498     }
499     // Fetch a new memref type after normalizing the old memref.
500     MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
501                                                    /*numSymbolicOperands=*/0);
502     if (newMemRefType == memrefType) {
503       // Either memrefType already had an identity map or the map couldn't
504       // be transformed to an identity map.
505       resultTypes.push_back(memrefType);
506       continue;
507     }
508     resultTypes.push_back(newMemRefType);
509     resultTypeNormalized = true;
510   }
511   result.addTypes(resultTypes);
512   // When all of the results of `oldOp` have no memrefs or memrefs without
513   // affine map, `oldOp` is returned without modification.
514   if (resultTypeNormalized) {
515     OpBuilder bb(oldOp);
516     return bb.createOperation(result);
517   } else
518     return oldOp;
519 }
520