1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilties for buffer optimization passes.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Transforms/BufferUtils.h"
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/LoopLikeInterface.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/SetOperations.h"
23
24 using namespace mlir;
25
26 //===----------------------------------------------------------------------===//
27 // BufferPlacementAllocs
28 //===----------------------------------------------------------------------===//
29
30 /// Get the start operation to place the given alloc value withing the
31 // specified placement block.
getStartOperation(Value allocValue,Block * placementBlock,const Liveness & liveness)32 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
33 Block *placementBlock,
34 const Liveness &liveness) {
35 // We have to ensure that we place the alloc before its first use in this
36 // block.
37 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
38 Operation *startOperation = livenessInfo.getStartOperation(allocValue);
39 // Check whether the start operation lies in the desired placement block.
40 // If not, we will use the terminator as this is the last operation in
41 // this block.
42 if (startOperation->getBlock() != placementBlock) {
43 Operation *opInPlacementBlock =
44 placementBlock->findAncestorOpInBlock(*startOperation);
45 startOperation = opInPlacementBlock ? opInPlacementBlock
46 : placementBlock->getTerminator();
47 }
48
49 return startOperation;
50 }
51
52 /// Finds associated deallocs that can be linked to our allocation nodes (if
53 /// any).
findDealloc(Value allocValue)54 Operation *BufferPlacementAllocs::findDealloc(Value allocValue) {
55 auto userIt = llvm::find_if(allocValue.getUsers(), [&](Operation *user) {
56 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(user);
57 if (!effectInterface)
58 return false;
59 // Try to find a free effect that is applied to one of our values
60 // that will be automatically freed by our pass.
61 SmallVector<MemoryEffects::EffectInstance, 2> effects;
62 effectInterface.getEffectsOnValue(allocValue, effects);
63 return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) {
64 return isa<MemoryEffects::Free>(it.getEffect());
65 });
66 });
67 // Assign the associated dealloc operation (if any).
68 return userIt != allocValue.user_end() ? *userIt : nullptr;
69 }
70
71 /// Initializes the internal list by discovering all supported allocation
72 /// nodes.
BufferPlacementAllocs(Operation * op)73 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
74
75 /// Searches for and registers all supported allocation entries.
build(Operation * op)76 void BufferPlacementAllocs::build(Operation *op) {
77 op->walk([&](MemoryEffectOpInterface opInterface) {
78 // Try to find a single allocation result.
79 SmallVector<MemoryEffects::EffectInstance, 2> effects;
80 opInterface.getEffects(effects);
81
82 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
83 llvm::copy_if(
84 effects, std::back_inserter(allocateResultEffects),
85 [=](MemoryEffects::EffectInstance &it) {
86 Value value = it.getValue();
87 return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
88 value.isa<OpResult>() &&
89 it.getResource() !=
90 SideEffects::AutomaticAllocationScopeResource::get();
91 });
92 // If there is one result only, we will be able to move the allocation and
93 // (possibly existing) deallocation ops.
94 if (allocateResultEffects.size() != 1)
95 return;
96 // Get allocation result.
97 Value allocValue = allocateResultEffects[0].getValue();
98 // Find the associated dealloc value and register the allocation entry.
99 allocs.push_back(std::make_tuple(allocValue, findDealloc(allocValue)));
100 });
101 }
102
103 //===----------------------------------------------------------------------===//
104 // BufferPlacementTransformationBase
105 //===----------------------------------------------------------------------===//
106
107 /// Constructs a new transformation base using the given root operation.
BufferPlacementTransformationBase(Operation * op)108 BufferPlacementTransformationBase::BufferPlacementTransformationBase(
109 Operation *op)
110 : aliases(op), allocs(op), liveness(op) {}
111
112 /// Returns true if the given operation represents a loop by testing whether it
113 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
114 /// the case of a `RegionBranchOpInterface`, it checks all region-based control-
115 /// flow edges for cycles.
isLoop(Operation * op)116 bool BufferPlacementTransformationBase::isLoop(Operation *op) {
117 // If the operation implements the `LoopLikeOpInterface` it can be considered
118 // a loop.
119 if (isa<LoopLikeOpInterface>(op))
120 return true;
121
122 // If the operation does not implement the `RegionBranchOpInterface`, it is
123 // (currently) not possible to detect a loop.
124 RegionBranchOpInterface regionInterface;
125 if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
126 return false;
127
128 // Recurses into a region using the current region interface to find potential
129 // cycles.
130 SmallPtrSet<Region *, 4> visitedRegions;
131 std::function<bool(Region *)> recurse = [&](Region *current) {
132 if (!current)
133 return false;
134 // If we have found a back edge, the parent operation induces a loop.
135 if (!visitedRegions.insert(current).second)
136 return true;
137 // Recurses into all region successors.
138 SmallVector<RegionSuccessor, 2> successors;
139 regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
140 for (RegionSuccessor ®ionEntry : successors)
141 if (recurse(regionEntry.getSuccessor()))
142 return true;
143 return false;
144 };
145
146 // Start with all entry regions and test whether they induce a loop.
147 SmallVector<RegionSuccessor, 2> successorRegions;
148 regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions);
149 for (RegionSuccessor ®ionEntry : successorRegions) {
150 if (recurse(regionEntry.getSuccessor()))
151 return true;
152 visitedRegions.clear();
153 }
154
155 return false;
156 }
157