1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <map>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/platform/logging.h"
40
41 namespace xla {
42 namespace {
43
44 using ::tensorflow::strings::HumanReadableNumBytes;
45
46 // Class implementing a list scheduler of HLO instructions which produces a
47 // sequence which minimizes memory usage by preferring to schedule the node that
48 // frees bigger buffer and defines smaller outputs.
49 //
50 // Note that list scheduler is a greedy algorithm which cannot guarantee a
51 // global optimal solution. As a counterexample, considering the following
52 // graph:
53 //
54 // +--> B ===> C -------+
55 // A -> | |
56 // | v
57 // +--> D ---> F=======>G
58 // | ^
59 // | |
60 // +--> E -----+
61 //
62 // --> : Buffer with size 1
63 // ==> : Buffer with size 2
64 //
65 // The list scheduler will always try to defer scheduling B in a greedy way
66 // since its output buffer is bigger than input. The sequence it creates will
67 // be:
68 // A D E F B C G
69 // , which has a maximum memory usage of 6 (B is alive while F is executing).
70 //
71 // An optimal way to schedule the previous graph is:
72 // A B C D E F G
73 // , which has a maximum memory usage of 5 (when F is executing).
74 //
75 class ListScheduler {
76 public:
77 // Construct and return a memory-minimizing sequence of HLO instructions
78 // containing the given HLO computation.
Run(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)79 static StatusOr<HloInstructionSequence> Run(
80 HloComputation* computation,
81 const TuplePointsToAnalysis& points_to_analysis,
82 const BufferValue::SizeFunction& size_function,
83 const absl::flat_hash_map<const HloComputation*, int64>&
84 memory_by_computation) {
85 ListScheduler scheduler(computation, points_to_analysis, size_function,
86 memory_by_computation);
87 return scheduler.CreateSchedule();
88 }
89
90 // Returns whether the memory used by the given HLO should be ignored by the
91 // scheduling heuristic.
IgnoreInstruction(const HloInstruction & instruction)92 static bool IgnoreInstruction(const HloInstruction& instruction) {
93 return instruction.opcode() == HloOpcode::kParameter ||
94 instruction.opcode() == HloOpcode::kConstant;
95 }
96
97 private:
98 // The scheduling priority of an instruction is first the number of bytes
99 // freed by scheduling the instruction, and second (tie-breaker) by the number
100 // of users. This is represented as a std::pair containing these two values
101 // (first element is the bytes freed). std::pair provides the necessary
102 // comparison operators.
103 using Priority = std::pair<int64, int64>;
104
ListScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)105 ListScheduler(HloComputation* computation,
106 const TuplePointsToAnalysis& points_to_analysis,
107 const BufferValue::SizeFunction& size_function,
108 const absl::flat_hash_map<const HloComputation*, int64>&
109 memory_by_computation)
110 : computation_(computation),
111 points_to_analysis_(points_to_analysis),
112 size_function_(size_function),
113 memory_by_computation_(memory_by_computation) {
114 // Create a map containing the LogicalBuffer uses for each HLO
115 // instruction. An HLO instruction "uses" a LogicalBuffer if the
116 // LogicalBuffer is in an operand of the instruction as indicated by
117 // points-to analysis.
118 for (auto* instruction : computation->instructions()) {
119 absl::flat_hash_set<const LogicalBuffer*> instr_uses;
120 for (auto* operand : instruction->operands()) {
121 points_to_analysis.GetPointsToSet(operand).ForEachElement(
122 [&](const ShapeIndex& /*index*/,
123 const PointsToSet::BufferList& buffers) {
124 instr_uses.insert(buffers.begin(), buffers.end());
125 });
126 }
127 buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
128 instr_uses.begin(), instr_uses.end());
129 }
130
131 // Create map containing the number of unscheduled uses (hlo instructions)
132 // of each logical buffer.
133 unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
134 for (auto* instruction : computation->instructions()) {
135 for (auto* buffer :
136 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
137 unscheduled_use_count_[buffer] = 0;
138 }
139 }
140 for (auto* instruction : computation->instructions()) {
141 for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
142 ++unscheduled_use_count_[buffer];
143 }
144 }
145
146 // Buffers live out of the computation have an implicit use at the end of
147 // the computation.
148 for (const LogicalBuffer* live_out_buffer :
149 points_to_analysis.GetPointsToSet(computation->root_instruction())
150 .CreateFlattenedSet()) {
151 ++unscheduled_use_count_[live_out_buffer];
152 }
153 }
154
155 // Returns whether the memory used by the given buffer should be ignored by
156 // the scheduling heuristic.
IgnoreBuffer(const LogicalBuffer & buffer)157 static bool IgnoreBuffer(const LogicalBuffer& buffer) {
158 return IgnoreInstruction(*buffer.instruction());
159 }
160
161 // An entry in the worklist used by CreateSchedule. Corresponds to one
162 // HloInstruction, plus some cached metadata, saved for the purposes of making
163 // BytesFreedIfScheduled fast.
164 struct ReadyListEntry {
165 HloInstruction* instruction;
166
167 // The total size of all buffers defined by this instruction.
168 int64 bytes_defined;
169
170 // For each buffer B used by this instruction, we keep a pair (B, U), where
171 // U is the number of uses of B that have not yet been scheduled. This pair
172 // is a pointer into the unscheduled_use_count_ map, so it gets updated for
173 // free when we update counts in the map.
174 std::vector<const std::pair<const LogicalBuffer* const, int64>*>
175 used_buffer_unscheduled_use_counts;
176 };
177
178 // Creates a ReadyListEntry for the given instruction.
MakeReadyListEntry(HloInstruction * instruction)179 ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
180 ReadyListEntry entry;
181 entry.instruction = instruction;
182
183 entry.bytes_defined = 0;
184 for (auto* buffer :
185 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
186 if (!IgnoreBuffer(*buffer)) {
187 entry.bytes_defined += size_function_(*buffer);
188 }
189 }
190
191 for (auto* buffer : buffer_uses_.at(instruction)) {
192 if (IgnoreBuffer(*buffer)) {
193 continue;
194 }
195 auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
196 CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
197 entry.used_buffer_unscheduled_use_counts.push_back(
198 &*unscheduled_use_count_it);
199 }
200 return entry;
201 }
202
203 // Returns the number of bytes freed *after* the HLO instruction finishes.
204 // The current List algorithm only considers two states for an instruction:
205 // right before it runs, and after it finishes. We don't represent memory
206 // usage during the execution of an instruction. But if the instruction calls
207 // subcomputations, they are only live during the instruction's execution.
208 // We end up counting the memory used by subcomputations as memory "defined"
209 // by the instruction. This is not entirely accurate, but it is more accurate
210 // than not taking subcomputations into account at all. In the future, we may
211 // improve accounting for subcomputation memory (b/65409243).
BytesFreedIfScheduled(const ReadyListEntry & entry)212 int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
213 auto instruction = entry.instruction;
214 int64 freed_bytes = 0;
215 for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
216 auto buffer = kv->first;
217 auto use_count = kv->second;
218 if (use_count == 1) {
219 freed_bytes += size_function_(*buffer);
220 }
221 }
222 // We only count the memory usage of the largest subcomputation, instead of
223 // adding them all, because subcomputations won't execute in parallel.
224 int64 max_subcomputation_bytes = 0;
225 for (const auto* c : instruction->called_computations()) {
226 auto it = memory_by_computation_.find(c);
227 if (it != memory_by_computation_.end()) {
228 int64 subcomputation_bytes = it->second;
229 if (subcomputation_bytes > max_subcomputation_bytes) {
230 max_subcomputation_bytes = subcomputation_bytes;
231 }
232 }
233 }
234 int64 bytes_defined;
235 auto opcode = instruction->opcode();
236 if (max_subcomputation_bytes > 0 &&
237 (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
238 opcode == HloOpcode::kConditional)) {
239 // The output buffer of while/call/conditional is always aliased with the
240 // output buffer of the root instruction in the body. Don't double count.
241 bytes_defined = max_subcomputation_bytes;
242 } else {
243 bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
244 }
245 return freed_bytes - bytes_defined;
246 }
247
248 // Constructs the scheduling priority of the given instruction.
GetPriority(const ReadyListEntry & entry)249 Priority GetPriority(const ReadyListEntry& entry) {
250 // Try to cluster scalars as close together as possible so that if they are
251 // in unfused hlos, they can still live in machine registers without
252 // excessive spilling.
253 if (ShapeUtil::IsEffectiveScalar(entry.instruction->shape())) {
254 return {std::numeric_limits<int64>::max(),
255 std::numeric_limits<int64>::max()};
256 }
257 return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
258 }
259
CreateSchedule()260 HloInstructionSequence CreateSchedule() {
261 HloInstructionSequence schedule;
262
263 // Populate the ready list with instructions which have no operands or
264 // control predecessors.
265 absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
266 for (auto* instruction : computation_->instructions()) {
267 // TODO(b/34466113): Replace this and above with successors() or
268 // predecessors() when these methods are added to HloInstruction.
269 for (HloInstruction* user : instruction->users()) {
270 unscheduled_pred_count[user]++;
271 }
272 for (HloInstruction* succ : instruction->control_successors()) {
273 unscheduled_pred_count[succ]++;
274 }
275 }
276
277 // Use a multimap to sort ReadyListEntry according to their priority.
278 std::multimap<Priority, ReadyListEntry> ready_queue;
279
280 // Map of ready instructions to their iterators in ready_queue.
281 absl::flat_hash_map<const HloInstruction*,
282 std::multimap<Priority, ReadyListEntry>::iterator>
283 ready_instructions;
284
285 auto add_to_ready_queue = [&](HloInstruction* inst) {
286 auto entry = MakeReadyListEntry(inst);
287 auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
288 ready_instructions[inst] = it;
289 };
290
291 for (auto* instruction : computation_->instructions()) {
292 if (instruction->operands().empty() &&
293 instruction->control_predecessors().empty()) {
294 add_to_ready_queue(instruction);
295 }
296 }
297
298 while (!ready_queue.empty()) {
299 // Remove the selected instruction from the ready list and add it to the
300 // schedule.
301 auto best_it = ready_queue.end();
302 --best_it;
303 HloInstruction* best = best_it->second.instruction;
304 VLOG(2) << "Schedule instruction: " << best->ToShortString()
305 << " Bytes freed: " << best_it->first.first;
306 ready_queue.erase(best_it);
307 ready_instructions.erase(best);
308 schedule.push_back(best);
309 scheduled_instructions_.insert(best);
310
311 bool adjust_ready_queue = false;
312 // Update the unscheduled uses of the logical buffers.
313 for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
314 int64& count = unscheduled_use_count_[buffer];
315 CHECK_GT(count, 0);
316 --count;
317 if (count == 1) {
318 adjust_ready_queue = true;
319 }
320 }
321
322 // Add new instructions to ready list.
323 auto update_pred_count = [&](HloInstruction* inst) {
324 int64 pred_count = --unscheduled_pred_count.at(inst);
325 CHECK_GE(pred_count, 0);
326 if (pred_count == 0) {
327 add_to_ready_queue(inst);
328 }
329 };
330 // TODO(b/34466113): Replace this and above with successors() or
331 // predecessors() when these methods are added to HloInstruction.
332 for (HloInstruction* user : best->users()) {
333 update_pred_count(user);
334 }
335 for (HloInstruction* succ : best->control_successors()) {
336 update_pred_count(succ);
337 }
338 // The unscheduled use count for a buffer has changed to 1, so the
339 // priorities of some ready instructions may go up. We update them in the
340 // ready queue, so that they can appear earlier.
341 if (adjust_ready_queue) {
342 for (HloInstruction* operand : best->operands()) {
343 for (HloInstruction* operand_user : operand->users()) {
344 auto ready_instructions_it = ready_instructions.find(operand_user);
345 if (ready_instructions_it == ready_instructions.end()) {
346 continue;
347 }
348 auto ready_queue_it = ready_instructions_it->second;
349 auto& entry = ready_queue_it->second;
350 Priority new_priority = GetPriority(entry);
351 if (new_priority == ready_queue_it->first) {
352 continue;
353 }
354 // Create a new entry in ready_queue, then update
355 // ready_instructions[operand_user] to refer to the new entry.
356 ready_instructions_it->second =
357 ready_queue.emplace(new_priority, std::move(entry));
358 // Remove the old entry in ready_queue.
359 ready_queue.erase(ready_queue_it);
360 }
361 }
362 }
363 }
364 CHECK_EQ(schedule.size(), computation_->instruction_count());
365 CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
366
367 return schedule;
368 }
369
370 HloComputation* computation_;
371 const TuplePointsToAnalysis& points_to_analysis_;
372 const BufferValue::SizeFunction& size_function_;
373 // Computations are analyzed in post-order. When scheduling an instruction
374 // that includes subcomputations, such as a while loop, we use this map to
375 // look up the memory needed by subcomputations.
376 const absl::flat_hash_map<const HloComputation*, int64>&
377 memory_by_computation_;
378
379 // A map containing the LogicalBuffers that each instruction uses.
380 absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
381 buffer_uses_;
382
383 // A map containing the count of unscheduled HLOs which using a particular
384 // LogicalBuffer.
385 absl::flat_hash_map<const LogicalBuffer*, int64> unscheduled_use_count_;
386
387 // Set of instructions which have been scheduled.
388 absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
389 };
390
SumLogicalBufferSizes(const TuplePointsToAnalysis::BufferDefinitionVector & buffers,const BufferValue::SizeFunction & size_function)391 int64 SumLogicalBufferSizes(
392 const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
393 const BufferValue::SizeFunction& size_function) {
394 int64 size = 0;
395 for (const LogicalBuffer* buffer : buffers) {
396 size += size_function(*buffer);
397 }
398 return size;
399 }
400
ScheduleComputationHelper(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)401 StatusOr<HloInstructionSequence> ScheduleComputationHelper(
402 HloComputation* computation,
403 const TuplePointsToAnalysis& points_to_analysis,
404 const HloAliasAnalysis& alias_analysis,
405 const BufferValue::SizeFunction& size_function,
406 const MemorySchedulerAlgorithm& algorithm,
407 const absl::flat_hash_map<const HloComputation*, int64>&
408 memory_by_computation,
409 int64* peak_memory) {
410 VLOG(2) << "Computation: " << computation->name();
411
412 if (algorithm) {
413 return algorithm(computation, points_to_analysis, alias_analysis,
414 size_function, memory_by_computation, peak_memory);
415 }
416 return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis,
417 size_function, memory_by_computation,
418 peak_memory);
419 }
420
421 } // namespace
422
DFSMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)423 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
424 HloComputation* computation,
425 const TuplePointsToAnalysis& points_to_analysis,
426 const HloAliasAnalysis& alias_analysis,
427 const BufferValue::SizeFunction& size_function,
428 const absl::flat_hash_map<const HloComputation*, int64>&
429 memory_by_computation,
430 int64* peak_memory) {
431 // These variables are a hack to prevent overflows.
432 int64 cumulative_total_size = 0;
433 int64 total_hlos = computation->parent()->instruction_count();
434 absl::flat_hash_map<const HloInstruction*, int64> extra_users;
435 absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
436 for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
437 if (ListScheduler::IgnoreInstruction(*hlo)) {
438 extra_users[hlo] = 0;
439 total_sizes[hlo] = 0;
440 continue;
441 }
442 // This ordering is based on DFS post-order, with a heuristic to decide
443 // which operand to visit first. The heuristic is based on 'extra_users',
444 // which is simply users-1 for each instruction. By subtracting 1, we're
445 // saying that instructions with no users or a single user don't count;
446 // instructions with lots of fan-out will be visited earlier.
447 extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
448 int64 logical_buffer_size = SumLogicalBufferSizes(
449 points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
450 total_sizes[hlo] = logical_buffer_size;
451 cumulative_total_size += logical_buffer_size;
452 absl::flat_hash_set<const HloInstruction*> unique_operands(
453 hlo->operands().begin(), hlo->operands().end());
454 for (const HloInstruction* operand : unique_operands) {
455 extra_users[hlo] += extra_users[operand];
456 total_sizes[hlo] += total_sizes[operand];
457 }
458 // total_sizes[hlo] transitively includes the sizes of all nodes that
459 // lead to it. But computation is a DAG, so we are double-counting nodes,
460 // which can lead to overflows for large programs.
461 // cumulative_total_size caps the size to prevent overflows.
462 // Same for total_hlos: it prevents overflows on very large and branchy
463 // models, where the number of paths is exponential to the number of nodes.
464 // NOTE(dimvar): this is quite ugly and should be changed. It's unclear
465 // why we care about transitive sizes; when scheduling a node, its input
466 // and output buffers should be all that matters, not its "history".
467 total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
468 extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
469 }
470 CHECK_EQ(extra_users.size(), computation->instruction_count());
471 CHECK_EQ(total_sizes.size(), computation->instruction_count());
472
473 // Construct a total order based on DFS post-order, visiting operands in
474 // decreasing cumulative extra user order, and next by cumulative size, with a
475 // tiebreaker by name for determinism.
476 HloInstructionSequence sequence;
477 FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
478 sequence.push_back(hlo);
479 return Status::OK();
480 });
481 visitor.ReserveVisitStates(computation->instruction_count());
482 TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
483 &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
484 const HloInstruction* b) {
485 if (extra_users[a] != extra_users[b]) {
486 return extra_users[a] > extra_users[b];
487 }
488 if (total_sizes[a] != total_sizes[b]) {
489 return total_sizes[a] > total_sizes[b];
490 }
491 return a->name() < b->name();
492 }));
493 CHECK_EQ(sequence.size(), computation->instruction_count());
494 if (peak_memory) {
495 TF_ASSIGN_OR_RETURN(
496 *peak_memory, HeapSimulator::MinimumMemoryForComputation(
497 *computation, sequence, alias_analysis, size_function,
498 &memory_by_computation));
499 }
500 return sequence;
501 } // namespace xla
502
ComputationSchedulerToModuleScheduler(const MemorySchedulerAlgorithm & computation_scheduler)503 ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
504 const MemorySchedulerAlgorithm& computation_scheduler) {
505 return [computation_scheduler](
506 HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
507 const HloAliasAnalysis& alias_analysis,
508 const LogicalBuffer::SizeFunction& size_func,
509 int64* peak_memory) -> StatusOr<HloSchedule> {
510 HloSchedule schedule(module);
511 absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
512 for (auto* computation : module->MakeComputationPostOrder()) {
513 if (!computation->IsFusionComputation()) {
514 TF_ASSIGN_OR_RETURN(
515 HloInstructionSequence computation_sequence,
516 ScheduleComputationHelper(
517 computation, points_to_analysis, alias_analysis, size_func,
518 computation_scheduler, memory_by_computation, nullptr));
519 schedule.set_sequence(computation, std::move(computation_sequence));
520 }
521 }
522 if (peak_memory) {
523 TF_ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule(
524 schedule, size_func));
525 }
526 return std::move(schedule);
527 };
528 }
529
ListMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)530 StatusOr<HloInstructionSequence> ListMemoryScheduler(
531 HloComputation* computation,
532 const TuplePointsToAnalysis& points_to_analysis,
533 const HloAliasAnalysis& alias_analysis,
534 const BufferValue::SizeFunction& size_function,
535 const absl::flat_hash_map<const HloComputation*, int64>&
536 memory_by_computation,
537 int64* peak_memory) {
538 TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence,
539 ListScheduler::Run(computation, points_to_analysis,
540 size_function, memory_by_computation));
541 if (peak_memory) {
542 TF_ASSIGN_OR_RETURN(
543 *peak_memory, HeapSimulator::MinimumMemoryForComputation(
544 *computation, sequence, alias_analysis, size_function,
545 &memory_by_computation));
546 }
547 return sequence;
548 }
549
PostOrderMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)550 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
551 HloComputation* computation,
552 const TuplePointsToAnalysis& points_to_analysis,
553 const HloAliasAnalysis& alias_analysis,
554 const BufferValue::SizeFunction& size_function,
555 const absl::flat_hash_map<const HloComputation*, int64>&
556 memory_by_computation,
557 int64* peak_memory) {
558 HloInstructionSequence sequence(computation->MakeInstructionPostOrder());
559 if (peak_memory) {
560 TF_ASSIGN_OR_RETURN(
561 *peak_memory, HeapSimulator::MinimumMemoryForComputation(
562 *computation, sequence, alias_analysis, size_function,
563 &memory_by_computation));
564 }
565 return sequence;
566 }
567
DefaultMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)568 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
569 HloComputation* computation,
570 const TuplePointsToAnalysis& points_to_analysis,
571 const HloAliasAnalysis& alias_analysis,
572 const BufferValue::SizeFunction& size_function,
573 const absl::flat_hash_map<const HloComputation*, int64>&
574 memory_by_computation,
575 int64* peak_memory) {
576 // We try a few schedulers and choose whichever returns a lower min-memory,
577 // not accounting for fragmentation.
578 // - List is a scheduler that uses greedy heuristics.
579 // - DFS visits HLOs in postorder, with a heuristic to decide the order of
580 // children.
581 // - Postorder does not use any heuristics.
582 // List wins for most of our benchmarks; postorder-based schedulers win for
583 // some RNNs.
584 int64 list_memory;
585 TF_ASSIGN_OR_RETURN(
586 HloInstructionSequence list_sequence,
587 ListMemoryScheduler(computation, points_to_analysis, alias_analysis,
588 size_function, memory_by_computation, &list_memory));
589 VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
590
591 int64 dfs_memory;
592 TF_ASSIGN_OR_RETURN(
593 HloInstructionSequence dfs_sequence,
594 DFSMemoryScheduler(computation, points_to_analysis, alias_analysis,
595 size_function, memory_by_computation, &dfs_memory));
596 VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
597
598 int64 post_order_memory;
599 TF_ASSIGN_OR_RETURN(
600 HloInstructionSequence post_order_sequence,
601 PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis,
602 size_function, memory_by_computation,
603 &post_order_memory));
604 VLOG(2) << "Min-memory post order sequence: "
605 << HumanReadableNumBytes(post_order_memory);
606
607 auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
608 if (peak_memory) {
609 *peak_memory = min_memory;
610 }
611
612 if (min_memory == list_memory) {
613 VLOG(2) << "Chose min-memory list sequence: "
614 << HumanReadableNumBytes(list_memory);
615 return list_sequence;
616 } else if (min_memory == dfs_memory) {
617 VLOG(2) << "Chose min-memory dfs sequence: "
618 << HumanReadableNumBytes(dfs_memory);
619 return dfs_sequence;
620 } else {
621 VLOG(2) << "Chose min-memory post_order sequence: "
622 << HumanReadableNumBytes(post_order_memory);
623 return post_order_sequence;
624 }
625 }
626
DefaultModuleScheduler(HloModule * module,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,int64 * peak_memory)627 StatusOr<HloSchedule> DefaultModuleScheduler(
628 HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
629 const HloAliasAnalysis& alias_analysis,
630 const BufferValue::SizeFunction& size_function, int64* peak_memory) {
631 // We try a few schedulers and choose whichever returns a lower min-memory,
632 // not accounting for fragmentation.
633 // - List is a scheduler that uses greedy heuristics.
634 // - DFS visits HLOs in postorder, with a heuristic to decide the order of
635 // children.
636 // - Postorder does not use any heuristics.
637 // List wins for most of our benchmarks; postorder-based schedulers win for
638 // some RNNs.
639 int64 list_memory;
640 TF_ASSIGN_OR_RETURN(
641 HloSchedule list_sequence,
642 ComputationSchedulerToModuleScheduler(ListMemoryScheduler)(
643 module, points_to_analysis, alias_analysis, size_function,
644 &list_memory));
645
646 VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
647
648 int64 dfs_memory;
649 TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence,
650 ComputationSchedulerToModuleScheduler(DFSMemoryScheduler)(
651 module, points_to_analysis, alias_analysis,
652 size_function, &dfs_memory));
653 VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
654
655 int64 post_order_memory;
656 TF_ASSIGN_OR_RETURN(
657 HloSchedule post_order_sequence,
658 ComputationSchedulerToModuleScheduler(PostOrderMemoryScheduler)(
659 module, points_to_analysis, alias_analysis, size_function,
660 &post_order_memory));
661 VLOG(2) << "Min-memory post order sequence: "
662 << HumanReadableNumBytes(post_order_memory);
663
664 auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
665 if (peak_memory) {
666 *peak_memory = min_memory;
667 }
668
669 if (min_memory == list_memory) {
670 VLOG(2) << "Chose min-memory list sequence: "
671 << HumanReadableNumBytes(list_memory);
672 return list_sequence;
673 } else if (min_memory == dfs_memory) {
674 VLOG(2) << "Chose min-memory dfs sequence: "
675 << HumanReadableNumBytes(dfs_memory);
676 return dfs_sequence;
677 } else {
678 VLOG(2) << "Chose min-memory post_order sequence: "
679 << HumanReadableNumBytes(post_order_memory);
680 return post_order_sequence;
681 }
682 }
683
ScheduleModule(HloModule * module,const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm,int64 * peak_memory)684 StatusOr<HloSchedule> ScheduleModule(
685 HloModule* module, const BufferValue::SizeFunction& size_function,
686 const ModuleSchedulerAlgorithm& algorithm, int64* peak_memory) {
687 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
688 TuplePointsToAnalysis::Run(module));
689 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
690 HloAliasAnalysis::Run(module));
691
692 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
693 (algorithm ? algorithm : DefaultModuleScheduler)(
694 module, *points_to_analysis, *alias_analysis,
695 size_function, peak_memory));
696
697 TF_RETURN_IF_ERROR(schedule.Verify());
698
699 return std::move(schedule);
700 }
701
ScheduleComputation(HloComputation * computation,const BufferValue::SizeFunction & size_function)702 StatusOr<HloInstructionSequence> ScheduleComputation(
703 HloComputation* computation,
704 const BufferValue::SizeFunction& size_function) {
705 CHECK(!computation->IsFusionComputation());
706 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
707 TuplePointsToAnalysis::Run(computation->parent()));
708 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
709 HloAliasAnalysis::Run(computation->parent()));
710 absl::flat_hash_map<const HloComputation*, int64> empty_map;
711 return ScheduleComputationHelper(computation, *points_to_analysis,
712 *alias_analysis, size_function, nullptr,
713 empty_map, nullptr);
714 }
715
HloMemoryScheduler(const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm)716 HloMemoryScheduler::HloMemoryScheduler(
717 const BufferValue::SizeFunction& size_function,
718 const ModuleSchedulerAlgorithm& algorithm)
719 : size_function_(size_function), algorithm_(algorithm) {}
720
Run(HloModule * module)721 StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
722 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
723 ScheduleModule(module, size_function_, algorithm_));
724 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
725 return true;
726 }
727
Run(HloModule * module)728 StatusOr<bool> HloTrivialScheduler::Run(HloModule* module) {
729 HloSchedule schedule(module);
730 for (HloComputation* computation : module->MakeComputationPostOrder()) {
731 if (!computation->IsFusionComputation()) {
732 HloInstructionSequence& computation_sequence =
733 schedule.GetOrCreateSequence(computation);
734 FunctionVisitor visitor(
735 [&computation_sequence](HloInstruction* instruction) {
736 computation_sequence.push_back(instruction);
737 return Status::OK();
738 });
739 visitor.ReserveVisitStates(computation->instruction_count());
740 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
741 }
742 }
743 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
744 return true;
745 }
746
Run(HloModule * module)747 StatusOr<bool> HloDescheduler::Run(HloModule* module) {
748 bool changed = module->has_schedule();
749 module->clear_schedule();
750 return changed;
751 }
752
753 } // namespace xla
754