• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 License for the specific language governing permissions and limitations under
13 the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
24 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
25 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
26 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/core/lib/core/status.h"
34 
35 namespace xla {
36 
37 // Class which computes live range of the output buffers of HLOs and their
38 // interference by flattening all computations. The live range is only available
39 // when all global computations (while, if, call, etc) have total order
40 // sequential orders.
41 class HloLiveRange {
42  public:
43   // Constructs a hlo live range object for the given module and computation
44   // assuming the given HLO instruction ordering.
45   static StatusOr<std::unique_ptr<HloLiveRange>> Run(
46       const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
47       const HloComputation* computation, bool module_scoped_analysis = true);
48 
49   // LogicalTime represents the time in a virtual clock. Each instruction has
50   // one monotonically increasing logical time assigned according to the
51   // schedule.
52   using LogicalTime = int64;
53 
54   struct TimeBound {
55     LogicalTime start;
56     LogicalTime end;
57     // The buffer can hold multiple instructions during its life time (each
58     // tenant exclusively owns the buffer at any given time). `end_instruction`
59     // represents the last instruction that the buffer holds.
60     HloPosition end_position;
61 
62     bool friend operator==(const TimeBound& a, const TimeBound& b) {
63       return a.start == b.start && a.end == b.end;
64     }
65     bool friend operator!=(const TimeBound& a, const TimeBound& b) {
66       return !(a == b);
67     }
68   };
69 
70   std::string ToString() const;
71 
flattened_instruction_sequence()72   const HloInstructionSequence& flattened_instruction_sequence() const {
73     return flattened_instruction_sequence_;
74   }
75 
76   // Returns the map from instruction to the end time of that instruction.
77   const absl::flat_hash_map<const HloInstruction*, LogicalTime>&
instruction_schedule()78   instruction_schedule() const {
79     return instruction_schedule_;
80   }
81 
82   // Returns the map from a hlo value to the definition time of that hlo value.
buffer_live_ranges()83   const absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges()
84       const {
85     return buffer_live_ranges_;
86   }
87 
buffer_live_ranges()88   absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() {
89     return buffer_live_ranges_;
90   }
91 
92   // Returns the map from a computation and its time span in the schedule.
93   const absl::flat_hash_map<const HloComputation*, TimeBound>&
computation_span_times()94   computation_span_times() const {
95     return computation_span_times_;
96   }
97 
98   // Returns the time stamp of the end of the program.
schedule_end_time()99   LogicalTime schedule_end_time() const { return schedule_end_time_; }
100 
101   // Returns whether hlo live range is available on this entire module. Hlo live
102   // range is not available if the module is partially ordered.
total_order_scheduled()103   bool total_order_scheduled() const { return total_order_scheduled_; }
104 
105  private:
HloLiveRange(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,bool module_scoped_analysis)106   explicit HloLiveRange(const HloSchedule& schedule,
107                         const HloAliasAnalysis& alias_analysis,
108                         bool module_scoped_analysis)
109       : schedule_(schedule),
110         alias_analysis_(alias_analysis),
111         module_scoped_analysis_(module_scoped_analysis) {}
112 
113   // FlattenSchedule walks through the instructions in `computation`, and
114   // recurse into each called computations in module_scoped_analysis mode. As it
115   // walks it also tracks down the ordinal number of each instruction in the
116   // schedule and store it in the `instruction_schedule` and
117   // 'flattened_instruction_sequence`. The end of each computation is tracked in
118   // `computation_end_time`.
119   int64 FlattenSchedule(const HloComputation& computation, int64_t start_time);
120 
121   // Based on the flattened schedule, calculate the start and end of each
122   // buffer.
123   void CalculateBufferStartEndMap();
124 
125   // The aliased buffers could have overlapping live ranges.
126   // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer
127   // has disjoint live range while keeping the live range union the same. This
128   // avoid double counting aliased buffer sizes.
129   //
130   // Before(buffer1 and 2 are aliased):
131   //
132   //           +----+          live range of buffer1
133   //   +------------------+    live range of buffer2
134   //
135   // After:
136   //
137   //           +----------+    live range of buffer1
138   //   +------+                live range of buffer2
139   //
140   // Before(buffer1 and 2 are aliased):
141   //
142   //           +----------+    live range of buffer1
143   //   +------------+          live range of buffer2
144   //
145   // After:
146   //
147   //           +----------+    live range of buffer1
148   //   +------+                live range of buffer2
149   //
150   // Before(buffer1 and 2 are aliased):
151   //
152   //           +----------+    live range of buffer1
153   //   +---+                   live range of buffer2
154   //
155   // After(unchanged):
156   //
157   //           +----------+    live range of buffer1
158   //   +---+                   live range of buffer2
159   //
160   // As another example, imagine we have the following code sequence with live
161   // ranges of each while-aliased buffers:
162   //
163   //                     a      p1    p2    e     b
164   // a = ...             +
165   //                     |
166   // {                   |
167   //   p1 = param        |       +
168   //   ROOT true         |       |
169   // }                   |       +
170   // { // body           |
171   //   p2 = param        +             +
172   //   c = p2 + 1                      +
173   //   d = c + 1
174   //   ROOT e = d + 1                       +
175   // }                                      |
176   //                                        |
177   // b = while (a)                          +     +
178   //                                              |
179   // f = b + 1                                    +
180   //
181   // After normalization it becomes:
182   //
183   //                     a      p1    p2    e     b
184   // a = ...             +
185   //                     |
186   // {                   +
187   //   p1 = param                +
188   //   ROOT true                 |
189   // }                           +
190   // { // body
191   //   p2 = param                      +
192   //   c = p2 + 1                      +
193   //   d = c + 1
194   //   ROOT e = d + 1                       +
195   // }                                      |
196   //                                        |
197   // b = while (a)                          +
198   //                                              +
199   // f = b + 1                                    +
200   //
201   // Note there is no overlap of live ranges after normalization.
202   void NormalizeAliasedBuffers();
203 
204   int64 ComputePeakMemoryMoment() const;
205 
206   const HloSchedule& schedule_;
207   const HloAliasAnalysis& alias_analysis_;
208   bool module_scoped_analysis_;
209   bool total_order_scheduled_ = true;
210 
211   HloInstructionSequence flattened_instruction_sequence_;
212   absl::flat_hash_map<const HloInstruction*, int64> instruction_schedule_;
213   absl::flat_hash_map<const HloComputation*, TimeBound> computation_span_times_;
214   absl::flat_hash_map<const HloValue*, TimeBound> buffer_live_ranges_;
215   LogicalTime schedule_end_time_;
216 };
217 
218 }  // namespace xla
219 
220 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_
221