1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
17
18 #include <algorithm>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
27 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_value.h"
30
31 namespace xla {
32 /*static*/
Run(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const HloComputation * computation,bool module_scoped_analysis)33 StatusOr<std::unique_ptr<HloLiveRange>> HloLiveRange::Run(
34 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
35 const HloComputation* computation, bool module_scoped_analysis) {
36 std::unique_ptr<HloLiveRange> hlo_live_range(
37 new HloLiveRange(schedule, alias_analysis, module_scoped_analysis));
38 hlo_live_range->FlattenSchedule(*computation);
39 hlo_live_range->CalculateBufferStartEndMap();
40 hlo_live_range->NormalizeAliasedBuffers();
41 return std::move(hlo_live_range);
42 }
43
NormalizeAliasedBuffers()44 void HloLiveRange::NormalizeAliasedBuffers() {
45 absl::flat_hash_map<HloBuffer::Id, std::vector<TimeBound*>>
46 live_ranges_by_buffer;
47 for (auto& entry : buffer_live_ranges_) {
48 const HloValue& value = *entry.first;
49 const HloBuffer& buffer = alias_analysis_.GetBufferContainingValue(value);
50 live_ranges_by_buffer[buffer.id()].push_back(&entry.second);
51 }
52
53 for (auto& entry : live_ranges_by_buffer) {
54 std::vector<TimeBound*>& aliased_live_ranges = entry.second;
55 absl::c_sort(aliased_live_ranges,
56 [](const TimeBound* a, const TimeBound* b) {
57 return std::forward_as_tuple(a->start, a->end) <
58 std::forward_as_tuple(b->start, b->end);
59 });
60
61 for (int64_t i = 0; i + 1 < aliased_live_ranges.size(); ++i) {
62 TimeBound& live_range1 = *aliased_live_ranges[i];
63 TimeBound& live_range2 = *aliased_live_ranges[i + 1];
64 live_range2.end = std::max(live_range1.end, live_range2.end);
65 live_range1.end = std::min(live_range1.end, live_range2.start);
66 }
67 }
68 }
69
70 // FlattenSchedule walks through the computation and tracks down the ordinal
71 // number of each instruction in the schedule.
FlattenSchedule(const HloComputation & computation,const HloComputation * async_context)72 void HloLiveRange::FlattenSchedule(const HloComputation& computation,
73 const HloComputation* async_context) {
74 auto it = schedule_.sequences().find(computation.unique_id());
75 if (it == schedule_.sequences().end()) {
76 total_order_scheduled_ = false;
77 return;
78 }
79
80 // Check if we've already processed this computation.
81 if (computation_span_times_.contains(&computation)) return;
82
83 // Mark this computation into the async context, if available.
84 if (async_context != nullptr) {
85 computations_in_async_context_[&computation] = async_context;
86 }
87
88 LogicalTime start_time = flattened_instruction_sequence_.size();
89
90 const HloInstructionSequence& instruction_sequence = it->second;
91 for (HloInstruction* instruction : instruction_sequence.instructions()) {
92 if (module_scoped_analysis_) {
93 // Recurse into sub computations if running with module scoped analysis
94 // mode.
95 if (instruction->opcode() == HloOpcode::kCall ||
96 instruction->opcode() == HloOpcode::kConditional ||
97 instruction->opcode() == HloOpcode::kAsyncStart) {
98 for (const HloComputation* called_computation :
99 instruction->called_computations()) {
100 // AsyncStart starts an async context. Other ops that call
101 // computations just propagate the existing one, if any.
102 FlattenSchedule(*called_computation,
103 instruction->opcode() == HloOpcode::kAsyncStart
104 ? called_computation
105 : async_context);
106 }
107 } else if (instruction->opcode() == HloOpcode::kWhile) {
108 FlattenSchedule(*instruction->while_condition(), async_context);
109 FlattenSchedule(*instruction->while_body(), async_context);
110 }
111 }
112
113 LogicalTime time = flattened_instruction_sequence_.size();
114 CHECK(instruction_schedule_.insert({instruction, time}).second);
115 flattened_instruction_sequence_.push_back(instruction);
116 }
117
118 LogicalTime end_time = flattened_instruction_sequence_.size();
119 computation_span_times_[&computation] = {start_time, end_time};
120 }
121
GetLastPosition(const HloValue & value,HloLiveRange::LogicalTime definition_end_time) const122 HloLiveRange::TimeBound HloLiveRange::GetLastPosition(
123 const HloValue& value,
124 HloLiveRange::LogicalTime definition_end_time) const {
125 LogicalTime end_time = definition_end_time;
126 const HloPosition* end_position = &value.defining_position();
127 // Loop over the non-defining positions to find the final one.
128 for (const HloPosition& position :
129 absl::Span<const HloPosition>(value.positions()).subspan(1)) {
130 const HloInstruction* position_inst = position.instruction;
131 LogicalTime position_time;
132 if (position_inst->IsRoot()) { // See comment above.
133 auto it = computation_span_times_.find(position_inst->parent());
134 if (it == computation_span_times_.end()) continue;
135 position_time = it->second.end;
136 } else {
137 auto it = instruction_schedule_.find(position_inst);
138 if (it == instruction_schedule_.end()) continue;
139 position_time = it->second;
140 }
141
142 if (position_time > end_time) {
143 end_time = position_time;
144 end_position = &position;
145 }
146 }
147 return {-1, end_time, *end_position};
148 }
149
GetLastUsageTime(const HloValue & value) const150 HloLiveRange::LogicalTime HloLiveRange::GetLastUsageTime(
151 const HloValue& value) const {
152 LogicalTime end_time = -1;
153 for (const HloUse& use : value.GetUses()) {
154 const HloInstruction* used = use.instruction;
155 // As an optimization, we deem a while's init value's live range ends as
156 // soon as the loop body starts. This optimization is only applicable in
157 // module scoped mode.
158 if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) {
159 // The current live range is at the end of the while, move it to
160 // the beginning of the body.
161 used = used->while_body()->parameter_instruction(0);
162 VLOG(1) << "Moved value " << value.ToShortString()
163 << " to while param: " << used->ToString();
164 }
165
166 // It's possible that we didn't track the instruction `used`. This
167 // happens when we do computation scope (versus module scope) heap
168 // simulation and the used instruction is outside of the computation
169 // being simulated.
170 auto it = instruction_schedule_.find(used);
171 if (it != instruction_schedule_.end()) {
172 end_time = std::max(end_time, it->second);
173 }
174 }
175 return end_time;
176 }
177
CalculateBufferStartEndMap()178 void HloLiveRange::CalculateBufferStartEndMap() {
179 for (const auto& entry : instruction_schedule_) {
180 const HloInstruction& instruction = *entry.first;
181 const HloComputation* computation = instruction.parent();
182
183 // Parameters are defined at the beginning of the computation. This prevents
184 // any instruction that's scheduled before the parameter clobbers the
185 // parameter's buffer.
186 LogicalTime start_time = (instruction.opcode() == HloOpcode::kParameter)
187 ? computation_span_times_[computation].start
188 : entry.second;
189
190 // If an instruction lives out, the live range of the instruction should be
191 // extended to the end of the computation.
192 LogicalTime definition_end_time =
193 instruction.IsRoot() ? computation_span_times_[computation].end
194 : entry.second;
195
196 // If the instruction is in an asynchronous context, extend the live range
197 // until the end of the async-done instruction.
198 auto async_context_it = computations_in_async_context_.find(computation);
199 if (async_context_it != computations_in_async_context_.end()) {
200 const HloComputation* async_context = async_context_it->second;
201 CHECK(async_context->IsAsyncComputation());
202 auto async_done_it = absl::c_find_if(
203 async_context->AsyncInstructions(),
204 [](const HloInstruction* instruction) {
205 return instruction->opcode() == HloOpcode::kAsyncDone;
206 });
207 CHECK(async_done_it != async_context->AsyncInstructions().end());
208 definition_end_time =
209 std::max(definition_end_time, instruction_schedule_[*async_done_it]);
210 VLOG(2) << "Setting the definition end time for op in async context: "
211 << definition_end_time;
212 }
213
214 const InstructionValueSet& value_set_tree =
215 alias_analysis_.dataflow_analysis().GetInstructionValueSet(
216 &instruction);
217
218 for (const auto& entry : value_set_tree) {
219 for (const HloValue* value : entry.second.values()) {
220 // The start time is only correct for the defining instruction.
221 if (value->defining_instruction() != &instruction) continue;
222
223 TimeBound live_range = GetLastPosition(*value, definition_end_time);
224 live_range.start = start_time;
225
226 // Readonly entry parameters (parameters that don't alias) live across
227 // whole computation.
228 const HloModule& module = *computation->parent();
229 if (instruction.opcode() == HloOpcode::kParameter &&
230 computation == module.entry_computation() &&
231 !module.input_output_alias_config().ParameterHasAlias(
232 instruction.parameter_number(), value->index())) {
233 live_range.end = schedule_end_time();
234 } else {
235 live_range.end = std::max(live_range.end, GetLastUsageTime(*value));
236 }
237
238 CHECK_LE(live_range.start, live_range.end) << instruction.ToString();
239 CHECK(buffer_live_ranges_.insert({value, live_range}).second);
240 }
241 }
242 }
243 }
244
ComputePeakMemoryMoment() const245 int64_t HloLiveRange::ComputePeakMemoryMoment() const {
246 std::vector<std::tuple<int64_t /*time*/, bool /*is_end*/, const HloValue*>>
247 events;
248 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
249 auto it = buffer_live_ranges_.find(value);
250 if (it != buffer_live_ranges_.end()) {
251 events.emplace_back(it->second.start, false, value);
252 events.emplace_back(it->second.end + 1, true, value);
253 }
254 }
255 std::sort(events.begin(), events.end());
256
257 int64_t memory_usage = 0;
258 int64_t peak_usage = 0;
259 std::optional<int64_t> peak_time;
260 for (const auto& event : events) {
261 int64_t time;
262 bool is_end;
263 const HloValue* value;
264 std::tie(time, is_end, value) = event;
265 auto buffer_size = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
266 if (is_end) {
267 memory_usage -= buffer_size;
268 } else {
269 memory_usage += buffer_size;
270 }
271 if (peak_usage < memory_usage) {
272 peak_usage = memory_usage;
273 peak_time = time;
274 }
275 }
276 return peak_time.value_or(0);
277 }
278
ToString() const279 std::string HloLiveRange::ToString() const {
280 std::string output;
281 absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n",
282 schedule_end_time());
283 absl::StrAppendFormat(&output, " InstructionSequence:\n");
284 auto& instructions = flattened_instruction_sequence().instructions();
285 for (int64_t i = 0; i < instructions.size(); ++i) {
286 absl::StrAppendFormat(&output, " %d:%s\n", i, instructions[i]->name());
287 }
288
289 absl::StrAppendFormat(&output, " BufferLiveRange:\n");
290
291 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
292 auto it = buffer_live_ranges_.find(value);
293 if (it != buffer_live_ranges_.end()) {
294 absl::StrAppendFormat(
295 &output, " %s%s:%d-%d\n", value->instruction()->name(),
296 value->index().ToString(), it->second.start, it->second.end);
297 }
298 }
299
300 int64_t peak_moment = ComputePeakMemoryMoment();
301
302 absl::StrAppendFormat(&output, " Live ranges at %lld (peak):\n",
303 peak_moment);
304 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
305 auto it = buffer_live_ranges_.find(value);
306 if (it != buffer_live_ranges_.end()) {
307 if (it->second.start <= peak_moment && peak_moment <= it->second.end) {
308 int64_t bytes = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
309 absl::StrAppendFormat(&output, " %s: %lld bytes\n",
310 value->instruction()->name(), bytes);
311 }
312 }
313 }
314
315 return output;
316 }
317
318 } // namespace xla
319