1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H 17 #define MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H 18 19 #include <vector> 20 21 #include "mlir/Analysis/Liveness.h" 22 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 23 #include "mlir/IR/Operation.h" 24 #include "mlir/IR/Value.h" 25 26 namespace mlir { 27 28 /// Representation of an inclusive Interval for the Userange. 29 struct UseInterval { 30 using Vector = SmallVector<UseInterval, 8>; 31 32 public: 33 /// UseInterval Constructor. 34 UseInterval(); 35 /// Empty UseInterval Constructor. UseIntervalUseInterval36 UseInterval(size_t start, size_t end) : start(start), end(end) {} 37 38 /// Checks if the given UseInterval overlaps with this UseInterval. isOverlappingUseInterval39 bool isOverlapping(const UseInterval &other) const { 40 return start <= other.end && end >= other.start; 41 } 42 43 /// Checks if the given UseInterval is contiguous with this UseInterval in 44 /// terms of doubled Ids. 45 /// For example: (0, 2) and (4, 6) are contiguous where (0, 2) and (5, 6) are 46 /// not. isContiguousUseInterval47 bool isContiguous(const UseInterval &other) const { 48 return start <= other.end + 2 && end + 2 >= other.start; 49 } 50 51 /// Checks if the given position is inside this UseInterval. containsUseInterval52 bool contains(size_t position) const { 53 return start <= position && end >= position; 54 } 55 56 /// Merges this UseInterval with the given UseInterval by updating start and 57 /// end. mergeWithUseInterval58 bool mergeWith(const UseInterval &other) { 59 if (!isContiguous(other)) return false; 60 start = std::min(start, other.start); 61 end = std::max(end, other.end); 62 return true; 63 } 64 65 /// Performs an interval subtraction => A = A - B. 66 static void intervalSubtract(Vector &a, const Vector &b); 67 68 /// Performs an interval intersection => A = A ^ B. 69 static void intervalIntersect(Vector &a, const Vector &b); 70 71 /// Performs an interval merge => A = A u B. 72 /// Note: All overlapping and contiguous UseIntervals are merged. 73 static void intervalMerge(Vector &a, const Vector &b); 74 75 /// Merge the UseIntervals and erase overlapping and contiguouse UseIntervals 76 /// of the UseInterval::Vector. 77 static void mergeAndEraseContiguousIntervals(Vector &interval, 78 UseInterval *iter, 79 const UseInterval &toMerge); 80 81 bool operator<(const UseInterval &other) const { return end < other.start; } 82 83 bool operator>(const UseInterval &other) const { return start > other.end; } 84 85 bool operator==(const UseInterval &other) const { 86 return start == other.start && end == other.end; 87 } 88 89 /// The start of this UseInterval. 90 size_t start; 91 92 /// The end of this UseInterval. 93 size_t end; 94 }; 95 96 /// Represents an analysis for computing the useranges of all alloc values 97 /// inside a given function operation. The analysis uses liveness information to 98 /// compute intervals starting at the first and ending with the last use of 99 /// every alloc value. 100 class UserangeAnalysis { 101 public: 102 using UsePosition = std::pair<size_t, Operation *>; 103 using UsePositionList = std::vector<UsePosition>; 104 105 UserangeAnalysis(Operation *op, 106 const bufferization::BufferPlacementAllocs &allocs, 107 const BufferViewFlowAnalysis &aliases); 108 109 /// Returns the index of the first operation that uses the given value or an 110 /// empty Optional if the value has no uses. getFirstUseIndex(Value value)111 llvm::Optional<size_t> getFirstUseIndex(Value value) const { 112 auto &intervals = useIntervalMap.find(value)->second; 113 if (intervals.empty()) return llvm::None; 114 return intervals.begin()->start; 115 } 116 117 /// Returns the UseInterval::Vector of the given value. getUserangeInterval(Value value)118 llvm::Optional<const UseInterval::Vector *> getUserangeInterval( 119 Value value) const { 120 auto intervals = useIntervalMap.find(value); 121 if (intervals == useIntervalMap.end()) return llvm::None; 122 return &intervals->second; 123 } 124 125 /// Returns an UsePositionList* of the given value or an empty Optional 126 /// if the value has no uses. getUserangePositions(Value value)127 llvm::Optional<const UsePositionList *> getUserangePositions( 128 Value value) const { 129 auto usePosition = usePositionMap.find(value); 130 if (usePosition == usePositionMap.end() || usePosition->second.empty()) 131 return llvm::None; 132 return &usePosition->second; 133 } 134 135 /// Returns the operation associated with a given Id. getOperation(size_t id)136 Operation *getOperation(size_t id) const { return operations[unwrapId(id)]; }; 137 138 /// Computes the doubled Id for the given value inside the operation based on 139 /// the program sequence. If the value has only read effects, the returning ID 140 /// will be even, otherwise odd. 141 size_t computeId(Value v, Operation *op) const; 142 143 /// Checks if the use intervals of the given values interfere. 144 bool rangesInterfere(Value itemA, Value itemB) const; 145 146 /// Merges the userange of itemB into the userange of itemA. 147 void unionRanges(Value itemA, Value itemB); 148 149 /// Merges listB into listA, sorts the result and removes all duplicates. 150 static void mergeUsePositions(UsePositionList &listA, 151 const UsePositionList &listB); 152 153 /// Dumps the liveness information to the given stream. 154 void dump(raw_ostream &os); 155 156 private: 157 using ValueSetT = BufferViewFlowAnalysis::ValueSetT; 158 using OperationListT = Liveness::OperationListT; 159 160 /// Builds an UseInterval::Vector corresponding to the given OperationList. 161 UseInterval::Vector computeInterval( 162 Value value, const Liveness::OperationListT &operationList); 163 164 /// Computes the UsePositions of the given Value, sorts and inserts them into 165 /// the usePositionMap. 166 void computeUsePositions(Value v); 167 168 /// Checks each operand within the operation for its memory effects and 169 /// separates them into read and write. 170 void gatherMemoryEffects(Operation *op); 171 172 /// Computes the doubled Id back to the OperationId. 173 size_t unwrapId(size_t id) const; 174 175 /// Maps each Operation to a unique ID according to the program sequence. 176 DenseMap<Operation *, size_t> operationIds; 177 178 /// Stores all operations according to the program sequence. 179 std::vector<Operation *> operations; 180 181 /// Maps a value to its UseInterval::Vector. 182 DenseMap<Value, UseInterval::Vector> useIntervalMap; 183 184 /// Maps an Operation to a pair of read and write Operands. 185 DenseMap<Operation *, std::pair<SmallPtrSet<Value, 2>, SmallPtrSet<Value, 2>>> 186 opReadWriteMap; 187 188 /// Maps aliasValues to their use ranges. This is necessary to prevent 189 /// recomputations of the use range intervals of the aliases. 190 DenseMap<Value, OperationListT> aliasUseranges; 191 192 /// Maps a Value to a UsePostionList which contains all uses of the Value and 193 /// their userange position. 194 DenseMap<Value, UsePositionList> usePositionMap; 195 196 /// Cache the alias lists for all values to avoid recomputation. 197 BufferViewFlowAnalysis::ValueMapT aliasCache; 198 199 /// The current liveness info. 200 Liveness liveness; 201 }; 202 203 } // namespace mlir 204 205 #endif // MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H 206