• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 License for the specific language governing permissions and limitations under
13 the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
25 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
27 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
31 #include "tensorflow/compiler/xla/service/hlo_value.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/lib/core/status.h"
35 
36 namespace xla {
37 
38 // Class which computes live range of the output buffers of HLOs and their
39 // interference by flattening all computations. The live range is only available
40 // when all global computations (while, if, call, etc) have total order
41 // sequential orders.
42 class HloLiveRange {
43  public:
44   // Constructs a hlo live range object for the given module and computation
45   // assuming the given HLO instruction ordering.
46   static StatusOr<std::unique_ptr<HloLiveRange>> Run(
47       const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
48       const HloComputation* computation, bool module_scoped_analysis = true);
49 
50   // LogicalTime represents the time in a virtual clock. Each instruction has
51   // one monotonically increasing logical time assigned according to the
52   // schedule.
53   using LogicalTime = int64_t;
54 
55   struct TimeBound {
56     LogicalTime start;
57     LogicalTime end;
58     // The buffer can hold multiple instructions during its life time (each
59     // tenant exclusively owns the buffer at any given time). `end_instruction`
60     // represents the last instruction that the buffer holds.
61     HloPosition end_position;
62 
63     bool friend operator==(const TimeBound& a, const TimeBound& b) {
64       return a.start == b.start && a.end == b.end;
65     }
66     bool friend operator!=(const TimeBound& a, const TimeBound& b) {
67       return !(a == b);
68     }
69   };
70 
71   std::string ToString() const;
72 
flattened_instruction_sequence()73   const HloInstructionSequence& flattened_instruction_sequence() const {
74     return flattened_instruction_sequence_;
75   }
76 
77   // Returns the map from instruction to the end time of that instruction.
78   const absl::flat_hash_map<const HloInstruction*, LogicalTime>&
instruction_schedule()79   instruction_schedule() const {
80     return instruction_schedule_;
81   }
82 
83   // Returns the map from a hlo value to the definition time of that hlo value.
buffer_live_ranges()84   const absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges()
85       const {
86     return buffer_live_ranges_;
87   }
88 
buffer_live_ranges()89   absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() {
90     return buffer_live_ranges_;
91   }
92 
93   // Returns the map from a computation and its time span in the schedule.
94   const absl::flat_hash_map<const HloComputation*, TimeBound>&
computation_span_times()95   computation_span_times() const {
96     return computation_span_times_;
97   }
98 
99   // Returns the time stamp of the end of the program.
schedule_end_time()100   LogicalTime schedule_end_time() const {
101     return flattened_instruction_sequence_.size();
102   }
103 
104   // Returns whether hlo live range is available on this entire module. Hlo live
105   // range is not available if the module is partially ordered.
total_order_scheduled()106   bool total_order_scheduled() const { return total_order_scheduled_; }
107 
108  private:
HloLiveRange(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,bool module_scoped_analysis)109   explicit HloLiveRange(const HloSchedule& schedule,
110                         const HloAliasAnalysis& alias_analysis,
111                         bool module_scoped_analysis)
112       : schedule_(schedule),
113         alias_analysis_(alias_analysis),
114         module_scoped_analysis_(module_scoped_analysis) {}
115 
116   // FlattenSchedule walks through the instructions in `computation`, and
117   // recurse into each called computations in module_scoped_analysis mode. As it
118   // walks it also tracks down the ordinal number of each instruction in the
119   // schedule and store it in the `instruction_schedule` and
120   // 'flattened_instruction_sequence`. async_context contains the asynchronous
121   // computation that this computation is in, if any. When this value is
122   // non-null, it means that this computation is called by an async op or
123   // another op in an asynchronous context.
124   void FlattenSchedule(const HloComputation& computation,
125                        const HloComputation* async_context = nullptr);
126 
127   // Returns the last position of a value.
128   TimeBound GetLastPosition(const HloValue& value,
129                             LogicalTime definition_end_time) const;
130 
131   // Returns the time of the last use of a value.
132   LogicalTime GetLastUsageTime(const HloValue& value) const;
133 
134   // Based on the flattened schedule, calculate the start and end of each
135   // buffer.
136   void CalculateBufferStartEndMap();
137 
138   // The aliased buffers could have overlapping live ranges.
139   // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer
140   // has disjoint live range while keeping the live range union the same. This
141   // avoid double counting aliased buffer sizes.
142   //
143   // Before(buffer1 and 2 are aliased):
144   //
145   //           +----+          live range of buffer1
146   //   +------------------+    live range of buffer2
147   //
148   // After:
149   //
150   //           +----------+    live range of buffer1
151   //   +-------+               live range of buffer2
152   //
153   // Before(buffer1 and 2 are aliased):
154   //
155   //           +----------+    live range of buffer1
156   //   +------------+          live range of buffer2
157   //
158   // After:
159   //
160   //           +----------+    live range of buffer1
161   //   +-------+               live range of buffer2
162   //
163   // Before(buffer1 and 2 are aliased):
164   //
165   //           +----------+    live range of buffer1
166   //   +---+                   live range of buffer2
167   //
168   // After(unchanged):
169   //
170   //           +----------+    live range of buffer1
171   //   +---+                   live range of buffer2
172   //
173   // As another example, imagine we have the following code sequence with live
174   // ranges of each while-aliased buffers:
175   //
176   //                     a      p1    p2    e     b
177   // a = ...             +
178   //                     |
179   // {                   |
180   //   p1 = param        |       +
181   //   ROOT true         |       |
182   // }                   |       +
183   // { // body           |
184   //   p2 = param        +             +
185   //   c = p2 + 1                      +
186   //   d = c + 1
187   //   ROOT e = d + 1                       +
188   // }                                      |
189   //                                        |
190   // b = while (a)                          +     +
191   //                                              |
192   // f = b + 1                                    +
193   //
194   // After normalization it becomes:
195   //
196   //                     a      p1    p2    e     b
197   // a = ...             +
198   //                     |
199   // {                   |
200   //   p1 = param        +       +
201   //   ROOT true                 |
202   // }                           |
203   // { // body                   |
204   //   p2 = param                +     +
205   //   c = p2 + 1                      +
206   //   d = c + 1
207   //   ROOT e = d + 1                       +
208   // }                                      |
209   //                                        |
210   // b = while (a)                          +     +
211   //                                              |
212   // f = b + 1                                    +
213   //
214   // Note there is no overlap of live ranges after normalization.
215   void NormalizeAliasedBuffers();
216 
217   LogicalTime ComputePeakMemoryMoment() const;
218 
219   const HloSchedule& schedule_;
220   const HloAliasAnalysis& alias_analysis_;
221   bool module_scoped_analysis_;
222   bool total_order_scheduled_ = true;
223 
224   HloInstructionSequence flattened_instruction_sequence_;
225   absl::flat_hash_map<const HloInstruction*, LogicalTime> instruction_schedule_;
226   absl::flat_hash_map<const HloComputation*, TimeBound> computation_span_times_;
227   absl::flat_hash_map<const HloValue*, TimeBound> buffer_live_ranges_;
228   absl::flat_hash_map<const HloComputation*, const HloComputation*>
229       computations_in_async_context_;
230 };
231 
232 }  // namespace xla
233 
234 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
235