1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 License for the specific language governing permissions and limitations under 13 the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 17 18 #include <memory> 19 #include <string> 20 #include <utility> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/container/flat_hash_set.h" 24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 25 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 26 #include "tensorflow/compiler/xla/service/hlo_buffer.h" 27 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 31 #include "tensorflow/compiler/xla/service/hlo_value.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/types.h" 34 #include "tensorflow/core/lib/core/status.h" 35 36 namespace xla { 37 38 // Class which computes live range of the output buffers of HLOs and their 39 // interference by flattening all computations. The live range is only available 40 // when all global computations (while, if, call, etc) have total order 41 // sequential orders. 42 class HloLiveRange { 43 public: 44 // Constructs a hlo live range object for the given module and computation 45 // assuming the given HLO instruction ordering. 46 static StatusOr<std::unique_ptr<HloLiveRange>> Run( 47 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, 48 const HloComputation* computation, bool module_scoped_analysis = true); 49 50 // LogicalTime represents the time in a virtual clock. Each instruction has 51 // one monotonically increasing logical time assigned according to the 52 // schedule. 53 using LogicalTime = int64_t; 54 55 struct TimeBound { 56 LogicalTime start; 57 LogicalTime end; 58 // The buffer can hold multiple instructions during its life time (each 59 // tenant exclusively owns the buffer at any given time). `end_instruction` 60 // represents the last instruction that the buffer holds. 61 HloPosition end_position; 62 63 bool friend operator==(const TimeBound& a, const TimeBound& b) { 64 return a.start == b.start && a.end == b.end; 65 } 66 bool friend operator!=(const TimeBound& a, const TimeBound& b) { 67 return !(a == b); 68 } 69 }; 70 71 std::string ToString() const; 72 flattened_instruction_sequence()73 const HloInstructionSequence& flattened_instruction_sequence() const { 74 return flattened_instruction_sequence_; 75 } 76 77 // Returns the map from instruction to the end time of that instruction. 78 const absl::flat_hash_map<const HloInstruction*, LogicalTime>& instruction_schedule()79 instruction_schedule() const { 80 return instruction_schedule_; 81 } 82 83 // Returns the map from a hlo value to the definition time of that hlo value. buffer_live_ranges()84 const absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() 85 const { 86 return buffer_live_ranges_; 87 } 88 buffer_live_ranges()89 absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() { 90 return buffer_live_ranges_; 91 } 92 93 // Returns the map from a computation and its time span in the schedule. 94 const absl::flat_hash_map<const HloComputation*, TimeBound>& computation_span_times()95 computation_span_times() const { 96 return computation_span_times_; 97 } 98 99 // Returns the time stamp of the end of the program. schedule_end_time()100 LogicalTime schedule_end_time() const { 101 return flattened_instruction_sequence_.size(); 102 } 103 104 // Returns whether hlo live range is available on this entire module. Hlo live 105 // range is not available if the module is partially ordered. total_order_scheduled()106 bool total_order_scheduled() const { return total_order_scheduled_; } 107 108 private: HloLiveRange(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,bool module_scoped_analysis)109 explicit HloLiveRange(const HloSchedule& schedule, 110 const HloAliasAnalysis& alias_analysis, 111 bool module_scoped_analysis) 112 : schedule_(schedule), 113 alias_analysis_(alias_analysis), 114 module_scoped_analysis_(module_scoped_analysis) {} 115 116 // FlattenSchedule walks through the instructions in `computation`, and 117 // recurse into each called computations in module_scoped_analysis mode. As it 118 // walks it also tracks down the ordinal number of each instruction in the 119 // schedule and store it in the `instruction_schedule` and 120 // 'flattened_instruction_sequence`. async_context contains the asynchronous 121 // computation that this computation is in, if any. When this value is 122 // non-null, it means that this computation is called by an async op or 123 // another op in an asynchronous context. 124 void FlattenSchedule(const HloComputation& computation, 125 const HloComputation* async_context = nullptr); 126 127 // Returns the last position of a value. 128 TimeBound GetLastPosition(const HloValue& value, 129 LogicalTime definition_end_time) const; 130 131 // Returns the time of the last use of a value. 132 LogicalTime GetLastUsageTime(const HloValue& value) const; 133 134 // Based on the flattened schedule, calculate the start and end of each 135 // buffer. 136 void CalculateBufferStartEndMap(); 137 138 // The aliased buffers could have overlapping live ranges. 139 // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer 140 // has disjoint live range while keeping the live range union the same. This 141 // avoid double counting aliased buffer sizes. 142 // 143 // Before(buffer1 and 2 are aliased): 144 // 145 // +----+ live range of buffer1 146 // +------------------+ live range of buffer2 147 // 148 // After: 149 // 150 // +----------+ live range of buffer1 151 // +-------+ live range of buffer2 152 // 153 // Before(buffer1 and 2 are aliased): 154 // 155 // +----------+ live range of buffer1 156 // +------------+ live range of buffer2 157 // 158 // After: 159 // 160 // +----------+ live range of buffer1 161 // +-------+ live range of buffer2 162 // 163 // Before(buffer1 and 2 are aliased): 164 // 165 // +----------+ live range of buffer1 166 // +---+ live range of buffer2 167 // 168 // After(unchanged): 169 // 170 // +----------+ live range of buffer1 171 // +---+ live range of buffer2 172 // 173 // As another example, imagine we have the following code sequence with live 174 // ranges of each while-aliased buffers: 175 // 176 // a p1 p2 e b 177 // a = ... + 178 // | 179 // { | 180 // p1 = param | + 181 // ROOT true | | 182 // } | + 183 // { // body | 184 // p2 = param + + 185 // c = p2 + 1 + 186 // d = c + 1 187 // ROOT e = d + 1 + 188 // } | 189 // | 190 // b = while (a) + + 191 // | 192 // f = b + 1 + 193 // 194 // After normalization it becomes: 195 // 196 // a p1 p2 e b 197 // a = ... + 198 // | 199 // { | 200 // p1 = param + + 201 // ROOT true | 202 // } | 203 // { // body | 204 // p2 = param + + 205 // c = p2 + 1 + 206 // d = c + 1 207 // ROOT e = d + 1 + 208 // } | 209 // | 210 // b = while (a) + + 211 // | 212 // f = b + 1 + 213 // 214 // Note there is no overlap of live ranges after normalization. 215 void NormalizeAliasedBuffers(); 216 217 LogicalTime ComputePeakMemoryMoment() const; 218 219 const HloSchedule& schedule_; 220 const HloAliasAnalysis& alias_analysis_; 221 bool module_scoped_analysis_; 222 bool total_order_scheduled_ = true; 223 224 HloInstructionSequence flattened_instruction_sequence_; 225 absl::flat_hash_map<const HloInstruction*, LogicalTime> instruction_schedule_; 226 absl::flat_hash_map<const HloComputation*, TimeBound> computation_span_times_; 227 absl::flat_hash_map<const HloValue*, TimeBound> buffer_live_ranges_; 228 absl::flat_hash_map<const HloComputation*, const HloComputation*> 229 computations_in_async_context_; 230 }; 231 232 } // namespace xla 233 234 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 235