• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17 
18 #include <map>
19 #include <queue>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/heap_simulator.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 namespace {
41 
42 using ::tensorflow::strings::HumanReadableNumBytes;
43 
44 // Class implementing a list scheduler of HLO instructions which produces a
45 // sequence which minimizes memory usage by preferring to schedule the node that
46 // frees bigger buffer and defines smaller outputs.
47 //
48 // Note that list scheduler is a greedy algorithm which cannot guarantee a
49 // global optimal solution. As a counterexample, considering the following
50 // graph:
51 //
52 //      +--> B ===> C -------+
53 // A -> |                    |
54 //      |                    v
55 //      +--> D ---> F=======>G
56 //      |           ^
57 //      |           |
58 //      +--> E -----+
59 //
60 //  --> : Buffer with size 1
61 //  ==> : Buffer with size 2
62 //
63 // The list scheduler will always try to defer scheduling B in a greedy way
64 // since its output buffer is bigger than input. The sequence it creates will
65 // be:
66 //   A D E F B C G
67 // , which has a maximum memory usage of 6 (B is alive while F is executing).
68 //
69 // An optimal way to shedule the previous graph is:
70 //   A B C D E F G
71 // , which has a maximum memory usage of 5 (when F is executing).
72 //
73 class ListScheduler {
74  public:
75   // Construct and return a memory-minimizing sequence of HLO instructions
76   // containing the given HLO computation.
Run(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)77   static StatusOr<HloInstructionSequence> Run(
78       HloComputation* computation,
79       const TuplePointsToAnalysis& points_to_analysis,
80       const LogicalBuffer::SizeFunction& size_function,
81       const absl::flat_hash_map<const HloComputation*, int64>&
82           memory_by_computation) {
83     ListScheduler scheduler(computation, points_to_analysis, size_function,
84                             memory_by_computation);
85     return scheduler.CreateSchedule();
86   }
87 
88   // Returns whether the memory used by the given HLO should be ignored by the
89   // scheduling heuristic.
IgnoreInstruction(const HloInstruction & instruction)90   static bool IgnoreInstruction(const HloInstruction& instruction) {
91     return instruction.opcode() == HloOpcode::kParameter ||
92            instruction.opcode() == HloOpcode::kConstant;
93   }
94 
95  private:
96   // The scheduling priority of an instruction is first the number of bytes
97   // freed by scheduling the instruction, and second (tie-breaker) by the number
98   // of users. This is represented as a std::pair containing these two values
99   // (first element is the bytes freed). std::pair provides the necessary
100   // comparison operators.
101   using Priority = std::pair<int64, int64>;
102 
ListScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)103   ListScheduler(HloComputation* computation,
104                 const TuplePointsToAnalysis& points_to_analysis,
105                 const LogicalBuffer::SizeFunction& size_function,
106                 const absl::flat_hash_map<const HloComputation*, int64>&
107                     memory_by_computation)
108       : computation_(computation),
109         points_to_analysis_(points_to_analysis),
110         size_function_(size_function),
111         memory_by_computation_(memory_by_computation) {
112     // Create a map containing the LogicalBuffer uses for each HLO
113     // instruction. An HLO instruction "uses" a LogicalBuffer if the
114     // LogicalBuffer is in an operand of the instruction as indicated by
115     // points-to analysis.
116     for (auto* instruction : computation->instructions()) {
117       absl::flat_hash_set<const LogicalBuffer*> instr_uses;
118       for (auto* operand : instruction->operands()) {
119         points_to_analysis.GetPointsToSet(operand).ForEachElement(
120             [&](const ShapeIndex& /*index*/,
121                 const PointsToSet::BufferList& buffers) {
122               instr_uses.insert(buffers.begin(), buffers.end());
123             });
124       }
125       buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
126           instr_uses.begin(), instr_uses.end());
127     }
128 
129     // Create map containing the number of unscheduled uses (hlo instructions)
130     // of each logical buffer.
131     for (auto* instruction : computation->instructions()) {
132       for (auto* buffer :
133            points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
134         unscheduled_use_count_[buffer] = 0;
135       }
136     }
137     for (auto* instruction : computation->instructions()) {
138       for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
139         ++unscheduled_use_count_[buffer];
140       }
141     }
142 
143     // Buffers live out of the computation have an implicit use at the end of
144     // the computation.
145     for (const LogicalBuffer* live_out_buffer :
146          points_to_analysis.GetPointsToSet(computation->root_instruction())
147              .CreateFlattenedSet()) {
148       ++unscheduled_use_count_[live_out_buffer];
149     }
150   }
151 
152   // Returns whether the memory used by the given buffer should be ignored by
153   // the scheduling heuristic.
IgnoreBuffer(const LogicalBuffer & buffer)154   static bool IgnoreBuffer(const LogicalBuffer& buffer) {
155     return IgnoreInstruction(*buffer.instruction());
156   }
157 
158   // An entry in the worklist used by CreateSchedule.  Corresponds to one
159   // HloInstruction, plus some cached metadata, saved for the purposes of making
160   // BytesFreedIfScheduled fast.
161   struct ReadyListEntry {
162     HloInstruction* instruction;
163 
164     // The total size of all buffers defined by this instruction.
165     int64 bytes_defined;
166 
167     // For each buffer B used by this instruction, we keep a pair (B, U), where
168     // U is the number of uses of B that have not yet been scheduled. This pair
169     // is a pointer into the unscheduled_use_count_ map, so it gets updated for
170     // free when we update counts in the map.
171     std::vector<const std::pair<const LogicalBuffer* const, int64>*>
172         used_buffer_unscheduled_use_counts;
173   };
174 
175   // Creates a ReadyListEntry for the given instruction.
MakeReadyListEntry(HloInstruction * instruction)176   ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
177     ReadyListEntry entry;
178     entry.instruction = instruction;
179 
180     entry.bytes_defined = 0;
181     for (auto* buffer :
182          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
183       if (!IgnoreBuffer(*buffer)) {
184         entry.bytes_defined += size_function_(*buffer);
185       }
186     }
187 
188     for (auto* buffer : buffer_uses_.at(instruction)) {
189       if (IgnoreBuffer(*buffer)) {
190         continue;
191       }
192       auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
193       CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
194       entry.used_buffer_unscheduled_use_counts.push_back(
195           &*unscheduled_use_count_it);
196     }
197     return entry;
198   }
199 
200   // Returns the number of bytes freed *after* the HLO instruction finishes.
201   // The current List algorithm only considers two states for an instruction:
202   // right before it runs, and after it finishes. We don't represent memory
203   // usage during the execution of an instruction. But if the instruction calls
204   // subcomputations, they are only live during the instruction's execution.
205   // We end up counting the memory used by subcomputations as memory "defined"
206   // by the instruction. This is not entirely accurate, but it is more accurate
207   // than not taking subcomputations into account at all. In the future, we may
208   // improve accounting for subcomputation memory (b/65409243).
BytesFreedIfScheduled(const ReadyListEntry & entry)209   int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
210     auto instruction = entry.instruction;
211     auto opcode = instruction->opcode();
212     // To keep the device busy between a host send and send-done, we schedule
213     // the send done as late as possible. Same for host recv-done. This is a
214     // hack because packing of computation between channel instructions
215     // normally happens in the module group scheduler, and the memory scheduler
216     // only tries to minimize memory.
217     if ((opcode == HloOpcode::kSendDone || opcode == HloOpcode::kRecvDone) &&
218         DynCast<HloSendRecvInstruction>(instruction)->is_host_transfer()) {
219       return INT_MIN;
220     }
221 
222     int64 freed_bytes = 0;
223     for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
224       auto buffer = kv->first;
225       auto use_count = kv->second;
226       if (use_count == 1) {
227         freed_bytes += size_function_(*buffer);
228       }
229     }
230     // We only count the memory usage of the largest subcomputation, instead of
231     // adding them all, because subcomputations won't execute in parallel.
232     int64 max_subcomputation_bytes = 0;
233     for (const auto* c : instruction->called_computations()) {
234       auto it = memory_by_computation_.find(c);
235       if (it != memory_by_computation_.end()) {
236         int64 subcomputation_bytes = it->second;
237         if (subcomputation_bytes > max_subcomputation_bytes) {
238           max_subcomputation_bytes = subcomputation_bytes;
239         }
240       }
241     }
242     int64 bytes_defined;
243     if (max_subcomputation_bytes > 0 &&
244         (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
245          opcode == HloOpcode::kConditional)) {
246       // The output buffer of while/call/conditional is always aliased with the
247       // output buffer of the root instruction in the body. Don't double count.
248       bytes_defined = max_subcomputation_bytes;
249     } else {
250       bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
251     }
252     return freed_bytes - bytes_defined;
253   }
254 
255   // Constructs the scheduling priority of the given instruction.
GetPriority(const ReadyListEntry & entry)256   Priority GetPriority(const ReadyListEntry& entry) {
257     return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
258   }
259 
CreateSchedule()260   HloInstructionSequence CreateSchedule() {
261     HloInstructionSequence schedule;
262 
263     // Populate the ready list with instructions which have no operands or
264     // control predecessors.
265     absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
266     for (auto* instruction : computation_->instructions()) {
267       // TODO(b/34466113): Replace this and above with successors() or
268       // predecessors() when these methods are added to HloInstruction.
269       for (HloInstruction* user : instruction->users()) {
270         unscheduled_pred_count[user]++;
271       }
272       for (HloInstruction* succ : instruction->control_successors()) {
273         unscheduled_pred_count[succ]++;
274       }
275     }
276 
277     // Use a multimap to sort ReadyListEntry according to their priority.
278     std::multimap<Priority, ReadyListEntry> ready_queue;
279 
280     // Map of ready instructions to their iterators in ready_queue.
281     absl::flat_hash_map<const HloInstruction*,
282                         std::multimap<Priority, ReadyListEntry>::iterator>
283         ready_instructions;
284 
285     auto add_to_ready_queue = [&](HloInstruction* inst) {
286       auto entry = MakeReadyListEntry(inst);
287       auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
288       ready_instructions[inst] = it;
289     };
290 
291     for (auto* instruction : computation_->instructions()) {
292       if (instruction->operands().empty() &&
293           instruction->control_predecessors().empty()) {
294         add_to_ready_queue(instruction);
295       }
296     }
297 
298     while (!ready_queue.empty()) {
299       // Remove the selected instruction from the ready list and add it to the
300       // schedule.
301       auto best_it = ready_queue.end();
302       --best_it;
303       HloInstruction* best = best_it->second.instruction;
304       VLOG(2) << "Schedule instruction: " << best->ToShortString()
305               << " Bytes freed: " << best_it->first.first;
306       ready_queue.erase(best_it);
307       ready_instructions.erase(best);
308       schedule.push_back(best);
309       scheduled_instructions_.insert(best);
310 
311       bool adjust_ready_queue = false;
312       // Update the unscheduled uses of the logical buffers.
313       for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
314         int64& count = unscheduled_use_count_[buffer];
315         CHECK_GT(count, 0);
316         --count;
317         if (count == 1) {
318           adjust_ready_queue = true;
319         }
320       }
321 
322       // Add new instructions to ready list.
323       auto update_pred_count = [&](HloInstruction* inst) {
324         int64 pred_count = --unscheduled_pred_count.at(inst);
325         CHECK_GE(pred_count, 0);
326         if (pred_count == 0) {
327           add_to_ready_queue(inst);
328         }
329       };
330       // TODO(b/34466113): Replace this and above with successors() or
331       // predecessors() when these methods are added to HloInstruction.
332       for (HloInstruction* user : best->users()) {
333         update_pred_count(user);
334       }
335       for (HloInstruction* succ : best->control_successors()) {
336         update_pred_count(succ);
337       }
338       // The unscheduled use count for a buffer has changed to 1, so the
339       // priorities of some ready instructions may go up. We update them in the
340       // ready queue, so that they can appear earlier.
341       if (adjust_ready_queue) {
342         for (HloInstruction* operand : best->operands()) {
343           for (HloInstruction* operand_user : operand->users()) {
344             auto ready_instructions_it = ready_instructions.find(operand_user);
345             if (ready_instructions_it == ready_instructions.end()) {
346               continue;
347             }
348             auto ready_queue_it = ready_instructions_it->second;
349             auto& entry = ready_queue_it->second;
350             Priority new_priority = GetPriority(entry);
351             if (new_priority == ready_queue_it->first) {
352               continue;
353             }
354             // Create a new entry in ready_queue, then update
355             // ready_instructions[operand_user] to refer to the new entry.
356             ready_instructions_it->second =
357                 ready_queue.emplace(new_priority, std::move(entry));
358             // Remove the old entry in ready_queue.
359             ready_queue.erase(ready_queue_it);
360           }
361         }
362       }
363     }
364     CHECK_EQ(schedule.size(), computation_->instruction_count());
365     CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
366 
367     return schedule;
368   }
369 
370   HloComputation* computation_;
371   const TuplePointsToAnalysis& points_to_analysis_;
372   const LogicalBuffer::SizeFunction& size_function_;
373   // Computations are analyzed in post-order. When scheduling an instruction
374   // that includes subcomputations, such as a while loop, we use this map to
375   // look up the memory needed by subcomputations.
376   const absl::flat_hash_map<const HloComputation*, int64>&
377       memory_by_computation_;
378 
379   // A map containing the LogicalBuffers that each instruction uses.
380   absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
381       buffer_uses_;
382 
383   // A map containing the count of unscheduled HLOs which using a particular
384   // LogicalBuffer.
385   absl::flat_hash_map<const LogicalBuffer*, int64> unscheduled_use_count_;
386 
387   // Set of instructions which have been scheduled.
388   absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
389 };
390 
SumLogicalBufferSizes(const TuplePointsToAnalysis::BufferDefinitionVector & buffers,const LogicalBuffer::SizeFunction & size_function)391 int64 SumLogicalBufferSizes(
392     const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
393     const LogicalBuffer::SizeFunction& size_function) {
394   int64 size = 0;
395   for (const LogicalBuffer* buffer : buffers) {
396     size += size_function(*buffer);
397   }
398   return size;
399 }
400 
ScheduleComputationHelper(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)401 StatusOr<HloInstructionSequence> ScheduleComputationHelper(
402     HloComputation* computation,
403     const TuplePointsToAnalysis& points_to_analysis,
404     const LogicalBuffer::SizeFunction& size_function,
405     const MemorySchedulerAlgorithm& algorithm,
406     const absl::flat_hash_map<const HloComputation*, int64>&
407         memory_by_computation) {
408   VLOG(2) << "Computation: " << computation->name();
409   if (algorithm) {
410     return algorithm(computation, points_to_analysis, size_function,
411                      memory_by_computation);
412   }
413   return DefaultMemoryScheduler(computation, points_to_analysis, size_function,
414                                 memory_by_computation);
415 }
416 
417 }  // namespace
418 
DFSMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)419 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
420     HloComputation* computation,
421     const TuplePointsToAnalysis& points_to_analysis,
422     const LogicalBuffer::SizeFunction& size_function,
423     const absl::flat_hash_map<const HloComputation*, int64>&
424         memory_by_computation) {
425   // These variables are a hack to prevent overflows.
426   int64 cumulative_total_size = 0;
427   int64 total_hlos = computation->parent()->instruction_count();
428   absl::flat_hash_map<const HloInstruction*, int64> extra_users;
429   absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
430   for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
431     if (ListScheduler::IgnoreInstruction(*hlo)) {
432       extra_users[hlo] = 0;
433       total_sizes[hlo] = 0;
434       continue;
435     }
436     // This ordering is based on DFS post-order, with a heuristic to decide
437     // which operand to visit first.  The heuristic is based on 'extra_users',
438     // which is simply users-1 for each instruction.  By subtracting 1, we're
439     // saying that instructions with no users or a single user don't count;
440     // instructions with lots of fan-out will be visited earlier.
441     extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
442     int64 logical_buffer_size = SumLogicalBufferSizes(
443         points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
444     total_sizes[hlo] = logical_buffer_size;
445     cumulative_total_size += logical_buffer_size;
446     absl::flat_hash_set<const HloInstruction*> unique_operands(
447         hlo->operands().begin(), hlo->operands().end());
448     for (const HloInstruction* operand : unique_operands) {
449       extra_users[hlo] += extra_users[operand];
450       total_sizes[hlo] += total_sizes[operand];
451     }
452     // total_sizes[hlo] transitively includes the sizes of all nodes that
453     // lead to it. But computation is a DAG, so we are double-counting nodes,
454     // which can lead to overflows for large programs.
455     // cumulative_total_size caps the size to prevent overflows.
456     // Same for total_hlos: it prevents overflows on very large and branchy
457     // models, where the number of paths is exponential to the number of nodes.
458     // NOTE(dimvar): this is quite ugly and should be changed. It's unclear
459     // why we care about transitive sizes; when scheduling a node, its input
460     // and output buffers should be all that matters, not its "history".
461     total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
462     extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
463   }
464   CHECK_EQ(extra_users.size(), computation->instruction_count());
465   CHECK_EQ(total_sizes.size(), computation->instruction_count());
466 
467   // Construct a total order based on DFS post-order, visiting operands in
468   // decreasing cumulative extra user order, and next by cumulative size, with a
469   // tiebreaker by name for determinism.
470   HloInstructionSequence sequence;
471   FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
472     sequence.push_back(hlo);
473     return Status::OK();
474   });
475   TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
476       &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
477                                              const HloInstruction* b) {
478         if (extra_users[a] != extra_users[b]) {
479           return extra_users[a] > extra_users[b];
480         }
481         if (total_sizes[a] != total_sizes[b]) {
482           return total_sizes[a] > total_sizes[b];
483         }
484         return a->name() < b->name();
485       }));
486   CHECK_EQ(sequence.size(), computation->instruction_count());
487   return sequence;
488 }  // namespace xla
489 
ListMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)490 StatusOr<HloInstructionSequence> ListMemoryScheduler(
491     HloComputation* computation,
492     const TuplePointsToAnalysis& points_to_analysis,
493     const LogicalBuffer::SizeFunction& size_function,
494     const absl::flat_hash_map<const HloComputation*, int64>&
495         memory_by_computation) {
496   return ListScheduler::Run(computation, points_to_analysis, size_function,
497                             memory_by_computation);
498 }
499 
PostOrderMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)500 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
501     HloComputation* computation,
502     const TuplePointsToAnalysis& points_to_analysis,
503     const LogicalBuffer::SizeFunction& size_function,
504     const absl::flat_hash_map<const HloComputation*, int64>&
505         memory_by_computation) {
506   return HloInstructionSequence(computation->MakeInstructionPostOrder());
507 }
508 
DefaultMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)509 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
510     HloComputation* computation,
511     const TuplePointsToAnalysis& points_to_analysis,
512     const LogicalBuffer::SizeFunction& size_function,
513     const absl::flat_hash_map<const HloComputation*, int64>&
514         memory_by_computation) {
515   // We try a few schedulers and choose whichever returns a lower min-memory,
516   // not accounting for fragmentation.
517   // - List is a scheduler that uses greedy heuristics.
518   // - DFS visits HLOs in postorder, with a heuristic to decide the order of
519   //   children.
520   // - Postorder does not use any heuristics.
521   // List wins for most of our benchmarks; postorder-based schedulers win for
522   // some RNNs.
523   TF_ASSIGN_OR_RETURN(
524       HloInstructionSequence list_sequence,
525       ListMemoryScheduler(computation, points_to_analysis, size_function,
526                           memory_by_computation));
527   TF_ASSIGN_OR_RETURN(const int64 list_memory,
528                       HeapSimulator::MinimumMemoryForComputation(
529                           *computation, list_sequence, points_to_analysis,
530                           size_function, &memory_by_computation));
531   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
532 
533   TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence,
534                       DFSMemoryScheduler(computation, points_to_analysis,
535                                          size_function, memory_by_computation));
536   TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
537                       HeapSimulator::MinimumMemoryForComputation(
538                           *computation, dfs_sequence, points_to_analysis,
539                           size_function, &memory_by_computation));
540   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
541 
542   TF_ASSIGN_OR_RETURN(
543       HloInstructionSequence post_order_sequence,
544       PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
545                                memory_by_computation));
546   TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
547                       HeapSimulator::MinimumMemoryForComputation(
548                           *computation, post_order_sequence, points_to_analysis,
549                           size_function, &memory_by_computation));
550   VLOG(2) << "Min-memory post order sequence: "
551           << HumanReadableNumBytes(post_order_memory);
552 
553   auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
554 
555   if (min_memory == list_memory) {
556     VLOG(2) << "Chose min-memory list sequence: "
557             << HumanReadableNumBytes(list_memory);
558     return list_sequence;
559   } else if (min_memory == dfs_memory) {
560     VLOG(2) << "Chose min-memory dfs sequence: "
561             << HumanReadableNumBytes(dfs_memory);
562     return dfs_sequence;
563   } else {
564     VLOG(2) << "Chose min-memory post_order sequence: "
565             << HumanReadableNumBytes(post_order_memory);
566     return post_order_sequence;
567   }
568 }
569 
ScheduleModule(HloModule * module,const LogicalBuffer::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm)570 StatusOr<HloSchedule> ScheduleModule(
571     HloModule* module, const LogicalBuffer::SizeFunction& size_function,
572     const MemorySchedulerAlgorithm& algorithm) {
573   HloSchedule schedule(module);
574   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
575                       TuplePointsToAnalysis::Run(module));
576   absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
577   for (auto* computation : module->MakeComputationPostOrder()) {
578     if (!computation->IsFusionComputation()) {
579       TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
580                           ScheduleComputationHelper(
581                               computation, *points_to_analysis, size_function,
582                               algorithm, memory_by_computation));
583       memory_by_computation[computation] =
584           HeapSimulator::MinimumMemoryForComputation(
585               *computation, computation_sequence, *points_to_analysis,
586               size_function, &memory_by_computation)
587               .ValueOrDie();
588       schedule.set_sequence(computation, std::move(computation_sequence));
589     }
590   }
591   VLOG(1) << "Module schedule:\n" << schedule;
592 
593   TF_RETURN_IF_ERROR(schedule.Verify());
594 
595   return std::move(schedule);
596 }
597 
ScheduleComputation(HloComputation * computation,const LogicalBuffer::SizeFunction & size_function)598 StatusOr<HloInstructionSequence> ScheduleComputation(
599     HloComputation* computation,
600     const LogicalBuffer::SizeFunction& size_function) {
601   CHECK(!computation->IsFusionComputation());
602   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
603                       TuplePointsToAnalysis::Run(computation->parent()));
604   absl::flat_hash_map<const HloComputation*, int64> empty_map;
605   return ScheduleComputationHelper(computation, *points_to_analysis,
606                                    size_function, nullptr, empty_map);
607 }
608 
HloMemoryScheduler(const LogicalBuffer::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm)609 HloMemoryScheduler::HloMemoryScheduler(
610     const LogicalBuffer::SizeFunction& size_function,
611     const MemorySchedulerAlgorithm& algorithm)
612     : size_function_(size_function), algorithm_(algorithm) {}
613 
Run(HloModule * module)614 StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
615   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
616                       ScheduleModule(module, size_function_, algorithm_));
617   TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
618   return true;
619 }
620 
Run(HloModule * module)621 StatusOr<bool> HloTrivialScheduler::Run(HloModule* module) {
622   HloSchedule schedule(module);
623   for (HloComputation* computation : module->MakeComputationPostOrder()) {
624     if (!computation->IsFusionComputation()) {
625       HloInstructionSequence& computation_sequence =
626           schedule.GetOrCreateSequence(computation);
627       TF_RETURN_IF_ERROR(computation->Accept(
628           [&computation_sequence](HloInstruction* instruction) {
629             computation_sequence.push_back(instruction);
630             return Status::OK();
631           }));
632     }
633   }
634   TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
635   return true;
636 }
637 
Run(HloModule * module)638 StatusOr<bool> HloDescheduler::Run(HloModule* module) {
639   bool changed = module->has_schedule();
640   module->clear_schedule();
641   return changed;
642 }
643 
644 }  // namespace xla
645