1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 License for the specific language governing permissions and limitations under 13 the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 17 18 #include <memory> 19 #include <string> 20 #include <utility> 21 22 #include "absl/container/flat_hash_set.h" 23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 24 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 25 #include "tensorflow/compiler/xla/service/hlo_buffer.h" 26 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 30 #include "tensorflow/compiler/xla/service/hlo_value.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/core/lib/core/status.h" 34 35 namespace xla { 36 37 // Class which computes live range of the output buffers of HLOs and their 38 // interference by flattening all computations. The live range is only available 39 // when all global computations (while, if, call, etc) have total order 40 // sequential orders. 41 class HloLiveRange { 42 public: 43 // Constructs a hlo live range object for the given module and computation 44 // assuming the given HLO instruction ordering. 45 static StatusOr<std::unique_ptr<HloLiveRange>> Run( 46 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, 47 const HloComputation* computation, bool module_scoped_analysis = true); 48 49 // LogicalTime represents the time in a virtual clock. Each instruction has 50 // one monotonically increasing logical time assigned according to the 51 // schedule. 52 using LogicalTime = int64; 53 54 struct TimeBound { 55 LogicalTime start; 56 LogicalTime end; 57 // The buffer can hold multiple instructions during its life time (each 58 // tenant exclusively owns the buffer at any given time). `end_instruction` 59 // represents the last instruction that the buffer holds. 60 HloPosition end_position; 61 62 bool friend operator==(const TimeBound& a, const TimeBound& b) { 63 return a.start == b.start && a.end == b.end; 64 } 65 bool friend operator!=(const TimeBound& a, const TimeBound& b) { 66 return !(a == b); 67 } 68 }; 69 70 std::string ToString() const; 71 flattened_instruction_sequence()72 const HloInstructionSequence& flattened_instruction_sequence() const { 73 return flattened_instruction_sequence_; 74 } 75 76 // Returns the map from instruction to the end time of that instruction. 77 const absl::flat_hash_map<const HloInstruction*, LogicalTime>& instruction_schedule()78 instruction_schedule() const { 79 return instruction_schedule_; 80 } 81 82 // Returns the map from a hlo value to the definition time of that hlo value. buffer_live_ranges()83 const absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() 84 const { 85 return buffer_live_ranges_; 86 } 87 buffer_live_ranges()88 absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() { 89 return buffer_live_ranges_; 90 } 91 92 // Returns the map from a computation and its time span in the schedule. 93 const absl::flat_hash_map<const HloComputation*, TimeBound>& computation_span_times()94 computation_span_times() const { 95 return computation_span_times_; 96 } 97 98 // Returns the time stamp of the end of the program. schedule_end_time()99 LogicalTime schedule_end_time() const { return schedule_end_time_; } 100 101 // Returns whether hlo live range is available on this entire module. Hlo live 102 // range is not available if the module is partially ordered. total_order_scheduled()103 bool total_order_scheduled() const { return total_order_scheduled_; } 104 105 private: HloLiveRange(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,bool module_scoped_analysis)106 explicit HloLiveRange(const HloSchedule& schedule, 107 const HloAliasAnalysis& alias_analysis, 108 bool module_scoped_analysis) 109 : schedule_(schedule), 110 alias_analysis_(alias_analysis), 111 module_scoped_analysis_(module_scoped_analysis) {} 112 113 // FlattenSchedule walks through the instructions in `computation`, and 114 // recurse into each called computations in module_scoped_analysis mode. As it 115 // walks it also tracks down the ordinal number of each instruction in the 116 // schedule and store it in the `instruction_schedule` and 117 // 'flattened_instruction_sequence`. The end of each computation is tracked in 118 // `computation_end_time`. 119 int64 FlattenSchedule(const HloComputation& computation, int64_t start_time); 120 121 // Based on the flattened schedule, calculate the start and end of each 122 // buffer. 123 void CalculateBufferStartEndMap(); 124 125 // The aliased buffers could have overlapping live ranges. 126 // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer 127 // has disjoint live range while keeping the live range union the same. This 128 // avoid double counting aliased buffer sizes. 129 // 130 // Before(buffer1 and 2 are aliased): 131 // 132 // +----+ live range of buffer1 133 // +------------------+ live range of buffer2 134 // 135 // After: 136 // 137 // +----------+ live range of buffer1 138 // +------+ live range of buffer2 139 // 140 // Before(buffer1 and 2 are aliased): 141 // 142 // +----------+ live range of buffer1 143 // +------------+ live range of buffer2 144 // 145 // After: 146 // 147 // +----------+ live range of buffer1 148 // +------+ live range of buffer2 149 // 150 // Before(buffer1 and 2 are aliased): 151 // 152 // +----------+ live range of buffer1 153 // +---+ live range of buffer2 154 // 155 // After(unchanged): 156 // 157 // +----------+ live range of buffer1 158 // +---+ live range of buffer2 159 // 160 // As another example, imagine we have the following code sequence with live 161 // ranges of each while-aliased buffers: 162 // 163 // a p1 p2 e b 164 // a = ... + 165 // | 166 // { | 167 // p1 = param | + 168 // ROOT true | | 169 // } | + 170 // { // body | 171 // p2 = param + + 172 // c = p2 + 1 + 173 // d = c + 1 174 // ROOT e = d + 1 + 175 // } | 176 // | 177 // b = while (a) + + 178 // | 179 // f = b + 1 + 180 // 181 // After normalization it becomes: 182 // 183 // a p1 p2 e b 184 // a = ... + 185 // | 186 // { + 187 // p1 = param + 188 // ROOT true | 189 // } + 190 // { // body 191 // p2 = param + 192 // c = p2 + 1 + 193 // d = c + 1 194 // ROOT e = d + 1 + 195 // } | 196 // | 197 // b = while (a) + 198 // + 199 // f = b + 1 + 200 // 201 // Note there is no overlap of live ranges after normalization. 202 void NormalizeAliasedBuffers(); 203 204 int64 ComputePeakMemoryMoment() const; 205 206 const HloSchedule& schedule_; 207 const HloAliasAnalysis& alias_analysis_; 208 bool module_scoped_analysis_; 209 bool total_order_scheduled_ = true; 210 211 HloInstructionSequence flattened_instruction_sequence_; 212 absl::flat_hash_map<const HloInstruction*, int64> instruction_schedule_; 213 absl::flat_hash_map<const HloComputation*, TimeBound> computation_span_times_; 214 absl::flat_hash_map<const HloValue*, TimeBound> buffer_live_ranges_; 215 LogicalTime schedule_end_time_; 216 }; 217 218 } // namespace xla 219 220 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 221