• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <string>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 namespace {
49 
50 using ::tensorflow::strings::HumanReadableNumBytes;
51 
52 // Potential optimizations:
53 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
54 //   of candidates.
55 // . Cache IsRematerializable in Item?  Only correct if control
56 //   predecessors and successors don't change.
57 
58 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)59 bool IsRematerializable(const HloInstruction* instruction) {
60   if (instruction->opcode() == HloOpcode::kCopy) {
61     if (LayoutUtil::Equal(instruction->shape().layout(),
62                           instruction->operand(0)->shape().layout())) {
63       // Don't rematerialize copies added by copy insertion (layout doesn't
64       // change).
65       return false;
66     }
67   }
68 
69   // Don't rematerialize instructions with side effects or instructions which
70   // cannot be cloned safely.
71   switch (instruction->opcode()) {
72     case HloOpcode::kCall:
73     case HloOpcode::kConstant:
74     case HloOpcode::kConditional:
75     case HloOpcode::kAllReduce:
76     case HloOpcode::kCustomCall:
77     case HloOpcode::kParameter:
78     case HloOpcode::kWhile:
79       return false;
80     default:
81       return !instruction->HasSideEffect();
82   }
83 }
84 
85 // Checks whether an instruction can be rematerialized, by looking up the
86 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)87 bool CanBeRematerialized(
88     const HloInstruction* instruction,
89     absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
90   auto it = remat_able->find(instruction);
91   if (it != remat_able->end()) {
92     return it->second;
93   }
94   bool rematerializable = IsRematerializable(instruction);
95   (*remat_able)[instruction] = rematerializable;
96   return rematerializable;
97 }
98 
99 // Type holding a unique identifier for each Buffer object.
100 using BufferId = int64;
101 using BufferIdList = absl::InlinedVector<BufferId, 3>;
102 
103 struct RematStrategy {
104   enum {
105     // Recompute the node at a later program point.
106     kRecompute,
107     // Change the layout into a compact form and uncompress it back at a later
108     // program point.
109     kCompress,
110   } kind;
111   Shape compact_shape;
112 };
113 
114 // We wrap HloInstruction* with an Item that holds auxiliary
115 // per-instruction state.
116 struct Item {
117   HloInstruction* instruction;
118 
119   // True once the instruction is marked as placed (when BeginInstruction
120   // has been called for this instruction).
121   bool placed = false;
122 
123   // To avoid an infinite loop rematerializing the same set of
124   // instructions ad infinitum, keep a blacklist of instructions
125   // which should not be rematerialized.
126   bool blacklisted = false;
127 
128   // The buffers defined by this instruction.
129   BufferIdList buffers_defined;
130 
131   // Output buffers of this instruction. This is used to track outputs by GTE
132   // instructions (where the instruction doesn't define a buffer).
133   BufferIdList buffers_output;
134 
135   // The buffers used by this instruction.
136   BufferIdList buffers_used;
137 
138  private:
139   friend class InstructionList;
140 
141   // Items are arranged in a doubly linked list.
142   Item* next;
143   Item* prev;
144 
145   // List is ordered by position, which can however be duplicated as
146   // new instructions are inserted.  See InsertBeforeInstructions
147   // comment for details.
148   int64 position;
149 };
150 
151 using ItemList = absl::InlinedVector<Item*, 3>;
152 
153 // Class which maintains an ordered list of instructions with fast insertion
154 // before arbitrary elements.
155 class InstructionList {
156  public:
InstructionList(const HloInstructionSequence & order)157   explicit InstructionList(const HloInstructionSequence& order) {
158     int64 position = 0;
159     Item* last = nullptr;
160     for (HloInstruction* inst : order.instructions()) {
161       // Add a new item to the linked list.
162       Item* item = new Item;
163       item->next = nullptr;
164       item->prev = last;
165       if (last == nullptr) {
166         first_ = item;
167       } else {
168         last->next = item;
169       }
170       last = item;
171 
172       // Initially position numbers are uniquely assigned in order. Later as
173       // instructions are added with InsertBefore* methods, some instructions
174       // may have duplicate position numbers, but the values will be guaranteed
175       // to be monotonically increasing through the list, and so is still useful
176       // for quickly(-ish) determining the order of arbitrary instructions in
177       // the list.
178       item->instruction = inst;
179       item->position = position;
180       position++;
181 
182       item_map_[inst] = item;
183     }
184   }
185 
~InstructionList()186   ~InstructionList() {
187     for (Item* item = first_; item != nullptr;) {
188       Item* next = item->next;
189       delete item;
190       item = next;
191     }
192   }
193 
size() const194   size_t size() const { return item_map_.size(); }
195 
196   // For ordered iteration over items.
197   //    for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const198   Item* first() const { return first_; }
next(Item * item) const199   Item* next(Item* item) const { return item->next; }
200 
201   // Creates an Item for the given instruction, but doesn't add it to the list.
202   // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)203   Item* CreateItem(HloInstruction* inst) {
204     Item* item = new Item;
205     item->instruction = inst;
206     CHECK(item_map_.insert({inst, item}).second)
207         << "inserting inst twice " << inst->name();
208     return item;
209   }
210 
211   // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const212   Item* GetItem(const HloInstruction* inst) const {
213     auto iter = item_map_.find(inst);
214     CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
215     return iter->second;
216   }
217 
218   // Insert instruction 'to_insert' immediately before the earliest instruction
219   // in 'before_instructions'.
220   //
221   // Each instruction gets a non-decreasing ordinal number. We use this to let
222   // InsertBeforeInstructions quickly insert an instruction before the earliest
223   // instruction in a set of instructions.  If position_number_[a] <
224   // position_number_[b] then 'a' comes before 'b' in the list. If the position
225   // numbers are the same then nothing can be said about their order without
226   // examining the list.
227   //
228   // On object construction this ordinal is precisely the instruction's index
229   // in the list. Later, instructions inserted via InsertBefore receive
230   // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)231   void InsertBeforeInstructions(Item* to_insert,
232                                 absl::Span<Item* const> before_instructions) {
233     VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
234             << " before {"
235             << absl::StrJoin(before_instructions, ", ",
236                              [](string* out, Item* item) {
237                                absl::StrAppend(out, item->instruction->name());
238                              })
239             << "}";
240 
241     // Find the minimal position number of any instruction in
242     // 'before_instructions'.
243     CHECK(!before_instructions.empty());
244     Item* min_position_item = nullptr;
245     for (Item* item : before_instructions) {
246       if (min_position_item == nullptr ||
247           item->position < min_position_item->position) {
248         min_position_item = item;
249       }
250     }
251 
252     // Because more than one instruction in 'before_instructions' may have a
253     // position number of 'min_position_number', find the first such instruction
254     // with position number 'min_position_number'.
255 
256     // First find first instruction with the min position.
257     while (min_position_item->prev != nullptr &&
258            min_position_item->position == min_position_item->prev->position) {
259       min_position_item = min_position_item->prev;
260     }
261 
262     // Now scan forwards until we find one of the before_instructions.
263     while (!absl::c_linear_search(before_instructions, min_position_item)) {
264       min_position_item = min_position_item->next;
265     }
266     return InsertBefore(to_insert, min_position_item);
267   }
268 
InsertAfterInstructions(Item * to_insert,absl::Span<Item * const> after_instructions)269   void InsertAfterInstructions(Item* to_insert,
270                                absl::Span<Item* const> after_instructions) {
271     VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
272             << " after {"
273             << absl::StrJoin(after_instructions, ", ",
274                              [](string* out, Item* item) {
275                                absl::StrAppend(out, item->instruction->name());
276                              })
277             << "}";
278 
279     // Find the max position number of any instruction in
280     // 'after_instructions'.
281     CHECK(!after_instructions.empty());
282     Item* max_position_item = nullptr;
283     for (Item* item : after_instructions) {
284       if (max_position_item == nullptr ||
285           item->position > max_position_item->position) {
286         max_position_item = item;
287       }
288     }
289     // No rematerializable instruction should be inserted at the end of the
290     // computation.
291     CHECK(max_position_item->next != nullptr);
292     InsertBeforeInstructions(to_insert, {max_position_item->next});
293   }
294 
Blacklist(const HloInstruction * inst)295   void Blacklist(const HloInstruction* inst) {
296     GetItem(inst)->blacklisted = true;
297   }
298 
299  private:
300   // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)301   void InsertBefore(Item* item, Item* before) {
302     VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
303             << before->instruction->name();
304     // Insert new item into linked list.
305     item->prev = before->prev;
306     item->next = before;
307     before->prev = item;
308     if (item->prev != nullptr) {
309       item->prev->next = item;
310     } else {
311       first_ = item;
312     }
313 
314     // Assign the same position number to the newly added instruction as
315     // 'before'. This guarantees monotonicity of the position numbers, but not
316     // uniqueness.
317     item->position = before->position;
318   }
319 
320   Item* first_;
321 
322   // Item for each instruction.
323   absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
324 };
325 
326 // Return the items which use the given LogicalBuffer. Sets
327 // has_indirect_users to whether any of the uses is indirect. A use is indirect
328 // if the instruction defining logical_buffer is not an operand of the use. This
329 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)330 ItemList GetUsers(const InstructionList& instruction_list,
331                   const LogicalBuffer* logical_buffer,
332                   const TuplePointsToAnalysis& points_to_analysis,
333                   bool* has_indirect_users) {
334   ItemList users;
335   // To identify uses iterate through all HloInstruction users of the
336   // BufferAliases of the logical buffer.
337   *has_indirect_users = false;
338   for (const BufferAlias& buffer_alias :
339        points_to_analysis.GetBufferAliases(*logical_buffer)) {
340     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
341       if (points_to_analysis.DoesNotUseOperandBuffer(
342               buffer_alias.instruction(), buffer_alias.index(), user)) {
343         // The alias may be an operand of 'user', but the LogicalBuffer cannot
344         // possibly be used by the instruction so ignore 'user'. This is the
345         // case, for example, for the tuple element buffers in a GetTupleElement
346         // instruction (the GTE instruction only uses the pointer vector).
347         continue;
348       }
349       if (buffer_alias.instruction() != logical_buffer->instruction()) {
350         *has_indirect_users = true;
351       }
352       // A buffer may be used by the instruction via more than one alias. For
353       // example, a buffer which appears in more than one element of a tuple.
354       Item* user_item = instruction_list.GetItem(user);
355       if (!absl::c_linear_search(users, user_item)) {
356         users.push_back(user_item);
357       }
358     }
359   }
360   return users;
361 }
362 
363 // Class for tracking memory usage of a computation as the instructions are
364 // placed sequentially. Memory usage is the sum of the sizes of live values
365 // (LogicalBuffers) at the current point in the instruction sequence.
366 class MemoryUsageTracker {
367  public:
368   MemoryUsageTracker(
369       const HloComputation* computation,
370       const HloRematerialization::ShapeSizeFunction& size_function,
371       const HloRematerialization::CompactShapeFunction& compact_shape_function,
372       const TuplePointsToAnalysis& points_to_analysis,
373       const InstructionList& instruction_list,
374       HloRematerialization::RematerializationMode mode);
375 
376   // Starts the placement of the given instruction. This adds the sizes of the
377   // LogicalBuffers defined by the instruction to the current memory
378   // usage. Placement is broken into two steps (BeginInstruction and
379   // EndInstruction) to accurately model memory usage. At BeginInstruction the
380   // memory for the output value(s) of the current instruction is allocated. At
381   // EndInstruction memory for dead operand(s) is freed.
382   Status BeginInstruction(Item* item);
383 
RematerializationCost(const HloInstruction * instruction,int64 memory_reduced,int64 memory_limit_bytes)384   int64 RematerializationCost(const HloInstruction* instruction,
385                               int64 memory_reduced, int64 memory_limit_bytes) {
386     // If none of the users of 'instruction' have been placed in the sequence
387     // (as tracked by memory_tracker), then rematerialization of 'instruction'
388     // is a zero-cost move of 'instruction' in the sequence.
389     if (!absl::c_any_of(
390             instruction->users(),
391             [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
392       return 0;
393     }
394 
395     CHECK_GT(memory_reduced, 0);
396     // Return the inverse of the benefit of rematerialization.
397     return memory_limit_bytes / memory_reduced;
398   }
399 
400   // Finishes the placement of the current instruction. This frees any dead
401   // operands or dead result of the instruction. This must be called after
402   // each call to BeginInstruction.
403   Status EndInstruction();
404 
405   // Returns the number of bytes that the current memory usage will be reduced
406   // if the given instruction is rematerialized.
407   int64 MemoryReducedIfRematerialized(Item* item) const;
408 
409   // Returns the number of bytes that the current memory usage will be reduced
410   // if the given instruction is compact.
411   int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
412 
413   // Returns the number of bytes that the current memory usage will be reduced
414   // by if the given sequence of instructions is rematerialized.
415   int64 MemoryReducedIfRematerialized(
416       absl::Span<const Item* const> items) const;
417 
418   Status AddCompressInstructions(Item* original_item, Item* compressed_item,
419                                  Item* uncompressed_item);
420 
421   // Adjusts memory usage to account for the rematerialization of
422   // original_item for all remaining unplaced uses. The rematerialization
423   // is remat_item. This method should be called after the HLO graph has
424   // been transformed (rematerialization instruction created and connected
425   // to uses).
426   Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
427 
428   std::pair<Item*, RematStrategy> PickRematerializationCandidate(
429       const InstructionList& instruction_list, int64 memory_limit_bytes,
430       absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
431 
432   // Returns whether the given instruction has been placed (BeginInstruction
433   // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const434   bool IsPlaced(const HloInstruction* instruction) const {
435     return instruction_list_.GetItem(instruction)->placed;
436   }
437 
438   // Returns whether 'item' has any unplaced users.
439   bool HasUnplacedUsers(Item* item) const;
440 
441   // Returns the current memory usage. This is the sum of sizes of all live
442   // values.
memory_usage() const443   int64 memory_usage() const { return memory_usage_; }
444 
445   // Check invariants of the data structure. This is expensive to call.
446   bool Check() const;
447 
448   string ToString() const;
449 
450  private:
451   // A Buffer represents a single LogicalBuffer in the computation including
452   // various metadata useful for tracking liveness of the value. A LogicalBuffer
453   // is not used directly because the HLO graph is transformed and
454   // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
455   // HLO graph transformations.
456   struct Buffer {
457     // The unique id of this Buffer. This value is equal to the buffer's index
458     // in the vector buffers_.
459     const BufferId id;
460 
461     // The instruction which defines this buffer.
462     Item* defining_instruction;
463 
464     // The materialized size of the buffer in bytes.
465     const int64 size;
466 
467     // Shape of the buffer.
468     Shape shape;
469 
470     // Whether this buffer is live-out of the computation.
471     bool live_out;
472 
473     // Whether this buffer has indirect uses. Ie, an instruction which is not a
474     // user of defining_instruction uses this buffer. This can occur due to
475     // buffer aliasing (eg, tuples).
476     bool has_indirect_uses;
477 
478     // The instructions which use this buffer.
479     ItemList users;
480 
481     // The number of users (HloInstructions) of this buffer which have not yet
482     // been placed in the sequence.
483     int64 unfinished_user_count;
484 
ToStringxla::__anon20d191180111::MemoryUsageTracker::Buffer485     string ToString() const {
486       return absl::StrCat("Buffer ", id, " (defined by ",
487                           defining_instruction->instruction->name(), ", size ",
488                           size, " bytes)");
489     }
490   };
491 
492   // Get the compact shape of given hlo instruction. An internal cache is used
493   // to avoid computing the shape multiple times.
494   StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
495 
496   // Creates a Buffer representing the given logical buffer. The buffer is added
497   // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool live_out)498   Buffer& CreateBufferFromLogicalBuffer(
499       const LogicalBuffer* logical_buffer,
500       const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
501     bool has_indirect_uses = false;
502     ItemList users = GetUsers(instruction_list_, logical_buffer,
503                               points_to_analysis, &has_indirect_uses);
504     return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
505                      logical_buffer->shape(), std::move(users), live_out,
506                      has_indirect_uses);
507   }
508 
509   // Create a new buffer representing a rematerialization of given buffer for
510   // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,ItemList && rematerialized_uses)511   Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
512                               ItemList&& rematerialized_uses) {
513     CHECK(original_buffer.defining_instruction->placed)
514         << original_buffer.defining_instruction->instruction->name();
515     CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
516     CHECK(!original_buffer.live_out) << original_buffer.ToString();
517     for (Item* use : rematerialized_uses) {
518       CHECK(!use->placed) << use->instruction->name();
519     }
520     return NewBuffer(remat_item, original_buffer.shape,
521                      std::move(rematerialized_uses), /*live_out=*/false,
522                      /*has_indirect_uses=*/false);
523   }
524 
525   // Return number of bytes allocated for the buffer with the given id. Buffers
526   // allocated by the calling computation (eg, parameter and output buffers) are
527   // considered to have zero bytes because the memory is accounted for in a
528   // different computation.
AllocatedSize(BufferId buffer_id) const529   int64 AllocatedSize(BufferId buffer_id) const {
530     const Buffer& buffer = buffers_.at(buffer_id);
531     HloInstruction* inst = buffer.defining_instruction->instruction;
532     HloOpcode def_opcode = inst->opcode();
533     if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
534       return 0;
535     } else {
536       return buffer.size;
537     }
538   }
539 
540   // Returns true if BeginInstruction and EndInstruction has been called for the
541   // given instruction.
IsFinished(Item * item) const542   bool IsFinished(Item* item) const {
543     return item->placed && item != in_progress_item_;
544   }
545 
546   // Returns whether the given buffer is being used by the in-progress
547   // instruction.
IsInUse(BufferId buffer_id) const548   bool IsInUse(BufferId buffer_id) const {
549     if (in_progress_item_ == nullptr) {
550       return false;
551     }
552     const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
553     return absl::c_linear_search(in_progress_uses, buffer_id);
554   }
555 
556   // Returns whether the given buffer is live at the current program
557   // point.
IsCurrentlyLive(BufferId buffer_id) const558   bool IsCurrentlyLive(BufferId buffer_id) const {
559     const Buffer& buffer = buffers_[buffer_id];
560     return (buffer.defining_instruction->placed &&
561             buffer.unfinished_user_count > 0);
562   }
563 
564   // Returns whether the given instruction is live at the current program
565   // point.
IsInstructionCurrentlyLive(Item * instruction) const566   bool IsInstructionCurrentlyLive(Item* instruction) const {
567     // If the instruction has not started yet, it is not alive.
568     if (!IsPlaced(instruction->instruction)) {
569       return false;
570     }
571     for (const HloInstruction* user : instruction->instruction->users()) {
572       if (!IsPlaced(user)) {
573         // If there is an unplaced user, consider this instruction currently
574         // live.
575         return true;
576       }
577     }
578     return false;
579   }
580 
581   // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,const Shape & shape,ItemList && users,bool live_out,bool has_indirect_uses)582   Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
583                     ItemList&& users, bool live_out, bool has_indirect_uses) {
584     int buffer_id = buffers_.size();
585     buffers_.push_back(Buffer{
586         buffer_id, defining_instruction, size_function_(shape), shape, live_out,
587         has_indirect_uses, users, static_cast<int64>(users.size())});
588     return buffers_.back();
589   }
590 
591   const HloComputation* computation_;
592 
593   // Instruction list containing the ordering of instructions in
594   // computation_. This is the order in which instructions are placed
595   // (BeginInstruction/EndInstruction calls).
596   const InstructionList& instruction_list_;
597 
598   // Size function returns the bytes of a given buffer.
599   const HloRematerialization::ShapeSizeFunction& size_function_;
600 
601   // Converts a shape into compact form, returns the same shape if a shape is
602   // already considered compact.
603   const HloRematerialization::CompactShapeFunction& compact_shape_function_;
604 
605   // A map that caches existing known compact shape for each instruction.
606   absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
607 
608   // Memory usage at the currently placed instruction.
609   int64 memory_usage_ = 0;
610 
611   // The instruction currently being placed. This value is non-null only
612   // between the calling of BeginInstruction and EndInstruction.
613   Item* in_progress_item_ = nullptr;
614 
615   HloRematerialization::RematerializationMode mode_;
616   // All buffers in the computation.
617   std::vector<Buffer> buffers_;
618 };
619 
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const HloRematerialization::CompactShapeFunction & compact_shape_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list,HloRematerialization::RematerializationMode mode)620 MemoryUsageTracker::MemoryUsageTracker(
621     const HloComputation* computation,
622     const HloRematerialization::ShapeSizeFunction& size_function,
623     const HloRematerialization::CompactShapeFunction& compact_shape_function,
624     const TuplePointsToAnalysis& points_to_analysis,
625     const InstructionList& instruction_list,
626     HloRematerialization::RematerializationMode mode)
627     : computation_(computation),
628       instruction_list_(instruction_list),
629       size_function_(size_function),
630       compact_shape_function_(compact_shape_function),
631       mode_(mode) {
632   PointsToSet::BufferSet live_out_set =
633       points_to_analysis.GetPointsToSet(computation_->root_instruction())
634           .CreateFlattenedSet();
635   absl::flat_hash_map<const LogicalBuffer*, BufferId>
636       logical_buffer_to_buffer_id;
637 
638   for (auto* item = instruction_list_.first(); item != nullptr;
639        item = instruction_list_.next(item)) {
640     const HloInstruction* const instruction = item->instruction;
641     for (const LogicalBuffer* logical_buffer :
642          points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
643       Buffer* buffer;
644       if (instruction->opcode() == HloOpcode::kWhile) {
645         // The while instruction defines no new buffers. Instead it reuses the
646         // buffers of its operand. Find the Buffer of its operand at the
647         // proper ShapeIndex.
648         const PointsToSet& operand_points_to =
649             points_to_analysis.GetPointsToSet(instruction->operand(0));
650         CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
651         const LogicalBuffer* source_logical_buffer =
652             operand_points_to.element(logical_buffer->index())[0];
653         buffer =
654             &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
655 
656         // Mark buffer as has indirect use and live out.
657         buffer->has_indirect_uses = true;
658         buffer->live_out =
659             buffer->live_out || ContainsKey(live_out_set, logical_buffer);
660 
661         // Add users of while to Buffer users.
662         bool unused;
663         for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
664                                         points_to_analysis, &unused)) {
665           if (!absl::c_linear_search(buffer->users, user_item)) {
666             buffer->users.push_back(user_item);
667             buffer->unfinished_user_count++;
668             user_item->buffers_used.push_back(buffer->id);
669           }
670         }
671       } else {
672         buffer = &CreateBufferFromLogicalBuffer(
673             logical_buffer, points_to_analysis,
674             ContainsKey(live_out_set, logical_buffer));
675         item->buffers_defined.push_back(buffer->id);
676         for (Item* user : buffer->users) {
677           user->buffers_used.push_back(buffer->id);
678         }
679       }
680 
681       logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
682     }
683 
684     // Trace the output of each instruction. This is so that we can properly
685     // track which outputs does GTEs have.
686     for (const LogicalBuffer* logical_buffer :
687          points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
688       item->buffers_output.push_back(
689           logical_buffer_to_buffer_id[logical_buffer]);
690     }
691   }
692   XLA_VLOG_LINES(10, ToString());
693   DCHECK(Check());
694 }
695 
BeginInstruction(Item * item)696 Status MemoryUsageTracker::BeginInstruction(Item* item) {
697   const HloInstruction* instruction = item->instruction;
698   VLOG(3) << "BeginInstruction " << instruction->name();
699   TF_RET_CHECK(in_progress_item_ == nullptr);
700   in_progress_item_ = item;
701 
702   item->placed = true;
703 
704   // All buffers defined by this instruction need memory.
705   for (BufferId buffer_id : item->buffers_defined) {
706     VLOG(3) << "  Buffer " << buffers_.at(buffer_id).ToString()
707             << " is now live.";
708     memory_usage_ += AllocatedSize(buffer_id);
709   }
710 
711   // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
712   // operand. Account for this potential reuse here.
713 
714   VLOG(3) << "  memory usage = " << memory_usage_;
715   VLOG(10) << ToString();
716 
717   if (VLOG_IS_ON(1)) {
718     DCHECK(Check());
719   }
720   return Status::OK();
721 }
722 
EndInstruction()723 Status MemoryUsageTracker::EndInstruction() {
724   TF_RET_CHECK(in_progress_item_ != nullptr);
725   VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
726 
727   for (BufferId buffer_id : in_progress_item_->buffers_used) {
728     Buffer& buffer = buffers_.at(buffer_id);
729     buffer.unfinished_user_count--;
730     CHECK_GE(buffer.unfinished_user_count, 0)
731         << buffer.ToString() << " has negative unfinished use count.";
732     if (buffer.unfinished_user_count == 0) {
733       // Buffer is now dead.
734       VLOG(3) << "  " << buffer.ToString() << " is now dead.";
735       memory_usage_ -= AllocatedSize(buffer_id);
736       // The memory usage can become negative inside the computation as we can
737       // free up the parameter space and reuse it for other tensors.
738     }
739   }
740 
741   // If any buffer defined by this instruction has no uses, then memory can be
742   // reclaimed immediately.
743   for (BufferId buffer_id : in_progress_item_->buffers_defined) {
744     const Buffer& buffer = buffers_.at(buffer_id);
745     if (buffer.unfinished_user_count == 0) {
746       VLOG(3) << "  " << buffer.ToString() << " is immediately dead.";
747       memory_usage_ -= AllocatedSize(buffer_id);
748       // The memory usage can become negative inside the computation as we can
749       // free up the parameter space and reuse it for other tensors.
750     }
751   }
752 
753   in_progress_item_ = nullptr;
754 
755   VLOG(3) << "  memory usage = " << memory_usage_;
756   VLOG(10) << ToString();
757 
758   if (VLOG_IS_ON(1)) {
759     DCHECK(Check());
760   }
761   return Status::OK();
762 }
763 
MemoryReducedIfCompressed(Item * item,const Shape & compact_shape) const764 int64 MemoryUsageTracker::MemoryReducedIfCompressed(
765     Item* item, const Shape& compact_shape) const {
766   CHECK_NE(in_progress_item_, nullptr);
767   if (!item->placed || item == in_progress_item_) {
768     return 0;
769   }
770 
771   int64 memory_reduced = 0;
772 
773   // We only compress a single piece of an output at one time.
774   CHECK_EQ(item->buffers_output.size(), 1);
775   BufferId buffer_id = item->buffers_output[0];
776   if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) &&
777       IsInstructionCurrentlyLive(item)) {
778     const Buffer& buffer = buffers_.at(buffer_id);
779     memory_reduced += buffer.size;
780 
781     int64 compact_shape_size = size_function_(compact_shape);
782     // Account for buffers that are compressed after instruction.
783     memory_reduced -= compact_shape_size;
784   }
785   return memory_reduced;
786 }
787 
MemoryReducedIfRematerialized(Item * item) const788 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
789   CHECK_NE(in_progress_item_, nullptr);
790   if (!item->placed || item == in_progress_item_) {
791     return 0;
792   }
793 
794   // TODO(b/37687140): Rematerialization can increase peak memory consumption at
795   // an earlier point in the program if rematerialization extends the live range
796   // of the operand of the instruction being rematerialized across the live
797   // range of the value of instruction being rematerialized. Don't rematerialize
798   // in this case (ie, return 0 here).
799 
800   // Compute the amount of memory reduced (if any) by rematerializing
801   // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer
802   // be live at this program point, so initially set memory_reduced to the
803   // size of its defined values.
804   int64 memory_reduced = 0;
805   for (BufferId buffer_id : item->buffers_defined) {
806     // Avoid rematerializing instructions with indirect uses as it is difficult
807     // to reason about liveness after rematerializing the instruction.
808     // TODO(b/37714814): Consider rematerializing instructions with indirect
809     // uses.
810     if (buffers_.at(buffer_id).has_indirect_uses) {
811       return 0;
812     }
813 
814     if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
815       memory_reduced += AllocatedSize(buffer_id);
816     }
817   }
818 
819   // Account for any logical buffers whose live range must be extended across
820   // this program point.
821   for (BufferId buffer_id : item->buffers_used) {
822     if (!IsCurrentlyLive(buffer_id)) {
823       // This logical buffer is used by 'instruction' but is not live at this
824       // program point. Rematerializing 'instruction' will extend the buffer's
825       // live range across this program point.
826       memory_reduced -= AllocatedSize(buffer_id);
827     }
828   }
829 
830   return memory_reduced;
831 }
832 
MemoryReducedIfRematerialized(absl::Span<const Item * const> items) const833 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
834     absl::Span<const Item* const> items) const {
835   CHECK_NE(in_progress_item_, nullptr);
836   int64 memory_reduced = 0;
837   absl::flat_hash_set<const Item*> remat_candidates;
838 
839   for (const Item* item : items) {
840     if (!item->placed || item == in_progress_item_) {
841       LOG(WARNING) << "Unplaced item or in progress item being checked for "
842                       "rematerialization.";
843       return 0;
844     }
845 
846     // Compute the amount of memory reduced (if any) by rematerializing
847     // 'item->instruction'. The LogicalBuffers defined by 'item->instruction'
848     // will no longer be live at this program point, so initially set
849     // memory_reduced to the size of its defined values.
850     for (BufferId buffer_id : item->buffers_defined) {
851       // Avoid rematerializing instructions with indirect uses as it is
852       // difficult to reason about liveness after rematerializing the
853       // instruction.
854       // Avoid rematerializing instructions with live out buffers.
855       // TODO(mpurohit): Check why live_out buffers are an issue here.
856       if (buffers_.at(buffer_id).has_indirect_uses ||
857           buffers_.at(buffer_id).live_out) {
858         return 0;
859       }
860 
861       if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
862         memory_reduced += AllocatedSize(buffer_id);
863       }
864     }
865 
866     // Account for any logical buffers whose live range must be extended across
867     // this program point.
868     for (BufferId buffer_id : item->buffers_used) {
869       if (!IsCurrentlyLive(buffer_id)) {
870         // This logical buffer is used by 'item->instruction' but is not live at
871         // this program point. Rematerializing 'item->instruction' will extend
872         // the buffer's live range across this program point unless it is
873         // defined by an instruction that is also being rematerialized.
874         Item* defining_instruction =
875             buffers_.at(buffer_id).defining_instruction;
876         if (!remat_candidates.contains(defining_instruction)) {
877           memory_reduced -= AllocatedSize(buffer_id);
878         }
879       }
880     }
881     remat_candidates.insert(item);
882   }
883 
884   return memory_reduced;
885 }
886 
AddCompressInstructions(Item * original_item,Item * compressed_item,Item * uncompressed_item)887 Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
888                                                    Item* compressed_item,
889                                                    Item* uncompressed_item) {
890   // Original buffer is now dead.
891   memory_usage_ -= size_function_(original_item->instruction->shape());
892   // Compressed buffer is now alive.
893   memory_usage_ += size_function_(compressed_item->instruction->shape());
894 
895   ItemList placed_users;
896   ItemList unplaced_users;
897   CHECK_EQ(original_item->buffers_output.size(), 1);
898   BufferId original_buffer_id = original_item->buffers_output[0];
899   Buffer& original_buffer = buffers_.at(original_buffer_id);
900   for (Item* user : original_buffer.users) {
901     if (user->placed) {
902       CHECK(IsFinished(user)) << user->instruction->name();
903       placed_users.push_back(user);
904     } else {
905       unplaced_users.push_back(user);
906     }
907   }
908   original_buffer.users = std::move(placed_users);
909   original_buffer.unfinished_user_count = 0;
910   original_buffer.users.push_back(compressed_item);
911   Buffer& compressed_buffer =
912       NewBuffer(compressed_item, compressed_item->instruction->shape(),
913                 {uncompressed_item}, /*live_out=*/false,
914                 /*has_indirect_uses=*/false);
915   compressed_item->buffers_used = original_item->buffers_output;
916   compressed_item->buffers_output = {compressed_buffer.id};
917   compressed_item->buffers_defined.push_back(compressed_buffer.id);
918 
919   Buffer& uncompressed_buffer =
920       NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
921                 std::move(unplaced_users), /*live_out=*/false,
922                 /*has_indirect_uses=*/false);
923 
924   uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
925   uncompressed_item->buffers_output = {uncompressed_buffer.id};
926   uncompressed_item->buffers_defined = {uncompressed_buffer.id};
927 
928   for (Item* user : uncompressed_buffer.users) {
929     BufferIdList& buffers_used = user->buffers_used;
930     std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
931                  uncompressed_buffer.id);
932   }
933 
934   return Status::OK();
935 }
936 
AddRematerializedInstruction(Item * original_item,Item * remat_item)937 Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
938                                                         Item* remat_item) {
939   VLOG(3) << "AddRematerializedInstruction: original_instruction = "
940           << original_item->instruction->name()
941           << ", remat_instruction = " << remat_item->instruction->name();
942 
943   TF_RET_CHECK(in_progress_item_ != nullptr);
944   TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
945   TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
946 
947   // Construct the list of buffers used and defined by the rematerialization.
948   remat_item->buffers_used = original_item->buffers_used;
949 
950   // Account for the additional buffer uses created by the new rematerialization
951   // instruction. Update memory usage if the rematerialization makes a dead
952   // buffer live again.
953   for (BufferId buffer_id : original_item->buffers_used) {
954     Buffer& buffer = buffers_.at(buffer_id);
955     if (buffer.unfinished_user_count == 0) {
956       // Buffer used by this instruction was dead, now is alive.
957       memory_usage_ += AllocatedSize(buffer.id);
958     }
959 
960     buffer.unfinished_user_count++;
961     buffer.users.push_back(remat_item);
962   }
963 
964   // Create a new set of Buffers defined by the new rematerialization
965   // instruction. Update the internal data structures and memory use to account
966   // for them.
967   for (BufferId old_buffer_id : original_item->buffers_defined) {
968     Buffer& old_buffer = buffers_.at(old_buffer_id);
969 
970     ItemList placed_users;
971     ItemList unplaced_users;
972     for (Item* user : old_buffer.users) {
973       if (user->placed) {
974         CHECK(IsFinished(user)) << user->instruction->name();
975         placed_users.push_back(user);
976       } else {
977         unplaced_users.push_back(user);
978       }
979     }
980     old_buffer.users = std::move(placed_users);
981     old_buffer.unfinished_user_count = 0;
982 
983     // Buffer is now dead.
984     memory_usage_ -= AllocatedSize(old_buffer.id);
985 
986     Buffer& new_buffer =
987         RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
988 
989     remat_item->buffers_defined.push_back(new_buffer.id);
990     for (Item* user : new_buffer.users) {
991       BufferIdList& buffers_used = user->buffers_used;
992       std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
993                    new_buffer.id);
994     }
995   }
996 
997   VLOG(3) << "  memory usage = " << memory_usage_;
998   XLA_VLOG_LINES(10, ToString());
999 
1000   DCHECK(Check());
1001 
1002   return Status::OK();
1003 }
1004 
ToString() const1005 string MemoryUsageTracker::ToString() const {
1006   string output =
1007       absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
1008   absl::StrAppend(&output,
1009                   "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
1010                   memory_usage(), " bytes)");
1011   for (auto* item = instruction_list_.first(); item != nullptr;
1012        item = instruction_list_.next(item)) {
1013     const HloInstruction* instruction = item->instruction;
1014     string inprogress = item == in_progress_item_ ? " in-progress" : "";
1015     string placed = item->placed ? " placed" : "";
1016     absl::StrAppend(&output, "  ", instruction->name(), inprogress, placed,
1017                     "\n    Defines:\n");
1018     for (BufferId buffer_id : item->buffers_defined) {
1019       const Buffer& buffer = buffers_[buffer_id];
1020       string live = IsCurrentlyLive(buffer_id) ? " live" : "";
1021       absl::StrAppend(&output, "      ", buffer.ToString(), live, ", ",
1022                       buffer.unfinished_user_count, " unfinished uses\n");
1023     }
1024     absl::StrAppend(&output, "    Uses:\n");
1025     for (BufferId buffer_id : item->buffers_used) {
1026       absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
1027     }
1028   }
1029   return output;
1030 }
1031 
GetCompactShape(const HloInstruction * hlo)1032 StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
1033   auto it = compact_shape_.find(hlo);
1034   if (it != compact_shape_.end()) {
1035     return it->second;
1036   }
1037   const Shape& original_shape = hlo->shape();
1038   TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
1039   compact_shape_[hlo] = min_shape;
1040   return min_shape;
1041 }
1042 
Check() const1043 bool MemoryUsageTracker::Check() const {
1044   auto elements_are_unique = [](const BufferIdList& vec) {
1045     return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
1046   };
1047 
1048   // Verify buffers_defined per instruction.
1049   for (auto* instruction : computation_->instructions()) {
1050     const BufferIdList& defined_buffers =
1051         instruction_list_.GetItem(instruction)->buffers_defined;
1052     CHECK(elements_are_unique(defined_buffers))
1053         << "Instruction " << instruction->name()
1054         << " does not have unique defined buffers: "
1055         << absl::StrJoin(
1056                defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
1057                  absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1058                });
1059 
1060     for (const Buffer& buffer : buffers_) {
1061       if (buffer.defining_instruction->instruction == instruction) {
1062         CHECK(absl::c_linear_search(defined_buffers, buffer.id))
1063             << "Instruction " << instruction->name()
1064             << " defined buffers is missing: " << buffer.ToString();
1065       }
1066     }
1067   }
1068 
1069   // Verify buffers_used per instruction.
1070   for (auto* instruction : computation_->instructions()) {
1071     const BufferIdList& used_buffers =
1072         instruction_list_.GetItem(instruction)->buffers_used;
1073     CHECK(elements_are_unique(used_buffers))
1074         << "Instruction " << instruction->name()
1075         << " does not have unique used buffers: "
1076         << absl::StrJoin(
1077                used_buffers, ", ", [this](string* out, BufferId buffer_id) {
1078                  absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1079                });
1080   }
1081   for (const Buffer& buffer : buffers_) {
1082     int64 unfinished_uses = 0;
1083     for (Item* user : buffer.users) {
1084       const BufferIdList& used_buffers = user->buffers_used;
1085       CHECK(absl::c_linear_search(used_buffers, buffer.id))
1086           << "Instruction " << user->instruction->name()
1087           << " used buffers is missing " << buffer.ToString();
1088       if (!IsFinished(user)) {
1089         unfinished_uses++;
1090       }
1091     }
1092     CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
1093         << "Incorrect unplaced use count for " << buffer.ToString();
1094   }
1095   return true;
1096 }
1097 
1098 // Computes and returns the cost of rematerializing the given instruction.
1099 // Cost per rematerialized instruction is defined as:
1100 //
1101 // memory_limit_bytes / memory_reduced
1102 //
1103 // The idea is to choose the operation that will save the most memory for
1104 // rematerialization and do not worry about how much the compute costs since
1105 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64 memory_reduced,int64 memory_limit_bytes)1106 int64 RematerializationCost(const HloInstruction* instruction,
1107                             const MemoryUsageTracker& memory_tracker,
1108                             int64 memory_reduced, int64 memory_limit_bytes) {
1109   // If none of the users of 'instruction' have been placed in the sequence (as
1110   // tracked by memory_tracker), then rematerialization of 'instruction' is a
1111   // zero-cost move of 'instruction' in the sequence.
1112   if (!absl::c_any_of(instruction->users(),
1113                       [&memory_tracker](const HloInstruction* inst) {
1114                         return memory_tracker.IsPlaced(inst);
1115                       })) {
1116     return 0;
1117   }
1118 
1119   CHECK_GT(memory_reduced, 0);
1120   // Return the inverse of the benefit of rematerialization.
1121   return memory_limit_bytes / memory_reduced;
1122 }
1123 
1124 // Selects and returns the best candidate instruction for rematerialization.
1125 // The instruction with lowest rematerialization cost is selected among those
1126 // candidate which reduce memory use at the program point of the current
1127 // instruction as indicated by memory_tracker. nullptr is returned if no
1128 // candidate can be found.
1129 std::pair<Item*, RematStrategy>
PickRematerializationCandidate(const InstructionList & instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)1130 MemoryUsageTracker::PickRematerializationCandidate(
1131     const InstructionList& instruction_list, int64 memory_limit_bytes,
1132     absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
1133   Item* best_item = nullptr;
1134   int64 best_cost = 0;
1135   RematStrategy best_strategy;
1136 
1137   VLOG(5) << "Picking candidate";
1138 
1139   // TODO(b/35244891): This is currently quadratic in the number of HLO
1140   // instructions.
1141   for (auto* item = instruction_list.first(); item != nullptr;
1142        item = instruction_list.next(item)) {
1143     if (!item->placed) {
1144       // Only iterate up to the currently placed instruction.
1145       // We are trying to reduce memory usage at the placed
1146       // instruction so rematerializing later values is of no benefit.
1147       break;
1148     }
1149     HloInstruction* candidate = item->instruction;
1150     VLOG(5) << "considering rematerialization candidate " << candidate->name();
1151 
1152     if (item->blacklisted) {
1153       // Skip instructions on the blacklist to avoid infinite loops of
1154       // rematerializing the same instruction(s) repeatedly.
1155       VLOG(5) << "candidate " << candidate->name()
1156               << " is excluded from rematerialization";
1157       continue;
1158     }
1159     if (!CanBeRematerialized(candidate, remat_able)) {
1160       VLOG(5) << "candidate " << candidate->name()
1161               << " not viable: is not rematerializable";
1162 
1163       continue;
1164     }
1165 
1166     if (item->buffers_output.size() == 1 &&
1167         (mode_ == HloRematerialization::RematerializationMode::kCompressOnly ||
1168          mode_ == HloRematerialization::RematerializationMode::
1169                       kRecomputeAndCompress)) {
1170       // Only consider compressing single output instruction.
1171       const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
1172 
1173       if (item->placed && item != in_progress_item_ &&
1174           !output_buffer.live_out) {
1175         const Shape& original_shape = item->instruction->shape();
1176         if (original_shape.IsArray()) {
1177           Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
1178           const int64 memory_reduced =
1179               MemoryReducedIfCompressed(item, compact_shape);
1180           if (memory_reduced > 0) {
1181             const int64 cost = memory_limit_bytes / memory_reduced;
1182             if (best_item == nullptr || cost < best_cost) {
1183               VLOG(3) << "candidate " << candidate->name() << "("
1184                       << candidate->ToShortString() << ")"
1185                       << " now best when compressed into "
1186                       << compact_shape.ToString(true);
1187               RematStrategy strategy;
1188               strategy.kind = RematStrategy::kCompress;
1189               best_strategy = strategy;
1190               best_strategy.compact_shape = compact_shape;
1191               best_item = item;
1192               best_cost = cost;
1193             }
1194           }
1195         }
1196       }
1197     }
1198 
1199     // If any of the candidate's control successor has been placed, we need
1200     // to skip this candidate. Otherwise we will violate control dependency.
1201     bool control_successor_placed = std::any_of(
1202         candidate->control_successors().begin(),
1203         candidate->control_successors().end(),
1204         [this](const HloInstruction* inst) { return IsPlaced(inst); });
1205 
1206     if (control_successor_placed) {
1207       continue;
1208     }
1209 
1210     // Do not consider recomputation in compress-only mode.
1211     if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
1212       continue;
1213     }
1214 
1215     const int64 memory_reduced = MemoryReducedIfRematerialized(item);
1216 
1217     if (memory_reduced > 0) {
1218       const int cost =
1219           RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
1220 
1221       VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
1222               << memory_reduced << ", cost per byte " << cost;
1223 
1224       if (best_item == nullptr || cost < best_cost) {
1225         VLOG(5) << "candidate " << candidate->name() << " now best";
1226         best_strategy.kind = RematStrategy::kRecompute;
1227         best_item = item;
1228         best_cost = cost;
1229       }
1230     }
1231   }
1232   return {best_item, best_strategy};
1233 }
1234 
HasUnplacedUsers(Item * item) const1235 bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
1236   for (BufferId buffer_id : item->buffers_defined) {
1237     const Buffer& buffer = buffers_.at(buffer_id);
1238     for (Item* user : buffer.users) {
1239       if (!user->placed) {
1240         return true;
1241       }
1242     }
1243   }
1244   return false;
1245 }
1246 
RematerializeInstructions(MemoryUsageTracker * memory_tracker,std::vector<Item * > * best_items,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,InstructionList * instruction_list)1247 StatusOr<int64> RematerializeInstructions(
1248     MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
1249     absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1250     InstructionList* instruction_list) {
1251   int64 net_instructions_added = 0;
1252   int64 total_memory_saved =
1253       memory_tracker->MemoryReducedIfRematerialized(*best_items);
1254   std::vector<string> instruction_names(best_items->size());
1255   // Rematerialize the block of instructions in the reverse order to account for
1256   // dependencies between instructions in best_items.
1257   for (int i = best_items->size() - 1; i >= 0; --i) {
1258     Item* best_item = (*best_items)[i];
1259     HloInstruction* best = best_item->instruction;
1260     instruction_names[i] = best->name();
1261     HloComputation* computation = best->parent();
1262 
1263     // If the item to remat has no unplaced users, then skip the
1264     // rematerialization. Such an instruction can appear in best_items because
1265     // it is part of a good block, but does not itself add any benefit.
1266     if (!memory_tracker->HasUnplacedUsers(best_item)) {
1267       continue;
1268     }
1269 
1270     HloInstruction* remat =
1271         computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1272 
1273     // Add control dependencies to the new operation.
1274     for (auto successor : best->control_successors()) {
1275       TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1276     }
1277     for (auto predecessor : best->control_predecessors()) {
1278       TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1279     }
1280 
1281     Item* remat_item = instruction_list->CreateItem(remat);
1282 
1283     // Replace each remaining use of 'best' with the rematerialization.
1284     std::vector<HloInstruction*> best_users_copy = best->users();
1285     for (HloInstruction* user : best_users_copy) {
1286       if (!memory_tracker->IsPlaced(user)) {
1287         VLOG(2) << "  Replacing use of " << best->name() << " in "
1288                 << user->name() << " with " << remat->name();
1289         TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
1290       }
1291     }
1292 
1293     // Account for the rematerialization in the memory tracker.
1294     TF_RETURN_IF_ERROR(
1295         memory_tracker->AddRematerializedInstruction(best_item, remat_item));
1296 
1297     // Insert rematerialized instruction right before the earliest unplaced
1298     // use of the instruction *and* the earliest unplaced last use of any
1299     // operands of remat. Unplaced uses of the remat's operands are included
1300     // because we don't want to extend the live range of remat's operands as
1301     // this could increase memory usage.
1302     ItemList place_before;
1303     for (auto user : remat->users()) {
1304       place_before.push_back(instruction_list->GetItem(user));
1305     }
1306     for (auto* operand : remat->operands()) {
1307       for (auto* operand_user : operand->users()) {
1308         if (operand_user != remat) {
1309           Item* operand_user_item = instruction_list->GetItem(operand_user);
1310           if (!operand_user_item->placed) {
1311             place_before.push_back(operand_user_item);
1312           }
1313         }
1314       }
1315     }
1316     // Insert rematerialized instruction before any of its successors to
1317     // preserve ordering regarding control dependency.
1318     for (auto successor : remat->control_successors()) {
1319       Item* successor_item = instruction_list->GetItem(successor);
1320       // Assert to make sure we never remat an operation with control
1321       // successor already placed.
1322       CHECK(!successor_item->placed) << successor_item->instruction->name();
1323       place_before.push_back(successor_item);
1324     }
1325     instruction_list->InsertBeforeInstructions(remat_item, place_before);
1326 
1327     // If the rematerialized instruction is dead then rematerialization is
1328     // essentially a move. Don't delete the instruction now because we don't
1329     // want duplicate HloInstruction* values during the course of the
1330     // transformation because we keep maps with HloInstruction* values as
1331     // keys.
1332     if (best->users().empty()) {
1333       VLOG(2) << best->name() << " is now dead";
1334       if (ContainsKey(*remat_move_instructions, best)) {
1335         // Previously, 'best' was a rematerialization which killed the
1336         // instruction it was a copying of. Now 'remat' is a rematerialization
1337         // of 'best' and kills 'best'. Stop rematerializing this instruction
1338         // to avoid an infinite loop.
1339         instruction_list->Blacklist(remat);
1340       }
1341       remat_move_instructions->insert(remat);
1342     } else {
1343       net_instructions_added++;
1344     }
1345   }
1346   VLOG(1) << "Rematerializing instructions ["
1347           << absl::StrJoin(instruction_names, ", ") << "] (saving "
1348           << HumanReadableNumBytes(total_memory_saved) << ")";
1349   return net_instructions_added;
1350 }
1351 
CompressInstruction(MemoryUsageTracker * memory_tracker,Item * best_item,const Shape & compact_shape,InstructionList * instruction_list)1352 StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
1353                                     Item* best_item, const Shape& compact_shape,
1354                                     InstructionList* instruction_list) {
1355   HloInstruction* best = best_item->instruction;
1356   VLOG(5) << "Transposing instruction " << best->name() << " (saving "
1357           << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1358                  best_item, compact_shape))
1359           << ") to" << compact_shape.ToString(true);
1360 
1361   HloComputation* computation = best->parent();
1362 
1363   HloInstruction* compressed = computation->AddInstruction(
1364       HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
1365 
1366   HloInstruction* uncompressed = computation->AddInstruction(
1367       HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
1368 
1369   Item* compressed_item = instruction_list->CreateItem(compressed);
1370   compressed_item->placed = true;
1371 
1372   Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
1373 
1374   // Replace each remaining use of 'best' with the uncompressed.
1375   std::vector<HloInstruction*> best_users_copy = best->users();
1376   for (HloInstruction* user : best_users_copy) {
1377     if (!memory_tracker->IsPlaced(user)) {
1378       VLOG(5) << "  Replacing use of " << best->name() << " in " << user->name()
1379               << " with " << uncompressed->name();
1380       TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
1381     }
1382   }
1383 
1384   // Account for the rematerialization in the memory tracker.
1385   TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
1386       best_item, compressed_item, uncompressed_item));
1387 
1388   // Insert rematerialized instruction right before the earliest unplaced
1389   // use of the instruction.
1390   ItemList place_before;
1391   for (auto user : uncompressed->users()) {
1392     place_before.push_back(instruction_list->GetItem(user));
1393   }
1394 
1395   instruction_list->Blacklist(compressed_item->instruction);
1396   instruction_list->Blacklist(uncompressed_item->instruction);
1397 
1398   instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
1399 
1400   instruction_list->InsertAfterInstructions(compressed_item, {best_item});
1401 
1402   return 2;
1403 }
1404 
1405 }  // namespace
1406 
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order) const1407 StatusOr<int64> HloRematerialization::ComputePeakMemory(
1408     const HloComputation* computation,
1409     const HloInstructionSequence& order) const {
1410   InstructionList instruction_list(order);
1411   MemoryUsageTracker tracker(computation, size_function_,
1412                              compact_shape_function_, *points_to_analysis_,
1413                              instruction_list, mode_);
1414   int64 peak_memory = tracker.memory_usage();
1415   for (auto* item = instruction_list.first(); item != nullptr;
1416        item = instruction_list.next(item)) {
1417     const HloInstruction* instruction = item->instruction;
1418     TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
1419     TF_ASSIGN_OR_RETURN(int64 callee_usage,
1420                         CalledComputationsMemoryUsage(instruction));
1421     peak_memory =
1422         std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage);
1423     TF_RETURN_IF_ERROR(tracker.EndInstruction());
1424   }
1425   VLOG(1) << "Peak memory for " << computation->name() << ": "
1426           << HumanReadableNumBytes(peak_memory);
1427   return peak_memory;
1428 }
1429 
CalledComputationsMemoryUsage(const HloInstruction * instruction) const1430 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
1431     const HloInstruction* instruction) const {
1432   const CallSite* callsite =
1433       call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
1434   if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
1435     return 0;
1436   }
1437   int64 callee_usage = 0;
1438   for (const HloComputation* computation : callsite->called_computations()) {
1439     TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
1440     callee_usage += computation_peak_memory_.at(computation);
1441   }
1442   return callee_usage;
1443 }
1444 
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64 memory_limit_bytes)1445 StatusOr<bool> HloRematerialization::RematerializeComputation(
1446     HloComputation* computation, HloSchedule* schedule,
1447     int64 memory_limit_bytes) {
1448   VLOG(1) << "Rematerializing computation " << computation->name()
1449           << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
1450   VLOG(1) << "peak memory usage is "
1451           << HumanReadableNumBytes(computation_peak_memory_.at(computation));
1452   CHECK(!ContainsKey(rematerialized_computations_, computation));
1453 
1454   InstructionList instruction_list(schedule->sequence(computation));
1455   MemoryUsageTracker memory_tracker(
1456       computation, size_function_, compact_shape_function_,
1457       *points_to_analysis_, instruction_list, mode_);
1458   bool changed = false;
1459 
1460   // If the rematerialization makes the source instruction dead, then the
1461   // rematerialization is added to 'remat_move_instructions' (the
1462   // rematerialization is essentially a move). If the next rematerialization of
1463   // the instruction is also a move then the rematerialization is added to the
1464   // blacklist.
1465   absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
1466 
1467   // The map from instructions to their rematerializable status.
1468   absl::flat_hash_map<const HloInstruction*, bool> remat_able;
1469 
1470   // The peak memory of the computation at any point in the instruction
1471   // sequence.
1472   int64 peak_memory = memory_tracker.memory_usage();
1473 
1474   // Total count of instructions rematerialized.
1475   int64 remat_count = 0;
1476   // Total count of clones created minus number of original rematerialized
1477   // instructions which are dead.
1478   int64 net_instructions_added = 0;
1479 
1480   const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1481 
1482   // Iterate through all instructions in the sequence. At each instruction
1483   // (program point) if memory_usage exceeds the specified limit then
1484   // rematerialize HLO instructions until memory_usage is reduced.
1485   int64 instruction_index = 0;
1486   for (auto* item = instruction_list.first(); item != nullptr;
1487        item = instruction_list.next(item)) {
1488     const HloInstruction* instruction = item->instruction;
1489     TF_ASSIGN_OR_RETURN(int64 callee_usage,
1490                         CalledComputationsMemoryUsage(instruction));
1491     TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1492 
1493     VLOG(2) << "Program point at " << instruction->name()
1494             << ", memory usage = " << memory_tracker.memory_usage()
1495             << ", callee usage = " << callee_usage << ", [" << instruction_index
1496             << "/" << instruction_list.size() << "]";
1497     instruction_index++;
1498 
1499     while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1500       VLOG(2) << "Over memory limit at instruction " << instruction->name()
1501               << ", using "
1502               << HumanReadableNumBytes(memory_tracker.memory_usage() +
1503                                        callee_usage)
1504               << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1505 
1506       Item* best_item;
1507       RematStrategy best_strategy;
1508       std::tie(best_item, best_strategy) =
1509           memory_tracker.PickRematerializationCandidate(
1510               instruction_list, memory_limit_bytes, &remat_able);
1511 
1512       if (best_item == nullptr) {
1513         VLOG(3) << "Unable to find rematerialization candidate at program "
1514                    "point "
1515                 << instruction->name() << ". Memory usage = "
1516                 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1517                                          callee_usage);
1518         break;
1519       }
1520 
1521       HloInstruction* best = best_item->instruction;
1522       changed = true;
1523       remat_count++;
1524 
1525       int64 num_instructions_added = 0;
1526       if (best_strategy.kind == RematStrategy::kCompress) {
1527         VLOG(1) << "Compressing instruction " << best->name() << " (saving "
1528                 << HumanReadableNumBytes(
1529                        memory_tracker.MemoryReducedIfCompressed(
1530                            best_item, best_strategy.compact_shape))
1531                 << ")";
1532 
1533         TF_ASSIGN_OR_RETURN(num_instructions_added,
1534                             CompressInstruction(&memory_tracker, best_item,
1535                                                 best_strategy.compact_shape,
1536                                                 &instruction_list));
1537       } else {
1538         VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
1539                 << HumanReadableNumBytes(
1540                        memory_tracker.MemoryReducedIfRematerialized(best_item))
1541                 << ")";
1542 
1543         std::vector<Item*> best_items{best_item};
1544         TF_ASSIGN_OR_RETURN(num_instructions_added,
1545                             RematerializeInstructions(
1546                                 &memory_tracker, &best_items,
1547                                 &remat_move_instructions, &instruction_list));
1548       }
1549       net_instructions_added += num_instructions_added;
1550 
1551       VLOG(1) << "memory_usage after rematerialization = "
1552               << HumanReadableNumBytes(memory_tracker.memory_usage());
1553     }
1554 
1555     const CallSite* callsite = call_graph_node.GetCallSite(instruction);
1556     if (callsite != nullptr &&
1557         callsite->context() == CallContext::kSequential &&
1558         memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1559       // Memory usage exceeds the limit. Try to rematerialize any
1560       // subcomputation(s) that this instruction calls.
1561       VLOG(1) << "Memory usage still over the limit ("
1562               << (memory_tracker.memory_usage() + callee_usage) << " > "
1563               << memory_limit_bytes
1564               << "). Rematerializing computations called by "
1565               << instruction->name();
1566 
1567       // Recompute callee usage to account for any rematerialization performed
1568       // in the callee computations.
1569       for (HloComputation* called_computation :
1570            callsite->called_computations()) {
1571         if (!ContainsKey(rematerialized_computations_, called_computation)) {
1572           // Memory limit for the subcomputation is the memory limit less the
1573           // amount of memory used at this point in the computation.
1574           int64 subcomputation_memory_limit_bytes = std::max<int64>(
1575               0, memory_limit_bytes - memory_tracker.memory_usage());
1576           TF_ASSIGN_OR_RETURN(
1577               bool subcomputation_changed,
1578               RematerializeComputation(called_computation, schedule,
1579                                        subcomputation_memory_limit_bytes));
1580           changed |= subcomputation_changed;
1581         }
1582       }
1583       TF_ASSIGN_OR_RETURN(callee_usage,
1584                           CalledComputationsMemoryUsage(instruction));
1585     }
1586 
1587     peak_memory = std::max<int64>(peak_memory,
1588                                   memory_tracker.memory_usage() + callee_usage);
1589     VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
1590 
1591     TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
1592   }
1593 
1594   // Verify some invariants on the memory tracker.
1595   for (auto* instruction : computation->instructions()) {
1596     CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
1597   }
1598 
1599   VLOG(1) << "In computation " << computation->name() << " rematerialized "
1600           << remat_count << " instructions; " << net_instructions_added
1601           << " net instructions added";
1602   VLOG(1) << "  peak memory usage now " << HumanReadableNumBytes(peak_memory)
1603           << " (was "
1604           << HumanReadableNumBytes(computation_peak_memory_.at(computation))
1605           << ")";
1606 
1607   // Update peak memory used by computation.
1608   computation_peak_memory_.at(computation) = peak_memory;
1609 
1610   // Update order to include rematerialized instructions.
1611   HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
1612   sequence.clear();
1613   for (auto* item = instruction_list.first(); item != nullptr;
1614        item = instruction_list.next(item)) {
1615     HloInstruction* instruction = item->instruction;
1616     sequence.push_back(instruction);
1617   }
1618   rematerialized_computations_.insert(computation);
1619 
1620   instructions_rematerialized_ += remat_count;
1621   net_instructions_added_ += net_instructions_added;
1622 
1623   return changed;
1624 }
1625 
Run(HloModule * module)1626 StatusOr<bool> HloRematerialization::Run(HloModule* module) {
1627   VLOG(1) << "HloRematerialization() with memory limit of "
1628           << HumanReadableNumBytes(memory_limit_bytes_);
1629   XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
1630 
1631   // Initialize pass object state.
1632   computation_peak_memory_.clear();
1633   rematerialized_computations_.clear();
1634   instructions_rematerialized_ = 0;
1635   net_instructions_added_ = 0;
1636 
1637   TF_RET_CHECK(module->has_schedule());
1638   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
1639 
1640   // Adjust memory limit to account for the output of the entry
1641   // computation. This is necessary because the per-computation accounting in
1642   // MemoryUsageTracker do not include output as these are typically allocated
1643   // by the caller.
1644   int64 module_output_size = 0;
1645   ShapeUtil::ForEachSubshape(
1646       module->result_shape(),
1647       [&module_output_size, module, this](const Shape& subshape,
1648                                           const ShapeIndex& output_index) {
1649         module_output_size += size_function_(subshape);
1650       });
1651 
1652   const int64 adjusted_memory_limit_bytes =
1653       memory_limit_bytes_ - module_output_size;
1654   VLOG(1) << "Adjusted memory limit accounting for output ("
1655           << HumanReadableNumBytes(module_output_size)
1656           << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
1657 
1658   // Compute peak memory usage of all computations in the module called in a
1659   // sequential context.
1660   call_graph_ = CallGraph::Build(module);
1661   TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
1662       [this, module](const CallGraphNode& node) -> Status {
1663         if (node.context() == CallContext::kSequential) {
1664           TF_ASSIGN_OR_RETURN(
1665               computation_peak_memory_[node.computation()],
1666               ComputePeakMemory(node.computation(), module->schedule().sequence(
1667                                                         node.computation())));
1668         }
1669         return Status::OK();
1670       },
1671       /*visit_unreachable_nodes=*/false));
1672 
1673   // The peak memory usage of the module equals the peak memory use of the entry
1674   // computation plus the output size of the computation. This is because the
1675   // peak memory for a computation does not include the output as this is
1676   // typically accounted for in the caller.
1677   const int64 before_peak_memory =
1678       computation_peak_memory_.at(module->entry_computation()) +
1679       module_output_size;
1680   VLOG(1) << "Peak memory usage of module (before): "
1681           << HumanReadableNumBytes(before_peak_memory);
1682 
1683   // Subcomputations called by the entry computation will also be
1684   // rematerialized.
1685   TF_ASSIGN_OR_RETURN(
1686       bool changed,
1687       RematerializeComputation(module->entry_computation(), &module->schedule(),
1688                                adjusted_memory_limit_bytes));
1689 
1690   // Rematerialization can introduce dead code. This occurs if all uses of an
1691   // instruction are replaced with rematerializations of the instruction.
1692 
1693   // Stash away the schedule during copy insertion, to avoid validation failures
1694   // while the module is in flux.
1695   HloSchedule saved_schedule = module->schedule();
1696   module->clear_schedule();
1697   TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
1698   changed |= dead_code_removed;
1699 
1700   // After DCE, the module sequence may include instructions which no longer
1701   // exist. Update the schedule and restore it.
1702   TF_RETURN_IF_ERROR(saved_schedule.Update());
1703   TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule)));
1704   VLOG(1) << "Rematerialized " << instructions_rematerialized_
1705           << " instructions in module " << module->name() << "; "
1706           << net_instructions_added_ << " net instructions added";
1707   const int64 current_peak_memory =
1708       computation_peak_memory_.at(module->entry_computation()) +
1709       module_output_size;
1710   VLOG(1) << "Peak memory usage of module now "
1711           << HumanReadableNumBytes(current_peak_memory) << " ("
1712           << current_peak_memory << " bytes), was "
1713           << HumanReadableNumBytes(before_peak_memory) << " ("
1714           << before_peak_memory << " bytes)";
1715   const int64 reduced_peak_memory = before_peak_memory - current_peak_memory;
1716   VLOG(1) << "Reduced peak memory by "
1717           << HumanReadableNumBytes(reduced_peak_memory) << " ("
1718           << reduced_peak_memory << " bytes)";
1719 
1720   if (sizes_ != nullptr) {
1721     sizes_->before_bytes = before_peak_memory;
1722     sizes_->after_bytes = current_peak_memory;
1723   }
1724 
1725   XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
1726 
1727   if (current_peak_memory > memory_limit_bytes_) {
1728     LOG(WARNING) << absl::StrFormat(
1729         "Can't reduce memory use below %s (%d bytes) by rematerialization; "
1730         "only reduced to %s (%d bytes)",
1731         HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
1732         HumanReadableNumBytes(current_peak_memory), current_peak_memory);
1733   }
1734 
1735   return changed;
1736 }
1737 
1738 }  // namespace xla
1739