1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <string>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48 namespace {
49
50 using ::tensorflow::strings::HumanReadableNumBytes;
51
52 // Potential optimizations:
53 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
54 // of candidates.
55 // . Cache IsRematerializable in Item? Only correct if control
56 // predecessors and successors don't change.
57
58 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)59 bool IsRematerializable(const HloInstruction* instruction) {
60 if (instruction->opcode() == HloOpcode::kCopy) {
61 if (LayoutUtil::Equal(instruction->shape().layout(),
62 instruction->operand(0)->shape().layout())) {
63 // Don't rematerialize copies added by copy insertion (layout doesn't
64 // change).
65 return false;
66 }
67 }
68
69 // Don't rematerialize instructions with side effects or instructions which
70 // cannot be cloned safely.
71 switch (instruction->opcode()) {
72 case HloOpcode::kCall:
73 case HloOpcode::kConstant:
74 case HloOpcode::kConditional:
75 case HloOpcode::kAllReduce:
76 case HloOpcode::kCustomCall:
77 case HloOpcode::kParameter:
78 case HloOpcode::kWhile:
79 return false;
80 default:
81 return !instruction->HasSideEffect();
82 }
83 }
84
85 // Checks whether an instruction can be rematerialized, by looking up the
86 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)87 bool CanBeRematerialized(
88 const HloInstruction* instruction,
89 absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
90 auto it = remat_able->find(instruction);
91 if (it != remat_able->end()) {
92 return it->second;
93 }
94 bool rematerializable = IsRematerializable(instruction);
95 (*remat_able)[instruction] = rematerializable;
96 return rematerializable;
97 }
98
99 // Type holding a unique identifier for each Buffer object.
100 using BufferId = int64;
101 using BufferIdList = absl::InlinedVector<BufferId, 3>;
102
103 struct RematStrategy {
104 enum {
105 // Recompute the node at a later program point.
106 kRecompute,
107 // Change the layout into a compact form and uncompress it back at a later
108 // program point.
109 kCompress,
110 } kind;
111 Shape compact_shape;
112 };
113
114 // We wrap HloInstruction* with an Item that holds auxiliary
115 // per-instruction state.
116 struct Item {
117 HloInstruction* instruction;
118
119 // True once the instruction is marked as placed (when BeginInstruction
120 // has been called for this instruction).
121 bool placed = false;
122
123 // To avoid an infinite loop rematerializing the same set of
124 // instructions ad infinitum, keep a blacklist of instructions
125 // which should not be rematerialized.
126 bool blacklisted = false;
127
128 // The buffers defined by this instruction.
129 BufferIdList buffers_defined;
130
131 // Output buffers of this instruction. This is used to track outputs by GTE
132 // instructions (where the instruction doesn't define a buffer).
133 BufferIdList buffers_output;
134
135 // The buffers used by this instruction.
136 BufferIdList buffers_used;
137
138 private:
139 friend class InstructionList;
140
141 // Items are arranged in a doubly linked list.
142 Item* next;
143 Item* prev;
144
145 // List is ordered by position, which can however be duplicated as
146 // new instructions are inserted. See InsertBeforeInstructions
147 // comment for details.
148 int64 position;
149 };
150
151 using ItemList = absl::InlinedVector<Item*, 3>;
152
153 // Class which maintains an ordered list of instructions with fast insertion
154 // before arbitrary elements.
155 class InstructionList {
156 public:
InstructionList(const HloInstructionSequence & order)157 explicit InstructionList(const HloInstructionSequence& order) {
158 int64 position = 0;
159 Item* last = nullptr;
160 for (HloInstruction* inst : order.instructions()) {
161 // Add a new item to the linked list.
162 Item* item = new Item;
163 item->next = nullptr;
164 item->prev = last;
165 if (last == nullptr) {
166 first_ = item;
167 } else {
168 last->next = item;
169 }
170 last = item;
171
172 // Initially position numbers are uniquely assigned in order. Later as
173 // instructions are added with InsertBefore* methods, some instructions
174 // may have duplicate position numbers, but the values will be guaranteed
175 // to be monotonically increasing through the list, and so is still useful
176 // for quickly(-ish) determining the order of arbitrary instructions in
177 // the list.
178 item->instruction = inst;
179 item->position = position;
180 position++;
181
182 item_map_[inst] = item;
183 }
184 }
185
~InstructionList()186 ~InstructionList() {
187 for (Item* item = first_; item != nullptr;) {
188 Item* next = item->next;
189 delete item;
190 item = next;
191 }
192 }
193
size() const194 size_t size() const { return item_map_.size(); }
195
196 // For ordered iteration over items.
197 // for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const198 Item* first() const { return first_; }
next(Item * item) const199 Item* next(Item* item) const { return item->next; }
200
201 // Creates an Item for the given instruction, but doesn't add it to the list.
202 // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)203 Item* CreateItem(HloInstruction* inst) {
204 Item* item = new Item;
205 item->instruction = inst;
206 CHECK(item_map_.insert({inst, item}).second)
207 << "inserting inst twice " << inst->name();
208 return item;
209 }
210
211 // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const212 Item* GetItem(const HloInstruction* inst) const {
213 auto iter = item_map_.find(inst);
214 CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
215 return iter->second;
216 }
217
218 // Insert instruction 'to_insert' immediately before the earliest instruction
219 // in 'before_instructions'.
220 //
221 // Each instruction gets a non-decreasing ordinal number. We use this to let
222 // InsertBeforeInstructions quickly insert an instruction before the earliest
223 // instruction in a set of instructions. If position_number_[a] <
224 // position_number_[b] then 'a' comes before 'b' in the list. If the position
225 // numbers are the same then nothing can be said about their order without
226 // examining the list.
227 //
228 // On object construction this ordinal is precisely the instruction's index
229 // in the list. Later, instructions inserted via InsertBefore receive
230 // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)231 void InsertBeforeInstructions(Item* to_insert,
232 absl::Span<Item* const> before_instructions) {
233 VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
234 << " before {"
235 << absl::StrJoin(before_instructions, ", ",
236 [](string* out, Item* item) {
237 absl::StrAppend(out, item->instruction->name());
238 })
239 << "}";
240
241 // Find the minimal position number of any instruction in
242 // 'before_instructions'.
243 CHECK(!before_instructions.empty());
244 Item* min_position_item = nullptr;
245 for (Item* item : before_instructions) {
246 if (min_position_item == nullptr ||
247 item->position < min_position_item->position) {
248 min_position_item = item;
249 }
250 }
251
252 // Because more than one instruction in 'before_instructions' may have a
253 // position number of 'min_position_number', find the first such instruction
254 // with position number 'min_position_number'.
255
256 // First find first instruction with the min position.
257 while (min_position_item->prev != nullptr &&
258 min_position_item->position == min_position_item->prev->position) {
259 min_position_item = min_position_item->prev;
260 }
261
262 // Now scan forwards until we find one of the before_instructions.
263 while (!absl::c_linear_search(before_instructions, min_position_item)) {
264 min_position_item = min_position_item->next;
265 }
266 return InsertBefore(to_insert, min_position_item);
267 }
268
InsertAfterInstructions(Item * to_insert,absl::Span<Item * const> after_instructions)269 void InsertAfterInstructions(Item* to_insert,
270 absl::Span<Item* const> after_instructions) {
271 VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
272 << " after {"
273 << absl::StrJoin(after_instructions, ", ",
274 [](string* out, Item* item) {
275 absl::StrAppend(out, item->instruction->name());
276 })
277 << "}";
278
279 // Find the max position number of any instruction in
280 // 'after_instructions'.
281 CHECK(!after_instructions.empty());
282 Item* max_position_item = nullptr;
283 for (Item* item : after_instructions) {
284 if (max_position_item == nullptr ||
285 item->position > max_position_item->position) {
286 max_position_item = item;
287 }
288 }
289 // No rematerializable instruction should be inserted at the end of the
290 // computation.
291 CHECK(max_position_item->next != nullptr);
292 InsertBeforeInstructions(to_insert, {max_position_item->next});
293 }
294
Blacklist(const HloInstruction * inst)295 void Blacklist(const HloInstruction* inst) {
296 GetItem(inst)->blacklisted = true;
297 }
298
299 private:
300 // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)301 void InsertBefore(Item* item, Item* before) {
302 VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
303 << before->instruction->name();
304 // Insert new item into linked list.
305 item->prev = before->prev;
306 item->next = before;
307 before->prev = item;
308 if (item->prev != nullptr) {
309 item->prev->next = item;
310 } else {
311 first_ = item;
312 }
313
314 // Assign the same position number to the newly added instruction as
315 // 'before'. This guarantees monotonicity of the position numbers, but not
316 // uniqueness.
317 item->position = before->position;
318 }
319
320 Item* first_;
321
322 // Item for each instruction.
323 absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
324 };
325
326 // Return the items which use the given LogicalBuffer. Sets
327 // has_indirect_users to whether any of the uses is indirect. A use is indirect
328 // if the instruction defining logical_buffer is not an operand of the use. This
329 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)330 ItemList GetUsers(const InstructionList& instruction_list,
331 const LogicalBuffer* logical_buffer,
332 const TuplePointsToAnalysis& points_to_analysis,
333 bool* has_indirect_users) {
334 ItemList users;
335 // To identify uses iterate through all HloInstruction users of the
336 // BufferAliases of the logical buffer.
337 *has_indirect_users = false;
338 for (const BufferAlias& buffer_alias :
339 points_to_analysis.GetBufferAliases(*logical_buffer)) {
340 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
341 if (points_to_analysis.DoesNotUseOperandBuffer(
342 buffer_alias.instruction(), buffer_alias.index(), user)) {
343 // The alias may be an operand of 'user', but the LogicalBuffer cannot
344 // possibly be used by the instruction so ignore 'user'. This is the
345 // case, for example, for the tuple element buffers in a GetTupleElement
346 // instruction (the GTE instruction only uses the pointer vector).
347 continue;
348 }
349 if (buffer_alias.instruction() != logical_buffer->instruction()) {
350 *has_indirect_users = true;
351 }
352 // A buffer may be used by the instruction via more than one alias. For
353 // example, a buffer which appears in more than one element of a tuple.
354 Item* user_item = instruction_list.GetItem(user);
355 if (!absl::c_linear_search(users, user_item)) {
356 users.push_back(user_item);
357 }
358 }
359 }
360 return users;
361 }
362
363 // Class for tracking memory usage of a computation as the instructions are
364 // placed sequentially. Memory usage is the sum of the sizes of live values
365 // (LogicalBuffers) at the current point in the instruction sequence.
366 class MemoryUsageTracker {
367 public:
368 MemoryUsageTracker(
369 const HloComputation* computation,
370 const HloRematerialization::ShapeSizeFunction& size_function,
371 const HloRematerialization::CompactShapeFunction& compact_shape_function,
372 const TuplePointsToAnalysis& points_to_analysis,
373 const InstructionList& instruction_list,
374 HloRematerialization::RematerializationMode mode);
375
376 // Starts the placement of the given instruction. This adds the sizes of the
377 // LogicalBuffers defined by the instruction to the current memory
378 // usage. Placement is broken into two steps (BeginInstruction and
379 // EndInstruction) to accurately model memory usage. At BeginInstruction the
380 // memory for the output value(s) of the current instruction is allocated. At
381 // EndInstruction memory for dead operand(s) is freed.
382 Status BeginInstruction(Item* item);
383
RematerializationCost(const HloInstruction * instruction,int64 memory_reduced,int64 memory_limit_bytes)384 int64 RematerializationCost(const HloInstruction* instruction,
385 int64 memory_reduced, int64 memory_limit_bytes) {
386 // If none of the users of 'instruction' have been placed in the sequence
387 // (as tracked by memory_tracker), then rematerialization of 'instruction'
388 // is a zero-cost move of 'instruction' in the sequence.
389 if (!absl::c_any_of(
390 instruction->users(),
391 [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
392 return 0;
393 }
394
395 CHECK_GT(memory_reduced, 0);
396 // Return the inverse of the benefit of rematerialization.
397 return memory_limit_bytes / memory_reduced;
398 }
399
400 // Finishes the placement of the current instruction. This frees any dead
401 // operands or dead result of the instruction. This must be called after
402 // each call to BeginInstruction.
403 Status EndInstruction();
404
405 // Returns the number of bytes that the current memory usage will be reduced
406 // if the given instruction is rematerialized.
407 int64 MemoryReducedIfRematerialized(Item* item) const;
408
409 // Returns the number of bytes that the current memory usage will be reduced
410 // if the given instruction is compact.
411 int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
412
413 // Returns the number of bytes that the current memory usage will be reduced
414 // by if the given sequence of instructions is rematerialized.
415 int64 MemoryReducedIfRematerialized(
416 absl::Span<const Item* const> items) const;
417
418 Status AddCompressInstructions(Item* original_item, Item* compressed_item,
419 Item* uncompressed_item);
420
421 // Adjusts memory usage to account for the rematerialization of
422 // original_item for all remaining unplaced uses. The rematerialization
423 // is remat_item. This method should be called after the HLO graph has
424 // been transformed (rematerialization instruction created and connected
425 // to uses).
426 Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
427
428 std::pair<Item*, RematStrategy> PickRematerializationCandidate(
429 const InstructionList& instruction_list, int64 memory_limit_bytes,
430 absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
431
432 // Returns whether the given instruction has been placed (BeginInstruction
433 // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const434 bool IsPlaced(const HloInstruction* instruction) const {
435 return instruction_list_.GetItem(instruction)->placed;
436 }
437
438 // Returns whether 'item' has any unplaced users.
439 bool HasUnplacedUsers(Item* item) const;
440
441 // Returns the current memory usage. This is the sum of sizes of all live
442 // values.
memory_usage() const443 int64 memory_usage() const { return memory_usage_; }
444
445 // Check invariants of the data structure. This is expensive to call.
446 bool Check() const;
447
448 string ToString() const;
449
450 private:
451 // A Buffer represents a single LogicalBuffer in the computation including
452 // various metadata useful for tracking liveness of the value. A LogicalBuffer
453 // is not used directly because the HLO graph is transformed and
454 // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
455 // HLO graph transformations.
456 struct Buffer {
457 // The unique id of this Buffer. This value is equal to the buffer's index
458 // in the vector buffers_.
459 const BufferId id;
460
461 // The instruction which defines this buffer.
462 Item* defining_instruction;
463
464 // The materialized size of the buffer in bytes.
465 const int64 size;
466
467 // Shape of the buffer.
468 Shape shape;
469
470 // Whether this buffer is live-out of the computation.
471 bool live_out;
472
473 // Whether this buffer has indirect uses. Ie, an instruction which is not a
474 // user of defining_instruction uses this buffer. This can occur due to
475 // buffer aliasing (eg, tuples).
476 bool has_indirect_uses;
477
478 // The instructions which use this buffer.
479 ItemList users;
480
481 // The number of users (HloInstructions) of this buffer which have not yet
482 // been placed in the sequence.
483 int64 unfinished_user_count;
484
ToStringxla::__anon20d191180111::MemoryUsageTracker::Buffer485 string ToString() const {
486 return absl::StrCat("Buffer ", id, " (defined by ",
487 defining_instruction->instruction->name(), ", size ",
488 size, " bytes)");
489 }
490 };
491
492 // Get the compact shape of given hlo instruction. An internal cache is used
493 // to avoid computing the shape multiple times.
494 StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
495
496 // Creates a Buffer representing the given logical buffer. The buffer is added
497 // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool live_out)498 Buffer& CreateBufferFromLogicalBuffer(
499 const LogicalBuffer* logical_buffer,
500 const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
501 bool has_indirect_uses = false;
502 ItemList users = GetUsers(instruction_list_, logical_buffer,
503 points_to_analysis, &has_indirect_uses);
504 return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
505 logical_buffer->shape(), std::move(users), live_out,
506 has_indirect_uses);
507 }
508
509 // Create a new buffer representing a rematerialization of given buffer for
510 // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,ItemList && rematerialized_uses)511 Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
512 ItemList&& rematerialized_uses) {
513 CHECK(original_buffer.defining_instruction->placed)
514 << original_buffer.defining_instruction->instruction->name();
515 CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
516 CHECK(!original_buffer.live_out) << original_buffer.ToString();
517 for (Item* use : rematerialized_uses) {
518 CHECK(!use->placed) << use->instruction->name();
519 }
520 return NewBuffer(remat_item, original_buffer.shape,
521 std::move(rematerialized_uses), /*live_out=*/false,
522 /*has_indirect_uses=*/false);
523 }
524
525 // Return number of bytes allocated for the buffer with the given id. Buffers
526 // allocated by the calling computation (eg, parameter and output buffers) are
527 // considered to have zero bytes because the memory is accounted for in a
528 // different computation.
AllocatedSize(BufferId buffer_id) const529 int64 AllocatedSize(BufferId buffer_id) const {
530 const Buffer& buffer = buffers_.at(buffer_id);
531 HloInstruction* inst = buffer.defining_instruction->instruction;
532 HloOpcode def_opcode = inst->opcode();
533 if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
534 return 0;
535 } else {
536 return buffer.size;
537 }
538 }
539
540 // Returns true if BeginInstruction and EndInstruction has been called for the
541 // given instruction.
IsFinished(Item * item) const542 bool IsFinished(Item* item) const {
543 return item->placed && item != in_progress_item_;
544 }
545
546 // Returns whether the given buffer is being used by the in-progress
547 // instruction.
IsInUse(BufferId buffer_id) const548 bool IsInUse(BufferId buffer_id) const {
549 if (in_progress_item_ == nullptr) {
550 return false;
551 }
552 const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
553 return absl::c_linear_search(in_progress_uses, buffer_id);
554 }
555
556 // Returns whether the given buffer is live at the current program
557 // point.
IsCurrentlyLive(BufferId buffer_id) const558 bool IsCurrentlyLive(BufferId buffer_id) const {
559 const Buffer& buffer = buffers_[buffer_id];
560 return (buffer.defining_instruction->placed &&
561 buffer.unfinished_user_count > 0);
562 }
563
564 // Returns whether the given instruction is live at the current program
565 // point.
IsInstructionCurrentlyLive(Item * instruction) const566 bool IsInstructionCurrentlyLive(Item* instruction) const {
567 // If the instruction has not started yet, it is not alive.
568 if (!IsPlaced(instruction->instruction)) {
569 return false;
570 }
571 for (const HloInstruction* user : instruction->instruction->users()) {
572 if (!IsPlaced(user)) {
573 // If there is an unplaced user, consider this instruction currently
574 // live.
575 return true;
576 }
577 }
578 return false;
579 }
580
581 // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,const Shape & shape,ItemList && users,bool live_out,bool has_indirect_uses)582 Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
583 ItemList&& users, bool live_out, bool has_indirect_uses) {
584 int buffer_id = buffers_.size();
585 buffers_.push_back(Buffer{
586 buffer_id, defining_instruction, size_function_(shape), shape, live_out,
587 has_indirect_uses, users, static_cast<int64>(users.size())});
588 return buffers_.back();
589 }
590
591 const HloComputation* computation_;
592
593 // Instruction list containing the ordering of instructions in
594 // computation_. This is the order in which instructions are placed
595 // (BeginInstruction/EndInstruction calls).
596 const InstructionList& instruction_list_;
597
598 // Size function returns the bytes of a given buffer.
599 const HloRematerialization::ShapeSizeFunction& size_function_;
600
601 // Converts a shape into compact form, returns the same shape if a shape is
602 // already considered compact.
603 const HloRematerialization::CompactShapeFunction& compact_shape_function_;
604
605 // A map that caches existing known compact shape for each instruction.
606 absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
607
608 // Memory usage at the currently placed instruction.
609 int64 memory_usage_ = 0;
610
611 // The instruction currently being placed. This value is non-null only
612 // between the calling of BeginInstruction and EndInstruction.
613 Item* in_progress_item_ = nullptr;
614
615 HloRematerialization::RematerializationMode mode_;
616 // All buffers in the computation.
617 std::vector<Buffer> buffers_;
618 };
619
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const HloRematerialization::CompactShapeFunction & compact_shape_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list,HloRematerialization::RematerializationMode mode)620 MemoryUsageTracker::MemoryUsageTracker(
621 const HloComputation* computation,
622 const HloRematerialization::ShapeSizeFunction& size_function,
623 const HloRematerialization::CompactShapeFunction& compact_shape_function,
624 const TuplePointsToAnalysis& points_to_analysis,
625 const InstructionList& instruction_list,
626 HloRematerialization::RematerializationMode mode)
627 : computation_(computation),
628 instruction_list_(instruction_list),
629 size_function_(size_function),
630 compact_shape_function_(compact_shape_function),
631 mode_(mode) {
632 PointsToSet::BufferSet live_out_set =
633 points_to_analysis.GetPointsToSet(computation_->root_instruction())
634 .CreateFlattenedSet();
635 absl::flat_hash_map<const LogicalBuffer*, BufferId>
636 logical_buffer_to_buffer_id;
637
638 for (auto* item = instruction_list_.first(); item != nullptr;
639 item = instruction_list_.next(item)) {
640 const HloInstruction* const instruction = item->instruction;
641 for (const LogicalBuffer* logical_buffer :
642 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
643 Buffer* buffer;
644 if (instruction->opcode() == HloOpcode::kWhile) {
645 // The while instruction defines no new buffers. Instead it reuses the
646 // buffers of its operand. Find the Buffer of its operand at the
647 // proper ShapeIndex.
648 const PointsToSet& operand_points_to =
649 points_to_analysis.GetPointsToSet(instruction->operand(0));
650 CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
651 const LogicalBuffer* source_logical_buffer =
652 operand_points_to.element(logical_buffer->index())[0];
653 buffer =
654 &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
655
656 // Mark buffer as has indirect use and live out.
657 buffer->has_indirect_uses = true;
658 buffer->live_out =
659 buffer->live_out || ContainsKey(live_out_set, logical_buffer);
660
661 // Add users of while to Buffer users.
662 bool unused;
663 for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
664 points_to_analysis, &unused)) {
665 if (!absl::c_linear_search(buffer->users, user_item)) {
666 buffer->users.push_back(user_item);
667 buffer->unfinished_user_count++;
668 user_item->buffers_used.push_back(buffer->id);
669 }
670 }
671 } else {
672 buffer = &CreateBufferFromLogicalBuffer(
673 logical_buffer, points_to_analysis,
674 ContainsKey(live_out_set, logical_buffer));
675 item->buffers_defined.push_back(buffer->id);
676 for (Item* user : buffer->users) {
677 user->buffers_used.push_back(buffer->id);
678 }
679 }
680
681 logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
682 }
683
684 // Trace the output of each instruction. This is so that we can properly
685 // track which outputs does GTEs have.
686 for (const LogicalBuffer* logical_buffer :
687 points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
688 item->buffers_output.push_back(
689 logical_buffer_to_buffer_id[logical_buffer]);
690 }
691 }
692 XLA_VLOG_LINES(10, ToString());
693 DCHECK(Check());
694 }
695
BeginInstruction(Item * item)696 Status MemoryUsageTracker::BeginInstruction(Item* item) {
697 const HloInstruction* instruction = item->instruction;
698 VLOG(3) << "BeginInstruction " << instruction->name();
699 TF_RET_CHECK(in_progress_item_ == nullptr);
700 in_progress_item_ = item;
701
702 item->placed = true;
703
704 // All buffers defined by this instruction need memory.
705 for (BufferId buffer_id : item->buffers_defined) {
706 VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
707 << " is now live.";
708 memory_usage_ += AllocatedSize(buffer_id);
709 }
710
711 // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
712 // operand. Account for this potential reuse here.
713
714 VLOG(3) << " memory usage = " << memory_usage_;
715 VLOG(10) << ToString();
716
717 if (VLOG_IS_ON(1)) {
718 DCHECK(Check());
719 }
720 return Status::OK();
721 }
722
EndInstruction()723 Status MemoryUsageTracker::EndInstruction() {
724 TF_RET_CHECK(in_progress_item_ != nullptr);
725 VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
726
727 for (BufferId buffer_id : in_progress_item_->buffers_used) {
728 Buffer& buffer = buffers_.at(buffer_id);
729 buffer.unfinished_user_count--;
730 CHECK_GE(buffer.unfinished_user_count, 0)
731 << buffer.ToString() << " has negative unfinished use count.";
732 if (buffer.unfinished_user_count == 0) {
733 // Buffer is now dead.
734 VLOG(3) << " " << buffer.ToString() << " is now dead.";
735 memory_usage_ -= AllocatedSize(buffer_id);
736 // The memory usage can become negative inside the computation as we can
737 // free up the parameter space and reuse it for other tensors.
738 }
739 }
740
741 // If any buffer defined by this instruction has no uses, then memory can be
742 // reclaimed immediately.
743 for (BufferId buffer_id : in_progress_item_->buffers_defined) {
744 const Buffer& buffer = buffers_.at(buffer_id);
745 if (buffer.unfinished_user_count == 0) {
746 VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
747 memory_usage_ -= AllocatedSize(buffer_id);
748 // The memory usage can become negative inside the computation as we can
749 // free up the parameter space and reuse it for other tensors.
750 }
751 }
752
753 in_progress_item_ = nullptr;
754
755 VLOG(3) << " memory usage = " << memory_usage_;
756 VLOG(10) << ToString();
757
758 if (VLOG_IS_ON(1)) {
759 DCHECK(Check());
760 }
761 return Status::OK();
762 }
763
MemoryReducedIfCompressed(Item * item,const Shape & compact_shape) const764 int64 MemoryUsageTracker::MemoryReducedIfCompressed(
765 Item* item, const Shape& compact_shape) const {
766 CHECK_NE(in_progress_item_, nullptr);
767 if (!item->placed || item == in_progress_item_) {
768 return 0;
769 }
770
771 int64 memory_reduced = 0;
772
773 // We only compress a single piece of an output at one time.
774 CHECK_EQ(item->buffers_output.size(), 1);
775 BufferId buffer_id = item->buffers_output[0];
776 if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) &&
777 IsInstructionCurrentlyLive(item)) {
778 const Buffer& buffer = buffers_.at(buffer_id);
779 memory_reduced += buffer.size;
780
781 int64 compact_shape_size = size_function_(compact_shape);
782 // Account for buffers that are compressed after instruction.
783 memory_reduced -= compact_shape_size;
784 }
785 return memory_reduced;
786 }
787
MemoryReducedIfRematerialized(Item * item) const788 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
789 CHECK_NE(in_progress_item_, nullptr);
790 if (!item->placed || item == in_progress_item_) {
791 return 0;
792 }
793
794 // TODO(b/37687140): Rematerialization can increase peak memory consumption at
795 // an earlier point in the program if rematerialization extends the live range
796 // of the operand of the instruction being rematerialized across the live
797 // range of the value of instruction being rematerialized. Don't rematerialize
798 // in this case (ie, return 0 here).
799
800 // Compute the amount of memory reduced (if any) by rematerializing
801 // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer
802 // be live at this program point, so initially set memory_reduced to the
803 // size of its defined values.
804 int64 memory_reduced = 0;
805 for (BufferId buffer_id : item->buffers_defined) {
806 // Avoid rematerializing instructions with indirect uses as it is difficult
807 // to reason about liveness after rematerializing the instruction.
808 // TODO(b/37714814): Consider rematerializing instructions with indirect
809 // uses.
810 if (buffers_.at(buffer_id).has_indirect_uses) {
811 return 0;
812 }
813
814 if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
815 memory_reduced += AllocatedSize(buffer_id);
816 }
817 }
818
819 // Account for any logical buffers whose live range must be extended across
820 // this program point.
821 for (BufferId buffer_id : item->buffers_used) {
822 if (!IsCurrentlyLive(buffer_id)) {
823 // This logical buffer is used by 'instruction' but is not live at this
824 // program point. Rematerializing 'instruction' will extend the buffer's
825 // live range across this program point.
826 memory_reduced -= AllocatedSize(buffer_id);
827 }
828 }
829
830 return memory_reduced;
831 }
832
MemoryReducedIfRematerialized(absl::Span<const Item * const> items) const833 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
834 absl::Span<const Item* const> items) const {
835 CHECK_NE(in_progress_item_, nullptr);
836 int64 memory_reduced = 0;
837 absl::flat_hash_set<const Item*> remat_candidates;
838
839 for (const Item* item : items) {
840 if (!item->placed || item == in_progress_item_) {
841 LOG(WARNING) << "Unplaced item or in progress item being checked for "
842 "rematerialization.";
843 return 0;
844 }
845
846 // Compute the amount of memory reduced (if any) by rematerializing
847 // 'item->instruction'. The LogicalBuffers defined by 'item->instruction'
848 // will no longer be live at this program point, so initially set
849 // memory_reduced to the size of its defined values.
850 for (BufferId buffer_id : item->buffers_defined) {
851 // Avoid rematerializing instructions with indirect uses as it is
852 // difficult to reason about liveness after rematerializing the
853 // instruction.
854 // Avoid rematerializing instructions with live out buffers.
855 // TODO(mpurohit): Check why live_out buffers are an issue here.
856 if (buffers_.at(buffer_id).has_indirect_uses ||
857 buffers_.at(buffer_id).live_out) {
858 return 0;
859 }
860
861 if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
862 memory_reduced += AllocatedSize(buffer_id);
863 }
864 }
865
866 // Account for any logical buffers whose live range must be extended across
867 // this program point.
868 for (BufferId buffer_id : item->buffers_used) {
869 if (!IsCurrentlyLive(buffer_id)) {
870 // This logical buffer is used by 'item->instruction' but is not live at
871 // this program point. Rematerializing 'item->instruction' will extend
872 // the buffer's live range across this program point unless it is
873 // defined by an instruction that is also being rematerialized.
874 Item* defining_instruction =
875 buffers_.at(buffer_id).defining_instruction;
876 if (!remat_candidates.contains(defining_instruction)) {
877 memory_reduced -= AllocatedSize(buffer_id);
878 }
879 }
880 }
881 remat_candidates.insert(item);
882 }
883
884 return memory_reduced;
885 }
886
AddCompressInstructions(Item * original_item,Item * compressed_item,Item * uncompressed_item)887 Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
888 Item* compressed_item,
889 Item* uncompressed_item) {
890 // Original buffer is now dead.
891 memory_usage_ -= size_function_(original_item->instruction->shape());
892 // Compressed buffer is now alive.
893 memory_usage_ += size_function_(compressed_item->instruction->shape());
894
895 ItemList placed_users;
896 ItemList unplaced_users;
897 CHECK_EQ(original_item->buffers_output.size(), 1);
898 BufferId original_buffer_id = original_item->buffers_output[0];
899 Buffer& original_buffer = buffers_.at(original_buffer_id);
900 for (Item* user : original_buffer.users) {
901 if (user->placed) {
902 CHECK(IsFinished(user)) << user->instruction->name();
903 placed_users.push_back(user);
904 } else {
905 unplaced_users.push_back(user);
906 }
907 }
908 original_buffer.users = std::move(placed_users);
909 original_buffer.unfinished_user_count = 0;
910 original_buffer.users.push_back(compressed_item);
911 Buffer& compressed_buffer =
912 NewBuffer(compressed_item, compressed_item->instruction->shape(),
913 {uncompressed_item}, /*live_out=*/false,
914 /*has_indirect_uses=*/false);
915 compressed_item->buffers_used = original_item->buffers_output;
916 compressed_item->buffers_output = {compressed_buffer.id};
917 compressed_item->buffers_defined.push_back(compressed_buffer.id);
918
919 Buffer& uncompressed_buffer =
920 NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
921 std::move(unplaced_users), /*live_out=*/false,
922 /*has_indirect_uses=*/false);
923
924 uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
925 uncompressed_item->buffers_output = {uncompressed_buffer.id};
926 uncompressed_item->buffers_defined = {uncompressed_buffer.id};
927
928 for (Item* user : uncompressed_buffer.users) {
929 BufferIdList& buffers_used = user->buffers_used;
930 std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
931 uncompressed_buffer.id);
932 }
933
934 return Status::OK();
935 }
936
AddRematerializedInstruction(Item * original_item,Item * remat_item)937 Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
938 Item* remat_item) {
939 VLOG(3) << "AddRematerializedInstruction: original_instruction = "
940 << original_item->instruction->name()
941 << ", remat_instruction = " << remat_item->instruction->name();
942
943 TF_RET_CHECK(in_progress_item_ != nullptr);
944 TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
945 TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
946
947 // Construct the list of buffers used and defined by the rematerialization.
948 remat_item->buffers_used = original_item->buffers_used;
949
950 // Account for the additional buffer uses created by the new rematerialization
951 // instruction. Update memory usage if the rematerialization makes a dead
952 // buffer live again.
953 for (BufferId buffer_id : original_item->buffers_used) {
954 Buffer& buffer = buffers_.at(buffer_id);
955 if (buffer.unfinished_user_count == 0) {
956 // Buffer used by this instruction was dead, now is alive.
957 memory_usage_ += AllocatedSize(buffer.id);
958 }
959
960 buffer.unfinished_user_count++;
961 buffer.users.push_back(remat_item);
962 }
963
964 // Create a new set of Buffers defined by the new rematerialization
965 // instruction. Update the internal data structures and memory use to account
966 // for them.
967 for (BufferId old_buffer_id : original_item->buffers_defined) {
968 Buffer& old_buffer = buffers_.at(old_buffer_id);
969
970 ItemList placed_users;
971 ItemList unplaced_users;
972 for (Item* user : old_buffer.users) {
973 if (user->placed) {
974 CHECK(IsFinished(user)) << user->instruction->name();
975 placed_users.push_back(user);
976 } else {
977 unplaced_users.push_back(user);
978 }
979 }
980 old_buffer.users = std::move(placed_users);
981 old_buffer.unfinished_user_count = 0;
982
983 // Buffer is now dead.
984 memory_usage_ -= AllocatedSize(old_buffer.id);
985
986 Buffer& new_buffer =
987 RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
988
989 remat_item->buffers_defined.push_back(new_buffer.id);
990 for (Item* user : new_buffer.users) {
991 BufferIdList& buffers_used = user->buffers_used;
992 std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
993 new_buffer.id);
994 }
995 }
996
997 VLOG(3) << " memory usage = " << memory_usage_;
998 XLA_VLOG_LINES(10, ToString());
999
1000 DCHECK(Check());
1001
1002 return Status::OK();
1003 }
1004
ToString() const1005 string MemoryUsageTracker::ToString() const {
1006 string output =
1007 absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
1008 absl::StrAppend(&output,
1009 "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
1010 memory_usage(), " bytes)");
1011 for (auto* item = instruction_list_.first(); item != nullptr;
1012 item = instruction_list_.next(item)) {
1013 const HloInstruction* instruction = item->instruction;
1014 string inprogress = item == in_progress_item_ ? " in-progress" : "";
1015 string placed = item->placed ? " placed" : "";
1016 absl::StrAppend(&output, " ", instruction->name(), inprogress, placed,
1017 "\n Defines:\n");
1018 for (BufferId buffer_id : item->buffers_defined) {
1019 const Buffer& buffer = buffers_[buffer_id];
1020 string live = IsCurrentlyLive(buffer_id) ? " live" : "";
1021 absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
1022 buffer.unfinished_user_count, " unfinished uses\n");
1023 }
1024 absl::StrAppend(&output, " Uses:\n");
1025 for (BufferId buffer_id : item->buffers_used) {
1026 absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
1027 }
1028 }
1029 return output;
1030 }
1031
GetCompactShape(const HloInstruction * hlo)1032 StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
1033 auto it = compact_shape_.find(hlo);
1034 if (it != compact_shape_.end()) {
1035 return it->second;
1036 }
1037 const Shape& original_shape = hlo->shape();
1038 TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
1039 compact_shape_[hlo] = min_shape;
1040 return min_shape;
1041 }
1042
Check() const1043 bool MemoryUsageTracker::Check() const {
1044 auto elements_are_unique = [](const BufferIdList& vec) {
1045 return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
1046 };
1047
1048 // Verify buffers_defined per instruction.
1049 for (auto* instruction : computation_->instructions()) {
1050 const BufferIdList& defined_buffers =
1051 instruction_list_.GetItem(instruction)->buffers_defined;
1052 CHECK(elements_are_unique(defined_buffers))
1053 << "Instruction " << instruction->name()
1054 << " does not have unique defined buffers: "
1055 << absl::StrJoin(
1056 defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
1057 absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1058 });
1059
1060 for (const Buffer& buffer : buffers_) {
1061 if (buffer.defining_instruction->instruction == instruction) {
1062 CHECK(absl::c_linear_search(defined_buffers, buffer.id))
1063 << "Instruction " << instruction->name()
1064 << " defined buffers is missing: " << buffer.ToString();
1065 }
1066 }
1067 }
1068
1069 // Verify buffers_used per instruction.
1070 for (auto* instruction : computation_->instructions()) {
1071 const BufferIdList& used_buffers =
1072 instruction_list_.GetItem(instruction)->buffers_used;
1073 CHECK(elements_are_unique(used_buffers))
1074 << "Instruction " << instruction->name()
1075 << " does not have unique used buffers: "
1076 << absl::StrJoin(
1077 used_buffers, ", ", [this](string* out, BufferId buffer_id) {
1078 absl::StrAppend(out, buffers_.at(buffer_id).ToString());
1079 });
1080 }
1081 for (const Buffer& buffer : buffers_) {
1082 int64 unfinished_uses = 0;
1083 for (Item* user : buffer.users) {
1084 const BufferIdList& used_buffers = user->buffers_used;
1085 CHECK(absl::c_linear_search(used_buffers, buffer.id))
1086 << "Instruction " << user->instruction->name()
1087 << " used buffers is missing " << buffer.ToString();
1088 if (!IsFinished(user)) {
1089 unfinished_uses++;
1090 }
1091 }
1092 CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
1093 << "Incorrect unplaced use count for " << buffer.ToString();
1094 }
1095 return true;
1096 }
1097
1098 // Computes and returns the cost of rematerializing the given instruction.
1099 // Cost per rematerialized instruction is defined as:
1100 //
1101 // memory_limit_bytes / memory_reduced
1102 //
1103 // The idea is to choose the operation that will save the most memory for
1104 // rematerialization and do not worry about how much the compute costs since
1105 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64 memory_reduced,int64 memory_limit_bytes)1106 int64 RematerializationCost(const HloInstruction* instruction,
1107 const MemoryUsageTracker& memory_tracker,
1108 int64 memory_reduced, int64 memory_limit_bytes) {
1109 // If none of the users of 'instruction' have been placed in the sequence (as
1110 // tracked by memory_tracker), then rematerialization of 'instruction' is a
1111 // zero-cost move of 'instruction' in the sequence.
1112 if (!absl::c_any_of(instruction->users(),
1113 [&memory_tracker](const HloInstruction* inst) {
1114 return memory_tracker.IsPlaced(inst);
1115 })) {
1116 return 0;
1117 }
1118
1119 CHECK_GT(memory_reduced, 0);
1120 // Return the inverse of the benefit of rematerialization.
1121 return memory_limit_bytes / memory_reduced;
1122 }
1123
1124 // Selects and returns the best candidate instruction for rematerialization.
1125 // The instruction with lowest rematerialization cost is selected among those
1126 // candidate which reduce memory use at the program point of the current
1127 // instruction as indicated by memory_tracker. nullptr is returned if no
1128 // candidate can be found.
1129 std::pair<Item*, RematStrategy>
PickRematerializationCandidate(const InstructionList & instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)1130 MemoryUsageTracker::PickRematerializationCandidate(
1131 const InstructionList& instruction_list, int64 memory_limit_bytes,
1132 absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
1133 Item* best_item = nullptr;
1134 int64 best_cost = 0;
1135 RematStrategy best_strategy;
1136
1137 VLOG(5) << "Picking candidate";
1138
1139 // TODO(b/35244891): This is currently quadratic in the number of HLO
1140 // instructions.
1141 for (auto* item = instruction_list.first(); item != nullptr;
1142 item = instruction_list.next(item)) {
1143 if (!item->placed) {
1144 // Only iterate up to the currently placed instruction.
1145 // We are trying to reduce memory usage at the placed
1146 // instruction so rematerializing later values is of no benefit.
1147 break;
1148 }
1149 HloInstruction* candidate = item->instruction;
1150 VLOG(5) << "considering rematerialization candidate " << candidate->name();
1151
1152 if (item->blacklisted) {
1153 // Skip instructions on the blacklist to avoid infinite loops of
1154 // rematerializing the same instruction(s) repeatedly.
1155 VLOG(5) << "candidate " << candidate->name()
1156 << " is excluded from rematerialization";
1157 continue;
1158 }
1159 if (!CanBeRematerialized(candidate, remat_able)) {
1160 VLOG(5) << "candidate " << candidate->name()
1161 << " not viable: is not rematerializable";
1162
1163 continue;
1164 }
1165
1166 if (item->buffers_output.size() == 1 &&
1167 (mode_ == HloRematerialization::RematerializationMode::kCompressOnly ||
1168 mode_ == HloRematerialization::RematerializationMode::
1169 kRecomputeAndCompress)) {
1170 // Only consider compressing single output instruction.
1171 const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
1172
1173 if (item->placed && item != in_progress_item_ &&
1174 !output_buffer.live_out) {
1175 const Shape& original_shape = item->instruction->shape();
1176 if (original_shape.IsArray()) {
1177 Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
1178 const int64 memory_reduced =
1179 MemoryReducedIfCompressed(item, compact_shape);
1180 if (memory_reduced > 0) {
1181 const int64 cost = memory_limit_bytes / memory_reduced;
1182 if (best_item == nullptr || cost < best_cost) {
1183 VLOG(3) << "candidate " << candidate->name() << "("
1184 << candidate->ToShortString() << ")"
1185 << " now best when compressed into "
1186 << compact_shape.ToString(true);
1187 RematStrategy strategy;
1188 strategy.kind = RematStrategy::kCompress;
1189 best_strategy = strategy;
1190 best_strategy.compact_shape = compact_shape;
1191 best_item = item;
1192 best_cost = cost;
1193 }
1194 }
1195 }
1196 }
1197 }
1198
1199 // If any of the candidate's control successor has been placed, we need
1200 // to skip this candidate. Otherwise we will violate control dependency.
1201 bool control_successor_placed = std::any_of(
1202 candidate->control_successors().begin(),
1203 candidate->control_successors().end(),
1204 [this](const HloInstruction* inst) { return IsPlaced(inst); });
1205
1206 if (control_successor_placed) {
1207 continue;
1208 }
1209
1210 // Do not consider recomputation in compress-only mode.
1211 if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
1212 continue;
1213 }
1214
1215 const int64 memory_reduced = MemoryReducedIfRematerialized(item);
1216
1217 if (memory_reduced > 0) {
1218 const int cost =
1219 RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
1220
1221 VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
1222 << memory_reduced << ", cost per byte " << cost;
1223
1224 if (best_item == nullptr || cost < best_cost) {
1225 VLOG(5) << "candidate " << candidate->name() << " now best";
1226 best_strategy.kind = RematStrategy::kRecompute;
1227 best_item = item;
1228 best_cost = cost;
1229 }
1230 }
1231 }
1232 return {best_item, best_strategy};
1233 }
1234
HasUnplacedUsers(Item * item) const1235 bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
1236 for (BufferId buffer_id : item->buffers_defined) {
1237 const Buffer& buffer = buffers_.at(buffer_id);
1238 for (Item* user : buffer.users) {
1239 if (!user->placed) {
1240 return true;
1241 }
1242 }
1243 }
1244 return false;
1245 }
1246
RematerializeInstructions(MemoryUsageTracker * memory_tracker,std::vector<Item * > * best_items,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,InstructionList * instruction_list)1247 StatusOr<int64> RematerializeInstructions(
1248 MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
1249 absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1250 InstructionList* instruction_list) {
1251 int64 net_instructions_added = 0;
1252 int64 total_memory_saved =
1253 memory_tracker->MemoryReducedIfRematerialized(*best_items);
1254 std::vector<string> instruction_names(best_items->size());
1255 // Rematerialize the block of instructions in the reverse order to account for
1256 // dependencies between instructions in best_items.
1257 for (int i = best_items->size() - 1; i >= 0; --i) {
1258 Item* best_item = (*best_items)[i];
1259 HloInstruction* best = best_item->instruction;
1260 instruction_names[i] = best->name();
1261 HloComputation* computation = best->parent();
1262
1263 // If the item to remat has no unplaced users, then skip the
1264 // rematerialization. Such an instruction can appear in best_items because
1265 // it is part of a good block, but does not itself add any benefit.
1266 if (!memory_tracker->HasUnplacedUsers(best_item)) {
1267 continue;
1268 }
1269
1270 HloInstruction* remat =
1271 computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1272
1273 // Add control dependencies to the new operation.
1274 for (auto successor : best->control_successors()) {
1275 TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1276 }
1277 for (auto predecessor : best->control_predecessors()) {
1278 TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1279 }
1280
1281 Item* remat_item = instruction_list->CreateItem(remat);
1282
1283 // Replace each remaining use of 'best' with the rematerialization.
1284 std::vector<HloInstruction*> best_users_copy = best->users();
1285 for (HloInstruction* user : best_users_copy) {
1286 if (!memory_tracker->IsPlaced(user)) {
1287 VLOG(2) << " Replacing use of " << best->name() << " in "
1288 << user->name() << " with " << remat->name();
1289 TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
1290 }
1291 }
1292
1293 // Account for the rematerialization in the memory tracker.
1294 TF_RETURN_IF_ERROR(
1295 memory_tracker->AddRematerializedInstruction(best_item, remat_item));
1296
1297 // Insert rematerialized instruction right before the earliest unplaced
1298 // use of the instruction *and* the earliest unplaced last use of any
1299 // operands of remat. Unplaced uses of the remat's operands are included
1300 // because we don't want to extend the live range of remat's operands as
1301 // this could increase memory usage.
1302 ItemList place_before;
1303 for (auto user : remat->users()) {
1304 place_before.push_back(instruction_list->GetItem(user));
1305 }
1306 for (auto* operand : remat->operands()) {
1307 for (auto* operand_user : operand->users()) {
1308 if (operand_user != remat) {
1309 Item* operand_user_item = instruction_list->GetItem(operand_user);
1310 if (!operand_user_item->placed) {
1311 place_before.push_back(operand_user_item);
1312 }
1313 }
1314 }
1315 }
1316 // Insert rematerialized instruction before any of its successors to
1317 // preserve ordering regarding control dependency.
1318 for (auto successor : remat->control_successors()) {
1319 Item* successor_item = instruction_list->GetItem(successor);
1320 // Assert to make sure we never remat an operation with control
1321 // successor already placed.
1322 CHECK(!successor_item->placed) << successor_item->instruction->name();
1323 place_before.push_back(successor_item);
1324 }
1325 instruction_list->InsertBeforeInstructions(remat_item, place_before);
1326
1327 // If the rematerialized instruction is dead then rematerialization is
1328 // essentially a move. Don't delete the instruction now because we don't
1329 // want duplicate HloInstruction* values during the course of the
1330 // transformation because we keep maps with HloInstruction* values as
1331 // keys.
1332 if (best->users().empty()) {
1333 VLOG(2) << best->name() << " is now dead";
1334 if (ContainsKey(*remat_move_instructions, best)) {
1335 // Previously, 'best' was a rematerialization which killed the
1336 // instruction it was a copying of. Now 'remat' is a rematerialization
1337 // of 'best' and kills 'best'. Stop rematerializing this instruction
1338 // to avoid an infinite loop.
1339 instruction_list->Blacklist(remat);
1340 }
1341 remat_move_instructions->insert(remat);
1342 } else {
1343 net_instructions_added++;
1344 }
1345 }
1346 VLOG(1) << "Rematerializing instructions ["
1347 << absl::StrJoin(instruction_names, ", ") << "] (saving "
1348 << HumanReadableNumBytes(total_memory_saved) << ")";
1349 return net_instructions_added;
1350 }
1351
CompressInstruction(MemoryUsageTracker * memory_tracker,Item * best_item,const Shape & compact_shape,InstructionList * instruction_list)1352 StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
1353 Item* best_item, const Shape& compact_shape,
1354 InstructionList* instruction_list) {
1355 HloInstruction* best = best_item->instruction;
1356 VLOG(5) << "Transposing instruction " << best->name() << " (saving "
1357 << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1358 best_item, compact_shape))
1359 << ") to" << compact_shape.ToString(true);
1360
1361 HloComputation* computation = best->parent();
1362
1363 HloInstruction* compressed = computation->AddInstruction(
1364 HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
1365
1366 HloInstruction* uncompressed = computation->AddInstruction(
1367 HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
1368
1369 Item* compressed_item = instruction_list->CreateItem(compressed);
1370 compressed_item->placed = true;
1371
1372 Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
1373
1374 // Replace each remaining use of 'best' with the uncompressed.
1375 std::vector<HloInstruction*> best_users_copy = best->users();
1376 for (HloInstruction* user : best_users_copy) {
1377 if (!memory_tracker->IsPlaced(user)) {
1378 VLOG(5) << " Replacing use of " << best->name() << " in " << user->name()
1379 << " with " << uncompressed->name();
1380 TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
1381 }
1382 }
1383
1384 // Account for the rematerialization in the memory tracker.
1385 TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
1386 best_item, compressed_item, uncompressed_item));
1387
1388 // Insert rematerialized instruction right before the earliest unplaced
1389 // use of the instruction.
1390 ItemList place_before;
1391 for (auto user : uncompressed->users()) {
1392 place_before.push_back(instruction_list->GetItem(user));
1393 }
1394
1395 instruction_list->Blacklist(compressed_item->instruction);
1396 instruction_list->Blacklist(uncompressed_item->instruction);
1397
1398 instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
1399
1400 instruction_list->InsertAfterInstructions(compressed_item, {best_item});
1401
1402 return 2;
1403 }
1404
1405 } // namespace
1406
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order) const1407 StatusOr<int64> HloRematerialization::ComputePeakMemory(
1408 const HloComputation* computation,
1409 const HloInstructionSequence& order) const {
1410 InstructionList instruction_list(order);
1411 MemoryUsageTracker tracker(computation, size_function_,
1412 compact_shape_function_, *points_to_analysis_,
1413 instruction_list, mode_);
1414 int64 peak_memory = tracker.memory_usage();
1415 for (auto* item = instruction_list.first(); item != nullptr;
1416 item = instruction_list.next(item)) {
1417 const HloInstruction* instruction = item->instruction;
1418 TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
1419 TF_ASSIGN_OR_RETURN(int64 callee_usage,
1420 CalledComputationsMemoryUsage(instruction));
1421 peak_memory =
1422 std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage);
1423 TF_RETURN_IF_ERROR(tracker.EndInstruction());
1424 }
1425 VLOG(1) << "Peak memory for " << computation->name() << ": "
1426 << HumanReadableNumBytes(peak_memory);
1427 return peak_memory;
1428 }
1429
CalledComputationsMemoryUsage(const HloInstruction * instruction) const1430 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
1431 const HloInstruction* instruction) const {
1432 const CallSite* callsite =
1433 call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
1434 if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
1435 return 0;
1436 }
1437 int64 callee_usage = 0;
1438 for (const HloComputation* computation : callsite->called_computations()) {
1439 TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
1440 callee_usage += computation_peak_memory_.at(computation);
1441 }
1442 return callee_usage;
1443 }
1444
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64 memory_limit_bytes)1445 StatusOr<bool> HloRematerialization::RematerializeComputation(
1446 HloComputation* computation, HloSchedule* schedule,
1447 int64 memory_limit_bytes) {
1448 VLOG(1) << "Rematerializing computation " << computation->name()
1449 << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
1450 VLOG(1) << "peak memory usage is "
1451 << HumanReadableNumBytes(computation_peak_memory_.at(computation));
1452 CHECK(!ContainsKey(rematerialized_computations_, computation));
1453
1454 InstructionList instruction_list(schedule->sequence(computation));
1455 MemoryUsageTracker memory_tracker(
1456 computation, size_function_, compact_shape_function_,
1457 *points_to_analysis_, instruction_list, mode_);
1458 bool changed = false;
1459
1460 // If the rematerialization makes the source instruction dead, then the
1461 // rematerialization is added to 'remat_move_instructions' (the
1462 // rematerialization is essentially a move). If the next rematerialization of
1463 // the instruction is also a move then the rematerialization is added to the
1464 // blacklist.
1465 absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
1466
1467 // The map from instructions to their rematerializable status.
1468 absl::flat_hash_map<const HloInstruction*, bool> remat_able;
1469
1470 // The peak memory of the computation at any point in the instruction
1471 // sequence.
1472 int64 peak_memory = memory_tracker.memory_usage();
1473
1474 // Total count of instructions rematerialized.
1475 int64 remat_count = 0;
1476 // Total count of clones created minus number of original rematerialized
1477 // instructions which are dead.
1478 int64 net_instructions_added = 0;
1479
1480 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1481
1482 // Iterate through all instructions in the sequence. At each instruction
1483 // (program point) if memory_usage exceeds the specified limit then
1484 // rematerialize HLO instructions until memory_usage is reduced.
1485 int64 instruction_index = 0;
1486 for (auto* item = instruction_list.first(); item != nullptr;
1487 item = instruction_list.next(item)) {
1488 const HloInstruction* instruction = item->instruction;
1489 TF_ASSIGN_OR_RETURN(int64 callee_usage,
1490 CalledComputationsMemoryUsage(instruction));
1491 TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1492
1493 VLOG(2) << "Program point at " << instruction->name()
1494 << ", memory usage = " << memory_tracker.memory_usage()
1495 << ", callee usage = " << callee_usage << ", [" << instruction_index
1496 << "/" << instruction_list.size() << "]";
1497 instruction_index++;
1498
1499 while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1500 VLOG(2) << "Over memory limit at instruction " << instruction->name()
1501 << ", using "
1502 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1503 callee_usage)
1504 << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1505
1506 Item* best_item;
1507 RematStrategy best_strategy;
1508 std::tie(best_item, best_strategy) =
1509 memory_tracker.PickRematerializationCandidate(
1510 instruction_list, memory_limit_bytes, &remat_able);
1511
1512 if (best_item == nullptr) {
1513 VLOG(3) << "Unable to find rematerialization candidate at program "
1514 "point "
1515 << instruction->name() << ". Memory usage = "
1516 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1517 callee_usage);
1518 break;
1519 }
1520
1521 HloInstruction* best = best_item->instruction;
1522 changed = true;
1523 remat_count++;
1524
1525 int64 num_instructions_added = 0;
1526 if (best_strategy.kind == RematStrategy::kCompress) {
1527 VLOG(1) << "Compressing instruction " << best->name() << " (saving "
1528 << HumanReadableNumBytes(
1529 memory_tracker.MemoryReducedIfCompressed(
1530 best_item, best_strategy.compact_shape))
1531 << ")";
1532
1533 TF_ASSIGN_OR_RETURN(num_instructions_added,
1534 CompressInstruction(&memory_tracker, best_item,
1535 best_strategy.compact_shape,
1536 &instruction_list));
1537 } else {
1538 VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
1539 << HumanReadableNumBytes(
1540 memory_tracker.MemoryReducedIfRematerialized(best_item))
1541 << ")";
1542
1543 std::vector<Item*> best_items{best_item};
1544 TF_ASSIGN_OR_RETURN(num_instructions_added,
1545 RematerializeInstructions(
1546 &memory_tracker, &best_items,
1547 &remat_move_instructions, &instruction_list));
1548 }
1549 net_instructions_added += num_instructions_added;
1550
1551 VLOG(1) << "memory_usage after rematerialization = "
1552 << HumanReadableNumBytes(memory_tracker.memory_usage());
1553 }
1554
1555 const CallSite* callsite = call_graph_node.GetCallSite(instruction);
1556 if (callsite != nullptr &&
1557 callsite->context() == CallContext::kSequential &&
1558 memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1559 // Memory usage exceeds the limit. Try to rematerialize any
1560 // subcomputation(s) that this instruction calls.
1561 VLOG(1) << "Memory usage still over the limit ("
1562 << (memory_tracker.memory_usage() + callee_usage) << " > "
1563 << memory_limit_bytes
1564 << "). Rematerializing computations called by "
1565 << instruction->name();
1566
1567 // Recompute callee usage to account for any rematerialization performed
1568 // in the callee computations.
1569 for (HloComputation* called_computation :
1570 callsite->called_computations()) {
1571 if (!ContainsKey(rematerialized_computations_, called_computation)) {
1572 // Memory limit for the subcomputation is the memory limit less the
1573 // amount of memory used at this point in the computation.
1574 int64 subcomputation_memory_limit_bytes = std::max<int64>(
1575 0, memory_limit_bytes - memory_tracker.memory_usage());
1576 TF_ASSIGN_OR_RETURN(
1577 bool subcomputation_changed,
1578 RematerializeComputation(called_computation, schedule,
1579 subcomputation_memory_limit_bytes));
1580 changed |= subcomputation_changed;
1581 }
1582 }
1583 TF_ASSIGN_OR_RETURN(callee_usage,
1584 CalledComputationsMemoryUsage(instruction));
1585 }
1586
1587 peak_memory = std::max<int64>(peak_memory,
1588 memory_tracker.memory_usage() + callee_usage);
1589 VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
1590
1591 TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
1592 }
1593
1594 // Verify some invariants on the memory tracker.
1595 for (auto* instruction : computation->instructions()) {
1596 CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
1597 }
1598
1599 VLOG(1) << "In computation " << computation->name() << " rematerialized "
1600 << remat_count << " instructions; " << net_instructions_added
1601 << " net instructions added";
1602 VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory)
1603 << " (was "
1604 << HumanReadableNumBytes(computation_peak_memory_.at(computation))
1605 << ")";
1606
1607 // Update peak memory used by computation.
1608 computation_peak_memory_.at(computation) = peak_memory;
1609
1610 // Update order to include rematerialized instructions.
1611 HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
1612 sequence.clear();
1613 for (auto* item = instruction_list.first(); item != nullptr;
1614 item = instruction_list.next(item)) {
1615 HloInstruction* instruction = item->instruction;
1616 sequence.push_back(instruction);
1617 }
1618 rematerialized_computations_.insert(computation);
1619
1620 instructions_rematerialized_ += remat_count;
1621 net_instructions_added_ += net_instructions_added;
1622
1623 return changed;
1624 }
1625
Run(HloModule * module)1626 StatusOr<bool> HloRematerialization::Run(HloModule* module) {
1627 VLOG(1) << "HloRematerialization() with memory limit of "
1628 << HumanReadableNumBytes(memory_limit_bytes_);
1629 XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
1630
1631 // Initialize pass object state.
1632 computation_peak_memory_.clear();
1633 rematerialized_computations_.clear();
1634 instructions_rematerialized_ = 0;
1635 net_instructions_added_ = 0;
1636
1637 TF_RET_CHECK(module->has_schedule());
1638 TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
1639
1640 // Adjust memory limit to account for the output of the entry
1641 // computation. This is necessary because the per-computation accounting in
1642 // MemoryUsageTracker do not include output as these are typically allocated
1643 // by the caller.
1644 int64 module_output_size = 0;
1645 ShapeUtil::ForEachSubshape(
1646 module->result_shape(),
1647 [&module_output_size, module, this](const Shape& subshape,
1648 const ShapeIndex& output_index) {
1649 module_output_size += size_function_(subshape);
1650 });
1651
1652 const int64 adjusted_memory_limit_bytes =
1653 memory_limit_bytes_ - module_output_size;
1654 VLOG(1) << "Adjusted memory limit accounting for output ("
1655 << HumanReadableNumBytes(module_output_size)
1656 << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
1657
1658 // Compute peak memory usage of all computations in the module called in a
1659 // sequential context.
1660 call_graph_ = CallGraph::Build(module);
1661 TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
1662 [this, module](const CallGraphNode& node) -> Status {
1663 if (node.context() == CallContext::kSequential) {
1664 TF_ASSIGN_OR_RETURN(
1665 computation_peak_memory_[node.computation()],
1666 ComputePeakMemory(node.computation(), module->schedule().sequence(
1667 node.computation())));
1668 }
1669 return Status::OK();
1670 },
1671 /*visit_unreachable_nodes=*/false));
1672
1673 // The peak memory usage of the module equals the peak memory use of the entry
1674 // computation plus the output size of the computation. This is because the
1675 // peak memory for a computation does not include the output as this is
1676 // typically accounted for in the caller.
1677 const int64 before_peak_memory =
1678 computation_peak_memory_.at(module->entry_computation()) +
1679 module_output_size;
1680 VLOG(1) << "Peak memory usage of module (before): "
1681 << HumanReadableNumBytes(before_peak_memory);
1682
1683 // Subcomputations called by the entry computation will also be
1684 // rematerialized.
1685 TF_ASSIGN_OR_RETURN(
1686 bool changed,
1687 RematerializeComputation(module->entry_computation(), &module->schedule(),
1688 adjusted_memory_limit_bytes));
1689
1690 // Rematerialization can introduce dead code. This occurs if all uses of an
1691 // instruction are replaced with rematerializations of the instruction.
1692
1693 // Stash away the schedule during copy insertion, to avoid validation failures
1694 // while the module is in flux.
1695 HloSchedule saved_schedule = module->schedule();
1696 module->clear_schedule();
1697 TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
1698 changed |= dead_code_removed;
1699
1700 // After DCE, the module sequence may include instructions which no longer
1701 // exist. Update the schedule and restore it.
1702 TF_RETURN_IF_ERROR(saved_schedule.Update());
1703 TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule)));
1704 VLOG(1) << "Rematerialized " << instructions_rematerialized_
1705 << " instructions in module " << module->name() << "; "
1706 << net_instructions_added_ << " net instructions added";
1707 const int64 current_peak_memory =
1708 computation_peak_memory_.at(module->entry_computation()) +
1709 module_output_size;
1710 VLOG(1) << "Peak memory usage of module now "
1711 << HumanReadableNumBytes(current_peak_memory) << " ("
1712 << current_peak_memory << " bytes), was "
1713 << HumanReadableNumBytes(before_peak_memory) << " ("
1714 << before_peak_memory << " bytes)";
1715 const int64 reduced_peak_memory = before_peak_memory - current_peak_memory;
1716 VLOG(1) << "Reduced peak memory by "
1717 << HumanReadableNumBytes(reduced_peak_memory) << " ("
1718 << reduced_peak_memory << " bytes)";
1719
1720 if (sizes_ != nullptr) {
1721 sizes_->before_bytes = before_peak_memory;
1722 sizes_->after_bytes = current_peak_memory;
1723 }
1724
1725 XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
1726
1727 if (current_peak_memory > memory_limit_bytes_) {
1728 LOG(WARNING) << absl::StrFormat(
1729 "Can't reduce memory use below %s (%d bytes) by rematerialization; "
1730 "only reduced to %s (%d bytes)",
1731 HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
1732 HumanReadableNumBytes(current_peak_memory), current_peak_memory);
1733 }
1734
1735 return changed;
1736 }
1737
1738 } // namespace xla
1739