• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
17 
18 #include <algorithm>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
27 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_value.h"
30 
31 namespace xla {
32 /*static*/
Run(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const HloComputation * computation,bool module_scoped_analysis)33 StatusOr<std::unique_ptr<HloLiveRange>> HloLiveRange::Run(
34     const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
35     const HloComputation* computation, bool module_scoped_analysis) {
36   std::unique_ptr<HloLiveRange> hlo_live_range(
37       new HloLiveRange(schedule, alias_analysis, module_scoped_analysis));
38   hlo_live_range->FlattenSchedule(*computation);
39   hlo_live_range->CalculateBufferStartEndMap();
40   hlo_live_range->NormalizeAliasedBuffers();
41   return std::move(hlo_live_range);
42 }
43 
NormalizeAliasedBuffers()44 void HloLiveRange::NormalizeAliasedBuffers() {
45   absl::flat_hash_map<HloBuffer::Id, std::vector<TimeBound*>>
46       live_ranges_by_buffer;
47   for (auto& entry : buffer_live_ranges_) {
48     const HloValue& value = *entry.first;
49     const HloBuffer& buffer = alias_analysis_.GetBufferContainingValue(value);
50     live_ranges_by_buffer[buffer.id()].push_back(&entry.second);
51   }
52 
53   for (auto& entry : live_ranges_by_buffer) {
54     std::vector<TimeBound*>& aliased_live_ranges = entry.second;
55     absl::c_sort(aliased_live_ranges,
56                  [](const TimeBound* a, const TimeBound* b) {
57                    return std::forward_as_tuple(a->start, a->end) <
58                           std::forward_as_tuple(b->start, b->end);
59                  });
60 
61     for (int64_t i = 0; i + 1 < aliased_live_ranges.size(); ++i) {
62       TimeBound& live_range1 = *aliased_live_ranges[i];
63       TimeBound& live_range2 = *aliased_live_ranges[i + 1];
64       live_range2.end = std::max(live_range1.end, live_range2.end);
65       live_range1.end = std::min(live_range1.end, live_range2.start);
66     }
67   }
68 }
69 
70 // FlattenSchedule walks through the computation and tracks down the ordinal
71 // number of each instruction in the schedule.
FlattenSchedule(const HloComputation & computation,const HloComputation * async_context)72 void HloLiveRange::FlattenSchedule(const HloComputation& computation,
73                                    const HloComputation* async_context) {
74   auto it = schedule_.sequences().find(computation.unique_id());
75   if (it == schedule_.sequences().end()) {
76     total_order_scheduled_ = false;
77     return;
78   }
79 
80   // Check if we've already processed this computation.
81   if (computation_span_times_.contains(&computation)) return;
82 
83   // Mark this computation into the async context, if available.
84   if (async_context != nullptr) {
85     computations_in_async_context_[&computation] = async_context;
86   }
87 
88   LogicalTime start_time = flattened_instruction_sequence_.size();
89 
90   const HloInstructionSequence& instruction_sequence = it->second;
91   for (HloInstruction* instruction : instruction_sequence.instructions()) {
92     if (module_scoped_analysis_) {
93       // Recurse into sub computations if running with module scoped analysis
94       // mode.
95       if (instruction->opcode() == HloOpcode::kCall ||
96           instruction->opcode() == HloOpcode::kConditional ||
97           instruction->opcode() == HloOpcode::kAsyncStart) {
98         for (const HloComputation* called_computation :
99              instruction->called_computations()) {
100           // AsyncStart starts an async context. Other ops that call
101           // computations just propagate the existing one, if any.
102           FlattenSchedule(*called_computation,
103                           instruction->opcode() == HloOpcode::kAsyncStart
104                               ? called_computation
105                               : async_context);
106         }
107       } else if (instruction->opcode() == HloOpcode::kWhile) {
108         FlattenSchedule(*instruction->while_condition(), async_context);
109         FlattenSchedule(*instruction->while_body(), async_context);
110       }
111     }
112 
113     LogicalTime time = flattened_instruction_sequence_.size();
114     CHECK(instruction_schedule_.insert({instruction, time}).second);
115     flattened_instruction_sequence_.push_back(instruction);
116   }
117 
118   LogicalTime end_time = flattened_instruction_sequence_.size();
119   computation_span_times_[&computation] = {start_time, end_time};
120 }
121 
GetLastPosition(const HloValue & value,HloLiveRange::LogicalTime definition_end_time) const122 HloLiveRange::TimeBound HloLiveRange::GetLastPosition(
123     const HloValue& value,
124     HloLiveRange::LogicalTime definition_end_time) const {
125   LogicalTime end_time = definition_end_time;
126   const HloPosition* end_position = &value.defining_position();
127   // Loop over the non-defining positions to find the final one.
128   for (const HloPosition& position :
129        absl::Span<const HloPosition>(value.positions()).subspan(1)) {
130     const HloInstruction* position_inst = position.instruction;
131     LogicalTime position_time;
132     if (position_inst->IsRoot()) {  // See comment above.
133       auto it = computation_span_times_.find(position_inst->parent());
134       if (it == computation_span_times_.end()) continue;
135       position_time = it->second.end;
136     } else {
137       auto it = instruction_schedule_.find(position_inst);
138       if (it == instruction_schedule_.end()) continue;
139       position_time = it->second;
140     }
141 
142     if (position_time > end_time) {
143       end_time = position_time;
144       end_position = &position;
145     }
146   }
147   return {-1, end_time, *end_position};
148 }
149 
GetLastUsageTime(const HloValue & value) const150 HloLiveRange::LogicalTime HloLiveRange::GetLastUsageTime(
151     const HloValue& value) const {
152   LogicalTime end_time = -1;
153   for (const HloUse& use : value.GetUses()) {
154     const HloInstruction* used = use.instruction;
155     // As an optimization, we deem a while's init value's live range ends as
156     // soon as the loop body starts. This optimization is only applicable in
157     // module scoped mode.
158     if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) {
159       // The current live range is at the end of the while, move it to
160       // the beginning of the body.
161       used = used->while_body()->parameter_instruction(0);
162       VLOG(1) << "Moved value " << value.ToShortString()
163               << " to while param: " << used->ToString();
164     }
165 
166     // It's possible that we didn't track the instruction `used`. This
167     // happens when we do computation scope (versus module scope) heap
168     // simulation and the used instruction is outside of the computation
169     // being simulated.
170     auto it = instruction_schedule_.find(used);
171     if (it != instruction_schedule_.end()) {
172       end_time = std::max(end_time, it->second);
173     }
174   }
175   return end_time;
176 }
177 
CalculateBufferStartEndMap()178 void HloLiveRange::CalculateBufferStartEndMap() {
179   for (const auto& entry : instruction_schedule_) {
180     const HloInstruction& instruction = *entry.first;
181     const HloComputation* computation = instruction.parent();
182 
183     // Parameters are defined at the beginning of the computation. This prevents
184     // any instruction that's scheduled before the parameter clobbers the
185     // parameter's buffer.
186     LogicalTime start_time = (instruction.opcode() == HloOpcode::kParameter)
187                                  ? computation_span_times_[computation].start
188                                  : entry.second;
189 
190     // If an instruction lives out, the live range of the instruction should be
191     // extended to the end of the computation.
192     LogicalTime definition_end_time =
193         instruction.IsRoot() ? computation_span_times_[computation].end
194                              : entry.second;
195 
196     // If the instruction is in an asynchronous context, extend the live range
197     // until the end of the async-done instruction.
198     auto async_context_it = computations_in_async_context_.find(computation);
199     if (async_context_it != computations_in_async_context_.end()) {
200       const HloComputation* async_context = async_context_it->second;
201       CHECK(async_context->IsAsyncComputation());
202       auto async_done_it = absl::c_find_if(
203           async_context->AsyncInstructions(),
204           [](const HloInstruction* instruction) {
205             return instruction->opcode() == HloOpcode::kAsyncDone;
206           });
207       CHECK(async_done_it != async_context->AsyncInstructions().end());
208       definition_end_time =
209           std::max(definition_end_time, instruction_schedule_[*async_done_it]);
210       VLOG(2) << "Setting the definition end time for op in async context: "
211               << definition_end_time;
212     }
213 
214     const InstructionValueSet& value_set_tree =
215         alias_analysis_.dataflow_analysis().GetInstructionValueSet(
216             &instruction);
217 
218     for (const auto& entry : value_set_tree) {
219       for (const HloValue* value : entry.second.values()) {
220         // The start time is only correct for the defining instruction.
221         if (value->defining_instruction() != &instruction) continue;
222 
223         TimeBound live_range = GetLastPosition(*value, definition_end_time);
224         live_range.start = start_time;
225 
226         // Readonly entry parameters (parameters that don't alias) live across
227         // whole computation.
228         const HloModule& module = *computation->parent();
229         if (instruction.opcode() == HloOpcode::kParameter &&
230             computation == module.entry_computation() &&
231             !module.input_output_alias_config().ParameterHasAlias(
232                 instruction.parameter_number(), value->index())) {
233           live_range.end = schedule_end_time();
234         } else {
235           live_range.end = std::max(live_range.end, GetLastUsageTime(*value));
236         }
237 
238         CHECK_LE(live_range.start, live_range.end) << instruction.ToString();
239         CHECK(buffer_live_ranges_.insert({value, live_range}).second);
240       }
241     }
242   }
243 }
244 
ComputePeakMemoryMoment() const245 int64_t HloLiveRange::ComputePeakMemoryMoment() const {
246   std::vector<std::tuple<int64_t /*time*/, bool /*is_end*/, const HloValue*>>
247       events;
248   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
249     auto it = buffer_live_ranges_.find(value);
250     if (it != buffer_live_ranges_.end()) {
251       events.emplace_back(it->second.start, false, value);
252       events.emplace_back(it->second.end + 1, true, value);
253     }
254   }
255   std::sort(events.begin(), events.end());
256 
257   int64_t memory_usage = 0;
258   int64_t peak_usage = 0;
259   std::optional<int64_t> peak_time;
260   for (const auto& event : events) {
261     int64_t time;
262     bool is_end;
263     const HloValue* value;
264     std::tie(time, is_end, value) = event;
265     auto buffer_size = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
266     if (is_end) {
267       memory_usage -= buffer_size;
268     } else {
269       memory_usage += buffer_size;
270     }
271     if (peak_usage < memory_usage) {
272       peak_usage = memory_usage;
273       peak_time = time;
274     }
275   }
276   return peak_time.value_or(0);
277 }
278 
ToString() const279 std::string HloLiveRange::ToString() const {
280   std::string output;
281   absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n",
282                         schedule_end_time());
283   absl::StrAppendFormat(&output, "  InstructionSequence:\n");
284   auto& instructions = flattened_instruction_sequence().instructions();
285   for (int64_t i = 0; i < instructions.size(); ++i) {
286     absl::StrAppendFormat(&output, "    %d:%s\n", i, instructions[i]->name());
287   }
288 
289   absl::StrAppendFormat(&output, "  BufferLiveRange:\n");
290 
291   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
292     auto it = buffer_live_ranges_.find(value);
293     if (it != buffer_live_ranges_.end()) {
294       absl::StrAppendFormat(
295           &output, "    %s%s:%d-%d\n", value->instruction()->name(),
296           value->index().ToString(), it->second.start, it->second.end);
297     }
298   }
299 
300   int64_t peak_moment = ComputePeakMemoryMoment();
301 
302   absl::StrAppendFormat(&output, "  Live ranges at %lld (peak):\n",
303                         peak_moment);
304   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
305     auto it = buffer_live_ranges_.find(value);
306     if (it != buffer_live_ranges_.end()) {
307       if (it->second.start <= peak_moment && peak_moment <= it->second.end) {
308         int64_t bytes = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
309         absl::StrAppendFormat(&output, "    %s: %lld bytes\n",
310                               value->instruction()->name(), bytes);
311       }
312     }
313   }
314 
315   return output;
316 }
317 
318 }  // namespace xla
319