1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
17
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
20 #include "tensorflow/compiler/xla/service/hlo_value.h"
21
22 namespace xla {
23 /*static*/
Run(const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const HloComputation * computation,bool module_scoped_analysis)24 StatusOr<std::unique_ptr<HloLiveRange>> HloLiveRange::Run(
25 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
26 const HloComputation* computation, bool module_scoped_analysis) {
27 std::unique_ptr<HloLiveRange> hlo_live_range(
28 new HloLiveRange(schedule, alias_analysis, module_scoped_analysis));
29 hlo_live_range->schedule_end_time_ =
30 hlo_live_range->FlattenSchedule(*computation, 0);
31 hlo_live_range->CalculateBufferStartEndMap();
32 hlo_live_range->NormalizeAliasedBuffers();
33 return std::move(hlo_live_range);
34 }
35
NormalizeAliasedBuffers()36 void HloLiveRange::NormalizeAliasedBuffers() {
37 for (const HloBuffer& hlo_buffer : alias_analysis_.buffers()) {
38 std::vector<const HloValue*> aliased_buffers;
39 for (const HloValue* hlo_value : hlo_buffer.values()) {
40 if (buffer_live_ranges_.contains(hlo_value)) {
41 aliased_buffers.push_back(hlo_value);
42 }
43 }
44 absl::c_sort(
45 aliased_buffers, [&](const HloValue* value1, const HloValue* value2) {
46 const TimeBound& live_range1 = buffer_live_ranges_.at(value1);
47 const TimeBound& live_range2 = buffer_live_ranges_.at(value2);
48
49 return std::forward_as_tuple(live_range1.start, live_range1.end) <
50 std::forward_as_tuple(live_range2.start, live_range2.end);
51 });
52
53 for (int64_t i = 0; i + 1 < aliased_buffers.size(); ++i) {
54 const HloValue* value1 = aliased_buffers[i];
55 const HloValue* value2 = aliased_buffers[i + 1];
56 TimeBound& live_range1 = buffer_live_ranges_[value1];
57 TimeBound& live_range2 = buffer_live_ranges_[value2];
58 if (live_range1.start == live_range2.start) {
59 // If value1 has the same start time as value2, make value1 disappear
60 // by setting the end time same as start time:
61 //
62 // Before:
63 // +----+ value1
64 // +----------+ value2
65 //
66 // After:
67 // + value1
68 // +----------+ value2
69 //
70 // Note that only when heap simulator runs before copy insertion can
71 // this happen where one instruction defines multiple aliased buffers
72 // -- This is illegle to execute and can be fixed by copy insertion
73 // later.
74 live_range1.end = live_range2.end;
75 continue;
76 }
77
78 if (live_range1.end < live_range2.start) {
79 continue;
80 }
81
82 if (live_range1.end > live_range2.end) {
83 live_range2.end = live_range1.end;
84 }
85 live_range1.end = live_range2.start - 1;
86 }
87 }
88 }
89
90 // FlattenSchedule walks through the computation and tracks down the ordinal
91 // number of each instruction in the schedule.
FlattenSchedule(const HloComputation & computation,int64_t start_time)92 int64 HloLiveRange::FlattenSchedule(const HloComputation& computation,
93 int64_t start_time) {
94 if (!schedule_.is_computation_scheduled(&computation)) {
95 total_order_scheduled_ = false;
96 return start_time;
97 }
98
99 const HloInstructionSequence& instruction_sequence =
100 schedule_.sequence(&computation);
101 int64_t time = start_time;
102 for (HloInstruction* instruction : instruction_sequence.instructions()) {
103 if (module_scoped_analysis_) {
104 // Recurse into sub computations if running with module scoped analysis
105 // mode.
106 if (instruction->opcode() == HloOpcode::kCall ||
107 instruction->opcode() == HloOpcode::kConditional) {
108 for (const HloComputation* called_computation :
109 instruction->called_computations()) {
110 time = FlattenSchedule(*called_computation, time);
111 }
112 }
113 if (instruction->opcode() == HloOpcode::kWhile) {
114 time = FlattenSchedule(*instruction->while_condition(), time);
115 time = FlattenSchedule(*instruction->while_body(), time);
116 }
117 }
118 if (instruction_schedule_.count(instruction) != 0) {
119 continue;
120 }
121 instruction_schedule_.insert({instruction, time++});
122 flattened_instruction_sequence_.push_back(instruction);
123 }
124 computation_span_times_.try_emplace(&computation,
125 TimeBound{start_time, time});
126 DCHECK_EQ(instruction_schedule_.size(),
127 flattened_instruction_sequence_.size());
128 DCHECK_LE(instruction_schedule_.size(), time);
129 return time;
130 }
131
CalculateBufferStartEndMap()132 void HloLiveRange::CalculateBufferStartEndMap() {
133 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
134 // Ignore buffers that are not defined.
135 if (instruction_schedule_.count(value->defining_instruction()) == 0) {
136 continue;
137 }
138
139 int64_t buffer_start_time = instruction_schedule_[value->instruction()];
140
141 int64_t buffer_end_time = -1;
142 for (const HloUse& use : value->uses()) {
143 const HloInstruction* used = use.instruction;
144 // As an optimization, we deem a while's init value's live range ends as
145 // soon as the loop body starts. This optimization is only applicable in
146 // module scoped mode.
147 if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) {
148 // The current live range is at the end of the while, move it to the
149 // beginning of the body.
150 used = used->while_body()->parameter_instruction(0);
151 VLOG(1) << "Moved value " << value->ToShortString()
152 << " to while param: " << used->ToString();
153 }
154 if (instruction_schedule_.count(used) == 0) {
155 // We didn't track the instruction `used`. This happens when we do
156 // computation scope (versus module scope) heap simulation and when
157 // the used instruction is outside of the computation being simulated.
158 continue;
159 }
160 buffer_end_time = std::max(buffer_end_time, instruction_schedule_[used]);
161 }
162
163 // Parameters are defined at the beginning of the computation. This prevents
164 // any instruction that's scheduled before the parameter clobbers the
165 // parameter's buffer.
166 if (value->instruction()->opcode() == HloOpcode::kParameter) {
167 const HloComputation* computation = value->instruction()->parent();
168 auto it = computation_span_times_.find(computation);
169 if (it != computation_span_times_.end()) {
170 buffer_start_time = std::min(buffer_start_time, it->second.start);
171 }
172 }
173
174 if (buffer_end_time == -1) {
175 buffer_end_time = buffer_start_time;
176 }
177
178 HloPosition end_position;
179 int64_t max_end_time = 0;
180 for (const HloPosition& position : value->positions()) {
181 if (instruction_schedule_[position.instruction] >= max_end_time) {
182 max_end_time = instruction_schedule_[value->instruction()];
183 end_position = position;
184 }
185 const HloComputation* position_comp = position.instruction->parent();
186 // If this instruction lives out, the live range of the instruction
187 // should be extended to the end of the computation.
188 if (position.instruction == position_comp->root_instruction()) {
189 auto it = computation_span_times_.find(position_comp);
190 if (it == computation_span_times_.end()) {
191 continue;
192 }
193 if (buffer_end_time < it->second.end) {
194 buffer_end_time = it->second.end;
195 end_position = position;
196 }
197 }
198 }
199
200 const HloModule* module = value->instruction()->parent()->parent();
201
202 // Readonly entry parameters (parameters that don't alias) live across whole
203 // computation.
204 if (value->instruction()->opcode() == HloOpcode::kParameter &&
205 value->instruction()->parent() == module->entry_computation() &&
206 !module->input_output_alias_config().ParameterHasAlias(
207 value->instruction()->parameter_number(), value->index())) {
208 buffer_end_time = schedule_end_time_;
209 }
210
211 CHECK(buffer_start_time <= buffer_end_time)
212 << buffer_start_time << ", " << buffer_end_time
213 << value->instruction()->ToString();
214
215 auto& live_range = buffer_live_ranges_[value];
216 live_range.start = buffer_start_time;
217 live_range.end = buffer_end_time;
218 live_range.end_position = end_position;
219 }
220 }
221
ComputePeakMemoryMoment() const222 int64 HloLiveRange::ComputePeakMemoryMoment() const {
223 std::vector<std::tuple<int64 /*time*/, bool /*is_end*/, const HloValue*>>
224 events;
225 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
226 auto it = buffer_live_ranges_.find(value);
227 if (it != buffer_live_ranges_.end()) {
228 events.emplace_back(it->second.start, false, value);
229 events.emplace_back(it->second.end + 1, true, value);
230 }
231 }
232 std::sort(events.begin(), events.end());
233
234 int64_t memory_usage = 0;
235 int64_t peak_usage = 0;
236 absl::optional<int64> peak_time;
237 for (const auto& event : events) {
238 int64_t time;
239 bool is_end;
240 const HloValue* value;
241 std::tie(time, is_end, value) = event;
242 auto buffer_size = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
243 if (is_end) {
244 memory_usage -= buffer_size;
245 } else {
246 memory_usage += buffer_size;
247 }
248 if (peak_usage < memory_usage) {
249 peak_usage = memory_usage;
250 peak_time = time;
251 }
252 }
253 return peak_time.value_or(0);
254 }
255
ToString() const256 std::string HloLiveRange::ToString() const {
257 std::string output;
258 absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n",
259 schedule_end_time_);
260 absl::StrAppendFormat(&output, " InstructionSequence:\n");
261 auto& instructions = flattened_instruction_sequence().instructions();
262 for (int64_t i = 0; i < instructions.size(); ++i) {
263 absl::StrAppendFormat(&output, " %d:%s\n", i, instructions[i]->name());
264 }
265
266 absl::StrAppendFormat(&output, " BufferLiveRange:\n");
267
268 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
269 auto it = buffer_live_ranges_.find(value);
270 if (it != buffer_live_ranges_.end()) {
271 absl::StrAppendFormat(
272 &output, " %s%s:%d-%d\n", value->instruction()->name(),
273 value->index().ToString(), it->second.start, it->second.end);
274 }
275 }
276
277 int64_t peak_moment = ComputePeakMemoryMoment();
278
279 absl::StrAppendFormat(&output, " Live ranges at %lld (peak):\n",
280 peak_moment);
281 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
282 auto it = buffer_live_ranges_.find(value);
283 if (it != buffer_live_ranges_.end()) {
284 if (it->second.start <= peak_moment && peak_moment <= it->second.end) {
285 int64_t bytes = ShapeUtil::ByteSizeOf(value->instruction()->shape(), 8);
286 absl::StrAppendFormat(&output, " %s: %lld bytes\n",
287 value->instruction()->name(), bytes);
288 }
289 }
290 }
291
292 return output;
293 }
294
295 } // namespace xla
296