1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <set>
22 #include <string>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/service/buffer_value.h"
34 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_dce.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
44 #include "tensorflow/compiler/xla/service/logical_buffer.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/platform/logging.h"
50
51 namespace xla {
52 namespace {
53
54 using ::tensorflow::strings::HumanReadableNumBytes;
55
56 // Potential optimizations:
57 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
58 // of candidates.
59 // . Cache IsRematerializable in Item? Only correct if control
60 // predecessors and successors don't change.
61
62 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)63 bool IsRematerializable(const HloInstruction* instruction) {
64 if (instruction->opcode() == HloOpcode::kCopy) {
65 if (LayoutUtil::Equal(instruction->shape().layout(),
66 instruction->operand(0)->shape().layout())) {
67 // Don't rematerialize copies added by copy insertion (layout doesn't
68 // change).
69 return false;
70 }
71 }
72
73 // Don't rematerialize instructions with side effects or instructions which
74 // cannot be cloned safely.
75 switch (instruction->opcode()) {
76 case HloOpcode::kCall:
77 case HloOpcode::kConstant:
78 case HloOpcode::kConditional:
79 case HloOpcode::kAllReduce:
80 case HloOpcode::kCustomCall:
81 case HloOpcode::kParameter:
82 case HloOpcode::kWhile:
83 return false;
84 default:
85 return !instruction->HasSideEffect();
86 }
87 }
88
89 // Checks whether an instruction can be rematerialized, by looking up the
90 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)91 bool CanBeRematerialized(
92 const HloInstruction* instruction,
93 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
94 auto it = rematerializable_map->find(instruction);
95 if (it != rematerializable_map->end()) {
96 return it->second;
97 }
98 bool rematerializable = IsRematerializable(instruction);
99 (*rematerializable_map)[instruction] = rematerializable;
100 return rematerializable;
101 }
102
103 // Return if this is an instruction that relays the buffers it uses to its own
104 // users and if this is one of these instructions we support the
105 // rematerialization of.
IsSupportedIndirectUser(const HloInstruction * instruction)106 bool IsSupportedIndirectUser(const HloInstruction* instruction) {
107 return instruction->opcode() == HloOpcode::kBitcast ||
108 instruction->opcode() == HloOpcode::kGetTupleElement;
109 }
110
111 // Type holding a unique identifier for each Buffer object.
112 using BufferId = int64;
113 using BufferIdList = absl::InlinedVector<BufferId, 3>;
114
115 struct RematStrategy {
116 enum {
117 // Recompute the node at a later program point.
118 kRecompute,
119 // Change the layout into a compact form and uncompress it back at a later
120 // program point.
121 kCompress,
122 } kind;
123 Shape compact_shape;
124 };
125
126 // We wrap HloInstruction* with an Item that holds auxiliary
127 // per-instruction state.
128 struct Item {
129 HloInstruction* instruction;
130
131 // True once the instruction is marked as placed (when BeginInstruction
132 // has been called for this instruction).
133 bool placed = false;
134
135 // To avoid an infinite loop rematerializing the same set of
136 // instructions ad infinitum, keep a denylist of instructions
137 // which should not be rematerialized.
138 bool denylisted = false;
139
140 // The buffers defined by this instruction.
141 BufferIdList buffers_defined;
142
143 // Output buffers of this instruction. This is used to track outputs by GTE
144 // instructions (where the instruction doesn't define a buffer).
145 BufferIdList buffers_output;
146
147 // The buffers used by this instruction.
148 BufferIdList buffers_used;
149
150 bool is_skip_node = false;
151
152 private:
153 friend class InstructionList;
154
155 // Items are arranged in a doubly linked list.
156 Item* next = nullptr;
157 Item* prev = nullptr;
158
159 Item* prev_skip_node = nullptr;
160 Item* next_skip_node = nullptr;
161
162 // List is ordered by position, which can however be duplicated as
163 // new instructions are inserted. See InsertBeforeInstructions
164 // comment for details.
165 int64 position;
166 };
167
168 // Data structure meant to record the user of the buffer defined from an Item.
169 // It records also the operand_number from where such use derives, so that
170 // indirect uses can be better identified (like for example a buffer used
171 // through a bitcast).
172 struct ItemUse {
173 Item* user;
174 int64 operand_number;
175 absl::optional<int64> index;
176
ItemUsexla::__anon95ba67740111::ItemUse177 ItemUse(Item* user, int64 op_num, absl::optional<int64> index)
178 : user(user), operand_number(op_num), index(index) {}
operator ==xla::__anon95ba67740111::ItemUse179 bool operator==(const ItemUse& other) const {
180 return user == other.user && operand_number == other.operand_number &&
181 index == other.index;
182 }
183 };
184
185 using ItemList = absl::InlinedVector<Item*, 3>;
186 using UsesList = absl::InlinedVector<ItemUse, 3>;
187
188 // Class which maintains an ordered list of instructions with fast insertion
189 // before arbitrary elements.
190 //
191 // This is a skip list structure that has two lanes: express lane and slow lane.
192 // All nodes are presented on the slow lane but a node can be promoted into
193 // express lane for fast iteration.
194 //
195 // In the following case, node 2 and node + 1 are connected via an express lane.
196 // +--------------------------+----------->: Express lane
197 // | |
198 // node1<-> node 2 <-> .. <-> node n <-> node n+1 <->...: Slow lane
199 //
200 class InstructionList {
201 public:
InstructionList(const HloInstructionSequence & order)202 explicit InstructionList(const HloInstructionSequence& order) {
203 int64 position = 0;
204 Item* last = nullptr;
205 last_skip_node_ = nullptr;
206 first_skip_node_ = nullptr;
207 for (HloInstruction* inst : order.instructions()) {
208 // Add a new item to the linked list.
209 Item* item = new Item;
210 item->next = nullptr;
211 item->prev = last;
212 if (last == nullptr) {
213 first_ = item;
214 } else {
215 last->next = item;
216 }
217 last = item;
218
219 // Initially position numbers are uniquely assigned in order. Later as
220 // instructions are added with InsertBefore* methods, some instructions
221 // may have duplicate position numbers, but the values will be guaranteed
222 // to be monotonically increasing through the list, and so is still useful
223 // for quickly(-ish) determining the order of arbitrary instructions in
224 // the list.
225 item->instruction = inst;
226 item->position = position;
227 position++;
228
229 item_map_[inst] = item;
230 }
231 }
232
~InstructionList()233 ~InstructionList() {
234 for (Item* item = first_; item != nullptr;) {
235 Item* next = item->next;
236 delete item;
237 item = next;
238 }
239 }
240
size() const241 size_t size() const { return item_map_.size(); }
242
243 // For ordered iteration over items.
244 // for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const245 Item* first() const { return first_; }
next(Item * item) const246 Item* next(Item* item) const { return item->next; }
247
first_skip_node() const248 Item* first_skip_node() const { return first_skip_node_; }
next_skip_node(Item * item) const249 Item* next_skip_node(Item* item) const { return item->next_skip_node; }
250
251 // Creates an Item for the given instruction, but doesn't add it to the list.
252 // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)253 Item* CreateItem(HloInstruction* inst) {
254 Item* item = new Item;
255 item->instruction = inst;
256 CHECK(item_map_.insert({inst, item}).second)
257 << "inserting inst twice " << inst->name();
258 return item;
259 }
260
261 // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const262 Item* GetItem(const HloInstruction* inst) const {
263 auto iter = item_map_.find(inst);
264 CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
265 return iter->second;
266 }
267
268 // Insert instruction 'to_insert' immediately before the earliest instruction
269 // in 'before_instructions'.
270 //
271 // Each instruction gets a non-decreasing ordinal number. We use this to let
272 // InsertBeforeInstructions quickly insert an instruction before the earliest
273 // instruction in a set of instructions. If position_number_[a] <
274 // position_number_[b] then 'a' comes before 'b' in the list. If the position
275 // numbers are the same then nothing can be said about their order without
276 // examining the list.
277 //
278 // On object construction this ordinal is precisely the instruction's index
279 // in the list. Later, instructions inserted via InsertBefore receive
280 // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)281 void InsertBeforeInstructions(Item* to_insert,
282 absl::Span<Item* const> before_instructions) {
283 VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
284 << " before {"
285 << absl::StrJoin(before_instructions, ", ",
286 [](string* out, Item* item) {
287 absl::StrAppend(out, item->instruction->name());
288 })
289 << "}";
290
291 // Find the minimal position number of any instruction in
292 // 'before_instructions'.
293 CHECK(!before_instructions.empty());
294 Item* min_position_item = nullptr;
295 for (Item* item : before_instructions) {
296 if (min_position_item == nullptr ||
297 item->position < min_position_item->position) {
298 min_position_item = item;
299 }
300 }
301
302 // Because more than one instruction in 'before_instructions' may have a
303 // position number of 'min_position_number', find the first such instruction
304 // with position number 'min_position_number'.
305
306 // First find first instruction with the min position.
307 while (min_position_item->prev != nullptr &&
308 min_position_item->position == min_position_item->prev->position) {
309 min_position_item = min_position_item->prev;
310 }
311
312 // Now scan forwards until we find one of the before_instructions.
313 while (!absl::c_linear_search(before_instructions, min_position_item)) {
314 min_position_item = min_position_item->next;
315 }
316 return InsertBefore(to_insert, min_position_item);
317 }
318
319 // Scan the list and promote nodes to express lane if should_promote(Item)
320 // returns true;
PromoteNodesToSkip(std::function<bool (Item *)> should_promote)321 void PromoteNodesToSkip(std::function<bool(Item*)> should_promote) {
322 int64 count = 0;
323 for (auto* item = first(); item != nullptr; item = next(item)) {
324 if (should_promote(item)) {
325 count += 1;
326 if (first_skip_node_ == nullptr) {
327 first_skip_node_ = item;
328 }
329 item->is_skip_node = true;
330 item->prev_skip_node = last_skip_node_;
331 if (last_skip_node_ != nullptr) {
332 last_skip_node_->next_skip_node = item;
333 }
334 last_skip_node_ = item;
335 }
336 }
337 VLOG(1) << " Rematerialization has " << count << " items in express lane";
338 }
339
InsertAfterInstructions(Item * to_insert,absl::Span<Item * const> after_instructions)340 void InsertAfterInstructions(Item* to_insert,
341 absl::Span<Item* const> after_instructions) {
342 VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
343 << " after {"
344 << absl::StrJoin(after_instructions, ", ",
345 [](string* out, Item* item) {
346 absl::StrAppend(out, item->instruction->name());
347 })
348 << "}";
349
350 // Find the max position number of any instruction in
351 // 'after_instructions'.
352 CHECK(!after_instructions.empty());
353 Item* max_position_item = nullptr;
354 for (Item* item : after_instructions) {
355 if (max_position_item == nullptr ||
356 item->position > max_position_item->position) {
357 max_position_item = item;
358 }
359 }
360 // No rematerializable instruction should be inserted at the end of the
361 // computation.
362 CHECK(max_position_item->next != nullptr);
363 InsertBeforeInstructions(to_insert, {max_position_item->next});
364 }
365
Denylist(const HloInstruction * inst)366 void Denylist(const HloInstruction* inst) {
367 GetItem(inst)->denylisted = true;
368 }
369
370 private:
371 // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)372 void InsertBefore(Item* item, Item* before) {
373 VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
374 << before->instruction->name();
375 // Always place new nodes on express lane for the ease of implementation.
376 item->is_skip_node = true;
377 // Find the next express node starting from 'before'. Set up the node's
378 // express pointers.
379 Item* cursor = before;
380 while (cursor != nullptr && !cursor->is_skip_node) {
381 cursor = cursor->next;
382 }
383 CHECK(cursor == nullptr || cursor->is_skip_node);
384 if (cursor == nullptr) {
385 //
386 // last_skip_node_<---+ : express lane
387 // |
388 // ...<->`item`<-> .. <-> `cursor`(null) : slow lane
389 //
390 // Reached the end. Set the prev_express to last_skip_node, and reset
391 // last_skip.
392 item->prev_skip_node = last_skip_node_;
393 item->next_skip_node = nullptr;
394 last_skip_node_ = item;
395 } else {
396 //
397 // <-+------------+----------------+---------> : express lane
398 // | | |
399 // prev_express..<->`item`<-> .. <-> `cursor` <-> ...: slow lane
400 //
401 // Reached the next skip node, sets up express pointers accordingly.
402 CHECK(cursor->is_skip_node);
403 item->prev_skip_node = cursor->prev_skip_node;
404 if (item->prev_skip_node != nullptr) {
405 item->prev_skip_node->next_skip_node = item;
406 }
407 item->next_skip_node = cursor;
408 cursor->prev_skip_node = item;
409 }
410 if (first_skip_node_ == cursor) {
411 first_skip_node_ = item;
412 }
413 // Insert new item into linked list.
414 item->prev = before->prev;
415 item->next = before;
416 before->prev = item;
417 if (item->prev != nullptr) {
418 item->prev->next = item;
419 } else {
420 first_ = item;
421 }
422
423 // Assign the same position number to the newly added instruction as
424 // 'before'. This guarantees monotonicity of the position numbers, but not
425 // uniqueness.
426 item->position = before->position;
427 }
428
429 Item* first_;
430
431 // First skip node of this list.
432 Item* first_skip_node_;
433
434 // Last skip node of this list.
435 Item* last_skip_node_;
436
437 // Item for each instruction.
438 absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
439 };
440
441 // Return the items which use the given LogicalBuffer. Sets
442 // has_indirect_users to whether any of the uses is indirect. A use is indirect
443 // if the instruction defining logical_buffer is not an operand of the use. This
444 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)445 UsesList GetUsers(const InstructionList& instruction_list,
446 const LogicalBuffer* logical_buffer,
447 const TuplePointsToAnalysis& points_to_analysis,
448 bool* has_indirect_users) {
449 UsesList users;
450 // To identify uses iterate through all HloInstruction users of the
451 // BufferAliases of the logical buffer.
452 *has_indirect_users = false;
453 for (const BufferAlias& buffer_alias :
454 points_to_analysis.GetBufferAliases(*logical_buffer)) {
455 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
456 if (points_to_analysis.DoesNotUseOperandBuffer(
457 buffer_alias.instruction(), buffer_alias.index(), user)) {
458 // The alias may be an operand of 'user', but the LogicalBuffer cannot
459 // possibly be used by the instruction so ignore 'user'. This is the
460 // case, for example, for the tuple element buffers in a GetTupleElement
461 // instruction (the GTE instruction only uses the pointer vector).
462 continue;
463 }
464 if (buffer_alias.instruction() != logical_buffer->instruction() &&
465 !IsSupportedIndirectUser(buffer_alias.instruction())) {
466 *has_indirect_users = true;
467 }
468 // A buffer may be used by the instruction via more than one alias. For
469 // example, a buffer which appears in more than one element of a tuple.
470 Item* user_item = instruction_list.GetItem(user);
471 absl::optional<int64> user_index =
472 logical_buffer->index().size() != 1
473 ? absl::nullopt
474 : absl::make_optional(logical_buffer->index().back());
475 for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
476 if (!absl::c_linear_search(
477 users,
478 ItemUse{user_item, static_cast<int>(op_idx), user_index})) {
479 users.push_back(
480 ItemUse{user_item, static_cast<int>(op_idx), user_index});
481 }
482 }
483 }
484 }
485 return users;
486 }
487
488 // Class for tracking memory usage of a computation as the instructions are
489 // placed sequentially. Memory usage is the sum of the sizes of live values
490 // (LogicalBuffers) at the current point in the instruction sequence.
491 class MemoryUsageTracker {
492 public:
493 MemoryUsageTracker(
494 const HloComputation* computation,
495 const HloRematerialization::ShapeSizeFunction& size_function,
496 const HloRematerialization::CompactShapeFunction& compact_shape_function,
497 const TuplePointsToAnalysis& points_to_analysis,
498 const InstructionList& instruction_list,
499 HloRematerialization::RematerializationMode mode);
500
501 // Starts the placement of the given instruction. This adds the sizes of the
502 // LogicalBuffers defined by the instruction to the current memory
503 // usage. Placement is broken into two steps (BeginInstruction and
504 // EndInstruction) to accurately model memory usage. At BeginInstruction the
505 // memory for the output value(s) of the current instruction is allocated. At
506 // EndInstruction memory for dead operand(s) is freed.
507 Status BeginInstruction(Item* item);
508
RematerializationCost(const std::vector<Item * > & items,int64 memory_reduced,int64 memory_limit_bytes)509 int64 RematerializationCost(const std::vector<Item*>& items,
510 int64 memory_reduced, int64 memory_limit_bytes) {
511 // If none of the users of any 'item' have been placed in the
512 // sequence (as tracked by memory_tracker), then rematerialization of
513 // 'item' is a zero-cost move of 'item->instruction' in the sequence.
514 bool zero_cost_move = true;
515 for (auto* item : items) {
516 auto* instruction = item->instruction;
517 if (absl::c_any_of(
518 instruction->users(),
519 [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
520 zero_cost_move = false;
521 break;
522 }
523 }
524 if (zero_cost_move) {
525 return 0;
526 }
527
528 CHECK_GT(memory_reduced, 0);
529 // Return the inverse of the benefit of rematerialization.
530 return memory_limit_bytes / memory_reduced;
531 }
532
533 // Finishes the placement of the current instruction. This frees any dead
534 // operands or dead result of the instruction. This must be called after
535 // each call to BeginInstruction.
536 Status EndInstruction();
537
538 // Returns the number of bytes that the current memory usage will be reduced
539 // if the given instruction is compact.
540 int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
541
542 // Returns the number of bytes that the current memory usage will be reduced
543 // by if the given sequence of instructions is rematerialized.
544 int64 MemoryReducedIfRematerialized(
545 absl::Span<const Item* const> items) const;
546
547 Status AddCompressInstructions(Item* original_item, Item* compressed_item,
548 Item* uncompressed_item);
549
550 // Adjusts memory usage to account for the rematerialization of
551 // original_item for all remaining unplaced uses. The rematerialization
552 // is remat_item. This method should be called after the HLO graph has
553 // been transformed (rematerialization instruction created and connected
554 // to uses).
555 Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
556 absl::Span<Item*> indirect_users);
557
558 // Selects and returns the best candidate instructions for rematerialization.
559 // A sequence of candidate instructions of length between min_block_size and
560 // max_block_size (both inclusive) with the lowest rematerialization cost is
561 // selected among those candidates which reduce memory use at the program
562 // point of the current instruction as indicated by memory_tracker. Returns an
563 // empty vector if no candidates are found.
564 std::pair<std::vector<Item*>, RematStrategy> PickRematerializationCandidates(
565 const InstructionList& instruction_list, int64 memory_limit_bytes,
566 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
567 int min_block_size, int max_block_size);
568
569 // Returns whether the given instruction has been placed (BeginInstruction
570 // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const571 bool IsPlaced(const HloInstruction* instruction) const {
572 return instruction_list_.GetItem(instruction)->placed;
573 }
574
575 // Returns whether 'item' has any unplaced users.
576 bool HasUnplacedUsers(Item* item) const;
577
578 // Returns the list of uses for a specific 'item'.
579 const UsesList GetItemUses(Item* item) const;
580
581 // Returns whether 'item' is currently in progress.
IsInProgressItem(Item * item) const582 bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
583
584 // Returns the current memory usage. This is the sum of sizes of all live
585 // values.
memory_usage() const586 int64 memory_usage() const { return memory_usage_; }
587
588 //
AllocatedSize(Item * item) const589 int64 AllocatedSize(Item* item) const {
590 int64 size = 0;
591 for (auto buffer_id : item->buffers_defined) {
592 size += AllocatedSize(buffer_id);
593 }
594 return size;
595 }
596
597 // Check invariants of the data structure. This is expensive to call.
598 bool Check() const;
599
600 string ToString() const;
601
602 private:
603 // A Buffer represents a single LogicalBuffer in the computation including
604 // various metadata useful for tracking liveness of the value. A LogicalBuffer
605 // is not used directly because the HLO graph is transformed and
606 // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
607 // HLO graph transformations.
608 struct Buffer {
609 // The unique id of this Buffer. This value is equal to the buffer's index
610 // in the vector buffers_.
611 const BufferId id;
612
613 // The instruction which defines this buffer.
614 Item* defining_instruction;
615
616 // The materialized size of the buffer in bytes.
617 const int64 size;
618
619 // Shape of the buffer.
620 Shape shape;
621
622 // Whether this buffer is live-out of the computation.
623 bool live_out;
624
625 // Whether this buffer has indirect uses. Ie, an instruction which is not a
626 // user of defining_instruction uses this buffer. This can occur due to
627 // buffer aliasing (eg, tuples).
628 bool has_indirect_uses;
629
630 // Position in the tuple this buffer definition lives in.
631 ShapeIndex index;
632
633 // The instructions which use this buffer.
634 UsesList users;
635
636 // The number of users (HloInstructions) of this buffer which have not yet
637 // been placed in the sequence.
638 int64 unfinished_user_count;
639
ToStringxla::__anon95ba67740111::MemoryUsageTracker::Buffer640 string ToString() const {
641 return absl::StrCat("Buffer ", id, " (defined by ",
642 defining_instruction->instruction->name(), ", size ",
643 size, " bytes)");
644 }
645 };
646
647 // Get the compact shape of given hlo instruction. An internal cache is used
648 // to avoid computing the shape multiple times.
649 StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
650
651 // Creates a Buffer representing the given logical buffer. The buffer is added
652 // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool live_out)653 Buffer& CreateBufferFromLogicalBuffer(
654 const LogicalBuffer* logical_buffer,
655 const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
656 bool has_indirect_uses = false;
657 UsesList users = GetUsers(instruction_list_, logical_buffer,
658 points_to_analysis, &has_indirect_uses);
659 return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
660 logical_buffer->shape(), logical_buffer->index(),
661 std::move(users), live_out, has_indirect_uses);
662 }
663
664 // Create a new buffer representing a rematerialization of given buffer for
665 // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,UsesList && rematerialized_uses)666 Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
667 UsesList&& rematerialized_uses) {
668 CHECK(original_buffer.defining_instruction->placed)
669 << original_buffer.defining_instruction->instruction->name();
670 CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
671 CHECK(!original_buffer.live_out) << original_buffer.ToString();
672 for (ItemUse& use : rematerialized_uses) {
673 CHECK(!use.user->placed) << use.user->instruction->name();
674 }
675 return NewBuffer(remat_item, original_buffer.shape, original_buffer.index,
676 std::move(rematerialized_uses), /*live_out=*/false,
677 /*has_indirect_uses=*/false);
678 }
679
680 // Return number of bytes allocated for the buffer with the given id. Buffers
681 // allocated by the calling computation (eg, parameter and output buffers) are
682 // considered to have zero bytes because the memory is accounted for in a
683 // different computation.
AllocatedSize(BufferId buffer_id) const684 int64 AllocatedSize(BufferId buffer_id) const {
685 const Buffer& buffer = buffers_.at(buffer_id);
686 HloInstruction* inst = buffer.defining_instruction->instruction;
687 HloOpcode def_opcode = inst->opcode();
688 if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
689 return 0;
690 } else {
691 return buffer.size;
692 }
693 }
694
695 // Returns true if BeginInstruction and EndInstruction has been called for the
696 // given instruction.
IsFinished(Item * item) const697 bool IsFinished(Item* item) const {
698 return item->placed && item != in_progress_item_;
699 }
700
701 // Returns whether the given buffer is being used by the in-progress
702 // instruction.
IsInUse(BufferId buffer_id) const703 bool IsInUse(BufferId buffer_id) const {
704 if (in_progress_item_ == nullptr) {
705 return false;
706 }
707 const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
708 return absl::c_linear_search(in_progress_uses, buffer_id);
709 }
710
IsCurrentlyLive(BufferId buffer_id) const711 bool IsCurrentlyLive(BufferId buffer_id) const {
712 const Buffer& buffer = buffers_[buffer_id];
713 return (buffer.defining_instruction->placed &&
714 buffer.unfinished_user_count > 0);
715 }
716
717 // Returns whether the given instruction is live at the current program
718 // point.
IsInstructionCurrentlyLive(Item * instruction) const719 bool IsInstructionCurrentlyLive(Item* instruction) const {
720 // If the instruction has not started yet, it is not alive.
721 if (!IsPlaced(instruction->instruction)) {
722 return false;
723 }
724 for (const HloInstruction* user : instruction->instruction->users()) {
725 if (!IsPlaced(user)) {
726 // If there is an unplaced user, consider this instruction currently
727 // live.
728 return true;
729 }
730 }
731 return false;
732 }
733
734 // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,const Shape & shape,const ShapeIndex & index,UsesList && uses,bool live_out,bool has_indirect_uses)735 Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
736 const ShapeIndex& index, UsesList&& uses, bool live_out,
737 bool has_indirect_uses) {
738 int buffer_id = buffers_.size();
739 auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
740 absl::flat_hash_set<Item*> users_set;
741 for (const ItemUse& use : uses) {
742 users_set.insert(use.user);
743 }
744 return users_set.size();
745 };
746 buffers_.push_back(Buffer{
747 buffer_id, defining_instruction, size_function_(shape), shape, live_out,
748 has_indirect_uses, index, uses, get_num_of_unique_users(uses)});
749 return buffers_.back();
750 }
751
752 const HloComputation* computation_;
753
754 // Instruction list containing the ordering of instructions in
755 // computation_. This is the order in which instructions are placed
756 // (BeginInstruction/EndInstruction calls).
757 const InstructionList& instruction_list_;
758
759 // Size function returns the bytes of a given buffer.
760 const HloRematerialization::ShapeSizeFunction& size_function_;
761
762 // Converts a shape into compact form, returns the same shape if a shape is
763 // already considered compact.
764 const HloRematerialization::CompactShapeFunction& compact_shape_function_;
765
766 // A map that caches existing known compact shape for each instruction.
767 absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
768
769 // Memory usage at the currently placed instruction.
770 int64 memory_usage_ = 0;
771
772 // The instruction currently being placed. This value is non-null only
773 // between the calling of BeginInstruction and EndInstruction.
774 Item* in_progress_item_ = nullptr;
775
776 HloRematerialization::RematerializationMode mode_;
777 // All buffers in the computation.
778 std::vector<Buffer> buffers_;
779 };
780
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const HloRematerialization::CompactShapeFunction & compact_shape_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list,HloRematerialization::RematerializationMode mode)781 MemoryUsageTracker::MemoryUsageTracker(
782 const HloComputation* computation,
783 const HloRematerialization::ShapeSizeFunction& size_function,
784 const HloRematerialization::CompactShapeFunction& compact_shape_function,
785 const TuplePointsToAnalysis& points_to_analysis,
786 const InstructionList& instruction_list,
787 HloRematerialization::RematerializationMode mode)
788 : computation_(computation),
789 instruction_list_(instruction_list),
790 size_function_(size_function),
791 compact_shape_function_(compact_shape_function),
792 mode_(mode) {
793 PointsToSet::BufferSet live_out_set =
794 points_to_analysis.GetPointsToSet(computation_->root_instruction())
795 .CreateFlattenedSet();
796 absl::flat_hash_map<const LogicalBuffer*, BufferId>
797 logical_buffer_to_buffer_id;
798 for (auto* item = instruction_list_.first(); item != nullptr;
799 item = instruction_list_.next(item)) {
800 const HloInstruction* const instruction = item->instruction;
801 for (const LogicalBuffer* logical_buffer :
802 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
803 Buffer* buffer;
804 if (instruction->opcode() == HloOpcode::kWhile) {
805 // The while instruction defines no new buffers. Instead it reuses the
806 // buffers of its operand. Find the Buffer of its operand at the
807 // proper ShapeIndex.
808 const PointsToSet& operand_points_to =
809 points_to_analysis.GetPointsToSet(instruction->operand(0));
810 CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
811 const LogicalBuffer* source_logical_buffer =
812 operand_points_to.element(logical_buffer->index())[0];
813 buffer =
814 &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
815
816 // Mark buffer as has indirect use and live out.
817 buffer->has_indirect_uses = true;
818 buffer->live_out =
819 buffer->live_out || ContainsKey(live_out_set, logical_buffer);
820
821 // Add users of while to Buffer users.
822 bool unused;
823 for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
824 points_to_analysis, &unused)) {
825 auto existing_user_it = absl::c_find_if(
826 buffer->users,
827 [&](const ItemUse& use) { return user_item.user == use.user; });
828 if (existing_user_it == buffer->users.end()) {
829 buffer->unfinished_user_count++;
830 user_item.user->buffers_used.push_back(buffer->id);
831 buffer->users.push_back(user_item);
832 }
833 }
834 } else {
835 buffer = &CreateBufferFromLogicalBuffer(
836 logical_buffer, points_to_analysis,
837 ContainsKey(live_out_set, logical_buffer));
838 item->buffers_defined.push_back(buffer->id);
839 for (ItemUse& user : buffer->users) {
840 if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
841 user.user->buffers_used.push_back(buffer->id);
842 }
843 }
844 }
845
846 logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
847 }
848
849 // Trace the output of each instruction. This is so that we can properly
850 // track which outputs does GTEs have.
851 for (const LogicalBuffer* logical_buffer :
852 points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
853 item->buffers_output.push_back(
854 logical_buffer_to_buffer_id[logical_buffer]);
855 }
856 }
857 XLA_VLOG_LINES(10, ToString());
858 DCHECK(Check());
859 }
860
BeginInstruction(Item * item)861 Status MemoryUsageTracker::BeginInstruction(Item* item) {
862 const HloInstruction* instruction = item->instruction;
863 VLOG(3) << "BeginInstruction " << instruction->name();
864 TF_RET_CHECK(in_progress_item_ == nullptr);
865 in_progress_item_ = item;
866
867 item->placed = true;
868
869 // All buffers defined by this instruction need memory.
870 for (BufferId buffer_id : item->buffers_defined) {
871 VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
872 << " is now live.";
873 memory_usage_ += AllocatedSize(buffer_id);
874 }
875
876 // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
877 // operand. Account for this potential reuse here.
878
879 VLOG(3) << " memory usage = " << memory_usage_;
880 VLOG(10) << ToString();
881
882 if (VLOG_IS_ON(1)) {
883 DCHECK(Check());
884 }
885 return Status::OK();
886 }
887
EndInstruction()888 Status MemoryUsageTracker::EndInstruction() {
889 TF_RET_CHECK(in_progress_item_ != nullptr);
890 VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
891
892 for (BufferId buffer_id : in_progress_item_->buffers_used) {
893 Buffer& buffer = buffers_.at(buffer_id);
894 buffer.unfinished_user_count--;
895 TF_RET_CHECK(buffer.unfinished_user_count >= 0)
896 << buffer.ToString() << " has negative unfinished user count.";
897 if (buffer.unfinished_user_count == 0) {
898 // Buffer is now dead.
899 VLOG(3) << " " << buffer.ToString() << " is now dead.";
900 memory_usage_ -= AllocatedSize(buffer_id);
901 // The memory usage can become negative inside the computation as we can
902 // free up the parameter space and reuse it for other tensors.
903 }
904 }
905
906 // If any buffer defined by this instruction has no uses, then memory can be
907 // reclaimed immediately.
908 for (BufferId buffer_id : in_progress_item_->buffers_defined) {
909 const Buffer& buffer = buffers_.at(buffer_id);
910 if (buffer.unfinished_user_count == 0) {
911 VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
912 memory_usage_ -= AllocatedSize(buffer_id);
913 // The memory usage can become negative inside the computation as we can
914 // free up the parameter space and reuse it for other tensors.
915 }
916 }
917
918 in_progress_item_ = nullptr;
919
920 VLOG(3) << " memory usage = " << memory_usage_;
921 VLOG(10) << ToString();
922
923 if (VLOG_IS_ON(1)) {
924 DCHECK(Check());
925 }
926 return Status::OK();
927 }
928
MemoryReducedIfCompressed(Item * item,const Shape & compact_shape) const929 int64 MemoryUsageTracker::MemoryReducedIfCompressed(
930 Item* item, const Shape& compact_shape) const {
931 CHECK_NE(in_progress_item_, nullptr);
932 if (!item->placed || item == in_progress_item_) {
933 return 0;
934 }
935
936 int64 memory_reduced = 0;
937
938 // We only compress a single piece of an output at one time.
939 CHECK_EQ(item->buffers_output.size(), 1);
940 BufferId buffer_id = item->buffers_output[0];
941 if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) &&
942 IsInstructionCurrentlyLive(item)) {
943 const Buffer& buffer = buffers_.at(buffer_id);
944 memory_reduced += buffer.size;
945
946 int64 compact_shape_size = size_function_(compact_shape);
947 // Account for buffers that are compressed after instruction.
948 memory_reduced -= compact_shape_size;
949 }
950 return memory_reduced;
951 }
952
MemoryReducedIfRematerialized(absl::Span<const Item * const> items) const953 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
954 absl::Span<const Item* const> items) const {
955 CHECK_NE(in_progress_item_, nullptr);
956 int64 memory_reduced = 0;
957 absl::flat_hash_set<const Item*> remat_candidates;
958
959 for (const Item* item : items) {
960 if (!item->placed || item == in_progress_item_) {
961 LOG(WARNING) << "Unplaced item or in progress item being checked for "
962 "rematerialization.";
963 return 0;
964 }
965
966 // Compute the amount of memory reduced (if any) by rematerializing
967 // 'item->instruction'. The LogicalBuffers defined by 'item->instruction'
968 // will no longer be live at this program point, so initially set
969 // memory_reduced to the size of its defined values.
970 for (BufferId buffer_id : item->buffers_defined) {
971 const Buffer& buffer = buffers_.at(buffer_id);
972 // Avoid rematerializing instructions with indirect uses as it is
973 // difficult to reason about liveness after rematerializing the
974 // instruction.
975 // Avoid rematerializing instructions with live out buffers.
976 // Avoid rematerializing buffers that are in nested tuples.
977 // TODO(mpurohit): Check why live_out buffers are an issue here.
978 if (buffer.has_indirect_uses || buffer.live_out ||
979 buffer.index.size() > 1) {
980 return 0;
981 }
982 if (IsInUse(buffer_id)) {
983 return 0;
984 }
985 if (IsCurrentlyLive(buffer_id)) {
986 memory_reduced += AllocatedSize(buffer_id);
987 }
988 }
989
990 // Account for any logical buffers whose live range must be extended across
991 // this program point.
992 for (BufferId buffer_id : item->buffers_used) {
993 if (!IsCurrentlyLive(buffer_id)) {
994 // This logical buffer is used by 'item->instruction' but is not live at
995 // this program point. Rematerializing 'item->instruction' will extend
996 // the buffer's live range across this program point unless it is
997 // defined by an instruction that is also being rematerialized.
998 Item* defining_instruction =
999 buffers_.at(buffer_id).defining_instruction;
1000 if (!remat_candidates.contains(defining_instruction)) {
1001 memory_reduced -= AllocatedSize(buffer_id);
1002 }
1003 }
1004 }
1005 remat_candidates.insert(item);
1006 }
1007
1008 return memory_reduced;
1009 }
1010
AddCompressInstructions(Item * original_item,Item * compressed_item,Item * uncompressed_item)1011 Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
1012 Item* compressed_item,
1013 Item* uncompressed_item) {
1014 // Original buffer is now dead.
1015 memory_usage_ -= size_function_(original_item->instruction->shape());
1016 // Compressed buffer is now alive.
1017 memory_usage_ += size_function_(compressed_item->instruction->shape());
1018
1019 UsesList placed_users;
1020 UsesList unplaced_users;
1021 CHECK_EQ(original_item->buffers_output.size(), 1);
1022 BufferId original_buffer_id = original_item->buffers_output[0];
1023 Buffer& original_buffer = buffers_.at(original_buffer_id);
1024 for (ItemUse& user : original_buffer.users) {
1025 if (user.user->placed) {
1026 CHECK(IsFinished(user.user)) << user.user->instruction->name();
1027 placed_users.push_back(user);
1028 } else {
1029 unplaced_users.push_back(user);
1030 }
1031 }
1032 original_buffer.users = std::move(placed_users);
1033 original_buffer.unfinished_user_count = 0;
1034 original_buffer.users.push_back(ItemUse{compressed_item, 0, absl::nullopt});
1035 // We are reallocating the vector containing the buffers potentially,
1036 // invalidating the original_buffer reference, so copy the index that we need
1037 // across NewBuffer calls.
1038 ShapeIndex copied_index = original_buffer.index;
1039 Buffer& compressed_buffer =
1040 NewBuffer(compressed_item, compressed_item->instruction->shape(),
1041 copied_index, {ItemUse{uncompressed_item, 0, absl::nullopt}},
1042 /*live_out=*/false,
1043 /*has_indirect_uses=*/false);
1044 compressed_item->buffers_used = original_item->buffers_output;
1045 compressed_item->buffers_output = {compressed_buffer.id};
1046 compressed_item->buffers_defined.push_back(compressed_buffer.id);
1047
1048 Buffer& uncompressed_buffer =
1049 NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
1050 copied_index, std::move(unplaced_users), /*live_out=*/false,
1051 /*has_indirect_uses=*/false);
1052
1053 uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
1054 uncompressed_item->buffers_output = {uncompressed_buffer.id};
1055 uncompressed_item->buffers_defined = {uncompressed_buffer.id};
1056
1057 for (ItemUse& user : uncompressed_buffer.users) {
1058 BufferIdList& buffers_used = user.user->buffers_used;
1059 std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
1060 uncompressed_buffer.id);
1061 }
1062
1063 return Status::OK();
1064 }
1065
AddRematerializedInstruction(Item * original_item,Item * remat_item,absl::Span<Item * > indirect_users)1066 Status MemoryUsageTracker::AddRematerializedInstruction(
1067 Item* original_item, Item* remat_item, absl::Span<Item*> indirect_users) {
1068 VLOG(3) << "AddRematerializedInstruction: original_instruction = "
1069 << original_item->instruction->name()
1070 << ", remat_instruction = " << remat_item->instruction->name();
1071
1072 TF_RET_CHECK(in_progress_item_ != nullptr);
1073 TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
1074 TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
1075
1076 // Construct the list of buffers used and defined by the rematerialization.
1077 remat_item->buffers_used = original_item->buffers_used;
1078
1079 // Account for the additional buffer uses created by the new rematerialization
1080 // instruction. Update memory usage if the rematerialization makes a dead
1081 // buffer live again.
1082 for (BufferId buffer_id : original_item->buffers_used) {
1083 Buffer& buffer = buffers_.at(buffer_id);
1084 if (buffer.unfinished_user_count == 0) {
1085 // Buffer used by this instruction was dead, now is alive.
1086 memory_usage_ += AllocatedSize(buffer.id);
1087 }
1088 buffer.unfinished_user_count++;
1089 absl::InlinedVector<ItemUse, 2> filtered_users;
1090 std::copy_if(buffer.users.begin(), buffer.users.end(),
1091 std::back_inserter(filtered_users),
1092 [&](const ItemUse& iu) { return iu.user == original_item; });
1093 for (ItemUse& u : filtered_users) {
1094 buffer.users.push_back(ItemUse{remat_item, u.operand_number, u.index});
1095 }
1096 }
1097
1098 const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1099 indirect_users.end());
1100 // Create a new set of Buffers defined by the new rematerialization
1101 // instruction. Update the internal data structures and memory use to account
1102 // for them.
1103 for (BufferId old_buffer_id : original_item->buffers_defined) {
1104 Buffer& old_buffer = buffers_.at(old_buffer_id);
1105
1106 UsesList placed_users;
1107 UsesList unplaced_users;
1108 for (ItemUse& user : old_buffer.users) {
1109 if (user.user->placed) {
1110 placed_users.push_back(user);
1111 } else {
1112 // We keep only the indirect users that are in the provided list.
1113 // We consider all the other dead and remove any buffer use they might
1114 // perform and remove it from the buffer user list.
1115 if (!IsSupportedIndirectUser(user.user->instruction) ||
1116 indirect_users_set.contains(user.user)) {
1117 unplaced_users.push_back(user);
1118 } else {
1119 CHECK(user.user->buffers_defined.empty())
1120 << "Buffers defined expected to be empty for use passthrough "
1121 "instructions";
1122 user.user->buffers_output.clear();
1123 user.user->buffers_used.clear();
1124 }
1125 }
1126 }
1127 old_buffer.users = std::move(placed_users);
1128 old_buffer.unfinished_user_count = 0;
1129
1130 // Buffer is now dead.
1131 memory_usage_ -= AllocatedSize(old_buffer.id);
1132
1133 Buffer& new_buffer =
1134 RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
1135
1136 remat_item->buffers_defined.push_back(new_buffer.id);
1137 auto update_buffers = [old_buffer_id, new_buffer_id = new_buffer.id](
1138 BufferIdList& to_update) {
1139 std::replace(to_update.begin(), to_update.end(), old_buffer_id,
1140 new_buffer_id);
1141 };
1142 // Update users with the id of the new buffer.
1143 for (ItemUse& user : new_buffer.users) {
1144 update_buffers(user.user->buffers_used);
1145 update_buffers(user.user->buffers_output);
1146 }
1147 }
1148
1149 // Update the indirect users with the id of the new buffers.
1150 for (Item* indirect_user : indirect_users) {
1151 // Source of the buffers that are gonna be passthrough.
1152 const Item* source_item =
1153 instruction_list_.GetItem(indirect_user->instruction->operand(0));
1154 switch (indirect_user->instruction->opcode()) {
1155 case HloOpcode::kBitcast: {
1156 // If the source is another indirect user then copy the output
1157 // in the used and output lists of the bitcast as they don't define any
1158 // buffer.
1159 if (IsSupportedIndirectUser(source_item->instruction)) {
1160 indirect_user->buffers_used = source_item->buffers_output;
1161 indirect_user->buffers_output = source_item->buffers_output;
1162 } else {
1163 // If it's a real instruction producing a buffer then copy the defined
1164 // buffers into used and output.
1165 indirect_user->buffers_used = source_item->buffers_defined;
1166 indirect_user->buffers_output = source_item->buffers_defined;
1167 }
1168 break;
1169 }
1170 case HloOpcode::kGetTupleElement: {
1171 // GTEs just use the tuple buffer and output the buffer they actually
1172 // extract from the tuple.
1173 const HloGetTupleElementInstruction* gte =
1174 Cast<HloGetTupleElementInstruction>(indirect_user->instruction);
1175 for (BufferId buffer_id : source_item->buffers_defined) {
1176 const Buffer& def_buffer = buffers_.at(buffer_id);
1177 if (def_buffer.index == ShapeIndex{gte->tuple_index()}) {
1178 indirect_user->buffers_output.push_back(buffer_id);
1179 }
1180 // This is the tuple buffer.
1181 if (def_buffer.index.empty()) {
1182 indirect_user->buffers_used.push_back(buffer_id);
1183 }
1184 }
1185 break;
1186 }
1187 default: {
1188 LOG(FATAL) << "Unsupported indirect instruction with opcode "
1189 << HloOpcodeString(indirect_user->instruction->opcode());
1190 break;
1191 }
1192 }
1193 // Fixup buffer users for the indirect instructions. For GTEs is only the
1194 // tuple buffer, while for bitcast is the buffer they pass through.
1195 for (BufferId buffer_id : indirect_user->buffers_used) {
1196 Buffer& buffer = buffers_.at(buffer_id);
1197 buffer.unfinished_user_count++;
1198 buffer.users.push_back(ItemUse{indirect_user, 0, absl::nullopt});
1199 }
1200 }
1201
1202 VLOG(3) << " memory usage = " << memory_usage_;
1203 XLA_VLOG_LINES(10, ToString());
1204
1205 DCHECK(Check());
1206
1207 return Status::OK();
1208 }
1209
ToString() const1210 string MemoryUsageTracker::ToString() const {
1211 string output =
1212 absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
1213 absl::StrAppend(&output,
1214 "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
1215 memory_usage(), " bytes)");
1216 for (auto* item = instruction_list_.first(); item != nullptr;
1217 item = instruction_list_.next(item)) {
1218 const HloInstruction* instruction = item->instruction;
1219 string inprogress = item == in_progress_item_ ? " in-progress" : "";
1220 string placed = item->placed ? " placed" : "";
1221 absl::StrAppend(&output, " ", instruction->name(), inprogress, placed,
1222 "\n Defines:\n");
1223 for (BufferId buffer_id : item->buffers_defined) {
1224 const Buffer& buffer = buffers_[buffer_id];
1225 string live = IsCurrentlyLive(buffer_id) ? " live" : "";
1226 absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
1227 buffer.unfinished_user_count, " unfinished uses\n");
1228 }
1229 absl::StrAppend(&output, " Outputs:\n");
1230 for (BufferId buffer_id : item->buffers_output) {
1231 absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
1232 }
1233 absl::StrAppend(&output, " Uses:\n");
1234 for (BufferId buffer_id : item->buffers_used) {
1235 absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
1236 }
1237 }
1238 return output;
1239 }
1240
GetCompactShape(const HloInstruction * hlo)1241 StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
1242 auto it = compact_shape_.find(hlo);
1243 if (it != compact_shape_.end()) {
1244 return it->second;
1245 }
1246 const Shape& original_shape = hlo->shape();
1247 TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
1248 compact_shape_[hlo] = min_shape;
1249 return min_shape;
1250 }
1251
Check() const1252 bool MemoryUsageTracker::Check() const {
1253 auto elements_are_unique = [](const BufferIdList& vec) {
1254 return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
1255 };
1256
1257 // Verify buffers_defined per instruction.
1258 for (auto* instruction : computation_->instructions()) {
1259 const BufferIdList& defined_buffers =
1260 instruction_list_.GetItem(instruction)->buffers_defined;
1261 CHECK(elements_are_unique(defined_buffers))
1262 << "Instruction " << instruction->name()
1263 << " does not have unique defined buffers: "
1264 << absl::StrJoin(
1265 defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
1266 absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1267 });
1268
1269 for (const Buffer& buffer : buffers_) {
1270 if (buffer.defining_instruction->instruction == instruction) {
1271 CHECK(absl::c_linear_search(defined_buffers, buffer.id))
1272 << "Instruction " << instruction->name()
1273 << " defined buffers is missing: " << buffer.ToString();
1274 }
1275 }
1276 }
1277
1278 // Verify buffers_used per instruction.
1279 for (auto* instruction : computation_->instructions()) {
1280 const BufferIdList& used_buffers =
1281 instruction_list_.GetItem(instruction)->buffers_used;
1282 CHECK(elements_are_unique(used_buffers))
1283 << "Instruction " << instruction->name()
1284 << " does not have unique used buffers: "
1285 << absl::StrJoin(
1286 used_buffers, ", ", [this](string* out, BufferId buffer_id) {
1287 absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1288 });
1289 }
1290 for (const Buffer& buffer : buffers_) {
1291 int64 unfinished_uses = 0;
1292 absl::flat_hash_set<Item*> already_counted_user;
1293 for (const ItemUse& user : buffer.users) {
1294 const BufferIdList& used_buffers = user.user->buffers_used;
1295 CHECK(absl::c_linear_search(used_buffers, buffer.id))
1296 << "Instruction " << user.user->instruction->name()
1297 << " used buffers is missing " << buffer.ToString();
1298 if (!IsFinished(user.user) &&
1299 already_counted_user.insert(user.user).second) {
1300 unfinished_uses++;
1301 }
1302 }
1303 CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
1304 << "Incorrect unplaced use count for " << buffer.ToString();
1305 }
1306 return true;
1307 }
1308
1309 // Computes and returns the cost of rematerializing the given instruction.
1310 // Cost per rematerialized instruction is defined as:
1311 //
1312 // memory_limit_bytes / memory_reduced
1313 //
1314 // The idea is to choose the operation that will save the most memory for
1315 // rematerialization and do not worry about how much the compute costs since
1316 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64 memory_reduced,int64 memory_limit_bytes)1317 int64 RematerializationCost(const HloInstruction* instruction,
1318 const MemoryUsageTracker& memory_tracker,
1319 int64 memory_reduced, int64 memory_limit_bytes) {
1320 // If none of the users of 'instruction' have been placed in the sequence (as
1321 // tracked by memory_tracker), then rematerialization of 'instruction' is a
1322 // zero-cost move of 'instruction' in the sequence.
1323 if (!absl::c_any_of(instruction->users(),
1324 [&memory_tracker](const HloInstruction* inst) {
1325 return memory_tracker.IsPlaced(inst);
1326 })) {
1327 return 0;
1328 }
1329
1330 CHECK_GT(memory_reduced, 0);
1331 // Return the inverse of the benefit of rematerialization.
1332 return memory_limit_bytes / memory_reduced;
1333 }
1334
1335 // Returns a block of up to min_block_size consecutive candidate instructions
1336 // from instruction_list starting from start_item. Returns fewer than
1337 // min_block_size instructions if the block of unplaced instructions starting
1338 // from start_item is smaller than min_block_size.
GetInitialBlock(const InstructionList & instruction_list,const MemoryUsageTracker & tracker,Item * start_item,int min_block_size)1339 std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
1340 const MemoryUsageTracker& tracker,
1341 Item* start_item, int min_block_size) {
1342 std::vector<Item*> item_block;
1343 Item* curr_item = start_item;
1344 for (int i = 0; i < min_block_size; ++i) {
1345 if (curr_item == nullptr || !curr_item->placed ||
1346 tracker.IsInProgressItem(curr_item)) {
1347 break;
1348 }
1349 item_block.push_back(curr_item);
1350 curr_item = instruction_list.next(curr_item);
1351 }
1352 return item_block;
1353 }
1354
1355 // Returns whether any instruction in 'block' is denylisted or
1356 // non-rematerializable.
AnyDenylistedOrNonRematerializable(const std::vector<Item * > & block,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)1357 bool AnyDenylistedOrNonRematerializable(
1358 const std::vector<Item*>& block,
1359 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
1360 for (auto* item : block) {
1361 if (item->denylisted) {
1362 return true;
1363 }
1364 if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
1365 return true;
1366 }
1367 }
1368 return false;
1369 }
1370
1371 std::pair<std::vector<Item*>, RematStrategy>
PickRematerializationCandidates(const InstructionList & instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,int min_block_size,int max_block_size)1372 MemoryUsageTracker::PickRematerializationCandidates(
1373 const InstructionList& instruction_list, int64 memory_limit_bytes,
1374 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1375 int min_block_size, int max_block_size) {
1376 std::vector<Item*> best_items;
1377 int64 best_cost = 0;
1378 RematStrategy best_strategy;
1379
1380 VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
1381 << max_block_size << "]";
1382
1383 for (auto* start_item = instruction_list.first_skip_node();
1384 start_item != nullptr;
1385 start_item = instruction_list.next_skip_node(start_item)) {
1386 std::vector<Item*> block =
1387 GetInitialBlock(instruction_list, *this, start_item, min_block_size);
1388 if (block.size() < min_block_size) {
1389 // There are no more blocks of size at least min_block_size with unplaced
1390 // instructions.
1391 break;
1392 }
1393 // If any item in the starting block are denylisted or non-rematable, then
1394 // break and move on to next start_item (we can actually move to the last
1395 // invalid item in this block, but let's ignore that optimization for now).
1396 if (AnyDenylistedOrNonRematerializable(block, rematerializable_map)) {
1397 continue;
1398 }
1399 while (block.size() <= max_block_size) {
1400 // block size = 1 is treated separately since we consider compression in
1401 // this case only.
1402 if (block.size() == 1) {
1403 auto* item = block[0];
1404 auto* candidate = item->instruction;
1405 if (item->buffers_output.size() == 1 &&
1406 (mode_ ==
1407 HloRematerialization::RematerializationMode::kCompressOnly ||
1408 mode_ == HloRematerialization::RematerializationMode::
1409 kRecomputeAndCompress)) {
1410 // Only consider compressing single output instruction.
1411 const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
1412
1413 if (item->placed && item != in_progress_item_ &&
1414 !output_buffer.live_out) {
1415 const Shape& original_shape = item->instruction->shape();
1416 if (original_shape.IsArray()) {
1417 Shape compact_shape =
1418 GetCompactShape(item->instruction).ValueOrDie();
1419 const int64 memory_reduced =
1420 MemoryReducedIfCompressed(item, compact_shape);
1421 if (memory_reduced > 0) {
1422 const int64 cost = memory_limit_bytes / memory_reduced;
1423 if (best_items.empty() || cost < best_cost) {
1424 VLOG(3) << "candidate " << candidate->name() << "("
1425 << candidate->ToShortString() << ")"
1426 << " now best when compressed into "
1427 << compact_shape.ToString(true);
1428 RematStrategy strategy;
1429 strategy.kind = RematStrategy::kCompress;
1430 best_strategy = strategy;
1431 best_strategy.compact_shape = compact_shape;
1432 best_items = block;
1433 best_cost = cost;
1434 }
1435 }
1436 }
1437 }
1438 }
1439 }
1440 // Do not consider recomputation in compress-only mode.
1441 if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
1442 // break out of this loop. Move on to the next start_item.
1443 break;
1444 }
1445 // If any of the candidate's control successor has been placed, we need
1446 // to skip this candidate. Otherwise we will violate control dependency.
1447 bool control_successor_placed = false;
1448 for (auto* item : block) {
1449 HloInstruction* candidate = item->instruction;
1450 if (std::any_of(candidate->control_successors().begin(),
1451 candidate->control_successors().end(),
1452 [this](const HloInstruction* inst) {
1453 return IsPlaced(inst);
1454 })) {
1455 control_successor_placed = true;
1456 break;
1457 }
1458 }
1459 if (control_successor_placed) {
1460 // break out of this loop. Move on to the next start_item.
1461 break;
1462 }
1463 VLOG(5) << "Block contains:";
1464 for (auto* hlo : block) {
1465 VLOG(5) << hlo->instruction->name();
1466 }
1467 const int64 memory_reduced = MemoryReducedIfRematerialized(block);
1468
1469 if (memory_reduced > 0) {
1470 const int cost =
1471 RematerializationCost(block, memory_reduced, memory_limit_bytes);
1472
1473 VLOG(5) << "Candidate block of size " << block.size()
1474 << " starting from " << block[0]->instruction->name()
1475 << ", memory reduced " << memory_reduced << ", cost per byte "
1476 << cost;
1477
1478 if (best_items.empty() || cost < best_cost) {
1479 VLOG(5) << "Candidate block of size " << block.size()
1480 << " starting from " << block[0]->instruction->name()
1481 << " now best";
1482 best_strategy.kind = RematStrategy::kRecompute;
1483 best_items = block;
1484 best_cost = cost;
1485 }
1486 }
1487
1488 // Time to update the block to include the next instruction.
1489 auto* last_item = block[block.size() - 1];
1490 auto* next_item = instruction_list.next(last_item);
1491 if (next_item == nullptr || next_item->denylisted || !next_item->placed ||
1492 next_item == in_progress_item_ ||
1493 !CanBeRematerialized(next_item->instruction, rematerializable_map)) {
1494 break;
1495 }
1496 block.push_back(next_item);
1497 }
1498 }
1499 return {best_items, best_strategy};
1500 }
1501
HasUnplacedUsers(Item * item) const1502 bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
1503 for (BufferId buffer_id : item->buffers_defined) {
1504 const Buffer& buffer = buffers_.at(buffer_id);
1505 for (const ItemUse& user : buffer.users) {
1506 if (!user.user->placed) {
1507 return true;
1508 }
1509 }
1510 }
1511 return false;
1512 }
1513
GetItemUses(Item * item) const1514 const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
1515 UsesList combined_users;
1516 for (BufferId buffer_id : item->buffers_defined) {
1517 const Buffer& buffer = buffers_.at(buffer_id);
1518 for (const ItemUse& user : buffer.users) {
1519 combined_users.push_back(user);
1520 }
1521 }
1522 return combined_users;
1523 }
1524
RematerializeInstructions(MemoryUsageTracker * memory_tracker,std::vector<Item * > * best_items,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,InstructionList * instruction_list)1525 StatusOr<int64> RematerializeInstructions(
1526 MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
1527 absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1528 InstructionList* instruction_list) {
1529 int64 net_instructions_added = 0;
1530 int64 total_memory_saved =
1531 memory_tracker->MemoryReducedIfRematerialized(*best_items);
1532 std::vector<string> instruction_names(best_items->size());
1533 // Rematerialize the block of instructions in the reverse order to account for
1534 // dependencies between instructions in best_items.
1535 for (int i = best_items->size() - 1; i >= 0; --i) {
1536 Item* best_item = (*best_items)[i];
1537 HloInstruction* best = best_item->instruction;
1538 instruction_names[i] = best->name();
1539 HloComputation* computation = best->parent();
1540
1541 // If the item to remat has no unplaced users, then skip the
1542 // rematerialization. Such an instruction can appear in best_items because
1543 // it is part of a good block, but does not itself add any benefit.
1544 if (!memory_tracker->HasUnplacedUsers(best_item)) {
1545 continue;
1546 }
1547
1548 HloInstruction* remat =
1549 computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1550
1551 // Add control dependencies to the new operation.
1552 for (auto successor : best->control_successors()) {
1553 TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1554 }
1555 for (auto predecessor : best->control_predecessors()) {
1556 TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1557 }
1558
1559 Item* remat_item = instruction_list->CreateItem(remat);
1560
1561 // Replace each remaining use of 'best' with the rematerialization.
1562 absl::InlinedVector<Item*, 4> indirect_users;
1563 absl::flat_hash_map<int64, HloInstruction*> gte_cache;
1564 for (auto& user : memory_tracker->GetItemUses(best_item)) {
1565 if (!memory_tracker->IsPlaced(user.user->instruction)) {
1566 VLOG(2) << " Replacing use of " << best->name() << " in "
1567 << user.user->instruction->name() << " with " << remat->name();
1568 const int64 op_idx = user.operand_number;
1569 HloInstruction* remat_use = remat;
1570 if (user.index) {
1571 auto cached_gte = gte_cache.find(*user.index);
1572 if (cached_gte == gte_cache.end()) {
1573 remat_use = computation->AddInstruction(
1574 HloInstruction::CreateGetTupleElement(
1575 ShapeUtil::GetTupleElementShape(remat_use->shape(),
1576 *user.index),
1577 remat_use, *user.index));
1578 indirect_users.push_back(instruction_list->CreateItem(remat_use));
1579 gte_cache[*user.index] = remat_use;
1580 } else {
1581 remat_use = cached_gte->second;
1582 }
1583 }
1584 if (user.user->instruction->operand(op_idx)->shape() !=
1585 remat_use->shape()) {
1586 remat_use = computation->AddInstruction(HloInstruction::CreateBitcast(
1587 user.user->instruction->operand(op_idx)->shape(), remat_use));
1588 indirect_users.push_back(instruction_list->CreateItem(remat_use));
1589 }
1590 TF_RETURN_IF_ERROR(
1591 user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
1592 }
1593 }
1594
1595 // Account for the rematerialization in the memory tracker.
1596 TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
1597 best_item, remat_item, absl::MakeSpan(indirect_users)));
1598
1599 // Insert rematerialized instruction right before the earliest unplaced
1600 // use of the instruction *and* the earliest unplaced last use of any
1601 // operands of remat. Unplaced uses of the remat's operands are included
1602 // because we don't want to extend the live range of remat's operands as
1603 // this could increase memory usage.
1604 ItemList place_before;
1605 const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1606 indirect_users.end());
1607 for (auto user : remat->users()) {
1608 if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1609 place_before.push_back(instruction_list->GetItem(user));
1610 }
1611 }
1612 for (auto* indirect_user : indirect_users) {
1613 for (auto user : indirect_user->instruction->users()) {
1614 if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1615 place_before.push_back(instruction_list->GetItem(user));
1616 }
1617 }
1618 }
1619 for (auto* operand : remat->operands()) {
1620 for (auto* operand_user : operand->users()) {
1621 if (operand_user != remat) {
1622 Item* operand_user_item = instruction_list->GetItem(operand_user);
1623 if (!operand_user_item->placed) {
1624 place_before.push_back(operand_user_item);
1625 }
1626 }
1627 }
1628 }
1629 // Insert rematerialized instruction before any of its successors to
1630 // preserve ordering regarding control dependency.
1631 for (auto successor : remat->control_successors()) {
1632 Item* successor_item = instruction_list->GetItem(successor);
1633 // Assert to make sure we never remat an operation with control
1634 // successor already placed.
1635 CHECK(!successor_item->placed) << successor_item->instruction->name();
1636 place_before.push_back(successor_item);
1637 }
1638 instruction_list->InsertBeforeInstructions(remat_item, place_before);
1639
1640 for (auto* bitcast : indirect_users) {
1641 instruction_list->InsertBeforeInstructions(bitcast, place_before);
1642 }
1643 // Helper function that looks through indirect users when determining if
1644 // there is an active user for an HloInstruction.
1645 std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
1646 for (auto* u : i->users()) {
1647 if (!IsSupportedIndirectUser(u) || !uses_empty(u)) {
1648 return false;
1649 }
1650 }
1651 return true;
1652 };
1653 // If the rematerialized instruction is dead then rematerialization is
1654 // essentially a move. Don't delete the instruction now because we don't
1655 // want duplicate HloInstruction* values during the course of the
1656 // transformation because we keep maps with HloInstruction* values as
1657 // keys.
1658 if (uses_empty(best)) {
1659 VLOG(2) << best->name() << " is now dead";
1660 if (ContainsKey(*remat_move_instructions, best)) {
1661 // Previously, 'best' was a rematerialization which killed the
1662 // instruction it was a copying of. Now 'remat' is a rematerialization
1663 // of 'best' and kills 'best'. Stop rematerializing this instruction
1664 // to avoid an infinite loop.
1665 instruction_list->Denylist(remat);
1666 }
1667 remat_move_instructions->insert(remat);
1668 net_instructions_added += indirect_users.size();
1669 } else {
1670 net_instructions_added += indirect_users.size() + 1;
1671 }
1672 for (auto* indirect_user : indirect_users) {
1673 instruction_list->Denylist(indirect_user->instruction);
1674 }
1675 }
1676 VLOG(1) << "Rematerializing instructions ["
1677 << absl::StrJoin(instruction_names, ", ") << "] (saving "
1678 << HumanReadableNumBytes(total_memory_saved) << ")";
1679 return net_instructions_added;
1680 }
1681
CompressInstruction(MemoryUsageTracker * memory_tracker,Item * best_item,const Shape & compact_shape,InstructionList * instruction_list)1682 StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
1683 Item* best_item, const Shape& compact_shape,
1684 InstructionList* instruction_list) {
1685 HloInstruction* best = best_item->instruction;
1686 VLOG(5) << "Transposing instruction " << best->name() << " (saving "
1687 << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1688 best_item, compact_shape))
1689 << ") to" << compact_shape.ToString(true);
1690
1691 HloComputation* computation = best->parent();
1692 HloInstruction* compressed = computation->AddInstruction(
1693 HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best),
1694 /*new_name=*/best->name() + ".remat_compressed");
1695
1696 HloInstruction* uncompressed = computation->AddInstruction(
1697 HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed),
1698 /*new_name=*/best->name() + ".remat_uncompressed");
1699
1700 Item* compressed_item = instruction_list->CreateItem(compressed);
1701 compressed_item->placed = true;
1702
1703 Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
1704
1705 // Replace each remaining use of 'best' with the uncompressed.
1706 std::vector<HloInstruction*> best_users_copy = best->users();
1707 for (HloInstruction* user : best_users_copy) {
1708 if (!memory_tracker->IsPlaced(user)) {
1709 VLOG(5) << " Replacing use of " << best->name() << " in " << user->name()
1710 << " with " << uncompressed->name();
1711 TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
1712 }
1713 }
1714
1715 // Account for the rematerialization in the memory tracker.
1716 TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
1717 best_item, compressed_item, uncompressed_item));
1718
1719 // Insert rematerialized instruction right before the earliest unplaced
1720 // use of the instruction.
1721 ItemList place_before;
1722 for (auto user : uncompressed->users()) {
1723 place_before.push_back(instruction_list->GetItem(user));
1724 }
1725
1726 instruction_list->Denylist(compressed_item->instruction);
1727 instruction_list->Denylist(uncompressed_item->instruction);
1728
1729 instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
1730
1731 instruction_list->InsertAfterInstructions(compressed_item, {best_item});
1732
1733 return 2;
1734 }
1735
1736 // A simple struct to encapsulate the number of instructions added during
1737 // rematerialization.
1738 struct InstructionsAdded {
1739 // Total count of instructions rematerialized.
1740 int remat_count;
1741 // Total count of instructions rematerialized minus number of original
1742 // instructions that are now dead.
1743 int net_instructions_added;
1744 };
1745
1746 // Rematerializes the best block of instructions of size between min_block_size
1747 // and max_block_size (both inclusive) if at least one candidate block of
1748 // instructions can be found. Returns number of instructions rematerialized.
RematerializeBestBlock(int min_block_size,int max_block_size,MemoryUsageTracker * memory_tracker,InstructionList * instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions)1749 StatusOr<InstructionsAdded> RematerializeBestBlock(
1750 int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
1751 InstructionList* instruction_list, int64 memory_limit_bytes,
1752 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1753 absl::flat_hash_set<const HloInstruction*>* remat_move_instructions) {
1754 CHECK(min_block_size > 0) << "Negative block size.";
1755
1756 std::vector<Item*> best_items;
1757 RematStrategy best_strategy;
1758 std::tie(best_items, best_strategy) =
1759 memory_tracker->PickRematerializationCandidates(
1760 *instruction_list, memory_limit_bytes, rematerializable_map,
1761 min_block_size, max_block_size);
1762 InstructionsAdded num_instructions_added;
1763 num_instructions_added.remat_count = best_items.size();
1764 if (best_items.empty()) {
1765 num_instructions_added.net_instructions_added = 0;
1766 return num_instructions_added;
1767 }
1768
1769 if (best_strategy.kind == RematStrategy::kCompress) {
1770 CHECK(best_items.size() == 1)
1771 << "More than one instruction compressed simultaneously.";
1772 HloInstruction* best = best_items[0]->instruction;
1773 VLOG(1) << "Compressing instruction " << best->name() << " (saving "
1774 << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1775 best_items[0], best_strategy.compact_shape))
1776 << ")";
1777
1778 TF_ASSIGN_OR_RETURN(
1779 num_instructions_added.net_instructions_added,
1780 CompressInstruction(memory_tracker, best_items[0],
1781 best_strategy.compact_shape, instruction_list));
1782 } else {
1783 TF_ASSIGN_OR_RETURN(
1784 num_instructions_added.net_instructions_added,
1785 RematerializeInstructions(memory_tracker, &best_items,
1786 remat_move_instructions, instruction_list));
1787 }
1788 return num_instructions_added;
1789 }
1790 } // namespace
1791
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order) const1792 StatusOr<int64> HloRematerialization::ComputePeakMemory(
1793 const HloComputation* computation,
1794 const HloInstructionSequence& order) const {
1795 InstructionList instruction_list(order);
1796 MemoryUsageTracker tracker(computation, size_function_,
1797 compact_shape_function_, *points_to_analysis_,
1798 instruction_list, mode_);
1799 int64 peak_memory = tracker.memory_usage();
1800 for (auto* item = instruction_list.first(); item != nullptr;
1801 item = instruction_list.next(item)) {
1802 const HloInstruction* instruction = item->instruction;
1803 TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
1804 TF_ASSIGN_OR_RETURN(int64 callee_usage,
1805 CalledComputationsMemoryUsage(instruction));
1806 peak_memory =
1807 std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage);
1808 TF_RETURN_IF_ERROR(tracker.EndInstruction());
1809 }
1810 VLOG(1) << "Peak memory for " << computation->name() << ": "
1811 << HumanReadableNumBytes(peak_memory);
1812 return peak_memory;
1813 }
1814
CalledComputationsMemoryUsage(const HloInstruction * instruction) const1815 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
1816 const HloInstruction* instruction) const {
1817 const CallSite* callsite =
1818 call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
1819 if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
1820 return 0;
1821 }
1822 int64 callee_usage = 0;
1823 for (const HloComputation* computation : callsite->called_computations()) {
1824 TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
1825 callee_usage += computation_peak_memory_.at(computation);
1826 }
1827 return callee_usage;
1828 }
1829
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64 memory_limit_bytes,int64 min_remat_size)1830 StatusOr<bool> HloRematerialization::RematerializeComputation(
1831 HloComputation* computation, HloSchedule* schedule,
1832 int64 memory_limit_bytes, int64 min_remat_size) {
1833 VLOG(1) << "Rematerializing computation " << computation->name()
1834 << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
1835 VLOG(1) << "peak memory usage is "
1836 << HumanReadableNumBytes(computation_peak_memory_.at(computation));
1837 CHECK(!ContainsKey(rematerialized_computations_, computation));
1838
1839 InstructionList instruction_list(schedule->sequence(computation));
1840 MemoryUsageTracker memory_tracker(
1841 computation, size_function_, compact_shape_function_,
1842 *points_to_analysis_, instruction_list, mode_);
1843
1844 instruction_list.PromoteNodesToSkip([&](Item* item) {
1845 return memory_tracker.AllocatedSize(item) >= min_remat_size;
1846 });
1847 bool changed = false;
1848
1849 // If the rematerialization makes the source instruction dead, then the
1850 // rematerialization is added to 'remat_move_instructions' (the
1851 // rematerialization is essentially a move). If the next rematerialization of
1852 // the instruction is also a move then the rematerialization is added to the
1853 // denylist.
1854 absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
1855
1856 // The map from instructions to their rematerializable status.
1857 absl::flat_hash_map<const HloInstruction*, bool> rematerializable_map;
1858
1859 // The peak memory of the computation at any point in the instruction
1860 // sequence.
1861 int64 peak_memory = memory_tracker.memory_usage();
1862
1863 // Total count of instructions rematerialized.
1864 int64 remat_count = 0;
1865 // Total count of clones created minus number of original rematerialized
1866 // instructions which are dead.
1867 int64 net_instructions_added = 0;
1868
1869 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1870
1871 // Iterate through all instructions in the sequence. At each instruction
1872 // (program point) if memory_usage exceeds the specified limit then
1873 // rematerialize HLO instructions until memory_usage is reduced.
1874 int64 instruction_index = 0;
1875 for (auto* item = instruction_list.first(); item != nullptr;
1876 item = instruction_list.next(item)) {
1877 const HloInstruction* instruction = item->instruction;
1878 TF_ASSIGN_OR_RETURN(int64 callee_usage,
1879 CalledComputationsMemoryUsage(instruction));
1880 TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1881
1882 VLOG(2) << "Program point at " << instruction->name()
1883 << ", memory usage = " << memory_tracker.memory_usage()
1884 << ", callee usage = " << callee_usage << ", [" << instruction_index
1885 << "/" << instruction_list.size() << "]";
1886 instruction_index++;
1887
1888 // Initialize both min_block_size and max_block_size to 1 so that only
1889 // single instruction rematerialization is considered first.
1890 int min_block_size = 1;
1891 int max_block_size = 1;
1892 // Only trigger rematerialization when the memory usage changes.
1893 if (memory_tracker.AllocatedSize(item) + callee_usage > 0) {
1894 while (memory_tracker.memory_usage() + callee_usage >
1895 memory_limit_bytes) {
1896 VLOG(2) << "Over memory limit at instruction " << instruction->name()
1897 << ", using "
1898 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1899 callee_usage)
1900 << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1901
1902 TF_ASSIGN_OR_RETURN(
1903 InstructionsAdded instructions_added,
1904 RematerializeBestBlock(min_block_size, max_block_size,
1905 &memory_tracker, &instruction_list,
1906 memory_limit_bytes, &rematerializable_map,
1907 &remat_move_instructions));
1908 net_instructions_added += instructions_added.net_instructions_added;
1909 remat_count += instructions_added.remat_count;
1910
1911 VLOG(1) << "memory_usage after rematerialization = "
1912 << HumanReadableNumBytes(memory_tracker.memory_usage());
1913 if (instructions_added.remat_count == 0) {
1914 // Unable to find a block to rematerialize.
1915 // Consider doubling the block size.
1916 min_block_size = max_block_size + 1;
1917 max_block_size = 2 * max_block_size;
1918 } else {
1919 // Found a valid block. Reset to start looking for single instructions
1920 // again.
1921 max_rematerialized_block_size_ =
1922 std::max(max_rematerialized_block_size_, max_block_size);
1923 changed = true;
1924 min_block_size = 1;
1925 max_block_size = 1;
1926 }
1927 if (max_block_size > block_size_limit_) {
1928 break;
1929 }
1930 }
1931 }
1932 const CallSite* callsite = call_graph_node.GetCallSite(instruction);
1933 if (callsite != nullptr &&
1934 callsite->context() == CallContext::kSequential &&
1935 memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1936 // Memory usage exceeds the limit. Try to rematerialize any
1937 // subcomputation(s) that this instruction calls.
1938 VLOG(1) << "Memory usage still over the limit ("
1939 << (memory_tracker.memory_usage() + callee_usage) << " > "
1940 << memory_limit_bytes
1941 << "). Rematerializing computations called by "
1942 << instruction->name();
1943
1944 // Recompute callee usage to account for any rematerialization performed
1945 // in the callee computations.
1946 for (HloComputation* called_computation :
1947 callsite->called_computations()) {
1948 if (!ContainsKey(rematerialized_computations_, called_computation)) {
1949 // Memory limit for the subcomputation is the memory limit less the
1950 // amount of memory used at this point in the computation.
1951 int64 subcomputation_memory_limit_bytes = std::max<int64>(
1952 0, memory_limit_bytes - memory_tracker.memory_usage());
1953 TF_ASSIGN_OR_RETURN(
1954 bool subcomputation_changed,
1955 RematerializeComputation(called_computation, schedule,
1956 subcomputation_memory_limit_bytes,
1957 min_remat_size));
1958 changed |= subcomputation_changed;
1959 }
1960 }
1961
1962 TF_ASSIGN_OR_RETURN(callee_usage,
1963 CalledComputationsMemoryUsage(instruction));
1964 }
1965
1966 peak_memory = std::max<int64>(peak_memory,
1967 memory_tracker.memory_usage() + callee_usage);
1968 VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
1969
1970 TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
1971 }
1972
1973 // Verify some invariants on the memory tracker.
1974 for (auto* instruction : computation->instructions()) {
1975 CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
1976 }
1977
1978 VLOG(1) << "In computation " << computation->name() << " rematerialized "
1979 << remat_count << " instructions; " << net_instructions_added
1980 << " net instructions added";
1981 VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory)
1982 << " (was "
1983 << HumanReadableNumBytes(computation_peak_memory_.at(computation))
1984 << ")";
1985
1986 // Update peak memory used by computation.
1987 computation_peak_memory_.at(computation) = peak_memory;
1988
1989 // Update order to include rematerialized instructions.
1990 HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
1991 sequence.clear();
1992 for (auto* item = instruction_list.first(); item != nullptr;
1993 item = instruction_list.next(item)) {
1994 HloInstruction* instruction = item->instruction;
1995 sequence.push_back(instruction);
1996 }
1997 rematerialized_computations_.insert(computation);
1998
1999 instructions_rematerialized_ += remat_count;
2000 net_instructions_added_ += net_instructions_added;
2001
2002 return changed;
2003 }
2004
Run(HloModule * module)2005 StatusOr<bool> HloRematerialization::Run(HloModule* module) {
2006 VLOG(1) << "HloRematerialization() with memory limit of "
2007 << HumanReadableNumBytes(memory_limit_bytes_);
2008 XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
2009
2010 // Initialize pass object state.
2011 computation_peak_memory_.clear();
2012 rematerialized_computations_.clear();
2013 instructions_rematerialized_ = 0;
2014 net_instructions_added_ = 0;
2015
2016 TF_RET_CHECK(module->has_schedule());
2017 TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
2018
2019 // Adjust memory limit to account for the output of the entry
2020 // computation. This is necessary because the per-computation accounting in
2021 // MemoryUsageTracker do not include output as these are typically allocated
2022 // by the caller.
2023 int64 module_output_size = 0;
2024 ShapeUtil::ForEachSubshape(
2025 module->result_shape(),
2026 [&module_output_size, module, this](const Shape& subshape,
2027 const ShapeIndex& output_index) {
2028 module_output_size += size_function_(subshape);
2029 });
2030
2031 const int64 adjusted_memory_limit_bytes =
2032 memory_limit_bytes_ - module_output_size;
2033 VLOG(1) << "Adjusted memory limit accounting for output ("
2034 << HumanReadableNumBytes(module_output_size)
2035 << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
2036
2037 // Compute peak memory usage of all computations in the module called in a
2038 // sequential context.
2039 call_graph_ = CallGraph::Build(module);
2040 TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
2041 [this, module](const CallGraphNode& node) -> Status {
2042 if (node.context() == CallContext::kSequential) {
2043 TF_ASSIGN_OR_RETURN(
2044 computation_peak_memory_[node.computation()],
2045 ComputePeakMemory(node.computation(), module->schedule().sequence(
2046 node.computation())));
2047 }
2048 return Status::OK();
2049 },
2050 /*visit_unreachable_nodes=*/false));
2051
2052 // The peak memory usage of the module equals the peak memory use of the entry
2053 // computation plus the output size of the computation. This is because the
2054 // peak memory for a computation does not include the output as this is
2055 // typically accounted for in the caller.
2056 const int64 before_peak_memory =
2057 computation_peak_memory_.at(module->entry_computation()) +
2058 module_output_size;
2059 VLOG(1) << "Peak memory usage of module (before): "
2060 << HumanReadableNumBytes(before_peak_memory);
2061 // Subcomputations called by the entry computation will also be
2062 // rematerialized.
2063 TF_ASSIGN_OR_RETURN(
2064 bool changed,
2065 RematerializeComputation(module->entry_computation(), &module->schedule(),
2066 adjusted_memory_limit_bytes, min_remat_size_));
2067 // Rematerialization can introduce dead code. This occurs if all uses of an
2068 // instruction are replaced with rematerializations of the instruction.
2069
2070 // Stash away the schedule during copy insertion, to avoid validation failures
2071 // while the module is in flux.
2072 HloSchedule saved_schedule = module->schedule();
2073 module->clear_schedule();
2074 TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
2075 changed |= dead_code_removed;
2076
2077 // After DCE, the module sequence may include instructions which no longer
2078 // exist. Update the schedule and restore it.
2079 TF_RETURN_IF_ERROR(saved_schedule.Update());
2080 TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule)));
2081 VLOG(1) << "Rematerialized " << instructions_rematerialized_
2082 << " instructions in module " << module->name() << "; "
2083 << net_instructions_added_ << " net instructions added";
2084 const int64 current_peak_memory =
2085 computation_peak_memory_.at(module->entry_computation()) +
2086 module_output_size;
2087 VLOG(1) << "Peak memory usage of module now "
2088 << HumanReadableNumBytes(current_peak_memory) << " ("
2089 << current_peak_memory << " bytes), was "
2090 << HumanReadableNumBytes(before_peak_memory) << " ("
2091 << before_peak_memory << " bytes)";
2092 const int64 reduced_peak_memory = before_peak_memory - current_peak_memory;
2093 VLOG(1) << "Reduced peak memory by "
2094 << HumanReadableNumBytes(reduced_peak_memory) << " ("
2095 << reduced_peak_memory << " bytes)";
2096
2097 if (sizes_ != nullptr) {
2098 sizes_->before_bytes = before_peak_memory;
2099 sizes_->after_bytes = current_peak_memory;
2100 }
2101
2102 XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
2103
2104 if (current_peak_memory > memory_limit_bytes_) {
2105 LOG(WARNING) << absl::StrFormat(
2106 "Can't reduce memory use below %s (%d bytes) by rematerialization; "
2107 "only reduced to %s (%d bytes)",
2108 HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
2109 HumanReadableNumBytes(current_peak_memory), current_peak_memory);
2110 }
2111 return changed;
2112 }
2113
2114 } // namespace xla
2115