• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
17 
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
20 #include "tensorflow/compiler/xla/service/hlo_value.h"
21 
22 namespace xla {
23 /*static*/
Run(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const HloComputation * computation,bool module_scoped_analysis)24 StatusOr<std::unique_ptr<HloLiveRange>> HloLiveRange::Run(
25     const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
26     const HloComputation* computation, bool module_scoped_analysis) {
27   std::unique_ptr<HloLiveRange> hlo_live_range(
28       new HloLiveRange(schedule, alias_analysis, module_scoped_analysis));
29   hlo_live_range->schedule_end_time_ =
30       hlo_live_range->FlattenSchedule(*computation, 0);
31   hlo_live_range->CalculateBufferStartEndMap();
32   hlo_live_range->NormalizeAliasedBuffers();
33   return std::move(hlo_live_range);
34 }
35 
NormalizeAliasedBuffers()36 void HloLiveRange::NormalizeAliasedBuffers() {
37   for (const HloBuffer& hlo_buffer : alias_analysis_.buffers()) {
38     std::vector<const HloValue*> aliased_buffers;
39     for (const HloValue* hlo_value : hlo_buffer.values()) {
40       if (buffer_live_ranges_.contains(hlo_value)) {
41         aliased_buffers.push_back(hlo_value);
42       }
43     }
44     absl::c_sort(
45         aliased_buffers, [&](const HloValue* value1, const HloValue* value2) {
46           const TimeBound& live_range1 = buffer_live_ranges_.at(value1);
47           const TimeBound& live_range2 = buffer_live_ranges_.at(value2);
48 
49           return std::forward_as_tuple(live_range1.start, live_range1.end) <
50                  std::forward_as_tuple(live_range2.start, live_range2.end);
51         });
52 
53     for (int64_t i = 0; i + 1 < aliased_buffers.size(); ++i) {
54       const HloValue* value1 = aliased_buffers[i];
55       const HloValue* value2 = aliased_buffers[i + 1];
56       TimeBound& live_range1 = buffer_live_ranges_[value1];
57       TimeBound& live_range2 = buffer_live_ranges_[value2];
58       if (live_range1.start == live_range2.start) {
59         // If value1 has the same start time as value2, make value1 disappear
60         // by setting the end time same as start time:
61         //
62         // Before:
63         // +----+           value1
64         // +----------+     value2
65         //
66         // After:
67         // +                value1
68         // +----------+     value2
69         //
70         // Note that only when heap simulator runs before copy insertion can
71         // this happen where one instruction defines multiple aliased buffers
72         // -- This is illegle to execute and can be fixed by copy insertion
73         // later.
74         live_range1.end = live_range2.end;
75         continue;
76       }
77 
78       if (live_range1.end < live_range2.start) {
79         continue;
80       }
81 
82       if (live_range1.end > live_range2.end) {
83         live_range2.end = live_range1.end;
84       }
85       live_range1.end = live_range2.start - 1;
86     }
87   }
88 }
89 
90 // FlattenSchedule walks through the computation and tracks down the ordinal
91 // number of each instruction in the schedule.
FlattenSchedule(const HloComputation & computation,int64_t start_time)92 int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
93                                     int64_t start_time) {
94   if (!schedule_.is_computation_scheduled(&computation)) {
95     total_order_scheduled_ = false;
96     return start_time;
97   }
98 
99   const HloInstructionSequence& instruction_sequence =
100       schedule_.sequence(&computation);
101   int64_t time = start_time;
102   for (HloInstruction* instruction : instruction_sequence.instructions()) {
103     if (module_scoped_analysis_) {
104       // Recurse into sub computations if running with module scoped analysis
105       // mode.
106       if (instruction->opcode() == HloOpcode::kCall ||
107           instruction->opcode() == HloOpcode::kConditional) {
108         for (const HloComputation* called_computation :
109              instruction->called_computations()) {
110           time = FlattenSchedule(*called_computation, time);
111         }
112       }
113       if (instruction->opcode() == HloOpcode::kWhile) {
114         time = FlattenSchedule(*instruction->while_condition(), time);
115         time = FlattenSchedule(*instruction->while_body(), time);
116       }
117     }
118     if (instruction_schedule_.count(instruction) != 0) {
119       continue;
120     }
121     instruction_schedule_.insert({instruction, time++});
122     flattened_instruction_sequence_.push_back(instruction);
123   }
124   computation_span_times_.try_emplace(&computation,
125                                       TimeBound{start_time, time});
126   DCHECK_EQ(instruction_schedule_.size(),
127             flattened_instruction_sequence_.size());
128   DCHECK_LE(instruction_schedule_.size(), time);
129   return time;
130 }
131 
CalculateBufferStartEndMap()132 void HloLiveRange::CalculateBufferStartEndMap() {
133   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
134     // Ignore buffers that are not defined.
135     if (instruction_schedule_.count(value->defining_instruction()) == 0) {
136       continue;
137     }
138 
139     int64_t buffer_start_time = instruction_schedule_[value->instruction()];
140 
141     int64_t buffer_end_time = -1;
142     for (const HloUse& use : value->uses()) {
143       const HloInstruction* used = use.instruction;
144       // As an optimization, we deem a while's init value's live range ends as
145       // soon as the loop body starts. This optimization is only applicable in
146       // module scoped mode.
147       if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) {
148         // The current live range is at the end of the while, move it to the
149         // beginning of the body.
150         used = used->while_body()->parameter_instruction(0);
151         VLOG(1) << "Moved value " << value->ToShortString()
152                 << " to while param: " << used->ToString();
153       }
154       if (instruction_schedule_.count(used) == 0) {
155         // We didn't track the instruction `used`. This happens when we do
156         // computation scope (versus module scope) heap simulation and when
157         // the used instruction is outside of the computation being simulated.
158         continue;
159       }
160       buffer_end_time = std::max(buffer_end_time, instruction_schedule_[used]);
161     }
162 
163     // Parameters are defined at the beginning of the computation. This prevents
164     // any instruction that's scheduled before the parameter clobbers the
165     // parameter's buffer.
166     if (value->instruction()->opcode() == HloOpcode::kParameter) {
167       const HloComputation* computation = value->instruction()->parent();
168       auto it = computation_span_times_.find(computation);
169       if (it != computation_span_times_.end()) {
170         buffer_start_time = std::min(buffer_start_time, it->second.start);
171       }
172     }
173 
174     if (buffer_end_time == -1) {
175       buffer_end_time = buffer_start_time;
176     }
177 
178     HloPosition end_position;
179     int64_t max_end_time = 0;
180     for (const HloPosition& position : value->positions()) {
181       if (instruction_schedule_[position.instruction] >= max_end_time) {
182         max_end_time = instruction_schedule_[value->instruction()];
183         end_position = position;
184       }
185       const HloComputation* position_comp = position.instruction->parent();
186       // If this instruction lives out, the live range of the instruction
187       // should be extended to the end of the computation.
188       if (position.instruction == position_comp->root_instruction()) {
189         auto it = computation_span_times_.find(position_comp);
190         if (it == computation_span_times_.end()) {
191           continue;
192         }
193         if (buffer_end_time < it->second.end) {
194           buffer_end_time = it->second.end;
195           end_position = position;
196         }
197       }
198     }
199 
200     const HloModule* module = value->instruction()->parent()->parent();
201 
202     // Readonly entry parameters (parameters that don't alias) live across whole
203     // computation.
204     if (value->instruction()->opcode() == HloOpcode::kParameter &&
205         value->instruction()->parent() == module->entry_computation() &&
206         !module->input_output_alias_config().ParameterHasAlias(
207             value->instruction()->parameter_number(), value->index())) {
208       buffer_end_time = schedule_end_time_;
209     }
210 
211     CHECK(buffer_start_time <= buffer_end_time)
212         << buffer_start_time << ", " << buffer_end_time
213         << value->instruction()->ToString();
214 
215     auto& live_range = buffer_live_ranges_[value];
216     live_range.start = buffer_start_time;
217     live_range.end = buffer_end_time;
218     live_range.end_position = end_position;
219   }
220 }
221 
ComputePeakMemoryMoment() const222 int64 HloLiveRange::ComputePeakMemoryMoment() const {
223   std::vector<std::tuple<int64 /*time*/, bool /*is_end*/, const HloValue*>>
224       events;
225   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
226     auto it = buffer_live_ranges_.find(value);
227     if (it != buffer_live_ranges_.end()) {
228       events.emplace_back(it->second.start, false, value);
229       events.emplace_back(it->second.end + 1, true, value);
230     }
231   }
232   std::sort(events.begin(), events.end());
233 
234   int64_t memory_usage = 0;
235   int64_t peak_usage = 0;
236   absl::optional<int64> peak_time;
237   for (const auto& event : events) {
238     int64_t time;
239     bool is_end;
240     const HloValue* value;
241     std::tie(time, is_end, value) = event;
242     auto buffer_size = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
243     if (is_end) {
244       memory_usage -= buffer_size;
245     } else {
246       memory_usage += buffer_size;
247     }
248     if (peak_usage < memory_usage) {
249       peak_usage = memory_usage;
250       peak_time = time;
251     }
252   }
253   return peak_time.value_or(0);
254 }
255 
ToString() const256 std::string HloLiveRange::ToString() const {
257   std::string output;
258   absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n",
259                         schedule_end_time_);
260   absl::StrAppendFormat(&output, "  InstructionSequence:\n");
261   auto& instructions = flattened_instruction_sequence().instructions();
262   for (int64_t i = 0; i < instructions.size(); ++i) {
263     absl::StrAppendFormat(&output, "    %d:%s\n", i, instructions[i]->name());
264   }
265 
266   absl::StrAppendFormat(&output, "  BufferLiveRange:\n");
267 
268   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
269     auto it = buffer_live_ranges_.find(value);
270     if (it != buffer_live_ranges_.end()) {
271       absl::StrAppendFormat(
272           &output, "    %s%s:%d-%d\n", value->instruction()->name(),
273           value->index().ToString(), it->second.start, it->second.end);
274     }
275   }
276 
277   int64_t peak_moment = ComputePeakMemoryMoment();
278 
279   absl::StrAppendFormat(&output, "  Live ranges at %lld (peak):\n",
280                         peak_moment);
281   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
282     auto it = buffer_live_ranges_.find(value);
283     if (it != buffer_live_ranges_.end()) {
284       if (it->second.start <= peak_moment && peak_moment <= it->second.end) {
285         int64_t bytes = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
286         absl::StrAppendFormat(&output, "    %s: %lld bytes\n",
287                               value->instruction()->name(), bytes);
288       }
289     }
290   }
291 
292   return output;
293 }
294 
295 }  // namespace xla
296