• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <set>
22 #include <string>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/service/buffer_value.h"
34 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_dce.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
44 #include "tensorflow/compiler/xla/service/logical_buffer.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/platform/logging.h"
50 
51 namespace xla {
52 namespace {
53 
54 using ::tensorflow::strings::HumanReadableNumBytes;
55 
56 // Potential optimizations:
57 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
58 //   of candidates.
59 // . Cache IsRematerializable in Item?  Only correct if control
60 //   predecessors and successors don't change.
61 
62 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)63 bool IsRematerializable(const HloInstruction* instruction) {
64   if (instruction->opcode() == HloOpcode::kCopy) {
65     if (LayoutUtil::Equal(instruction->shape().layout(),
66                           instruction->operand(0)->shape().layout())) {
67       // Don't rematerialize copies added by copy insertion (layout doesn't
68       // change).
69       return false;
70     }
71   }
72 
73   // Don't rematerialize instructions with side effects or instructions which
74   // cannot be cloned safely.
75   switch (instruction->opcode()) {
76     case HloOpcode::kCall:
77     case HloOpcode::kConstant:
78     case HloOpcode::kConditional:
79     case HloOpcode::kAllReduce:
80     case HloOpcode::kCustomCall:
81     case HloOpcode::kParameter:
82     case HloOpcode::kWhile:
83       return false;
84     default:
85       return !instruction->HasSideEffect();
86   }
87 }
88 
89 // Checks whether an instruction can be rematerialized, by looking up the
90 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)91 bool CanBeRematerialized(
92     const HloInstruction* instruction,
93     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
94   auto it = rematerializable_map->find(instruction);
95   if (it != rematerializable_map->end()) {
96     return it->second;
97   }
98   bool rematerializable = IsRematerializable(instruction);
99   (*rematerializable_map)[instruction] = rematerializable;
100   return rematerializable;
101 }
102 
103 // Return if this is an instruction that relays the buffers it uses to its own
104 // users and if this is one of these instructions we support the
105 // rematerialization of.
IsSupportedIndirectUser(const HloInstruction * instruction)106 bool IsSupportedIndirectUser(const HloInstruction* instruction) {
107   return instruction->opcode() == HloOpcode::kBitcast ||
108          instruction->opcode() == HloOpcode::kGetTupleElement;
109 }
110 
111 // Type holding a unique identifier for each Buffer object.
112 using BufferId = int64;
113 using BufferIdList = absl::InlinedVector<BufferId, 3>;
114 
115 struct RematStrategy {
116   enum {
117     // Recompute the node at a later program point.
118     kRecompute,
119     // Change the layout into a compact form and uncompress it back at a later
120     // program point.
121     kCompress,
122   } kind;
123   Shape compact_shape;
124 };
125 
126 // We wrap HloInstruction* with an Item that holds auxiliary
127 // per-instruction state.
128 struct Item {
129   HloInstruction* instruction;
130 
131   // True once the instruction is marked as placed (when BeginInstruction
132   // has been called for this instruction).
133   bool placed = false;
134 
135   // To avoid an infinite loop rematerializing the same set of
136   // instructions ad infinitum, keep a denylist of instructions
137   // which should not be rematerialized.
138   bool denylisted = false;
139 
140   // The buffers defined by this instruction.
141   BufferIdList buffers_defined;
142 
143   // Output buffers of this instruction. This is used to track outputs by GTE
144   // instructions (where the instruction doesn't define a buffer).
145   BufferIdList buffers_output;
146 
147   // The buffers used by this instruction.
148   BufferIdList buffers_used;
149 
150   bool is_skip_node = false;
151 
152  private:
153   friend class InstructionList;
154 
155   // Items are arranged in a doubly linked list.
156   Item* next = nullptr;
157   Item* prev = nullptr;
158 
159   Item* prev_skip_node = nullptr;
160   Item* next_skip_node = nullptr;
161 
162   // List is ordered by position, which can however be duplicated as
163   // new instructions are inserted.  See InsertBeforeInstructions
164   // comment for details.
165   int64 position;
166 };
167 
168 // Data structure meant to record the user of the buffer defined from an Item.
169 // It records also the operand_number from where such use derives, so that
170 // indirect uses can be better identified (like for example a buffer used
171 // through a bitcast).
172 struct ItemUse {
173   Item* user;
174   int64 operand_number;
175   absl::optional<int64> index;
176 
ItemUsexla::__anon95ba67740111::ItemUse177   ItemUse(Item* user, int64 op_num, absl::optional<int64> index)
178       : user(user), operand_number(op_num), index(index) {}
operator ==xla::__anon95ba67740111::ItemUse179   bool operator==(const ItemUse& other) const {
180     return user == other.user && operand_number == other.operand_number &&
181            index == other.index;
182   }
183 };
184 
185 using ItemList = absl::InlinedVector<Item*, 3>;
186 using UsesList = absl::InlinedVector<ItemUse, 3>;
187 
188 // Class which maintains an ordered list of instructions with fast insertion
189 // before arbitrary elements.
190 //
191 // This is a skip list structure that has two lanes: express lane and slow lane.
192 // All nodes are presented on the slow lane but a node can be promoted into
193 // express lane for fast iteration.
194 //
195 // In the following case, node 2 and node + 1 are connected via an express lane.
196 //                    +--------------------------+----------->: Express lane
197 //                    |                          |
198 //       node1<-> node 2 <-> .. <-> node n <-> node n+1 <->...: Slow lane
199 //
200 class InstructionList {
201  public:
InstructionList(const HloInstructionSequence & order)202   explicit InstructionList(const HloInstructionSequence& order) {
203     int64 position = 0;
204     Item* last = nullptr;
205     last_skip_node_ = nullptr;
206     first_skip_node_ = nullptr;
207     for (HloInstruction* inst : order.instructions()) {
208       // Add a new item to the linked list.
209       Item* item = new Item;
210       item->next = nullptr;
211       item->prev = last;
212       if (last == nullptr) {
213         first_ = item;
214       } else {
215         last->next = item;
216       }
217       last = item;
218 
219       // Initially position numbers are uniquely assigned in order. Later as
220       // instructions are added with InsertBefore* methods, some instructions
221       // may have duplicate position numbers, but the values will be guaranteed
222       // to be monotonically increasing through the list, and so is still useful
223       // for quickly(-ish) determining the order of arbitrary instructions in
224       // the list.
225       item->instruction = inst;
226       item->position = position;
227       position++;
228 
229       item_map_[inst] = item;
230     }
231   }
232 
~InstructionList()233   ~InstructionList() {
234     for (Item* item = first_; item != nullptr;) {
235       Item* next = item->next;
236       delete item;
237       item = next;
238     }
239   }
240 
size() const241   size_t size() const { return item_map_.size(); }
242 
243   // For ordered iteration over items.
244   //    for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const245   Item* first() const { return first_; }
next(Item * item) const246   Item* next(Item* item) const { return item->next; }
247 
first_skip_node() const248   Item* first_skip_node() const { return first_skip_node_; }
next_skip_node(Item * item) const249   Item* next_skip_node(Item* item) const { return item->next_skip_node; }
250 
251   // Creates an Item for the given instruction, but doesn't add it to the list.
252   // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)253   Item* CreateItem(HloInstruction* inst) {
254     Item* item = new Item;
255     item->instruction = inst;
256     CHECK(item_map_.insert({inst, item}).second)
257         << "inserting inst twice " << inst->name();
258     return item;
259   }
260 
261   // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const262   Item* GetItem(const HloInstruction* inst) const {
263     auto iter = item_map_.find(inst);
264     CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
265     return iter->second;
266   }
267 
268   // Insert instruction 'to_insert' immediately before the earliest instruction
269   // in 'before_instructions'.
270   //
271   // Each instruction gets a non-decreasing ordinal number. We use this to let
272   // InsertBeforeInstructions quickly insert an instruction before the earliest
273   // instruction in a set of instructions.  If position_number_[a] <
274   // position_number_[b] then 'a' comes before 'b' in the list. If the position
275   // numbers are the same then nothing can be said about their order without
276   // examining the list.
277   //
278   // On object construction this ordinal is precisely the instruction's index
279   // in the list. Later, instructions inserted via InsertBefore receive
280   // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)281   void InsertBeforeInstructions(Item* to_insert,
282                                 absl::Span<Item* const> before_instructions) {
283     VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
284             << " before {"
285             << absl::StrJoin(before_instructions, ", ",
286                              [](string* out, Item* item) {
287                                absl::StrAppend(out, item->instruction->name());
288                              })
289             << "}";
290 
291     // Find the minimal position number of any instruction in
292     // 'before_instructions'.
293     CHECK(!before_instructions.empty());
294     Item* min_position_item = nullptr;
295     for (Item* item : before_instructions) {
296       if (min_position_item == nullptr ||
297           item->position < min_position_item->position) {
298         min_position_item = item;
299       }
300     }
301 
302     // Because more than one instruction in 'before_instructions' may have a
303     // position number of 'min_position_number', find the first such instruction
304     // with position number 'min_position_number'.
305 
306     // First find first instruction with the min position.
307     while (min_position_item->prev != nullptr &&
308            min_position_item->position == min_position_item->prev->position) {
309       min_position_item = min_position_item->prev;
310     }
311 
312     // Now scan forwards until we find one of the before_instructions.
313     while (!absl::c_linear_search(before_instructions, min_position_item)) {
314       min_position_item = min_position_item->next;
315     }
316     return InsertBefore(to_insert, min_position_item);
317   }
318 
319   // Scan the list and promote nodes to express lane if should_promote(Item)
320   // returns true;
PromoteNodesToSkip(std::function<bool (Item *)> should_promote)321   void PromoteNodesToSkip(std::function<bool(Item*)> should_promote) {
322     int64 count = 0;
323     for (auto* item = first(); item != nullptr; item = next(item)) {
324       if (should_promote(item)) {
325         count += 1;
326         if (first_skip_node_ == nullptr) {
327           first_skip_node_ = item;
328         }
329         item->is_skip_node = true;
330         item->prev_skip_node = last_skip_node_;
331         if (last_skip_node_ != nullptr) {
332           last_skip_node_->next_skip_node = item;
333         }
334         last_skip_node_ = item;
335       }
336     }
337     VLOG(1) << " Rematerialization has " << count << " items in express lane";
338   }
339 
InsertAfterInstructions(Item * to_insert,absl::Span<Item * const> after_instructions)340   void InsertAfterInstructions(Item* to_insert,
341                                absl::Span<Item* const> after_instructions) {
342     VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
343             << " after {"
344             << absl::StrJoin(after_instructions, ", ",
345                              [](string* out, Item* item) {
346                                absl::StrAppend(out, item->instruction->name());
347                              })
348             << "}";
349 
350     // Find the max position number of any instruction in
351     // 'after_instructions'.
352     CHECK(!after_instructions.empty());
353     Item* max_position_item = nullptr;
354     for (Item* item : after_instructions) {
355       if (max_position_item == nullptr ||
356           item->position > max_position_item->position) {
357         max_position_item = item;
358       }
359     }
360     // No rematerializable instruction should be inserted at the end of the
361     // computation.
362     CHECK(max_position_item->next != nullptr);
363     InsertBeforeInstructions(to_insert, {max_position_item->next});
364   }
365 
Denylist(const HloInstruction * inst)366   void Denylist(const HloInstruction* inst) {
367     GetItem(inst)->denylisted = true;
368   }
369 
370  private:
371   // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)372   void InsertBefore(Item* item, Item* before) {
373     VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
374             << before->instruction->name();
375     // Always place new nodes on express lane for the ease of implementation.
376     item->is_skip_node = true;
377     // Find the next express node starting from 'before'. Set up the node's
378     // express pointers.
379     Item* cursor = before;
380     while (cursor != nullptr && !cursor->is_skip_node) {
381       cursor = cursor->next;
382     }
383     CHECK(cursor == nullptr || cursor->is_skip_node);
384     if (cursor == nullptr) {
385       //
386       // last_skip_node_<---+                              : express lane
387       //                    |
388       //           ...<->`item`<-> .. <-> `cursor`(null)   : slow lane
389       //
390       // Reached the end. Set the prev_express to last_skip_node, and reset
391       // last_skip.
392       item->prev_skip_node = last_skip_node_;
393       item->next_skip_node = nullptr;
394       last_skip_node_ = item;
395     } else {
396       //
397       //     <-+------------+----------------+--------->   : express lane
398       //       |            |                |
399       // prev_express..<->`item`<-> .. <-> `cursor` <-> ...: slow lane
400       //
401       // Reached the next skip node, sets up express pointers accordingly.
402       CHECK(cursor->is_skip_node);
403       item->prev_skip_node = cursor->prev_skip_node;
404       if (item->prev_skip_node != nullptr) {
405         item->prev_skip_node->next_skip_node = item;
406       }
407       item->next_skip_node = cursor;
408       cursor->prev_skip_node = item;
409     }
410     if (first_skip_node_ == cursor) {
411       first_skip_node_ = item;
412     }
413     // Insert new item into linked list.
414     item->prev = before->prev;
415     item->next = before;
416     before->prev = item;
417     if (item->prev != nullptr) {
418       item->prev->next = item;
419     } else {
420       first_ = item;
421     }
422 
423     // Assign the same position number to the newly added instruction as
424     // 'before'. This guarantees monotonicity of the position numbers, but not
425     // uniqueness.
426     item->position = before->position;
427   }
428 
429   Item* first_;
430 
431   // First skip node of this list.
432   Item* first_skip_node_;
433 
434   // Last skip node of this list.
435   Item* last_skip_node_;
436 
437   // Item for each instruction.
438   absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
439 };
440 
441 // Return the items which use the given LogicalBuffer. Sets
442 // has_indirect_users to whether any of the uses is indirect. A use is indirect
443 // if the instruction defining logical_buffer is not an operand of the use. This
444 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)445 UsesList GetUsers(const InstructionList& instruction_list,
446                   const LogicalBuffer* logical_buffer,
447                   const TuplePointsToAnalysis& points_to_analysis,
448                   bool* has_indirect_users) {
449   UsesList users;
450   // To identify uses iterate through all HloInstruction users of the
451   // BufferAliases of the logical buffer.
452   *has_indirect_users = false;
453   for (const BufferAlias& buffer_alias :
454        points_to_analysis.GetBufferAliases(*logical_buffer)) {
455     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
456       if (points_to_analysis.DoesNotUseOperandBuffer(
457               buffer_alias.instruction(), buffer_alias.index(), user)) {
458         // The alias may be an operand of 'user', but the LogicalBuffer cannot
459         // possibly be used by the instruction so ignore 'user'. This is the
460         // case, for example, for the tuple element buffers in a GetTupleElement
461         // instruction (the GTE instruction only uses the pointer vector).
462         continue;
463       }
464       if (buffer_alias.instruction() != logical_buffer->instruction() &&
465           !IsSupportedIndirectUser(buffer_alias.instruction())) {
466         *has_indirect_users = true;
467       }
468       // A buffer may be used by the instruction via more than one alias. For
469       // example, a buffer which appears in more than one element of a tuple.
470       Item* user_item = instruction_list.GetItem(user);
471       absl::optional<int64> user_index =
472           logical_buffer->index().size() != 1
473               ? absl::nullopt
474               : absl::make_optional(logical_buffer->index().back());
475       for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
476         if (!absl::c_linear_search(
477                 users,
478                 ItemUse{user_item, static_cast<int>(op_idx), user_index})) {
479           users.push_back(
480               ItemUse{user_item, static_cast<int>(op_idx), user_index});
481         }
482       }
483     }
484   }
485   return users;
486 }
487 
488 // Class for tracking memory usage of a computation as the instructions are
489 // placed sequentially. Memory usage is the sum of the sizes of live values
490 // (LogicalBuffers) at the current point in the instruction sequence.
491 class MemoryUsageTracker {
492  public:
493   MemoryUsageTracker(
494       const HloComputation* computation,
495       const HloRematerialization::ShapeSizeFunction& size_function,
496       const HloRematerialization::CompactShapeFunction& compact_shape_function,
497       const TuplePointsToAnalysis& points_to_analysis,
498       const InstructionList& instruction_list,
499       HloRematerialization::RematerializationMode mode);
500 
501   // Starts the placement of the given instruction. This adds the sizes of the
502   // LogicalBuffers defined by the instruction to the current memory
503   // usage. Placement is broken into two steps (BeginInstruction and
504   // EndInstruction) to accurately model memory usage. At BeginInstruction the
505   // memory for the output value(s) of the current instruction is allocated. At
506   // EndInstruction memory for dead operand(s) is freed.
507   Status BeginInstruction(Item* item);
508 
RematerializationCost(const std::vector<Item * > & items,int64 memory_reduced,int64 memory_limit_bytes)509   int64 RematerializationCost(const std::vector<Item*>& items,
510                               int64 memory_reduced, int64 memory_limit_bytes) {
511     // If none of the users of any 'item' have been placed in the
512     // sequence (as tracked by memory_tracker), then rematerialization of
513     // 'item' is a zero-cost move of 'item->instruction' in the sequence.
514     bool zero_cost_move = true;
515     for (auto* item : items) {
516       auto* instruction = item->instruction;
517       if (absl::c_any_of(
518               instruction->users(),
519               [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
520         zero_cost_move = false;
521         break;
522       }
523     }
524     if (zero_cost_move) {
525       return 0;
526     }
527 
528     CHECK_GT(memory_reduced, 0);
529     // Return the inverse of the benefit of rematerialization.
530     return memory_limit_bytes / memory_reduced;
531   }
532 
533   // Finishes the placement of the current instruction. This frees any dead
534   // operands or dead result of the instruction. This must be called after
535   // each call to BeginInstruction.
536   Status EndInstruction();
537 
538   // Returns the number of bytes that the current memory usage will be reduced
539   // if the given instruction is compact.
540   int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
541 
542   // Returns the number of bytes that the current memory usage will be reduced
543   // by if the given sequence of instructions is rematerialized.
544   int64 MemoryReducedIfRematerialized(
545       absl::Span<const Item* const> items) const;
546 
547   Status AddCompressInstructions(Item* original_item, Item* compressed_item,
548                                  Item* uncompressed_item);
549 
550   // Adjusts memory usage to account for the rematerialization of
551   // original_item for all remaining unplaced uses. The rematerialization
552   // is remat_item. This method should be called after the HLO graph has
553   // been transformed (rematerialization instruction created and connected
554   // to uses).
555   Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
556                                       absl::Span<Item*> indirect_users);
557 
558   // Selects and returns the best candidate instructions for rematerialization.
559   // A sequence of candidate instructions of length between min_block_size and
560   // max_block_size (both inclusive) with the lowest rematerialization cost is
561   // selected among those candidates which reduce memory use at the program
562   // point of the current instruction as indicated by memory_tracker. Returns an
563   // empty vector if no candidates are found.
564   std::pair<std::vector<Item*>, RematStrategy> PickRematerializationCandidates(
565       const InstructionList& instruction_list, int64 memory_limit_bytes,
566       absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
567       int min_block_size, int max_block_size);
568 
569   // Returns whether the given instruction has been placed (BeginInstruction
570   // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const571   bool IsPlaced(const HloInstruction* instruction) const {
572     return instruction_list_.GetItem(instruction)->placed;
573   }
574 
575   // Returns whether 'item' has any unplaced users.
576   bool HasUnplacedUsers(Item* item) const;
577 
578   // Returns the list of uses for a specific 'item'.
579   const UsesList GetItemUses(Item* item) const;
580 
581   // Returns whether 'item' is currently in progress.
IsInProgressItem(Item * item) const582   bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
583 
584   // Returns the current memory usage. This is the sum of sizes of all live
585   // values.
memory_usage() const586   int64 memory_usage() const { return memory_usage_; }
587 
588   //
AllocatedSize(Item * item) const589   int64 AllocatedSize(Item* item) const {
590     int64 size = 0;
591     for (auto buffer_id : item->buffers_defined) {
592       size += AllocatedSize(buffer_id);
593     }
594     return size;
595   }
596 
597   // Check invariants of the data structure. This is expensive to call.
598   bool Check() const;
599 
600   string ToString() const;
601 
602  private:
603   // A Buffer represents a single LogicalBuffer in the computation including
604   // various metadata useful for tracking liveness of the value. A LogicalBuffer
605   // is not used directly because the HLO graph is transformed and
606   // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
607   // HLO graph transformations.
608   struct Buffer {
609     // The unique id of this Buffer. This value is equal to the buffer's index
610     // in the vector buffers_.
611     const BufferId id;
612 
613     // The instruction which defines this buffer.
614     Item* defining_instruction;
615 
616     // The materialized size of the buffer in bytes.
617     const int64 size;
618 
619     // Shape of the buffer.
620     Shape shape;
621 
622     // Whether this buffer is live-out of the computation.
623     bool live_out;
624 
625     // Whether this buffer has indirect uses. Ie, an instruction which is not a
626     // user of defining_instruction uses this buffer. This can occur due to
627     // buffer aliasing (eg, tuples).
628     bool has_indirect_uses;
629 
630     // Position in the tuple this buffer definition lives in.
631     ShapeIndex index;
632 
633     // The instructions which use this buffer.
634     UsesList users;
635 
636     // The number of users (HloInstructions) of this buffer which have not yet
637     // been placed in the sequence.
638     int64 unfinished_user_count;
639 
ToStringxla::__anon95ba67740111::MemoryUsageTracker::Buffer640     string ToString() const {
641       return absl::StrCat("Buffer ", id, " (defined by ",
642                           defining_instruction->instruction->name(), ", size ",
643                           size, " bytes)");
644     }
645   };
646 
647   // Get the compact shape of given hlo instruction. An internal cache is used
648   // to avoid computing the shape multiple times.
649   StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
650 
651   // Creates a Buffer representing the given logical buffer. The buffer is added
652   // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool live_out)653   Buffer& CreateBufferFromLogicalBuffer(
654       const LogicalBuffer* logical_buffer,
655       const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
656     bool has_indirect_uses = false;
657     UsesList users = GetUsers(instruction_list_, logical_buffer,
658                               points_to_analysis, &has_indirect_uses);
659     return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
660                      logical_buffer->shape(), logical_buffer->index(),
661                      std::move(users), live_out, has_indirect_uses);
662   }
663 
664   // Create a new buffer representing a rematerialization of given buffer for
665   // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,UsesList && rematerialized_uses)666   Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
667                               UsesList&& rematerialized_uses) {
668     CHECK(original_buffer.defining_instruction->placed)
669         << original_buffer.defining_instruction->instruction->name();
670     CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
671     CHECK(!original_buffer.live_out) << original_buffer.ToString();
672     for (ItemUse& use : rematerialized_uses) {
673       CHECK(!use.user->placed) << use.user->instruction->name();
674     }
675     return NewBuffer(remat_item, original_buffer.shape, original_buffer.index,
676                      std::move(rematerialized_uses), /*live_out=*/false,
677                      /*has_indirect_uses=*/false);
678   }
679 
680   // Return number of bytes allocated for the buffer with the given id. Buffers
681   // allocated by the calling computation (eg, parameter and output buffers) are
682   // considered to have zero bytes because the memory is accounted for in a
683   // different computation.
AllocatedSize(BufferId buffer_id) const684   int64 AllocatedSize(BufferId buffer_id) const {
685     const Buffer& buffer = buffers_.at(buffer_id);
686     HloInstruction* inst = buffer.defining_instruction->instruction;
687     HloOpcode def_opcode = inst->opcode();
688     if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
689       return 0;
690     } else {
691       return buffer.size;
692     }
693   }
694 
695   // Returns true if BeginInstruction and EndInstruction has been called for the
696   // given instruction.
IsFinished(Item * item) const697   bool IsFinished(Item* item) const {
698     return item->placed && item != in_progress_item_;
699   }
700 
701   // Returns whether the given buffer is being used by the in-progress
702   // instruction.
IsInUse(BufferId buffer_id) const703   bool IsInUse(BufferId buffer_id) const {
704     if (in_progress_item_ == nullptr) {
705       return false;
706     }
707     const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
708     return absl::c_linear_search(in_progress_uses, buffer_id);
709   }
710 
IsCurrentlyLive(BufferId buffer_id) const711   bool IsCurrentlyLive(BufferId buffer_id) const {
712     const Buffer& buffer = buffers_[buffer_id];
713     return (buffer.defining_instruction->placed &&
714             buffer.unfinished_user_count > 0);
715   }
716 
717   // Returns whether the given instruction is live at the current program
718   // point.
IsInstructionCurrentlyLive(Item * instruction) const719   bool IsInstructionCurrentlyLive(Item* instruction) const {
720     // If the instruction has not started yet, it is not alive.
721     if (!IsPlaced(instruction->instruction)) {
722       return false;
723     }
724     for (const HloInstruction* user : instruction->instruction->users()) {
725       if (!IsPlaced(user)) {
726         // If there is an unplaced user, consider this instruction currently
727         // live.
728         return true;
729       }
730     }
731     return false;
732   }
733 
734   // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,const Shape & shape,const ShapeIndex & index,UsesList && uses,bool live_out,bool has_indirect_uses)735   Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
736                     const ShapeIndex& index, UsesList&& uses, bool live_out,
737                     bool has_indirect_uses) {
738     int buffer_id = buffers_.size();
739     auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
740       absl::flat_hash_set<Item*> users_set;
741       for (const ItemUse& use : uses) {
742         users_set.insert(use.user);
743       }
744       return users_set.size();
745     };
746     buffers_.push_back(Buffer{
747         buffer_id, defining_instruction, size_function_(shape), shape, live_out,
748         has_indirect_uses, index, uses, get_num_of_unique_users(uses)});
749     return buffers_.back();
750   }
751 
752   const HloComputation* computation_;
753 
754   // Instruction list containing the ordering of instructions in
755   // computation_. This is the order in which instructions are placed
756   // (BeginInstruction/EndInstruction calls).
757   const InstructionList& instruction_list_;
758 
759   // Size function returns the bytes of a given buffer.
760   const HloRematerialization::ShapeSizeFunction& size_function_;
761 
762   // Converts a shape into compact form, returns the same shape if a shape is
763   // already considered compact.
764   const HloRematerialization::CompactShapeFunction& compact_shape_function_;
765 
766   // A map that caches existing known compact shape for each instruction.
767   absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
768 
769   // Memory usage at the currently placed instruction.
770   int64 memory_usage_ = 0;
771 
772   // The instruction currently being placed. This value is non-null only
773   // between the calling of BeginInstruction and EndInstruction.
774   Item* in_progress_item_ = nullptr;
775 
776   HloRematerialization::RematerializationMode mode_;
777   // All buffers in the computation.
778   std::vector<Buffer> buffers_;
779 };
780 
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const HloRematerialization::CompactShapeFunction & compact_shape_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list,HloRematerialization::RematerializationMode mode)781 MemoryUsageTracker::MemoryUsageTracker(
782     const HloComputation* computation,
783     const HloRematerialization::ShapeSizeFunction& size_function,
784     const HloRematerialization::CompactShapeFunction& compact_shape_function,
785     const TuplePointsToAnalysis& points_to_analysis,
786     const InstructionList& instruction_list,
787     HloRematerialization::RematerializationMode mode)
788     : computation_(computation),
789       instruction_list_(instruction_list),
790       size_function_(size_function),
791       compact_shape_function_(compact_shape_function),
792       mode_(mode) {
793   PointsToSet::BufferSet live_out_set =
794       points_to_analysis.GetPointsToSet(computation_->root_instruction())
795           .CreateFlattenedSet();
796   absl::flat_hash_map<const LogicalBuffer*, BufferId>
797       logical_buffer_to_buffer_id;
798   for (auto* item = instruction_list_.first(); item != nullptr;
799        item = instruction_list_.next(item)) {
800     const HloInstruction* const instruction = item->instruction;
801     for (const LogicalBuffer* logical_buffer :
802          points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
803       Buffer* buffer;
804       if (instruction->opcode() == HloOpcode::kWhile) {
805         // The while instruction defines no new buffers. Instead it reuses the
806         // buffers of its operand. Find the Buffer of its operand at the
807         // proper ShapeIndex.
808         const PointsToSet& operand_points_to =
809             points_to_analysis.GetPointsToSet(instruction->operand(0));
810         CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
811         const LogicalBuffer* source_logical_buffer =
812             operand_points_to.element(logical_buffer->index())[0];
813         buffer =
814             &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
815 
816         // Mark buffer as has indirect use and live out.
817         buffer->has_indirect_uses = true;
818         buffer->live_out =
819             buffer->live_out || ContainsKey(live_out_set, logical_buffer);
820 
821         // Add users of while to Buffer users.
822         bool unused;
823         for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
824                                            points_to_analysis, &unused)) {
825           auto existing_user_it = absl::c_find_if(
826               buffer->users,
827               [&](const ItemUse& use) { return user_item.user == use.user; });
828           if (existing_user_it == buffer->users.end()) {
829             buffer->unfinished_user_count++;
830             user_item.user->buffers_used.push_back(buffer->id);
831             buffer->users.push_back(user_item);
832           }
833         }
834       } else {
835         buffer = &CreateBufferFromLogicalBuffer(
836             logical_buffer, points_to_analysis,
837             ContainsKey(live_out_set, logical_buffer));
838         item->buffers_defined.push_back(buffer->id);
839         for (ItemUse& user : buffer->users) {
840           if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
841             user.user->buffers_used.push_back(buffer->id);
842           }
843         }
844       }
845 
846       logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
847     }
848 
849     // Trace the output of each instruction. This is so that we can properly
850     // track which outputs does GTEs have.
851     for (const LogicalBuffer* logical_buffer :
852          points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
853       item->buffers_output.push_back(
854           logical_buffer_to_buffer_id[logical_buffer]);
855     }
856   }
857   XLA_VLOG_LINES(10, ToString());
858   DCHECK(Check());
859 }
860 
BeginInstruction(Item * item)861 Status MemoryUsageTracker::BeginInstruction(Item* item) {
862   const HloInstruction* instruction = item->instruction;
863   VLOG(3) << "BeginInstruction " << instruction->name();
864   TF_RET_CHECK(in_progress_item_ == nullptr);
865   in_progress_item_ = item;
866 
867   item->placed = true;
868 
869   // All buffers defined by this instruction need memory.
870   for (BufferId buffer_id : item->buffers_defined) {
871     VLOG(3) << "  Buffer " << buffers_.at(buffer_id).ToString()
872             << " is now live.";
873     memory_usage_ += AllocatedSize(buffer_id);
874   }
875 
876   // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
877   // operand. Account for this potential reuse here.
878 
879   VLOG(3) << "  memory usage = " << memory_usage_;
880   VLOG(10) << ToString();
881 
882   if (VLOG_IS_ON(1)) {
883     DCHECK(Check());
884   }
885   return Status::OK();
886 }
887 
EndInstruction()888 Status MemoryUsageTracker::EndInstruction() {
889   TF_RET_CHECK(in_progress_item_ != nullptr);
890   VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
891 
892   for (BufferId buffer_id : in_progress_item_->buffers_used) {
893     Buffer& buffer = buffers_.at(buffer_id);
894     buffer.unfinished_user_count--;
895     TF_RET_CHECK(buffer.unfinished_user_count >= 0)
896         << buffer.ToString() << " has negative unfinished user count.";
897     if (buffer.unfinished_user_count == 0) {
898       // Buffer is now dead.
899       VLOG(3) << "  " << buffer.ToString() << " is now dead.";
900       memory_usage_ -= AllocatedSize(buffer_id);
901       // The memory usage can become negative inside the computation as we can
902       // free up the parameter space and reuse it for other tensors.
903     }
904   }
905 
906   // If any buffer defined by this instruction has no uses, then memory can be
907   // reclaimed immediately.
908   for (BufferId buffer_id : in_progress_item_->buffers_defined) {
909     const Buffer& buffer = buffers_.at(buffer_id);
910     if (buffer.unfinished_user_count == 0) {
911       VLOG(3) << "  " << buffer.ToString() << " is immediately dead.";
912       memory_usage_ -= AllocatedSize(buffer_id);
913       // The memory usage can become negative inside the computation as we can
914       // free up the parameter space and reuse it for other tensors.
915     }
916   }
917 
918   in_progress_item_ = nullptr;
919 
920   VLOG(3) << "  memory usage = " << memory_usage_;
921   VLOG(10) << ToString();
922 
923   if (VLOG_IS_ON(1)) {
924     DCHECK(Check());
925   }
926   return Status::OK();
927 }
928 
MemoryReducedIfCompressed(Item * item,const Shape & compact_shape) const929 int64 MemoryUsageTracker::MemoryReducedIfCompressed(
930     Item* item, const Shape& compact_shape) const {
931   CHECK_NE(in_progress_item_, nullptr);
932   if (!item->placed || item == in_progress_item_) {
933     return 0;
934   }
935 
936   int64 memory_reduced = 0;
937 
938   // We only compress a single piece of an output at one time.
939   CHECK_EQ(item->buffers_output.size(), 1);
940   BufferId buffer_id = item->buffers_output[0];
941   if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) &&
942       IsInstructionCurrentlyLive(item)) {
943     const Buffer& buffer = buffers_.at(buffer_id);
944     memory_reduced += buffer.size;
945 
946     int64 compact_shape_size = size_function_(compact_shape);
947     // Account for buffers that are compressed after instruction.
948     memory_reduced -= compact_shape_size;
949   }
950   return memory_reduced;
951 }
952 
MemoryReducedIfRematerialized(absl::Span<const Item * const> items) const953 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
954     absl::Span<const Item* const> items) const {
955   CHECK_NE(in_progress_item_, nullptr);
956   int64 memory_reduced = 0;
957   absl::flat_hash_set<const Item*> remat_candidates;
958 
959   for (const Item* item : items) {
960     if (!item->placed || item == in_progress_item_) {
961       LOG(WARNING) << "Unplaced item or in progress item being checked for "
962                       "rematerialization.";
963       return 0;
964     }
965 
966     // Compute the amount of memory reduced (if any) by rematerializing
967     // 'item->instruction'. The LogicalBuffers defined by 'item->instruction'
968     // will no longer be live at this program point, so initially set
969     // memory_reduced to the size of its defined values.
970     for (BufferId buffer_id : item->buffers_defined) {
971       const Buffer& buffer = buffers_.at(buffer_id);
972       // Avoid rematerializing instructions with indirect uses as it is
973       // difficult to reason about liveness after rematerializing the
974       // instruction.
975       // Avoid rematerializing instructions with live out buffers.
976       // Avoid rematerializing buffers that are in nested tuples.
977       // TODO(mpurohit): Check why live_out buffers are an issue here.
978       if (buffer.has_indirect_uses || buffer.live_out ||
979           buffer.index.size() > 1) {
980         return 0;
981       }
982       if (IsInUse(buffer_id)) {
983         return 0;
984       }
985       if (IsCurrentlyLive(buffer_id)) {
986         memory_reduced += AllocatedSize(buffer_id);
987       }
988     }
989 
990     // Account for any logical buffers whose live range must be extended across
991     // this program point.
992     for (BufferId buffer_id : item->buffers_used) {
993       if (!IsCurrentlyLive(buffer_id)) {
994         // This logical buffer is used by 'item->instruction' but is not live at
995         // this program point. Rematerializing 'item->instruction' will extend
996         // the buffer's live range across this program point unless it is
997         // defined by an instruction that is also being rematerialized.
998         Item* defining_instruction =
999             buffers_.at(buffer_id).defining_instruction;
1000         if (!remat_candidates.contains(defining_instruction)) {
1001           memory_reduced -= AllocatedSize(buffer_id);
1002         }
1003       }
1004     }
1005     remat_candidates.insert(item);
1006   }
1007 
1008   return memory_reduced;
1009 }
1010 
AddCompressInstructions(Item * original_item,Item * compressed_item,Item * uncompressed_item)1011 Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
1012                                                    Item* compressed_item,
1013                                                    Item* uncompressed_item) {
1014   // Original buffer is now dead.
1015   memory_usage_ -= size_function_(original_item->instruction->shape());
1016   // Compressed buffer is now alive.
1017   memory_usage_ += size_function_(compressed_item->instruction->shape());
1018 
1019   UsesList placed_users;
1020   UsesList unplaced_users;
1021   CHECK_EQ(original_item->buffers_output.size(), 1);
1022   BufferId original_buffer_id = original_item->buffers_output[0];
1023   Buffer& original_buffer = buffers_.at(original_buffer_id);
1024   for (ItemUse& user : original_buffer.users) {
1025     if (user.user->placed) {
1026       CHECK(IsFinished(user.user)) << user.user->instruction->name();
1027       placed_users.push_back(user);
1028     } else {
1029       unplaced_users.push_back(user);
1030     }
1031   }
1032   original_buffer.users = std::move(placed_users);
1033   original_buffer.unfinished_user_count = 0;
1034   original_buffer.users.push_back(ItemUse{compressed_item, 0, absl::nullopt});
1035   // We are reallocating the vector containing the buffers potentially,
1036   // invalidating the original_buffer reference, so copy the index that we need
1037   // across NewBuffer calls.
1038   ShapeIndex copied_index = original_buffer.index;
1039   Buffer& compressed_buffer =
1040       NewBuffer(compressed_item, compressed_item->instruction->shape(),
1041                 copied_index, {ItemUse{uncompressed_item, 0, absl::nullopt}},
1042                 /*live_out=*/false,
1043                 /*has_indirect_uses=*/false);
1044   compressed_item->buffers_used = original_item->buffers_output;
1045   compressed_item->buffers_output = {compressed_buffer.id};
1046   compressed_item->buffers_defined.push_back(compressed_buffer.id);
1047 
1048   Buffer& uncompressed_buffer =
1049       NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
1050                 copied_index, std::move(unplaced_users), /*live_out=*/false,
1051                 /*has_indirect_uses=*/false);
1052 
1053   uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
1054   uncompressed_item->buffers_output = {uncompressed_buffer.id};
1055   uncompressed_item->buffers_defined = {uncompressed_buffer.id};
1056 
1057   for (ItemUse& user : uncompressed_buffer.users) {
1058     BufferIdList& buffers_used = user.user->buffers_used;
1059     std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
1060                  uncompressed_buffer.id);
1061   }
1062 
1063   return Status::OK();
1064 }
1065 
AddRematerializedInstruction(Item * original_item,Item * remat_item,absl::Span<Item * > indirect_users)1066 Status MemoryUsageTracker::AddRematerializedInstruction(
1067     Item* original_item, Item* remat_item, absl::Span<Item*> indirect_users) {
1068   VLOG(3) << "AddRematerializedInstruction: original_instruction = "
1069           << original_item->instruction->name()
1070           << ", remat_instruction = " << remat_item->instruction->name();
1071 
1072   TF_RET_CHECK(in_progress_item_ != nullptr);
1073   TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
1074   TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
1075 
1076   // Construct the list of buffers used and defined by the rematerialization.
1077   remat_item->buffers_used = original_item->buffers_used;
1078 
1079   // Account for the additional buffer uses created by the new rematerialization
1080   // instruction. Update memory usage if the rematerialization makes a dead
1081   // buffer live again.
1082   for (BufferId buffer_id : original_item->buffers_used) {
1083     Buffer& buffer = buffers_.at(buffer_id);
1084     if (buffer.unfinished_user_count == 0) {
1085       // Buffer used by this instruction was dead, now is alive.
1086       memory_usage_ += AllocatedSize(buffer.id);
1087     }
1088     buffer.unfinished_user_count++;
1089     absl::InlinedVector<ItemUse, 2> filtered_users;
1090     std::copy_if(buffer.users.begin(), buffer.users.end(),
1091                  std::back_inserter(filtered_users),
1092                  [&](const ItemUse& iu) { return iu.user == original_item; });
1093     for (ItemUse& u : filtered_users) {
1094       buffer.users.push_back(ItemUse{remat_item, u.operand_number, u.index});
1095     }
1096   }
1097 
1098   const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1099                                                       indirect_users.end());
1100   // Create a new set of Buffers defined by the new rematerialization
1101   // instruction. Update the internal data structures and memory use to account
1102   // for them.
1103   for (BufferId old_buffer_id : original_item->buffers_defined) {
1104     Buffer& old_buffer = buffers_.at(old_buffer_id);
1105 
1106     UsesList placed_users;
1107     UsesList unplaced_users;
1108     for (ItemUse& user : old_buffer.users) {
1109       if (user.user->placed) {
1110         placed_users.push_back(user);
1111       } else {
1112         // We keep only the indirect users that are in the provided list.
1113         // We consider all the other dead and remove any buffer use they might
1114         // perform and remove it from the buffer user list.
1115         if (!IsSupportedIndirectUser(user.user->instruction) ||
1116             indirect_users_set.contains(user.user)) {
1117           unplaced_users.push_back(user);
1118         } else {
1119           CHECK(user.user->buffers_defined.empty())
1120               << "Buffers defined expected to be empty for use passthrough "
1121                  "instructions";
1122           user.user->buffers_output.clear();
1123           user.user->buffers_used.clear();
1124         }
1125       }
1126     }
1127     old_buffer.users = std::move(placed_users);
1128     old_buffer.unfinished_user_count = 0;
1129 
1130     // Buffer is now dead.
1131     memory_usage_ -= AllocatedSize(old_buffer.id);
1132 
1133     Buffer& new_buffer =
1134         RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
1135 
1136     remat_item->buffers_defined.push_back(new_buffer.id);
1137     auto update_buffers = [old_buffer_id, new_buffer_id = new_buffer.id](
1138                               BufferIdList& to_update) {
1139       std::replace(to_update.begin(), to_update.end(), old_buffer_id,
1140                    new_buffer_id);
1141     };
1142     // Update users with the id of the new buffer.
1143     for (ItemUse& user : new_buffer.users) {
1144       update_buffers(user.user->buffers_used);
1145       update_buffers(user.user->buffers_output);
1146     }
1147   }
1148 
1149   // Update the indirect users with the id of the new buffers.
1150   for (Item* indirect_user : indirect_users) {
1151     // Source of the buffers that are gonna be passthrough.
1152     const Item* source_item =
1153         instruction_list_.GetItem(indirect_user->instruction->operand(0));
1154     switch (indirect_user->instruction->opcode()) {
1155       case HloOpcode::kBitcast: {
1156         // If the source is another indirect user then copy the output
1157         // in the used and output lists of the bitcast as they don't define any
1158         // buffer.
1159         if (IsSupportedIndirectUser(source_item->instruction)) {
1160           indirect_user->buffers_used = source_item->buffers_output;
1161           indirect_user->buffers_output = source_item->buffers_output;
1162         } else {
1163           // If it's a real instruction producing a buffer then copy the defined
1164           // buffers into used and output.
1165           indirect_user->buffers_used = source_item->buffers_defined;
1166           indirect_user->buffers_output = source_item->buffers_defined;
1167         }
1168         break;
1169       }
1170       case HloOpcode::kGetTupleElement: {
1171         // GTEs just use the tuple buffer and output the buffer they actually
1172         // extract from the tuple.
1173         const HloGetTupleElementInstruction* gte =
1174             Cast<HloGetTupleElementInstruction>(indirect_user->instruction);
1175         for (BufferId buffer_id : source_item->buffers_defined) {
1176           const Buffer& def_buffer = buffers_.at(buffer_id);
1177           if (def_buffer.index == ShapeIndex{gte->tuple_index()}) {
1178             indirect_user->buffers_output.push_back(buffer_id);
1179           }
1180           // This is the tuple buffer.
1181           if (def_buffer.index.empty()) {
1182             indirect_user->buffers_used.push_back(buffer_id);
1183           }
1184         }
1185         break;
1186       }
1187       default: {
1188         LOG(FATAL) << "Unsupported indirect instruction with opcode "
1189                    << HloOpcodeString(indirect_user->instruction->opcode());
1190         break;
1191       }
1192     }
1193     // Fixup buffer users for the indirect instructions. For GTEs is only the
1194     // tuple buffer, while for bitcast is the buffer they pass through.
1195     for (BufferId buffer_id : indirect_user->buffers_used) {
1196       Buffer& buffer = buffers_.at(buffer_id);
1197       buffer.unfinished_user_count++;
1198       buffer.users.push_back(ItemUse{indirect_user, 0, absl::nullopt});
1199     }
1200   }
1201 
1202   VLOG(3) << "  memory usage = " << memory_usage_;
1203   XLA_VLOG_LINES(10, ToString());
1204 
1205   DCHECK(Check());
1206 
1207   return Status::OK();
1208 }
1209 
ToString() const1210 string MemoryUsageTracker::ToString() const {
1211   string output =
1212       absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
1213   absl::StrAppend(&output,
1214                   "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
1215                   memory_usage(), " bytes)");
1216   for (auto* item = instruction_list_.first(); item != nullptr;
1217        item = instruction_list_.next(item)) {
1218     const HloInstruction* instruction = item->instruction;
1219     string inprogress = item == in_progress_item_ ? " in-progress" : "";
1220     string placed = item->placed ? " placed" : "";
1221     absl::StrAppend(&output, "  ", instruction->name(), inprogress, placed,
1222                     "\n    Defines:\n");
1223     for (BufferId buffer_id : item->buffers_defined) {
1224       const Buffer& buffer = buffers_[buffer_id];
1225       string live = IsCurrentlyLive(buffer_id) ? " live" : "";
1226       absl::StrAppend(&output, "      ", buffer.ToString(), live, ", ",
1227                       buffer.unfinished_user_count, " unfinished uses\n");
1228     }
1229     absl::StrAppend(&output, "    Outputs:\n");
1230     for (BufferId buffer_id : item->buffers_output) {
1231       absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
1232     }
1233     absl::StrAppend(&output, "    Uses:\n");
1234     for (BufferId buffer_id : item->buffers_used) {
1235       absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
1236     }
1237   }
1238   return output;
1239 }
1240 
GetCompactShape(const HloInstruction * hlo)1241 StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
1242   auto it = compact_shape_.find(hlo);
1243   if (it != compact_shape_.end()) {
1244     return it->second;
1245   }
1246   const Shape& original_shape = hlo->shape();
1247   TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
1248   compact_shape_[hlo] = min_shape;
1249   return min_shape;
1250 }
1251 
Check() const1252 bool MemoryUsageTracker::Check() const {
1253   auto elements_are_unique = [](const BufferIdList& vec) {
1254     return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
1255   };
1256 
1257   // Verify buffers_defined per instruction.
1258   for (auto* instruction : computation_->instructions()) {
1259     const BufferIdList& defined_buffers =
1260         instruction_list_.GetItem(instruction)->buffers_defined;
1261     CHECK(elements_are_unique(defined_buffers))
1262         << "Instruction " << instruction->name()
1263         << " does not have unique defined buffers: "
1264         << absl::StrJoin(
1265                defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
1266                  absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1267                });
1268 
1269     for (const Buffer& buffer : buffers_) {
1270       if (buffer.defining_instruction->instruction == instruction) {
1271         CHECK(absl::c_linear_search(defined_buffers, buffer.id))
1272             << "Instruction " << instruction->name()
1273             << " defined buffers is missing: " << buffer.ToString();
1274       }
1275     }
1276   }
1277 
1278   // Verify buffers_used per instruction.
1279   for (auto* instruction : computation_->instructions()) {
1280     const BufferIdList& used_buffers =
1281         instruction_list_.GetItem(instruction)->buffers_used;
1282     CHECK(elements_are_unique(used_buffers))
1283         << "Instruction " << instruction->name()
1284         << " does not have unique used buffers: "
1285         << absl::StrJoin(
1286                used_buffers, ", ", [this](string* out, BufferId buffer_id) {
1287                  absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1288                });
1289   }
1290   for (const Buffer& buffer : buffers_) {
1291     int64 unfinished_uses = 0;
1292     absl::flat_hash_set<Item*> already_counted_user;
1293     for (const ItemUse& user : buffer.users) {
1294       const BufferIdList& used_buffers = user.user->buffers_used;
1295       CHECK(absl::c_linear_search(used_buffers, buffer.id))
1296           << "Instruction " << user.user->instruction->name()
1297           << " used buffers is missing " << buffer.ToString();
1298       if (!IsFinished(user.user) &&
1299           already_counted_user.insert(user.user).second) {
1300         unfinished_uses++;
1301       }
1302     }
1303     CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
1304         << "Incorrect unplaced use count for " << buffer.ToString();
1305   }
1306   return true;
1307 }
1308 
1309 // Computes and returns the cost of rematerializing the given instruction.
1310 // Cost per rematerialized instruction is defined as:
1311 //
1312 // memory_limit_bytes / memory_reduced
1313 //
1314 // The idea is to choose the operation that will save the most memory for
1315 // rematerialization and do not worry about how much the compute costs since
1316 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64 memory_reduced,int64 memory_limit_bytes)1317 int64 RematerializationCost(const HloInstruction* instruction,
1318                             const MemoryUsageTracker& memory_tracker,
1319                             int64 memory_reduced, int64 memory_limit_bytes) {
1320   // If none of the users of 'instruction' have been placed in the sequence (as
1321   // tracked by memory_tracker), then rematerialization of 'instruction' is a
1322   // zero-cost move of 'instruction' in the sequence.
1323   if (!absl::c_any_of(instruction->users(),
1324                       [&memory_tracker](const HloInstruction* inst) {
1325                         return memory_tracker.IsPlaced(inst);
1326                       })) {
1327     return 0;
1328   }
1329 
1330   CHECK_GT(memory_reduced, 0);
1331   // Return the inverse of the benefit of rematerialization.
1332   return memory_limit_bytes / memory_reduced;
1333 }
1334 
1335 // Returns a block of up to min_block_size consecutive candidate instructions
1336 // from instruction_list starting from start_item. Returns fewer than
1337 // min_block_size instructions if the block of unplaced instructions starting
1338 // from start_item is smaller than min_block_size.
GetInitialBlock(const InstructionList & instruction_list,const MemoryUsageTracker & tracker,Item * start_item,int min_block_size)1339 std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
1340                                    const MemoryUsageTracker& tracker,
1341                                    Item* start_item, int min_block_size) {
1342   std::vector<Item*> item_block;
1343   Item* curr_item = start_item;
1344   for (int i = 0; i < min_block_size; ++i) {
1345     if (curr_item == nullptr || !curr_item->placed ||
1346         tracker.IsInProgressItem(curr_item)) {
1347       break;
1348     }
1349     item_block.push_back(curr_item);
1350     curr_item = instruction_list.next(curr_item);
1351   }
1352   return item_block;
1353 }
1354 
1355 // Returns whether any instruction in 'block' is denylisted or
1356 // non-rematerializable.
AnyDenylistedOrNonRematerializable(const std::vector<Item * > & block,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)1357 bool AnyDenylistedOrNonRematerializable(
1358     const std::vector<Item*>& block,
1359     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
1360   for (auto* item : block) {
1361     if (item->denylisted) {
1362       return true;
1363     }
1364     if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
1365       return true;
1366     }
1367   }
1368   return false;
1369 }
1370 
1371 std::pair<std::vector<Item*>, RematStrategy>
PickRematerializationCandidates(const InstructionList & instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,int min_block_size,int max_block_size)1372 MemoryUsageTracker::PickRematerializationCandidates(
1373     const InstructionList& instruction_list, int64 memory_limit_bytes,
1374     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1375     int min_block_size, int max_block_size) {
1376   std::vector<Item*> best_items;
1377   int64 best_cost = 0;
1378   RematStrategy best_strategy;
1379 
1380   VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
1381           << max_block_size << "]";
1382 
1383   for (auto* start_item = instruction_list.first_skip_node();
1384        start_item != nullptr;
1385        start_item = instruction_list.next_skip_node(start_item)) {
1386     std::vector<Item*> block =
1387         GetInitialBlock(instruction_list, *this, start_item, min_block_size);
1388     if (block.size() < min_block_size) {
1389       // There are no more blocks of size at least min_block_size with unplaced
1390       // instructions.
1391       break;
1392     }
1393     // If any item in the starting block are denylisted or non-rematable, then
1394     // break and move on to next start_item (we can actually move to the last
1395     // invalid item in this block, but let's ignore that optimization for now).
1396     if (AnyDenylistedOrNonRematerializable(block, rematerializable_map)) {
1397       continue;
1398     }
1399     while (block.size() <= max_block_size) {
1400       // block size = 1 is treated separately since we consider compression in
1401       // this case only.
1402       if (block.size() == 1) {
1403         auto* item = block[0];
1404         auto* candidate = item->instruction;
1405         if (item->buffers_output.size() == 1 &&
1406             (mode_ ==
1407                  HloRematerialization::RematerializationMode::kCompressOnly ||
1408              mode_ == HloRematerialization::RematerializationMode::
1409                           kRecomputeAndCompress)) {
1410           // Only consider compressing single output instruction.
1411           const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
1412 
1413           if (item->placed && item != in_progress_item_ &&
1414               !output_buffer.live_out) {
1415             const Shape& original_shape = item->instruction->shape();
1416             if (original_shape.IsArray()) {
1417               Shape compact_shape =
1418                   GetCompactShape(item->instruction).ValueOrDie();
1419               const int64 memory_reduced =
1420                   MemoryReducedIfCompressed(item, compact_shape);
1421               if (memory_reduced > 0) {
1422                 const int64 cost = memory_limit_bytes / memory_reduced;
1423                 if (best_items.empty() || cost < best_cost) {
1424                   VLOG(3) << "candidate " << candidate->name() << "("
1425                           << candidate->ToShortString() << ")"
1426                           << " now best when compressed into "
1427                           << compact_shape.ToString(true);
1428                   RematStrategy strategy;
1429                   strategy.kind = RematStrategy::kCompress;
1430                   best_strategy = strategy;
1431                   best_strategy.compact_shape = compact_shape;
1432                   best_items = block;
1433                   best_cost = cost;
1434                 }
1435               }
1436             }
1437           }
1438         }
1439       }
1440       // Do not consider recomputation in compress-only mode.
1441       if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
1442         // break out of this loop. Move on to the next start_item.
1443         break;
1444       }
1445       // If any of the candidate's control successor has been placed, we need
1446       // to skip this candidate. Otherwise we will violate control dependency.
1447       bool control_successor_placed = false;
1448       for (auto* item : block) {
1449         HloInstruction* candidate = item->instruction;
1450         if (std::any_of(candidate->control_successors().begin(),
1451                         candidate->control_successors().end(),
1452                         [this](const HloInstruction* inst) {
1453                           return IsPlaced(inst);
1454                         })) {
1455           control_successor_placed = true;
1456           break;
1457         }
1458       }
1459       if (control_successor_placed) {
1460         // break out of this loop. Move on to the next start_item.
1461         break;
1462       }
1463       VLOG(5) << "Block contains:";
1464       for (auto* hlo : block) {
1465         VLOG(5) << hlo->instruction->name();
1466       }
1467       const int64 memory_reduced = MemoryReducedIfRematerialized(block);
1468 
1469       if (memory_reduced > 0) {
1470         const int cost =
1471             RematerializationCost(block, memory_reduced, memory_limit_bytes);
1472 
1473         VLOG(5) << "Candidate block of size " << block.size()
1474                 << " starting from " << block[0]->instruction->name()
1475                 << ", memory reduced " << memory_reduced << ", cost per byte "
1476                 << cost;
1477 
1478         if (best_items.empty() || cost < best_cost) {
1479           VLOG(5) << "Candidate block of size " << block.size()
1480                   << " starting from " << block[0]->instruction->name()
1481                   << " now best";
1482           best_strategy.kind = RematStrategy::kRecompute;
1483           best_items = block;
1484           best_cost = cost;
1485         }
1486       }
1487 
1488       // Time to update the block to include the next instruction.
1489       auto* last_item = block[block.size() - 1];
1490       auto* next_item = instruction_list.next(last_item);
1491       if (next_item == nullptr || next_item->denylisted || !next_item->placed ||
1492           next_item == in_progress_item_ ||
1493           !CanBeRematerialized(next_item->instruction, rematerializable_map)) {
1494         break;
1495       }
1496       block.push_back(next_item);
1497     }
1498   }
1499   return {best_items, best_strategy};
1500 }
1501 
HasUnplacedUsers(Item * item) const1502 bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
1503   for (BufferId buffer_id : item->buffers_defined) {
1504     const Buffer& buffer = buffers_.at(buffer_id);
1505     for (const ItemUse& user : buffer.users) {
1506       if (!user.user->placed) {
1507         return true;
1508       }
1509     }
1510   }
1511   return false;
1512 }
1513 
GetItemUses(Item * item) const1514 const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
1515   UsesList combined_users;
1516   for (BufferId buffer_id : item->buffers_defined) {
1517     const Buffer& buffer = buffers_.at(buffer_id);
1518     for (const ItemUse& user : buffer.users) {
1519       combined_users.push_back(user);
1520     }
1521   }
1522   return combined_users;
1523 }
1524 
RematerializeInstructions(MemoryUsageTracker * memory_tracker,std::vector<Item * > * best_items,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,InstructionList * instruction_list)1525 StatusOr<int64> RematerializeInstructions(
1526     MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
1527     absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1528     InstructionList* instruction_list) {
1529   int64 net_instructions_added = 0;
1530   int64 total_memory_saved =
1531       memory_tracker->MemoryReducedIfRematerialized(*best_items);
1532   std::vector<string> instruction_names(best_items->size());
1533   // Rematerialize the block of instructions in the reverse order to account for
1534   // dependencies between instructions in best_items.
1535   for (int i = best_items->size() - 1; i >= 0; --i) {
1536     Item* best_item = (*best_items)[i];
1537     HloInstruction* best = best_item->instruction;
1538     instruction_names[i] = best->name();
1539     HloComputation* computation = best->parent();
1540 
1541     // If the item to remat has no unplaced users, then skip the
1542     // rematerialization. Such an instruction can appear in best_items because
1543     // it is part of a good block, but does not itself add any benefit.
1544     if (!memory_tracker->HasUnplacedUsers(best_item)) {
1545       continue;
1546     }
1547 
1548     HloInstruction* remat =
1549         computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1550 
1551     // Add control dependencies to the new operation.
1552     for (auto successor : best->control_successors()) {
1553       TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1554     }
1555     for (auto predecessor : best->control_predecessors()) {
1556       TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1557     }
1558 
1559     Item* remat_item = instruction_list->CreateItem(remat);
1560 
1561     // Replace each remaining use of 'best' with the rematerialization.
1562     absl::InlinedVector<Item*, 4> indirect_users;
1563     absl::flat_hash_map<int64, HloInstruction*> gte_cache;
1564     for (auto& user : memory_tracker->GetItemUses(best_item)) {
1565       if (!memory_tracker->IsPlaced(user.user->instruction)) {
1566         VLOG(2) << "  Replacing use of " << best->name() << " in "
1567                 << user.user->instruction->name() << " with " << remat->name();
1568         const int64 op_idx = user.operand_number;
1569         HloInstruction* remat_use = remat;
1570         if (user.index) {
1571           auto cached_gte = gte_cache.find(*user.index);
1572           if (cached_gte == gte_cache.end()) {
1573             remat_use = computation->AddInstruction(
1574                 HloInstruction::CreateGetTupleElement(
1575                     ShapeUtil::GetTupleElementShape(remat_use->shape(),
1576                                                     *user.index),
1577                     remat_use, *user.index));
1578             indirect_users.push_back(instruction_list->CreateItem(remat_use));
1579             gte_cache[*user.index] = remat_use;
1580           } else {
1581             remat_use = cached_gte->second;
1582           }
1583         }
1584         if (user.user->instruction->operand(op_idx)->shape() !=
1585             remat_use->shape()) {
1586           remat_use = computation->AddInstruction(HloInstruction::CreateBitcast(
1587               user.user->instruction->operand(op_idx)->shape(), remat_use));
1588           indirect_users.push_back(instruction_list->CreateItem(remat_use));
1589         }
1590         TF_RETURN_IF_ERROR(
1591             user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
1592       }
1593     }
1594 
1595     // Account for the rematerialization in the memory tracker.
1596     TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
1597         best_item, remat_item, absl::MakeSpan(indirect_users)));
1598 
1599     // Insert rematerialized instruction right before the earliest unplaced
1600     // use of the instruction *and* the earliest unplaced last use of any
1601     // operands of remat. Unplaced uses of the remat's operands are included
1602     // because we don't want to extend the live range of remat's operands as
1603     // this could increase memory usage.
1604     ItemList place_before;
1605     const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1606                                                         indirect_users.end());
1607     for (auto user : remat->users()) {
1608       if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1609         place_before.push_back(instruction_list->GetItem(user));
1610       }
1611     }
1612     for (auto* indirect_user : indirect_users) {
1613       for (auto user : indirect_user->instruction->users()) {
1614         if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1615           place_before.push_back(instruction_list->GetItem(user));
1616         }
1617       }
1618     }
1619     for (auto* operand : remat->operands()) {
1620       for (auto* operand_user : operand->users()) {
1621         if (operand_user != remat) {
1622           Item* operand_user_item = instruction_list->GetItem(operand_user);
1623           if (!operand_user_item->placed) {
1624             place_before.push_back(operand_user_item);
1625           }
1626         }
1627       }
1628     }
1629     // Insert rematerialized instruction before any of its successors to
1630     // preserve ordering regarding control dependency.
1631     for (auto successor : remat->control_successors()) {
1632       Item* successor_item = instruction_list->GetItem(successor);
1633       // Assert to make sure we never remat an operation with control
1634       // successor already placed.
1635       CHECK(!successor_item->placed) << successor_item->instruction->name();
1636       place_before.push_back(successor_item);
1637     }
1638     instruction_list->InsertBeforeInstructions(remat_item, place_before);
1639 
1640     for (auto* bitcast : indirect_users) {
1641       instruction_list->InsertBeforeInstructions(bitcast, place_before);
1642     }
1643     // Helper function that looks through indirect users when determining if
1644     // there is an active user for an HloInstruction.
1645     std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
1646       for (auto* u : i->users()) {
1647         if (!IsSupportedIndirectUser(u) || !uses_empty(u)) {
1648           return false;
1649         }
1650       }
1651       return true;
1652     };
1653     // If the rematerialized instruction is dead then rematerialization is
1654     // essentially a move. Don't delete the instruction now because we don't
1655     // want duplicate HloInstruction* values during the course of the
1656     // transformation because we keep maps with HloInstruction* values as
1657     // keys.
1658     if (uses_empty(best)) {
1659       VLOG(2) << best->name() << " is now dead";
1660       if (ContainsKey(*remat_move_instructions, best)) {
1661         // Previously, 'best' was a rematerialization which killed the
1662         // instruction it was a copying of. Now 'remat' is a rematerialization
1663         // of 'best' and kills 'best'. Stop rematerializing this instruction
1664         // to avoid an infinite loop.
1665         instruction_list->Denylist(remat);
1666       }
1667       remat_move_instructions->insert(remat);
1668       net_instructions_added += indirect_users.size();
1669     } else {
1670       net_instructions_added += indirect_users.size() + 1;
1671     }
1672     for (auto* indirect_user : indirect_users) {
1673       instruction_list->Denylist(indirect_user->instruction);
1674     }
1675   }
1676   VLOG(1) << "Rematerializing instructions ["
1677           << absl::StrJoin(instruction_names, ", ") << "] (saving "
1678           << HumanReadableNumBytes(total_memory_saved) << ")";
1679   return net_instructions_added;
1680 }
1681 
CompressInstruction(MemoryUsageTracker * memory_tracker,Item * best_item,const Shape & compact_shape,InstructionList * instruction_list)1682 StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
1683                                     Item* best_item, const Shape& compact_shape,
1684                                     InstructionList* instruction_list) {
1685   HloInstruction* best = best_item->instruction;
1686   VLOG(5) << "Transposing instruction " << best->name() << " (saving "
1687           << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1688                  best_item, compact_shape))
1689           << ") to" << compact_shape.ToString(true);
1690 
1691   HloComputation* computation = best->parent();
1692   HloInstruction* compressed = computation->AddInstruction(
1693       HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best),
1694       /*new_name=*/best->name() + ".remat_compressed");
1695 
1696   HloInstruction* uncompressed = computation->AddInstruction(
1697       HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed),
1698       /*new_name=*/best->name() + ".remat_uncompressed");
1699 
1700   Item* compressed_item = instruction_list->CreateItem(compressed);
1701   compressed_item->placed = true;
1702 
1703   Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
1704 
1705   // Replace each remaining use of 'best' with the uncompressed.
1706   std::vector<HloInstruction*> best_users_copy = best->users();
1707   for (HloInstruction* user : best_users_copy) {
1708     if (!memory_tracker->IsPlaced(user)) {
1709       VLOG(5) << "  Replacing use of " << best->name() << " in " << user->name()
1710               << " with " << uncompressed->name();
1711       TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
1712     }
1713   }
1714 
1715   // Account for the rematerialization in the memory tracker.
1716   TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
1717       best_item, compressed_item, uncompressed_item));
1718 
1719   // Insert rematerialized instruction right before the earliest unplaced
1720   // use of the instruction.
1721   ItemList place_before;
1722   for (auto user : uncompressed->users()) {
1723     place_before.push_back(instruction_list->GetItem(user));
1724   }
1725 
1726   instruction_list->Denylist(compressed_item->instruction);
1727   instruction_list->Denylist(uncompressed_item->instruction);
1728 
1729   instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
1730 
1731   instruction_list->InsertAfterInstructions(compressed_item, {best_item});
1732 
1733   return 2;
1734 }
1735 
1736 // A simple struct to encapsulate the number of instructions added during
1737 // rematerialization.
1738 struct InstructionsAdded {
1739   // Total count of instructions rematerialized.
1740   int remat_count;
1741   // Total count of instructions rematerialized minus number of original
1742   // instructions that are now dead.
1743   int net_instructions_added;
1744 };
1745 
1746 // Rematerializes the best block of instructions of size between min_block_size
1747 // and max_block_size (both inclusive) if at least one candidate block of
1748 // instructions can be found. Returns number of instructions rematerialized.
RematerializeBestBlock(int min_block_size,int max_block_size,MemoryUsageTracker * memory_tracker,InstructionList * instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions)1749 StatusOr<InstructionsAdded> RematerializeBestBlock(
1750     int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
1751     InstructionList* instruction_list, int64 memory_limit_bytes,
1752     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1753     absl::flat_hash_set<const HloInstruction*>* remat_move_instructions) {
1754   CHECK(min_block_size > 0) << "Negative block size.";
1755 
1756   std::vector<Item*> best_items;
1757   RematStrategy best_strategy;
1758   std::tie(best_items, best_strategy) =
1759       memory_tracker->PickRematerializationCandidates(
1760           *instruction_list, memory_limit_bytes, rematerializable_map,
1761           min_block_size, max_block_size);
1762   InstructionsAdded num_instructions_added;
1763   num_instructions_added.remat_count = best_items.size();
1764   if (best_items.empty()) {
1765     num_instructions_added.net_instructions_added = 0;
1766     return num_instructions_added;
1767   }
1768 
1769   if (best_strategy.kind == RematStrategy::kCompress) {
1770     CHECK(best_items.size() == 1)
1771         << "More than one instruction compressed simultaneously.";
1772     HloInstruction* best = best_items[0]->instruction;
1773     VLOG(1) << "Compressing instruction " << best->name() << " (saving "
1774             << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1775                    best_items[0], best_strategy.compact_shape))
1776             << ")";
1777 
1778     TF_ASSIGN_OR_RETURN(
1779         num_instructions_added.net_instructions_added,
1780         CompressInstruction(memory_tracker, best_items[0],
1781                             best_strategy.compact_shape, instruction_list));
1782   } else {
1783     TF_ASSIGN_OR_RETURN(
1784         num_instructions_added.net_instructions_added,
1785         RematerializeInstructions(memory_tracker, &best_items,
1786                                   remat_move_instructions, instruction_list));
1787   }
1788   return num_instructions_added;
1789 }
1790 }  // namespace
1791 
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order) const1792 StatusOr<int64> HloRematerialization::ComputePeakMemory(
1793     const HloComputation* computation,
1794     const HloInstructionSequence& order) const {
1795   InstructionList instruction_list(order);
1796   MemoryUsageTracker tracker(computation, size_function_,
1797                              compact_shape_function_, *points_to_analysis_,
1798                              instruction_list, mode_);
1799   int64 peak_memory = tracker.memory_usage();
1800   for (auto* item = instruction_list.first(); item != nullptr;
1801        item = instruction_list.next(item)) {
1802     const HloInstruction* instruction = item->instruction;
1803     TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
1804     TF_ASSIGN_OR_RETURN(int64 callee_usage,
1805                         CalledComputationsMemoryUsage(instruction));
1806     peak_memory =
1807         std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage);
1808     TF_RETURN_IF_ERROR(tracker.EndInstruction());
1809   }
1810   VLOG(1) << "Peak memory for " << computation->name() << ": "
1811           << HumanReadableNumBytes(peak_memory);
1812   return peak_memory;
1813 }
1814 
CalledComputationsMemoryUsage(const HloInstruction * instruction) const1815 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
1816     const HloInstruction* instruction) const {
1817   const CallSite* callsite =
1818       call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
1819   if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
1820     return 0;
1821   }
1822   int64 callee_usage = 0;
1823   for (const HloComputation* computation : callsite->called_computations()) {
1824     TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
1825     callee_usage += computation_peak_memory_.at(computation);
1826   }
1827   return callee_usage;
1828 }
1829 
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64 memory_limit_bytes,int64 min_remat_size)1830 StatusOr<bool> HloRematerialization::RematerializeComputation(
1831     HloComputation* computation, HloSchedule* schedule,
1832     int64 memory_limit_bytes, int64 min_remat_size) {
1833   VLOG(1) << "Rematerializing computation " << computation->name()
1834           << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
1835   VLOG(1) << "peak memory usage is "
1836           << HumanReadableNumBytes(computation_peak_memory_.at(computation));
1837   CHECK(!ContainsKey(rematerialized_computations_, computation));
1838 
1839   InstructionList instruction_list(schedule->sequence(computation));
1840   MemoryUsageTracker memory_tracker(
1841       computation, size_function_, compact_shape_function_,
1842       *points_to_analysis_, instruction_list, mode_);
1843 
1844   instruction_list.PromoteNodesToSkip([&](Item* item) {
1845     return memory_tracker.AllocatedSize(item) >= min_remat_size;
1846   });
1847   bool changed = false;
1848 
1849   // If the rematerialization makes the source instruction dead, then the
1850   // rematerialization is added to 'remat_move_instructions' (the
1851   // rematerialization is essentially a move). If the next rematerialization of
1852   // the instruction is also a move then the rematerialization is added to the
1853   // denylist.
1854   absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
1855 
1856   // The map from instructions to their rematerializable status.
1857   absl::flat_hash_map<const HloInstruction*, bool> rematerializable_map;
1858 
1859   // The peak memory of the computation at any point in the instruction
1860   // sequence.
1861   int64 peak_memory = memory_tracker.memory_usage();
1862 
1863   // Total count of instructions rematerialized.
1864   int64 remat_count = 0;
1865   // Total count of clones created minus number of original rematerialized
1866   // instructions which are dead.
1867   int64 net_instructions_added = 0;
1868 
1869   const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1870 
1871   // Iterate through all instructions in the sequence. At each instruction
1872   // (program point) if memory_usage exceeds the specified limit then
1873   // rematerialize HLO instructions until memory_usage is reduced.
1874   int64 instruction_index = 0;
1875   for (auto* item = instruction_list.first(); item != nullptr;
1876        item = instruction_list.next(item)) {
1877     const HloInstruction* instruction = item->instruction;
1878     TF_ASSIGN_OR_RETURN(int64 callee_usage,
1879                         CalledComputationsMemoryUsage(instruction));
1880     TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1881 
1882     VLOG(2) << "Program point at " << instruction->name()
1883             << ", memory usage = " << memory_tracker.memory_usage()
1884             << ", callee usage = " << callee_usage << ", [" << instruction_index
1885             << "/" << instruction_list.size() << "]";
1886     instruction_index++;
1887 
1888     // Initialize both min_block_size and max_block_size to 1 so that only
1889     // single instruction rematerialization is considered first.
1890     int min_block_size = 1;
1891     int max_block_size = 1;
1892     // Only trigger rematerialization when the memory usage changes.
1893     if (memory_tracker.AllocatedSize(item) + callee_usage > 0) {
1894       while (memory_tracker.memory_usage() + callee_usage >
1895              memory_limit_bytes) {
1896         VLOG(2) << "Over memory limit at instruction " << instruction->name()
1897                 << ", using "
1898                 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1899                                          callee_usage)
1900                 << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1901 
1902         TF_ASSIGN_OR_RETURN(
1903             InstructionsAdded instructions_added,
1904             RematerializeBestBlock(min_block_size, max_block_size,
1905                                    &memory_tracker, &instruction_list,
1906                                    memory_limit_bytes, &rematerializable_map,
1907                                    &remat_move_instructions));
1908         net_instructions_added += instructions_added.net_instructions_added;
1909         remat_count += instructions_added.remat_count;
1910 
1911         VLOG(1) << "memory_usage after rematerialization = "
1912                 << HumanReadableNumBytes(memory_tracker.memory_usage());
1913         if (instructions_added.remat_count == 0) {
1914           // Unable to find a block to rematerialize.
1915           // Consider doubling the block size.
1916           min_block_size = max_block_size + 1;
1917           max_block_size = 2 * max_block_size;
1918         } else {
1919           // Found a valid block. Reset to start looking for single instructions
1920           // again.
1921           max_rematerialized_block_size_ =
1922               std::max(max_rematerialized_block_size_, max_block_size);
1923           changed = true;
1924           min_block_size = 1;
1925           max_block_size = 1;
1926         }
1927         if (max_block_size > block_size_limit_) {
1928           break;
1929         }
1930       }
1931     }
1932     const CallSite* callsite = call_graph_node.GetCallSite(instruction);
1933     if (callsite != nullptr &&
1934         callsite->context() == CallContext::kSequential &&
1935         memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1936       // Memory usage exceeds the limit. Try to rematerialize any
1937       // subcomputation(s) that this instruction calls.
1938       VLOG(1) << "Memory usage still over the limit ("
1939               << (memory_tracker.memory_usage() + callee_usage) << " > "
1940               << memory_limit_bytes
1941               << "). Rematerializing computations called by "
1942               << instruction->name();
1943 
1944       // Recompute callee usage to account for any rematerialization performed
1945       // in the callee computations.
1946       for (HloComputation* called_computation :
1947            callsite->called_computations()) {
1948         if (!ContainsKey(rematerialized_computations_, called_computation)) {
1949           // Memory limit for the subcomputation is the memory limit less the
1950           // amount of memory used at this point in the computation.
1951           int64 subcomputation_memory_limit_bytes = std::max<int64>(
1952               0, memory_limit_bytes - memory_tracker.memory_usage());
1953           TF_ASSIGN_OR_RETURN(
1954               bool subcomputation_changed,
1955               RematerializeComputation(called_computation, schedule,
1956                                        subcomputation_memory_limit_bytes,
1957                                        min_remat_size));
1958           changed |= subcomputation_changed;
1959         }
1960       }
1961 
1962       TF_ASSIGN_OR_RETURN(callee_usage,
1963                           CalledComputationsMemoryUsage(instruction));
1964     }
1965 
1966     peak_memory = std::max<int64>(peak_memory,
1967                                   memory_tracker.memory_usage() + callee_usage);
1968     VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
1969 
1970     TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
1971   }
1972 
1973   // Verify some invariants on the memory tracker.
1974   for (auto* instruction : computation->instructions()) {
1975     CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
1976   }
1977 
1978   VLOG(1) << "In computation " << computation->name() << " rematerialized "
1979           << remat_count << " instructions; " << net_instructions_added
1980           << " net instructions added";
1981   VLOG(1) << "  peak memory usage now " << HumanReadableNumBytes(peak_memory)
1982           << " (was "
1983           << HumanReadableNumBytes(computation_peak_memory_.at(computation))
1984           << ")";
1985 
1986   // Update peak memory used by computation.
1987   computation_peak_memory_.at(computation) = peak_memory;
1988 
1989   // Update order to include rematerialized instructions.
1990   HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
1991   sequence.clear();
1992   for (auto* item = instruction_list.first(); item != nullptr;
1993        item = instruction_list.next(item)) {
1994     HloInstruction* instruction = item->instruction;
1995     sequence.push_back(instruction);
1996   }
1997   rematerialized_computations_.insert(computation);
1998 
1999   instructions_rematerialized_ += remat_count;
2000   net_instructions_added_ += net_instructions_added;
2001 
2002   return changed;
2003 }
2004 
Run(HloModule * module)2005 StatusOr<bool> HloRematerialization::Run(HloModule* module) {
2006   VLOG(1) << "HloRematerialization() with memory limit of "
2007           << HumanReadableNumBytes(memory_limit_bytes_);
2008   XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
2009 
2010   // Initialize pass object state.
2011   computation_peak_memory_.clear();
2012   rematerialized_computations_.clear();
2013   instructions_rematerialized_ = 0;
2014   net_instructions_added_ = 0;
2015 
2016   TF_RET_CHECK(module->has_schedule());
2017   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
2018 
2019   // Adjust memory limit to account for the output of the entry
2020   // computation. This is necessary because the per-computation accounting in
2021   // MemoryUsageTracker do not include output as these are typically allocated
2022   // by the caller.
2023   int64 module_output_size = 0;
2024   ShapeUtil::ForEachSubshape(
2025       module->result_shape(),
2026       [&module_output_size, module, this](const Shape& subshape,
2027                                           const ShapeIndex& output_index) {
2028         module_output_size += size_function_(subshape);
2029       });
2030 
2031   const int64 adjusted_memory_limit_bytes =
2032       memory_limit_bytes_ - module_output_size;
2033   VLOG(1) << "Adjusted memory limit accounting for output ("
2034           << HumanReadableNumBytes(module_output_size)
2035           << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
2036 
2037   // Compute peak memory usage of all computations in the module called in a
2038   // sequential context.
2039   call_graph_ = CallGraph::Build(module);
2040   TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
2041       [this, module](const CallGraphNode& node) -> Status {
2042         if (node.context() == CallContext::kSequential) {
2043           TF_ASSIGN_OR_RETURN(
2044               computation_peak_memory_[node.computation()],
2045               ComputePeakMemory(node.computation(), module->schedule().sequence(
2046                                                         node.computation())));
2047         }
2048         return Status::OK();
2049       },
2050       /*visit_unreachable_nodes=*/false));
2051 
2052   // The peak memory usage of the module equals the peak memory use of the entry
2053   // computation plus the output size of the computation. This is because the
2054   // peak memory for a computation does not include the output as this is
2055   // typically accounted for in the caller.
2056   const int64 before_peak_memory =
2057       computation_peak_memory_.at(module->entry_computation()) +
2058       module_output_size;
2059   VLOG(1) << "Peak memory usage of module (before): "
2060           << HumanReadableNumBytes(before_peak_memory);
2061   // Subcomputations called by the entry computation will also be
2062   // rematerialized.
2063   TF_ASSIGN_OR_RETURN(
2064       bool changed,
2065       RematerializeComputation(module->entry_computation(), &module->schedule(),
2066                                adjusted_memory_limit_bytes, min_remat_size_));
2067   // Rematerialization can introduce dead code. This occurs if all uses of an
2068   // instruction are replaced with rematerializations of the instruction.
2069 
2070   // Stash away the schedule during copy insertion, to avoid validation failures
2071   // while the module is in flux.
2072   HloSchedule saved_schedule = module->schedule();
2073   module->clear_schedule();
2074   TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
2075   changed |= dead_code_removed;
2076 
2077   // After DCE, the module sequence may include instructions which no longer
2078   // exist. Update the schedule and restore it.
2079   TF_RETURN_IF_ERROR(saved_schedule.Update());
2080   TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule)));
2081   VLOG(1) << "Rematerialized " << instructions_rematerialized_
2082           << " instructions in module " << module->name() << "; "
2083           << net_instructions_added_ << " net instructions added";
2084   const int64 current_peak_memory =
2085       computation_peak_memory_.at(module->entry_computation()) +
2086       module_output_size;
2087   VLOG(1) << "Peak memory usage of module now "
2088           << HumanReadableNumBytes(current_peak_memory) << " ("
2089           << current_peak_memory << " bytes), was "
2090           << HumanReadableNumBytes(before_peak_memory) << " ("
2091           << before_peak_memory << " bytes)";
2092   const int64 reduced_peak_memory = before_peak_memory - current_peak_memory;
2093   VLOG(1) << "Reduced peak memory by "
2094           << HumanReadableNumBytes(reduced_peak_memory) << " ("
2095           << reduced_peak_memory << " bytes)";
2096 
2097   if (sizes_ != nullptr) {
2098     sizes_->before_bytes = before_peak_memory;
2099     sizes_->after_bytes = current_peak_memory;
2100   }
2101 
2102   XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
2103 
2104   if (current_peak_memory > memory_limit_bytes_) {
2105     LOG(WARNING) << absl::StrFormat(
2106         "Can't reduce memory use below %s (%d bytes) by rematerialization; "
2107         "only reduced to %s (%d bytes)",
2108         HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
2109         HumanReadableNumBytes(current_peak_memory), current_peak_memory);
2110   }
2111   return changed;
2112 }
2113 
2114 }  // namespace xla
2115