• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H
17 #define MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H
18 
19 #include <vector>
20 
21 #include "mlir/Analysis/Liveness.h"
22 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/Value.h"
25 
26 namespace mlir {
27 
28 /// Representation of an inclusive Interval for the Userange.
29 struct UseInterval {
30   using Vector = SmallVector<UseInterval, 8>;
31 
32  public:
33   /// UseInterval Constructor.
34   UseInterval();
35   /// Empty UseInterval Constructor.
UseIntervalUseInterval36   UseInterval(size_t start, size_t end) : start(start), end(end) {}
37 
38   /// Checks if the given UseInterval overlaps with this UseInterval.
isOverlappingUseInterval39   bool isOverlapping(const UseInterval &other) const {
40     return start <= other.end && end >= other.start;
41   }
42 
43   /// Checks if the given UseInterval is contiguous with this UseInterval in
44   /// terms of doubled Ids.
45   /// For example: (0, 2) and (4, 6) are contiguous where (0, 2) and (5, 6) are
46   ///              not.
isContiguousUseInterval47   bool isContiguous(const UseInterval &other) const {
48     return start <= other.end + 2 && end + 2 >= other.start;
49   }
50 
51   /// Checks if the given position is inside this UseInterval.
containsUseInterval52   bool contains(size_t position) const {
53     return start <= position && end >= position;
54   }
55 
56   /// Merges this UseInterval with the given UseInterval by updating start and
57   /// end.
mergeWithUseInterval58   bool mergeWith(const UseInterval &other) {
59     if (!isContiguous(other)) return false;
60     start = std::min(start, other.start);
61     end = std::max(end, other.end);
62     return true;
63   }
64 
65   /// Performs an interval subtraction => A = A - B.
66   static void intervalSubtract(Vector &a, const Vector &b);
67 
68   /// Performs an interval intersection => A = A ^ B.
69   static void intervalIntersect(Vector &a, const Vector &b);
70 
71   /// Performs an interval merge => A = A u B.
72   /// Note: All overlapping and contiguous UseIntervals are merged.
73   static void intervalMerge(Vector &a, const Vector &b);
74 
75   /// Merge the UseIntervals and erase overlapping and contiguouse UseIntervals
76   /// of the UseInterval::Vector.
77   static void mergeAndEraseContiguousIntervals(Vector &interval,
78                                                UseInterval *iter,
79                                                const UseInterval &toMerge);
80 
81   bool operator<(const UseInterval &other) const { return end < other.start; }
82 
83   bool operator>(const UseInterval &other) const { return start > other.end; }
84 
85   bool operator==(const UseInterval &other) const {
86     return start == other.start && end == other.end;
87   }
88 
89   /// The start of this UseInterval.
90   size_t start;
91 
92   /// The end of this UseInterval.
93   size_t end;
94 };
95 
96 /// Represents an analysis for computing the useranges of all alloc values
97 /// inside a given function operation. The analysis uses liveness information to
98 /// compute intervals starting at the first and ending with the last use of
99 /// every alloc value.
100 class UserangeAnalysis {
101  public:
102   using UsePosition = std::pair<size_t, Operation *>;
103   using UsePositionList = std::vector<UsePosition>;
104 
105   UserangeAnalysis(Operation *op,
106                    const bufferization::BufferPlacementAllocs &allocs,
107                    const BufferViewFlowAnalysis &aliases);
108 
109   /// Returns the index of the first operation that uses the given value or an
110   /// empty Optional if the value has no uses.
getFirstUseIndex(Value value)111   llvm::Optional<size_t> getFirstUseIndex(Value value) const {
112     auto &intervals = useIntervalMap.find(value)->second;
113     if (intervals.empty()) return llvm::None;
114     return intervals.begin()->start;
115   }
116 
117   /// Returns the UseInterval::Vector of the given value.
getUserangeInterval(Value value)118   llvm::Optional<const UseInterval::Vector *> getUserangeInterval(
119       Value value) const {
120     auto intervals = useIntervalMap.find(value);
121     if (intervals == useIntervalMap.end()) return llvm::None;
122     return &intervals->second;
123   }
124 
125   /// Returns an UsePositionList* of the given value or an empty Optional
126   /// if the value has no uses.
getUserangePositions(Value value)127   llvm::Optional<const UsePositionList *> getUserangePositions(
128       Value value) const {
129     auto usePosition = usePositionMap.find(value);
130     if (usePosition == usePositionMap.end() || usePosition->second.empty())
131       return llvm::None;
132     return &usePosition->second;
133   }
134 
135   /// Returns the operation associated with a given Id.
getOperation(size_t id)136   Operation *getOperation(size_t id) const { return operations[unwrapId(id)]; };
137 
138   /// Computes the doubled Id for the given value inside the operation based on
139   /// the program sequence. If the value has only read effects, the returning ID
140   /// will be even, otherwise odd.
141   size_t computeId(Value v, Operation *op) const;
142 
143   /// Checks if the use intervals of the given values interfere.
144   bool rangesInterfere(Value itemA, Value itemB) const;
145 
146   /// Merges the userange of itemB into the userange of itemA.
147   void unionRanges(Value itemA, Value itemB);
148 
149   /// Merges listB into listA, sorts the result and removes all duplicates.
150   static void mergeUsePositions(UsePositionList &listA,
151                                 const UsePositionList &listB);
152 
153   /// Dumps the liveness information to the given stream.
154   void dump(raw_ostream &os);
155 
156  private:
157   using ValueSetT = BufferViewFlowAnalysis::ValueSetT;
158   using OperationListT = Liveness::OperationListT;
159 
160   /// Builds an UseInterval::Vector corresponding to the given OperationList.
161   UseInterval::Vector computeInterval(
162       Value value, const Liveness::OperationListT &operationList);
163 
164   /// Computes the UsePositions of the given Value, sorts and inserts them into
165   /// the usePositionMap.
166   void computeUsePositions(Value v);
167 
168   /// Checks each operand within the operation for its memory effects and
169   /// separates them into read and write.
170   void gatherMemoryEffects(Operation *op);
171 
172   /// Computes the doubled Id back to the OperationId.
173   size_t unwrapId(size_t id) const;
174 
175   /// Maps each Operation to a unique ID according to the program sequence.
176   DenseMap<Operation *, size_t> operationIds;
177 
178   /// Stores all operations according to the program sequence.
179   std::vector<Operation *> operations;
180 
181   /// Maps a value to its UseInterval::Vector.
182   DenseMap<Value, UseInterval::Vector> useIntervalMap;
183 
184   /// Maps an Operation to a pair of read and write Operands.
185   DenseMap<Operation *, std::pair<SmallPtrSet<Value, 2>, SmallPtrSet<Value, 2>>>
186       opReadWriteMap;
187 
188   /// Maps aliasValues to their use ranges. This is necessary to prevent
189   /// recomputations of the use range intervals of the aliases.
190   DenseMap<Value, OperationListT> aliasUseranges;
191 
192   /// Maps a Value to a UsePostionList which contains all uses of the Value and
193   /// their userange position.
194   DenseMap<Value, UsePositionList> usePositionMap;
195 
196   /// Cache the alias lists for all values to avoid recomputation.
197   BufferViewFlowAnalysis::ValueMapT aliasCache;
198 
199   /// The current liveness info.
200   Liveness liveness;
201 };
202 
203 }  // namespace mlir
204 
205 #endif  // MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H
206