• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <map>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/platform/logging.h"
40 
41 namespace xla {
42 namespace {
43 
44 using ::tensorflow::strings::HumanReadableNumBytes;
45 
46 // Class implementing a list scheduler of HLO instructions which produces a
47 // sequence which minimizes memory usage by preferring to schedule the node that
48 // frees bigger buffer and defines smaller outputs.
49 //
50 // Note that list scheduler is a greedy algorithm which cannot guarantee a
51 // global optimal solution. As a counterexample, considering the following
52 // graph:
53 //
54 //      +--> B ===> C -------+
55 // A -> |                    |
56 //      |                    v
57 //      +--> D ---> F=======>G
58 //      |           ^
59 //      |           |
60 //      +--> E -----+
61 //
62 //  --> : Buffer with size 1
63 //  ==> : Buffer with size 2
64 //
65 // The list scheduler will always try to defer scheduling B in a greedy way
66 // since its output buffer is bigger than input. The sequence it creates will
67 // be:
68 //   A D E F B C G
69 // , which has a maximum memory usage of 6 (B is alive while F is executing).
70 //
71 // An optimal way to schedule the previous graph is:
72 //   A B C D E F G
73 // , which has a maximum memory usage of 5 (when F is executing).
74 //
75 class ListScheduler {
76  public:
77   // Construct and return a memory-minimizing sequence of HLO instructions
78   // containing the given HLO computation.
Run(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)79   static StatusOr<HloInstructionSequence> Run(
80       HloComputation* computation,
81       const TuplePointsToAnalysis& points_to_analysis,
82       const BufferValue::SizeFunction& size_function,
83       const absl::flat_hash_map<const HloComputation*, int64>&
84           memory_by_computation) {
85     ListScheduler scheduler(computation, points_to_analysis, size_function,
86                             memory_by_computation);
87     return scheduler.CreateSchedule();
88   }
89 
90   // Returns whether the memory used by the given HLO should be ignored by the
91   // scheduling heuristic.
IgnoreInstruction(const HloInstruction & instruction)92   static bool IgnoreInstruction(const HloInstruction& instruction) {
93     return instruction.opcode() == HloOpcode::kParameter ||
94            instruction.opcode() == HloOpcode::kConstant;
95   }
96 
97  private:
98   // The scheduling priority of an instruction is first the number of bytes
99   // freed by scheduling the instruction, and second (tie-breaker) by the number
100   // of users. This is represented as a std::pair containing these two values
101   // (first element is the bytes freed). std::pair provides the necessary
102   // comparison operators.
103   using Priority = std::pair<int64, int64>;
104 
ListScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)105   ListScheduler(HloComputation* computation,
106                 const TuplePointsToAnalysis& points_to_analysis,
107                 const BufferValue::SizeFunction& size_function,
108                 const absl::flat_hash_map<const HloComputation*, int64>&
109                     memory_by_computation)
110       : computation_(computation),
111         points_to_analysis_(points_to_analysis),
112         size_function_(size_function),
113         memory_by_computation_(memory_by_computation) {
114     // Create a map containing the LogicalBuffer uses for each HLO
115     // instruction. An HLO instruction "uses" a LogicalBuffer if the
116     // LogicalBuffer is in an operand of the instruction as indicated by
117     // points-to analysis.
118     for (auto* instruction : computation->instructions()) {
119       absl::flat_hash_set<const LogicalBuffer*> instr_uses;
120       for (auto* operand : instruction->operands()) {
121         points_to_analysis.GetPointsToSet(operand).ForEachElement(
122             [&](const ShapeIndex& /*index*/,
123                 const PointsToSet::BufferList& buffers) {
124               instr_uses.insert(buffers.begin(), buffers.end());
125             });
126       }
127       buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
128           instr_uses.begin(), instr_uses.end());
129     }
130 
131     // Create map containing the number of unscheduled uses (hlo instructions)
132     // of each logical buffer.
133     unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
134     for (auto* instruction : computation->instructions()) {
135       for (auto* buffer :
136            points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
137         unscheduled_use_count_[buffer] = 0;
138       }
139     }
140     for (auto* instruction : computation->instructions()) {
141       for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
142         ++unscheduled_use_count_[buffer];
143       }
144     }
145 
146     // Buffers live out of the computation have an implicit use at the end of
147     // the computation.
148     for (const LogicalBuffer* live_out_buffer :
149          points_to_analysis.GetPointsToSet(computation->root_instruction())
150              .CreateFlattenedSet()) {
151       ++unscheduled_use_count_[live_out_buffer];
152     }
153   }
154 
155   // Returns whether the memory used by the given buffer should be ignored by
156   // the scheduling heuristic.
IgnoreBuffer(const LogicalBuffer & buffer)157   static bool IgnoreBuffer(const LogicalBuffer& buffer) {
158     return IgnoreInstruction(*buffer.instruction());
159   }
160 
161   // An entry in the worklist used by CreateSchedule.  Corresponds to one
162   // HloInstruction, plus some cached metadata, saved for the purposes of making
163   // BytesFreedIfScheduled fast.
164   struct ReadyListEntry {
165     HloInstruction* instruction;
166 
167     // The total size of all buffers defined by this instruction.
168     int64 bytes_defined;
169 
170     // For each buffer B used by this instruction, we keep a pair (B, U), where
171     // U is the number of uses of B that have not yet been scheduled. This pair
172     // is a pointer into the unscheduled_use_count_ map, so it gets updated for
173     // free when we update counts in the map.
174     std::vector<const std::pair<const LogicalBuffer* const, int64>*>
175         used_buffer_unscheduled_use_counts;
176   };
177 
178   // Creates a ReadyListEntry for the given instruction.
MakeReadyListEntry(HloInstruction * instruction)179   ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
180     ReadyListEntry entry;
181     entry.instruction = instruction;
182 
183     entry.bytes_defined = 0;
184     for (auto* buffer :
185          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
186       if (!IgnoreBuffer(*buffer)) {
187         entry.bytes_defined += size_function_(*buffer);
188       }
189     }
190 
191     for (auto* buffer : buffer_uses_.at(instruction)) {
192       if (IgnoreBuffer(*buffer)) {
193         continue;
194       }
195       auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
196       CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
197       entry.used_buffer_unscheduled_use_counts.push_back(
198           &*unscheduled_use_count_it);
199     }
200     return entry;
201   }
202 
203   // Returns the number of bytes freed *after* the HLO instruction finishes.
204   // The current List algorithm only considers two states for an instruction:
205   // right before it runs, and after it finishes. We don't represent memory
206   // usage during the execution of an instruction. But if the instruction calls
207   // subcomputations, they are only live during the instruction's execution.
208   // We end up counting the memory used by subcomputations as memory "defined"
209   // by the instruction. This is not entirely accurate, but it is more accurate
210   // than not taking subcomputations into account at all. In the future, we may
211   // improve accounting for subcomputation memory (b/65409243).
BytesFreedIfScheduled(const ReadyListEntry & entry)212   int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
213     auto instruction = entry.instruction;
214     int64 freed_bytes = 0;
215     for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
216       auto buffer = kv->first;
217       auto use_count = kv->second;
218       if (use_count == 1) {
219         freed_bytes += size_function_(*buffer);
220       }
221     }
222     // We only count the memory usage of the largest subcomputation, instead of
223     // adding them all, because subcomputations won't execute in parallel.
224     int64 max_subcomputation_bytes = 0;
225     for (const auto* c : instruction->called_computations()) {
226       auto it = memory_by_computation_.find(c);
227       if (it != memory_by_computation_.end()) {
228         int64 subcomputation_bytes = it->second;
229         if (subcomputation_bytes > max_subcomputation_bytes) {
230           max_subcomputation_bytes = subcomputation_bytes;
231         }
232       }
233     }
234     int64 bytes_defined;
235     auto opcode = instruction->opcode();
236     if (max_subcomputation_bytes > 0 &&
237         (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
238          opcode == HloOpcode::kConditional)) {
239       // The output buffer of while/call/conditional is always aliased with the
240       // output buffer of the root instruction in the body. Don't double count.
241       bytes_defined = max_subcomputation_bytes;
242     } else {
243       bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
244     }
245     return freed_bytes - bytes_defined;
246   }
247 
248   // Constructs the scheduling priority of the given instruction.
GetPriority(const ReadyListEntry & entry)249   Priority GetPriority(const ReadyListEntry& entry) {
250     // Try to cluster scalars as close together as possible so that if they are
251     // in unfused hlos, they can still live in machine registers without
252     // excessive spilling.
253     if (ShapeUtil::IsEffectiveScalar(entry.instruction->shape())) {
254       return {std::numeric_limits<int64>::max(),
255               std::numeric_limits<int64>::max()};
256     }
257     return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
258   }
259 
CreateSchedule()260   HloInstructionSequence CreateSchedule() {
261     HloInstructionSequence schedule;
262 
263     // Populate the ready list with instructions which have no operands or
264     // control predecessors.
265     absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
266     for (auto* instruction : computation_->instructions()) {
267       // TODO(b/34466113): Replace this and above with successors() or
268       // predecessors() when these methods are added to HloInstruction.
269       for (HloInstruction* user : instruction->users()) {
270         unscheduled_pred_count[user]++;
271       }
272       for (HloInstruction* succ : instruction->control_successors()) {
273         unscheduled_pred_count[succ]++;
274       }
275     }
276 
277     // Use a multimap to sort ReadyListEntry according to their priority.
278     std::multimap<Priority, ReadyListEntry> ready_queue;
279 
280     // Map of ready instructions to their iterators in ready_queue.
281     absl::flat_hash_map<const HloInstruction*,
282                         std::multimap<Priority, ReadyListEntry>::iterator>
283         ready_instructions;
284 
285     auto add_to_ready_queue = [&](HloInstruction* inst) {
286       auto entry = MakeReadyListEntry(inst);
287       auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
288       ready_instructions[inst] = it;
289     };
290 
291     for (auto* instruction : computation_->instructions()) {
292       if (instruction->operands().empty() &&
293           instruction->control_predecessors().empty()) {
294         add_to_ready_queue(instruction);
295       }
296     }
297 
298     while (!ready_queue.empty()) {
299       // Remove the selected instruction from the ready list and add it to the
300       // schedule.
301       auto best_it = ready_queue.end();
302       --best_it;
303       HloInstruction* best = best_it->second.instruction;
304       VLOG(2) << "Schedule instruction: " << best->ToShortString()
305               << " Bytes freed: " << best_it->first.first;
306       ready_queue.erase(best_it);
307       ready_instructions.erase(best);
308       schedule.push_back(best);
309       scheduled_instructions_.insert(best);
310 
311       bool adjust_ready_queue = false;
312       // Update the unscheduled uses of the logical buffers.
313       for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
314         int64& count = unscheduled_use_count_[buffer];
315         CHECK_GT(count, 0);
316         --count;
317         if (count == 1) {
318           adjust_ready_queue = true;
319         }
320       }
321 
322       // Add new instructions to ready list.
323       auto update_pred_count = [&](HloInstruction* inst) {
324         int64 pred_count = --unscheduled_pred_count.at(inst);
325         CHECK_GE(pred_count, 0);
326         if (pred_count == 0) {
327           add_to_ready_queue(inst);
328         }
329       };
330       // TODO(b/34466113): Replace this and above with successors() or
331       // predecessors() when these methods are added to HloInstruction.
332       for (HloInstruction* user : best->users()) {
333         update_pred_count(user);
334       }
335       for (HloInstruction* succ : best->control_successors()) {
336         update_pred_count(succ);
337       }
338       // The unscheduled use count for a buffer has changed to 1, so the
339       // priorities of some ready instructions may go up. We update them in the
340       // ready queue, so that they can appear earlier.
341       if (adjust_ready_queue) {
342         for (HloInstruction* operand : best->operands()) {
343           for (HloInstruction* operand_user : operand->users()) {
344             auto ready_instructions_it = ready_instructions.find(operand_user);
345             if (ready_instructions_it == ready_instructions.end()) {
346               continue;
347             }
348             auto ready_queue_it = ready_instructions_it->second;
349             auto& entry = ready_queue_it->second;
350             Priority new_priority = GetPriority(entry);
351             if (new_priority == ready_queue_it->first) {
352               continue;
353             }
354             // Create a new entry in ready_queue, then update
355             // ready_instructions[operand_user] to refer to the new entry.
356             ready_instructions_it->second =
357                 ready_queue.emplace(new_priority, std::move(entry));
358             // Remove the old entry in ready_queue.
359             ready_queue.erase(ready_queue_it);
360           }
361         }
362       }
363     }
364     CHECK_EQ(schedule.size(), computation_->instruction_count());
365     CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
366 
367     return schedule;
368   }
369 
370   HloComputation* computation_;
371   const TuplePointsToAnalysis& points_to_analysis_;
372   const BufferValue::SizeFunction& size_function_;
373   // Computations are analyzed in post-order. When scheduling an instruction
374   // that includes subcomputations, such as a while loop, we use this map to
375   // look up the memory needed by subcomputations.
376   const absl::flat_hash_map<const HloComputation*, int64>&
377       memory_by_computation_;
378 
379   // A map containing the LogicalBuffers that each instruction uses.
380   absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
381       buffer_uses_;
382 
383   // A map containing the count of unscheduled HLOs which using a particular
384   // LogicalBuffer.
385   absl::flat_hash_map<const LogicalBuffer*, int64> unscheduled_use_count_;
386 
387   // Set of instructions which have been scheduled.
388   absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
389 };
390 
SumLogicalBufferSizes(const TuplePointsToAnalysis::BufferDefinitionVector & buffers,const BufferValue::SizeFunction & size_function)391 int64 SumLogicalBufferSizes(
392     const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
393     const BufferValue::SizeFunction& size_function) {
394   int64 size = 0;
395   for (const LogicalBuffer* buffer : buffers) {
396     size += size_function(*buffer);
397   }
398   return size;
399 }
400 
ScheduleComputationHelper(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)401 StatusOr<HloInstructionSequence> ScheduleComputationHelper(
402     HloComputation* computation,
403     const TuplePointsToAnalysis& points_to_analysis,
404     const HloAliasAnalysis& alias_analysis,
405     const BufferValue::SizeFunction& size_function,
406     const MemorySchedulerAlgorithm& algorithm,
407     const absl::flat_hash_map<const HloComputation*, int64>&
408         memory_by_computation,
409     int64* peak_memory) {
410   VLOG(2) << "Computation: " << computation->name();
411 
412   if (algorithm) {
413     return algorithm(computation, points_to_analysis, alias_analysis,
414                      size_function, memory_by_computation, peak_memory);
415   }
416   return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis,
417                                 size_function, memory_by_computation,
418                                 peak_memory);
419 }
420 
421 }  // namespace
422 
DFSMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)423 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
424     HloComputation* computation,
425     const TuplePointsToAnalysis& points_to_analysis,
426     const HloAliasAnalysis& alias_analysis,
427     const BufferValue::SizeFunction& size_function,
428     const absl::flat_hash_map<const HloComputation*, int64>&
429         memory_by_computation,
430     int64* peak_memory) {
431   // These variables are a hack to prevent overflows.
432   int64 cumulative_total_size = 0;
433   int64 total_hlos = computation->parent()->instruction_count();
434   absl::flat_hash_map<const HloInstruction*, int64> extra_users;
435   absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
436   for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
437     if (ListScheduler::IgnoreInstruction(*hlo)) {
438       extra_users[hlo] = 0;
439       total_sizes[hlo] = 0;
440       continue;
441     }
442     // This ordering is based on DFS post-order, with a heuristic to decide
443     // which operand to visit first.  The heuristic is based on 'extra_users',
444     // which is simply users-1 for each instruction.  By subtracting 1, we're
445     // saying that instructions with no users or a single user don't count;
446     // instructions with lots of fan-out will be visited earlier.
447     extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
448     int64 logical_buffer_size = SumLogicalBufferSizes(
449         points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
450     total_sizes[hlo] = logical_buffer_size;
451     cumulative_total_size += logical_buffer_size;
452     absl::flat_hash_set<const HloInstruction*> unique_operands(
453         hlo->operands().begin(), hlo->operands().end());
454     for (const HloInstruction* operand : unique_operands) {
455       extra_users[hlo] += extra_users[operand];
456       total_sizes[hlo] += total_sizes[operand];
457     }
458     // total_sizes[hlo] transitively includes the sizes of all nodes that
459     // lead to it. But computation is a DAG, so we are double-counting nodes,
460     // which can lead to overflows for large programs.
461     // cumulative_total_size caps the size to prevent overflows.
462     // Same for total_hlos: it prevents overflows on very large and branchy
463     // models, where the number of paths is exponential to the number of nodes.
464     // NOTE(dimvar): this is quite ugly and should be changed. It's unclear
465     // why we care about transitive sizes; when scheduling a node, its input
466     // and output buffers should be all that matters, not its "history".
467     total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
468     extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
469   }
470   CHECK_EQ(extra_users.size(), computation->instruction_count());
471   CHECK_EQ(total_sizes.size(), computation->instruction_count());
472 
473   // Construct a total order based on DFS post-order, visiting operands in
474   // decreasing cumulative extra user order, and next by cumulative size, with a
475   // tiebreaker by name for determinism.
476   HloInstructionSequence sequence;
477   FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
478     sequence.push_back(hlo);
479     return Status::OK();
480   });
481   visitor.ReserveVisitStates(computation->instruction_count());
482   TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
483       &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
484                                              const HloInstruction* b) {
485         if (extra_users[a] != extra_users[b]) {
486           return extra_users[a] > extra_users[b];
487         }
488         if (total_sizes[a] != total_sizes[b]) {
489           return total_sizes[a] > total_sizes[b];
490         }
491         return a->name() < b->name();
492       }));
493   CHECK_EQ(sequence.size(), computation->instruction_count());
494   if (peak_memory) {
495     TF_ASSIGN_OR_RETURN(
496         *peak_memory, HeapSimulator::MinimumMemoryForComputation(
497                           *computation, sequence, alias_analysis, size_function,
498                           &memory_by_computation));
499   }
500   return sequence;
501 }  // namespace xla
502 
ComputationSchedulerToModuleScheduler(const MemorySchedulerAlgorithm & computation_scheduler)503 ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
504     const MemorySchedulerAlgorithm& computation_scheduler) {
505   return [computation_scheduler](
506              HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
507              const HloAliasAnalysis& alias_analysis,
508              const LogicalBuffer::SizeFunction& size_func,
509              int64* peak_memory) -> StatusOr<HloSchedule> {
510     HloSchedule schedule(module);
511     absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
512     for (auto* computation : module->MakeComputationPostOrder()) {
513       if (!computation->IsFusionComputation()) {
514         TF_ASSIGN_OR_RETURN(
515             HloInstructionSequence computation_sequence,
516             ScheduleComputationHelper(
517                 computation, points_to_analysis, alias_analysis, size_func,
518                 computation_scheduler, memory_by_computation, nullptr));
519         schedule.set_sequence(computation, std::move(computation_sequence));
520       }
521     }
522     if (peak_memory) {
523       TF_ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule(
524                                             schedule, size_func));
525     }
526     return std::move(schedule);
527   };
528 }
529 
ListMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)530 StatusOr<HloInstructionSequence> ListMemoryScheduler(
531     HloComputation* computation,
532     const TuplePointsToAnalysis& points_to_analysis,
533     const HloAliasAnalysis& alias_analysis,
534     const BufferValue::SizeFunction& size_function,
535     const absl::flat_hash_map<const HloComputation*, int64>&
536         memory_by_computation,
537     int64* peak_memory) {
538   TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence,
539                       ListScheduler::Run(computation, points_to_analysis,
540                                          size_function, memory_by_computation));
541   if (peak_memory) {
542     TF_ASSIGN_OR_RETURN(
543         *peak_memory, HeapSimulator::MinimumMemoryForComputation(
544                           *computation, sequence, alias_analysis, size_function,
545                           &memory_by_computation));
546   }
547   return sequence;
548 }
549 
PostOrderMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)550 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
551     HloComputation* computation,
552     const TuplePointsToAnalysis& points_to_analysis,
553     const HloAliasAnalysis& alias_analysis,
554     const BufferValue::SizeFunction& size_function,
555     const absl::flat_hash_map<const HloComputation*, int64>&
556         memory_by_computation,
557     int64* peak_memory) {
558   HloInstructionSequence sequence(computation->MakeInstructionPostOrder());
559   if (peak_memory) {
560     TF_ASSIGN_OR_RETURN(
561         *peak_memory, HeapSimulator::MinimumMemoryForComputation(
562                           *computation, sequence, alias_analysis, size_function,
563                           &memory_by_computation));
564   }
565   return sequence;
566 }
567 
DefaultMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation,int64 * peak_memory)568 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
569     HloComputation* computation,
570     const TuplePointsToAnalysis& points_to_analysis,
571     const HloAliasAnalysis& alias_analysis,
572     const BufferValue::SizeFunction& size_function,
573     const absl::flat_hash_map<const HloComputation*, int64>&
574         memory_by_computation,
575     int64* peak_memory) {
576   // We try a few schedulers and choose whichever returns a lower min-memory,
577   // not accounting for fragmentation.
578   // - List is a scheduler that uses greedy heuristics.
579   // - DFS visits HLOs in postorder, with a heuristic to decide the order of
580   //   children.
581   // - Postorder does not use any heuristics.
582   // List wins for most of our benchmarks; postorder-based schedulers win for
583   // some RNNs.
584   int64 list_memory;
585   TF_ASSIGN_OR_RETURN(
586       HloInstructionSequence list_sequence,
587       ListMemoryScheduler(computation, points_to_analysis, alias_analysis,
588                           size_function, memory_by_computation, &list_memory));
589   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
590 
591   int64 dfs_memory;
592   TF_ASSIGN_OR_RETURN(
593       HloInstructionSequence dfs_sequence,
594       DFSMemoryScheduler(computation, points_to_analysis, alias_analysis,
595                          size_function, memory_by_computation, &dfs_memory));
596   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
597 
598   int64 post_order_memory;
599   TF_ASSIGN_OR_RETURN(
600       HloInstructionSequence post_order_sequence,
601       PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis,
602                                size_function, memory_by_computation,
603                                &post_order_memory));
604   VLOG(2) << "Min-memory post order sequence: "
605           << HumanReadableNumBytes(post_order_memory);
606 
607   auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
608   if (peak_memory) {
609     *peak_memory = min_memory;
610   }
611 
612   if (min_memory == list_memory) {
613     VLOG(2) << "Chose min-memory list sequence: "
614             << HumanReadableNumBytes(list_memory);
615     return list_sequence;
616   } else if (min_memory == dfs_memory) {
617     VLOG(2) << "Chose min-memory dfs sequence: "
618             << HumanReadableNumBytes(dfs_memory);
619     return dfs_sequence;
620   } else {
621     VLOG(2) << "Chose min-memory post_order sequence: "
622             << HumanReadableNumBytes(post_order_memory);
623     return post_order_sequence;
624   }
625 }
626 
DefaultModuleScheduler(HloModule * module,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,int64 * peak_memory)627 StatusOr<HloSchedule> DefaultModuleScheduler(
628     HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
629     const HloAliasAnalysis& alias_analysis,
630     const BufferValue::SizeFunction& size_function, int64* peak_memory) {
631   // We try a few schedulers and choose whichever returns a lower min-memory,
632   // not accounting for fragmentation.
633   // - List is a scheduler that uses greedy heuristics.
634   // - DFS visits HLOs in postorder, with a heuristic to decide the order of
635   //   children.
636   // - Postorder does not use any heuristics.
637   // List wins for most of our benchmarks; postorder-based schedulers win for
638   // some RNNs.
639   int64 list_memory;
640   TF_ASSIGN_OR_RETURN(
641       HloSchedule list_sequence,
642       ComputationSchedulerToModuleScheduler(ListMemoryScheduler)(
643           module, points_to_analysis, alias_analysis, size_function,
644           &list_memory));
645 
646   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
647 
648   int64 dfs_memory;
649   TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence,
650                       ComputationSchedulerToModuleScheduler(DFSMemoryScheduler)(
651                           module, points_to_analysis, alias_analysis,
652                           size_function, &dfs_memory));
653   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
654 
655   int64 post_order_memory;
656   TF_ASSIGN_OR_RETURN(
657       HloSchedule post_order_sequence,
658       ComputationSchedulerToModuleScheduler(PostOrderMemoryScheduler)(
659           module, points_to_analysis, alias_analysis, size_function,
660           &post_order_memory));
661   VLOG(2) << "Min-memory post order sequence: "
662           << HumanReadableNumBytes(post_order_memory);
663 
664   auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
665   if (peak_memory) {
666     *peak_memory = min_memory;
667   }
668 
669   if (min_memory == list_memory) {
670     VLOG(2) << "Chose min-memory list sequence: "
671             << HumanReadableNumBytes(list_memory);
672     return list_sequence;
673   } else if (min_memory == dfs_memory) {
674     VLOG(2) << "Chose min-memory dfs sequence: "
675             << HumanReadableNumBytes(dfs_memory);
676     return dfs_sequence;
677   } else {
678     VLOG(2) << "Chose min-memory post_order sequence: "
679             << HumanReadableNumBytes(post_order_memory);
680     return post_order_sequence;
681   }
682 }
683 
ScheduleModule(HloModule * module,const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm,int64 * peak_memory)684 StatusOr<HloSchedule> ScheduleModule(
685     HloModule* module, const BufferValue::SizeFunction& size_function,
686     const ModuleSchedulerAlgorithm& algorithm, int64* peak_memory) {
687   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
688                       TuplePointsToAnalysis::Run(module));
689   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
690                       HloAliasAnalysis::Run(module));
691 
692   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
693                       (algorithm ? algorithm : DefaultModuleScheduler)(
694                           module, *points_to_analysis, *alias_analysis,
695                           size_function, peak_memory));
696 
697   TF_RETURN_IF_ERROR(schedule.Verify());
698 
699   return std::move(schedule);
700 }
701 
ScheduleComputation(HloComputation * computation,const BufferValue::SizeFunction & size_function)702 StatusOr<HloInstructionSequence> ScheduleComputation(
703     HloComputation* computation,
704     const BufferValue::SizeFunction& size_function) {
705   CHECK(!computation->IsFusionComputation());
706   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
707                       TuplePointsToAnalysis::Run(computation->parent()));
708   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
709                       HloAliasAnalysis::Run(computation->parent()));
710   absl::flat_hash_map<const HloComputation*, int64> empty_map;
711   return ScheduleComputationHelper(computation, *points_to_analysis,
712                                    *alias_analysis, size_function, nullptr,
713                                    empty_map, nullptr);
714 }
715 
HloMemoryScheduler(const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm)716 HloMemoryScheduler::HloMemoryScheduler(
717     const BufferValue::SizeFunction& size_function,
718     const ModuleSchedulerAlgorithm& algorithm)
719     : size_function_(size_function), algorithm_(algorithm) {}
720 
Run(HloModule * module)721 StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
722   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
723                       ScheduleModule(module, size_function_, algorithm_));
724   TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
725   return true;
726 }
727 
Run(HloModule * module)728 StatusOr<bool> HloTrivialScheduler::Run(HloModule* module) {
729   HloSchedule schedule(module);
730   for (HloComputation* computation : module->MakeComputationPostOrder()) {
731     if (!computation->IsFusionComputation()) {
732       HloInstructionSequence& computation_sequence =
733           schedule.GetOrCreateSequence(computation);
734       FunctionVisitor visitor(
735           [&computation_sequence](HloInstruction* instruction) {
736             computation_sequence.push_back(instruction);
737             return Status::OK();
738           });
739       visitor.ReserveVisitStates(computation->instruction_count());
740       TF_RETURN_IF_ERROR(computation->Accept(&visitor));
741     }
742   }
743   TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
744   return true;
745 }
746 
Run(HloModule * module)747 StatusOr<bool> HloDescheduler::Run(HloModule* module) {
748   bool changed = module->has_schedule();
749   module->clear_schedule();
750   return changed;
751 }
752 
753 }  // namespace xla
754