1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 License for the specific language governing permissions and limitations under 13 the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 17 18 #include <memory> 19 #include <string> 20 #include <utility> 21 22 #include "absl/container/flat_hash_set.h" 23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 24 #include "tensorflow/compiler/xla/service/hlo_buffer.h" 25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_module.h" 28 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/lib/core/status.h" 32 33 namespace xla { 34 35 // Class which computes live range of the output buffers of HLOs and their 36 // interference by flattening all computations. The live range is only available 37 // when all global computations (while, if, call, etc) have total order 38 // sequential orders. 39 class HloLiveRange { 40 public: 41 // Constructs a hlo live range object for the given module and computation 42 // assuming the given HLO instruction ordering. 43 static StatusOr<std::unique_ptr<HloLiveRange>> Run( 44 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, 45 const HloComputation* computation, bool module_scoped_analysis = true); 46 47 // LogicalTime represents the time in a virtual clock. Each instruction has 48 // one monotonically increasing logical time assigned according to the 49 // schedule. 50 using LogicalTime = int64; 51 52 struct TimeBound { 53 LogicalTime start; 54 LogicalTime end; 55 56 bool friend operator==(const TimeBound& a, const TimeBound& b) { 57 return a.start == b.start && a.end == b.end; 58 } 59 bool friend operator!=(const TimeBound& a, const TimeBound& b) { 60 return !(a == b); 61 } 62 }; 63 64 std::string ToString() const; 65 flattened_instruction_sequence()66 const HloInstructionSequence& flattened_instruction_sequence() const { 67 return flattened_instruction_sequence_; 68 } 69 70 // Returns the map from instruction to the end time of that instruction. 71 const absl::flat_hash_map<const HloInstruction*, LogicalTime>& instruction_schedule()72 instruction_schedule() const { 73 return instruction_schedule_; 74 } 75 76 // Returns the map from a hlo value to the definition time of that hlo value. buffer_live_ranges()77 const absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() 78 const { 79 return buffer_live_ranges_; 80 } 81 buffer_live_ranges()82 absl::flat_hash_map<const HloValue*, TimeBound>& buffer_live_ranges() { 83 return buffer_live_ranges_; 84 } 85 86 // Returns the map from a computation and its time span in the schedule. 87 const absl::flat_hash_map<const HloComputation*, TimeBound>& computation_span_times()88 computation_span_times() const { 89 return computation_span_times_; 90 } 91 92 // Returns the time stamp of the end of the program. schedule_end_time()93 LogicalTime schedule_end_time() const { return schedule_end_time_; } 94 95 // Returns whether hlo live range is available on this entire module. Hlo live 96 // range is not available if the module is partially ordered. total_order_scheduled()97 bool total_order_scheduled() const { return total_order_scheduled_; } 98 99 private: HloLiveRange(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,bool module_scoped_analysis)100 explicit HloLiveRange(const HloSchedule& schedule, 101 const HloAliasAnalysis& alias_analysis, 102 bool module_scoped_analysis) 103 : schedule_(schedule), 104 alias_analysis_(alias_analysis), 105 module_scoped_analysis_(module_scoped_analysis) {} 106 107 // FlattenSchedule walks through the instructions in `computation`, and 108 // recurse into each called computations in module_scoped_analysis mode. As it 109 // walks it also tracks down the ordinal number of each instruction in the 110 // schedule and store it in the `instruction_schedule` and 111 // 'flattened_instruction_sequence`. The end of each computation is tracked in 112 // `computation_end_time`. 113 int64 FlattenSchedule(const HloComputation& computation, int64 start_time); 114 115 // Based on the flattened schedule, calculate the start and end of each 116 // buffer. 117 void CalculateBufferStartEndMap(); 118 119 // The aliased buffers could have overlapping live ranges. 120 // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer 121 // has disjoint live range while keeping the live range union the same. This 122 // avoid double counting aliased buffer sizes. 123 // 124 // Before(buffer1 and 2 are aliased): 125 // 126 // +----+ live range of buffer1 127 // +------------------+ live range of buffer2 128 // 129 // After: 130 // 131 // +----------+ live range of buffer1 132 // +------+ live range of buffer2 133 // 134 // Before(buffer1 and 2 are aliased): 135 // 136 // +----------+ live range of buffer1 137 // +------------+ live range of buffer2 138 // 139 // After: 140 // 141 // +----------+ live range of buffer1 142 // +------+ live range of buffer2 143 // 144 // Before(buffer1 and 2 are aliased): 145 // 146 // +----------+ live range of buffer1 147 // +---+ live range of buffer2 148 // 149 // After(unchanged): 150 // 151 // +----------+ live range of buffer1 152 // +---+ live range of buffer2 153 // 154 // As another example, imagine we have the following code sequence with live 155 // ranges of each while-aliased buffers: 156 // 157 // a p1 p2 e b 158 // a = ... + 159 // | 160 // { | 161 // p1 = param | + 162 // ROOT true | | 163 // } | + 164 // { // body | 165 // p2 = param + + 166 // c = p2 + 1 + 167 // d = c + 1 168 // ROOT e = d + 1 + 169 // } | 170 // | 171 // b = while (a) + + 172 // | 173 // f = b + 1 + 174 // 175 // After normalization it becomes: 176 // 177 // a p1 p2 e b 178 // a = ... + 179 // | 180 // { + 181 // p1 = param + 182 // ROOT true | 183 // } + 184 // { // body 185 // p2 = param + 186 // c = p2 + 1 + 187 // d = c + 1 188 // ROOT e = d + 1 + 189 // } | 190 // | 191 // b = while (a) + 192 // + 193 // f = b + 1 + 194 // 195 // Note there is no overlap of live ranges after normalization. 196 void NormalizeAliasedBuffers(); 197 198 const HloSchedule& schedule_; 199 const HloAliasAnalysis& alias_analysis_; 200 bool module_scoped_analysis_; 201 bool total_order_scheduled_ = true; 202 203 HloInstructionSequence flattened_instruction_sequence_; 204 absl::flat_hash_map<const HloInstruction*, int64> instruction_schedule_; 205 absl::flat_hash_map<const HloComputation*, TimeBound> computation_span_times_; 206 absl::flat_hash_map<const HloValue*, TimeBound> buffer_live_ranges_; 207 LogicalTime schedule_end_time_; 208 }; 209 210 } // namespace xla 211 212 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ 213