1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17
18 #include <map>
19 #include <queue>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/heap_simulator.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40 namespace {
41
42 using ::tensorflow::strings::HumanReadableNumBytes;
43
44 // Class implementing a list scheduler of HLO instructions which produces a
45 // sequence which minimizes memory usage by preferring to schedule the node that
46 // frees bigger buffer and defines smaller outputs.
47 //
48 // Note that list scheduler is a greedy algorithm which cannot guarantee a
49 // global optimal solution. As a counterexample, considering the following
50 // graph:
51 //
52 // +--> B ===> C -------+
53 // A -> | |
54 // | v
55 // +--> D ---> F=======>G
56 // | ^
57 // | |
58 // +--> E -----+
59 //
60 // --> : Buffer with size 1
61 // ==> : Buffer with size 2
62 //
63 // The list scheduler will always try to defer scheduling B in a greedy way
64 // since its output buffer is bigger than input. The sequence it creates will
65 // be:
66 // A D E F B C G
67 // , which has a maximum memory usage of 6 (B is alive while F is executing).
68 //
69 // An optimal way to shedule the previous graph is:
70 // A B C D E F G
71 // , which has a maximum memory usage of 5 (when F is executing).
72 //
73 class ListScheduler {
74 public:
75 // Construct and return a memory-minimizing sequence of HLO instructions
76 // containing the given HLO computation.
Run(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)77 static StatusOr<HloInstructionSequence> Run(
78 HloComputation* computation,
79 const TuplePointsToAnalysis& points_to_analysis,
80 const LogicalBuffer::SizeFunction& size_function,
81 const absl::flat_hash_map<const HloComputation*, int64>&
82 memory_by_computation) {
83 ListScheduler scheduler(computation, points_to_analysis, size_function,
84 memory_by_computation);
85 return scheduler.CreateSchedule();
86 }
87
88 // Returns whether the memory used by the given HLO should be ignored by the
89 // scheduling heuristic.
IgnoreInstruction(const HloInstruction & instruction)90 static bool IgnoreInstruction(const HloInstruction& instruction) {
91 return instruction.opcode() == HloOpcode::kParameter ||
92 instruction.opcode() == HloOpcode::kConstant;
93 }
94
95 private:
96 // The scheduling priority of an instruction is first the number of bytes
97 // freed by scheduling the instruction, and second (tie-breaker) by the number
98 // of users. This is represented as a std::pair containing these two values
99 // (first element is the bytes freed). std::pair provides the necessary
100 // comparison operators.
101 using Priority = std::pair<int64, int64>;
102
ListScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)103 ListScheduler(HloComputation* computation,
104 const TuplePointsToAnalysis& points_to_analysis,
105 const LogicalBuffer::SizeFunction& size_function,
106 const absl::flat_hash_map<const HloComputation*, int64>&
107 memory_by_computation)
108 : computation_(computation),
109 points_to_analysis_(points_to_analysis),
110 size_function_(size_function),
111 memory_by_computation_(memory_by_computation) {
112 // Create a map containing the LogicalBuffer uses for each HLO
113 // instruction. An HLO instruction "uses" a LogicalBuffer if the
114 // LogicalBuffer is in an operand of the instruction as indicated by
115 // points-to analysis.
116 for (auto* instruction : computation->instructions()) {
117 absl::flat_hash_set<const LogicalBuffer*> instr_uses;
118 for (auto* operand : instruction->operands()) {
119 points_to_analysis.GetPointsToSet(operand).ForEachElement(
120 [&](const ShapeIndex& /*index*/,
121 const PointsToSet::BufferList& buffers) {
122 instr_uses.insert(buffers.begin(), buffers.end());
123 });
124 }
125 buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
126 instr_uses.begin(), instr_uses.end());
127 }
128
129 // Create map containing the number of unscheduled uses (hlo instructions)
130 // of each logical buffer.
131 for (auto* instruction : computation->instructions()) {
132 for (auto* buffer :
133 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
134 unscheduled_use_count_[buffer] = 0;
135 }
136 }
137 for (auto* instruction : computation->instructions()) {
138 for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
139 ++unscheduled_use_count_[buffer];
140 }
141 }
142
143 // Buffers live out of the computation have an implicit use at the end of
144 // the computation.
145 for (const LogicalBuffer* live_out_buffer :
146 points_to_analysis.GetPointsToSet(computation->root_instruction())
147 .CreateFlattenedSet()) {
148 ++unscheduled_use_count_[live_out_buffer];
149 }
150 }
151
152 // Returns whether the memory used by the given buffer should be ignored by
153 // the scheduling heuristic.
IgnoreBuffer(const LogicalBuffer & buffer)154 static bool IgnoreBuffer(const LogicalBuffer& buffer) {
155 return IgnoreInstruction(*buffer.instruction());
156 }
157
158 // An entry in the worklist used by CreateSchedule. Corresponds to one
159 // HloInstruction, plus some cached metadata, saved for the purposes of making
160 // BytesFreedIfScheduled fast.
161 struct ReadyListEntry {
162 HloInstruction* instruction;
163
164 // The total size of all buffers defined by this instruction.
165 int64 bytes_defined;
166
167 // For each buffer B used by this instruction, we keep a pair (B, U), where
168 // U is the number of uses of B that have not yet been scheduled. This pair
169 // is a pointer into the unscheduled_use_count_ map, so it gets updated for
170 // free when we update counts in the map.
171 std::vector<const std::pair<const LogicalBuffer* const, int64>*>
172 used_buffer_unscheduled_use_counts;
173 };
174
175 // Creates a ReadyListEntry for the given instruction.
MakeReadyListEntry(HloInstruction * instruction)176 ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
177 ReadyListEntry entry;
178 entry.instruction = instruction;
179
180 entry.bytes_defined = 0;
181 for (auto* buffer :
182 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
183 if (!IgnoreBuffer(*buffer)) {
184 entry.bytes_defined += size_function_(*buffer);
185 }
186 }
187
188 for (auto* buffer : buffer_uses_.at(instruction)) {
189 if (IgnoreBuffer(*buffer)) {
190 continue;
191 }
192 auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
193 CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
194 entry.used_buffer_unscheduled_use_counts.push_back(
195 &*unscheduled_use_count_it);
196 }
197 return entry;
198 }
199
200 // Returns the number of bytes freed *after* the HLO instruction finishes.
201 // The current List algorithm only considers two states for an instruction:
202 // right before it runs, and after it finishes. We don't represent memory
203 // usage during the execution of an instruction. But if the instruction calls
204 // subcomputations, they are only live during the instruction's execution.
205 // We end up counting the memory used by subcomputations as memory "defined"
206 // by the instruction. This is not entirely accurate, but it is more accurate
207 // than not taking subcomputations into account at all. In the future, we may
208 // improve accounting for subcomputation memory (b/65409243).
BytesFreedIfScheduled(const ReadyListEntry & entry)209 int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
210 auto instruction = entry.instruction;
211 auto opcode = instruction->opcode();
212 // To keep the device busy between a host send and send-done, we schedule
213 // the send done as late as possible. Same for host recv-done. This is a
214 // hack because packing of computation between channel instructions
215 // normally happens in the module group scheduler, and the memory scheduler
216 // only tries to minimize memory.
217 if ((opcode == HloOpcode::kSendDone || opcode == HloOpcode::kRecvDone) &&
218 DynCast<HloSendRecvInstruction>(instruction)->is_host_transfer()) {
219 return INT_MIN;
220 }
221
222 int64 freed_bytes = 0;
223 for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
224 auto buffer = kv->first;
225 auto use_count = kv->second;
226 if (use_count == 1) {
227 freed_bytes += size_function_(*buffer);
228 }
229 }
230 // We only count the memory usage of the largest subcomputation, instead of
231 // adding them all, because subcomputations won't execute in parallel.
232 int64 max_subcomputation_bytes = 0;
233 for (const auto* c : instruction->called_computations()) {
234 auto it = memory_by_computation_.find(c);
235 if (it != memory_by_computation_.end()) {
236 int64 subcomputation_bytes = it->second;
237 if (subcomputation_bytes > max_subcomputation_bytes) {
238 max_subcomputation_bytes = subcomputation_bytes;
239 }
240 }
241 }
242 int64 bytes_defined;
243 if (max_subcomputation_bytes > 0 &&
244 (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
245 opcode == HloOpcode::kConditional)) {
246 // The output buffer of while/call/conditional is always aliased with the
247 // output buffer of the root instruction in the body. Don't double count.
248 bytes_defined = max_subcomputation_bytes;
249 } else {
250 bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
251 }
252 return freed_bytes - bytes_defined;
253 }
254
255 // Constructs the scheduling priority of the given instruction.
GetPriority(const ReadyListEntry & entry)256 Priority GetPriority(const ReadyListEntry& entry) {
257 return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
258 }
259
CreateSchedule()260 HloInstructionSequence CreateSchedule() {
261 HloInstructionSequence schedule;
262
263 // Populate the ready list with instructions which have no operands or
264 // control predecessors.
265 absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
266 for (auto* instruction : computation_->instructions()) {
267 // TODO(b/34466113): Replace this and above with successors() or
268 // predecessors() when these methods are added to HloInstruction.
269 for (HloInstruction* user : instruction->users()) {
270 unscheduled_pred_count[user]++;
271 }
272 for (HloInstruction* succ : instruction->control_successors()) {
273 unscheduled_pred_count[succ]++;
274 }
275 }
276
277 // Use a multimap to sort ReadyListEntry according to their priority.
278 std::multimap<Priority, ReadyListEntry> ready_queue;
279
280 // Map of ready instructions to their iterators in ready_queue.
281 absl::flat_hash_map<const HloInstruction*,
282 std::multimap<Priority, ReadyListEntry>::iterator>
283 ready_instructions;
284
285 auto add_to_ready_queue = [&](HloInstruction* inst) {
286 auto entry = MakeReadyListEntry(inst);
287 auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
288 ready_instructions[inst] = it;
289 };
290
291 for (auto* instruction : computation_->instructions()) {
292 if (instruction->operands().empty() &&
293 instruction->control_predecessors().empty()) {
294 add_to_ready_queue(instruction);
295 }
296 }
297
298 while (!ready_queue.empty()) {
299 // Remove the selected instruction from the ready list and add it to the
300 // schedule.
301 auto best_it = ready_queue.end();
302 --best_it;
303 HloInstruction* best = best_it->second.instruction;
304 VLOG(2) << "Schedule instruction: " << best->ToShortString()
305 << " Bytes freed: " << best_it->first.first;
306 ready_queue.erase(best_it);
307 ready_instructions.erase(best);
308 schedule.push_back(best);
309 scheduled_instructions_.insert(best);
310
311 bool adjust_ready_queue = false;
312 // Update the unscheduled uses of the logical buffers.
313 for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
314 int64& count = unscheduled_use_count_[buffer];
315 CHECK_GT(count, 0);
316 --count;
317 if (count == 1) {
318 adjust_ready_queue = true;
319 }
320 }
321
322 // Add new instructions to ready list.
323 auto update_pred_count = [&](HloInstruction* inst) {
324 int64 pred_count = --unscheduled_pred_count.at(inst);
325 CHECK_GE(pred_count, 0);
326 if (pred_count == 0) {
327 add_to_ready_queue(inst);
328 }
329 };
330 // TODO(b/34466113): Replace this and above with successors() or
331 // predecessors() when these methods are added to HloInstruction.
332 for (HloInstruction* user : best->users()) {
333 update_pred_count(user);
334 }
335 for (HloInstruction* succ : best->control_successors()) {
336 update_pred_count(succ);
337 }
338 // The unscheduled use count for a buffer has changed to 1, so the
339 // priorities of some ready instructions may go up. We update them in the
340 // ready queue, so that they can appear earlier.
341 if (adjust_ready_queue) {
342 for (HloInstruction* operand : best->operands()) {
343 for (HloInstruction* operand_user : operand->users()) {
344 auto ready_instructions_it = ready_instructions.find(operand_user);
345 if (ready_instructions_it == ready_instructions.end()) {
346 continue;
347 }
348 auto ready_queue_it = ready_instructions_it->second;
349 auto& entry = ready_queue_it->second;
350 Priority new_priority = GetPriority(entry);
351 if (new_priority == ready_queue_it->first) {
352 continue;
353 }
354 // Create a new entry in ready_queue, then update
355 // ready_instructions[operand_user] to refer to the new entry.
356 ready_instructions_it->second =
357 ready_queue.emplace(new_priority, std::move(entry));
358 // Remove the old entry in ready_queue.
359 ready_queue.erase(ready_queue_it);
360 }
361 }
362 }
363 }
364 CHECK_EQ(schedule.size(), computation_->instruction_count());
365 CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
366
367 return schedule;
368 }
369
370 HloComputation* computation_;
371 const TuplePointsToAnalysis& points_to_analysis_;
372 const LogicalBuffer::SizeFunction& size_function_;
373 // Computations are analyzed in post-order. When scheduling an instruction
374 // that includes subcomputations, such as a while loop, we use this map to
375 // look up the memory needed by subcomputations.
376 const absl::flat_hash_map<const HloComputation*, int64>&
377 memory_by_computation_;
378
379 // A map containing the LogicalBuffers that each instruction uses.
380 absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
381 buffer_uses_;
382
383 // A map containing the count of unscheduled HLOs which using a particular
384 // LogicalBuffer.
385 absl::flat_hash_map<const LogicalBuffer*, int64> unscheduled_use_count_;
386
387 // Set of instructions which have been scheduled.
388 absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
389 };
390
SumLogicalBufferSizes(const TuplePointsToAnalysis::BufferDefinitionVector & buffers,const LogicalBuffer::SizeFunction & size_function)391 int64 SumLogicalBufferSizes(
392 const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
393 const LogicalBuffer::SizeFunction& size_function) {
394 int64 size = 0;
395 for (const LogicalBuffer* buffer : buffers) {
396 size += size_function(*buffer);
397 }
398 return size;
399 }
400
ScheduleComputationHelper(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)401 StatusOr<HloInstructionSequence> ScheduleComputationHelper(
402 HloComputation* computation,
403 const TuplePointsToAnalysis& points_to_analysis,
404 const LogicalBuffer::SizeFunction& size_function,
405 const MemorySchedulerAlgorithm& algorithm,
406 const absl::flat_hash_map<const HloComputation*, int64>&
407 memory_by_computation) {
408 VLOG(2) << "Computation: " << computation->name();
409 if (algorithm) {
410 return algorithm(computation, points_to_analysis, size_function,
411 memory_by_computation);
412 }
413 return DefaultMemoryScheduler(computation, points_to_analysis, size_function,
414 memory_by_computation);
415 }
416
417 } // namespace
418
DFSMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)419 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
420 HloComputation* computation,
421 const TuplePointsToAnalysis& points_to_analysis,
422 const LogicalBuffer::SizeFunction& size_function,
423 const absl::flat_hash_map<const HloComputation*, int64>&
424 memory_by_computation) {
425 // These variables are a hack to prevent overflows.
426 int64 cumulative_total_size = 0;
427 int64 total_hlos = computation->parent()->instruction_count();
428 absl::flat_hash_map<const HloInstruction*, int64> extra_users;
429 absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
430 for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
431 if (ListScheduler::IgnoreInstruction(*hlo)) {
432 extra_users[hlo] = 0;
433 total_sizes[hlo] = 0;
434 continue;
435 }
436 // This ordering is based on DFS post-order, with a heuristic to decide
437 // which operand to visit first. The heuristic is based on 'extra_users',
438 // which is simply users-1 for each instruction. By subtracting 1, we're
439 // saying that instructions with no users or a single user don't count;
440 // instructions with lots of fan-out will be visited earlier.
441 extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
442 int64 logical_buffer_size = SumLogicalBufferSizes(
443 points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
444 total_sizes[hlo] = logical_buffer_size;
445 cumulative_total_size += logical_buffer_size;
446 absl::flat_hash_set<const HloInstruction*> unique_operands(
447 hlo->operands().begin(), hlo->operands().end());
448 for (const HloInstruction* operand : unique_operands) {
449 extra_users[hlo] += extra_users[operand];
450 total_sizes[hlo] += total_sizes[operand];
451 }
452 // total_sizes[hlo] transitively includes the sizes of all nodes that
453 // lead to it. But computation is a DAG, so we are double-counting nodes,
454 // which can lead to overflows for large programs.
455 // cumulative_total_size caps the size to prevent overflows.
456 // Same for total_hlos: it prevents overflows on very large and branchy
457 // models, where the number of paths is exponential to the number of nodes.
458 // NOTE(dimvar): this is quite ugly and should be changed. It's unclear
459 // why we care about transitive sizes; when scheduling a node, its input
460 // and output buffers should be all that matters, not its "history".
461 total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
462 extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
463 }
464 CHECK_EQ(extra_users.size(), computation->instruction_count());
465 CHECK_EQ(total_sizes.size(), computation->instruction_count());
466
467 // Construct a total order based on DFS post-order, visiting operands in
468 // decreasing cumulative extra user order, and next by cumulative size, with a
469 // tiebreaker by name for determinism.
470 HloInstructionSequence sequence;
471 FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
472 sequence.push_back(hlo);
473 return Status::OK();
474 });
475 TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
476 &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
477 const HloInstruction* b) {
478 if (extra_users[a] != extra_users[b]) {
479 return extra_users[a] > extra_users[b];
480 }
481 if (total_sizes[a] != total_sizes[b]) {
482 return total_sizes[a] > total_sizes[b];
483 }
484 return a->name() < b->name();
485 }));
486 CHECK_EQ(sequence.size(), computation->instruction_count());
487 return sequence;
488 } // namespace xla
489
ListMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)490 StatusOr<HloInstructionSequence> ListMemoryScheduler(
491 HloComputation* computation,
492 const TuplePointsToAnalysis& points_to_analysis,
493 const LogicalBuffer::SizeFunction& size_function,
494 const absl::flat_hash_map<const HloComputation*, int64>&
495 memory_by_computation) {
496 return ListScheduler::Run(computation, points_to_analysis, size_function,
497 memory_by_computation);
498 }
499
PostOrderMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)500 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
501 HloComputation* computation,
502 const TuplePointsToAnalysis& points_to_analysis,
503 const LogicalBuffer::SizeFunction& size_function,
504 const absl::flat_hash_map<const HloComputation*, int64>&
505 memory_by_computation) {
506 return HloInstructionSequence(computation->MakeInstructionPostOrder());
507 }
508
DefaultMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)509 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
510 HloComputation* computation,
511 const TuplePointsToAnalysis& points_to_analysis,
512 const LogicalBuffer::SizeFunction& size_function,
513 const absl::flat_hash_map<const HloComputation*, int64>&
514 memory_by_computation) {
515 // We try a few schedulers and choose whichever returns a lower min-memory,
516 // not accounting for fragmentation.
517 // - List is a scheduler that uses greedy heuristics.
518 // - DFS visits HLOs in postorder, with a heuristic to decide the order of
519 // children.
520 // - Postorder does not use any heuristics.
521 // List wins for most of our benchmarks; postorder-based schedulers win for
522 // some RNNs.
523 TF_ASSIGN_OR_RETURN(
524 HloInstructionSequence list_sequence,
525 ListMemoryScheduler(computation, points_to_analysis, size_function,
526 memory_by_computation));
527 TF_ASSIGN_OR_RETURN(const int64 list_memory,
528 HeapSimulator::MinimumMemoryForComputation(
529 *computation, list_sequence, points_to_analysis,
530 size_function, &memory_by_computation));
531 VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
532
533 TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence,
534 DFSMemoryScheduler(computation, points_to_analysis,
535 size_function, memory_by_computation));
536 TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
537 HeapSimulator::MinimumMemoryForComputation(
538 *computation, dfs_sequence, points_to_analysis,
539 size_function, &memory_by_computation));
540 VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
541
542 TF_ASSIGN_OR_RETURN(
543 HloInstructionSequence post_order_sequence,
544 PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
545 memory_by_computation));
546 TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
547 HeapSimulator::MinimumMemoryForComputation(
548 *computation, post_order_sequence, points_to_analysis,
549 size_function, &memory_by_computation));
550 VLOG(2) << "Min-memory post order sequence: "
551 << HumanReadableNumBytes(post_order_memory);
552
553 auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
554
555 if (min_memory == list_memory) {
556 VLOG(2) << "Chose min-memory list sequence: "
557 << HumanReadableNumBytes(list_memory);
558 return list_sequence;
559 } else if (min_memory == dfs_memory) {
560 VLOG(2) << "Chose min-memory dfs sequence: "
561 << HumanReadableNumBytes(dfs_memory);
562 return dfs_sequence;
563 } else {
564 VLOG(2) << "Chose min-memory post_order sequence: "
565 << HumanReadableNumBytes(post_order_memory);
566 return post_order_sequence;
567 }
568 }
569
ScheduleModule(HloModule * module,const LogicalBuffer::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm)570 StatusOr<HloSchedule> ScheduleModule(
571 HloModule* module, const LogicalBuffer::SizeFunction& size_function,
572 const MemorySchedulerAlgorithm& algorithm) {
573 HloSchedule schedule(module);
574 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
575 TuplePointsToAnalysis::Run(module));
576 absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
577 for (auto* computation : module->MakeComputationPostOrder()) {
578 if (!computation->IsFusionComputation()) {
579 TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
580 ScheduleComputationHelper(
581 computation, *points_to_analysis, size_function,
582 algorithm, memory_by_computation));
583 memory_by_computation[computation] =
584 HeapSimulator::MinimumMemoryForComputation(
585 *computation, computation_sequence, *points_to_analysis,
586 size_function, &memory_by_computation)
587 .ValueOrDie();
588 schedule.set_sequence(computation, std::move(computation_sequence));
589 }
590 }
591 VLOG(1) << "Module schedule:\n" << schedule;
592
593 TF_RETURN_IF_ERROR(schedule.Verify());
594
595 return std::move(schedule);
596 }
597
ScheduleComputation(HloComputation * computation,const LogicalBuffer::SizeFunction & size_function)598 StatusOr<HloInstructionSequence> ScheduleComputation(
599 HloComputation* computation,
600 const LogicalBuffer::SizeFunction& size_function) {
601 CHECK(!computation->IsFusionComputation());
602 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
603 TuplePointsToAnalysis::Run(computation->parent()));
604 absl::flat_hash_map<const HloComputation*, int64> empty_map;
605 return ScheduleComputationHelper(computation, *points_to_analysis,
606 size_function, nullptr, empty_map);
607 }
608
HloMemoryScheduler(const LogicalBuffer::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm)609 HloMemoryScheduler::HloMemoryScheduler(
610 const LogicalBuffer::SizeFunction& size_function,
611 const MemorySchedulerAlgorithm& algorithm)
612 : size_function_(size_function), algorithm_(algorithm) {}
613
Run(HloModule * module)614 StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
615 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
616 ScheduleModule(module, size_function_, algorithm_));
617 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
618 return true;
619 }
620
Run(HloModule * module)621 StatusOr<bool> HloTrivialScheduler::Run(HloModule* module) {
622 HloSchedule schedule(module);
623 for (HloComputation* computation : module->MakeComputationPostOrder()) {
624 if (!computation->IsFusionComputation()) {
625 HloInstructionSequence& computation_sequence =
626 schedule.GetOrCreateSequence(computation);
627 TF_RETURN_IF_ERROR(computation->Accept(
628 [&computation_sequence](HloInstruction* instruction) {
629 computation_sequence.push_back(instruction);
630 return Status::OK();
631 }));
632 }
633 }
634 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
635 return true;
636 }
637
Run(HloModule * module)638 StatusOr<bool> HloDescheduler::Run(HloModule* module) {
639 bool changed = module->has_schedule();
640 module->clear_schedule();
641 return changed;
642 }
643
644 } // namespace xla
645