1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20
21 namespace xla {
22
IsValueAllowedInAlternateMemory(const HloValue * value)23 bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
24 const HloValue* value) {
25 // If the buffer is a tuple, don't use this algorithm for now. The buffers
26 // that are pointed to by the tuple will still use this algorithm. Because
27 // tuples are cheap to place in the alternate memory (they are just pointers)
28 // we don't need to use prefetch/evict logic.
29 if (value->shape().IsTuple()) {
30 VLOG(4) << "Keeping value " << value->ToShortString()
31 << " in default mem because it is a tuple.";
32 return false;
33 }
34
35 // Don't place scalars in the alternate memory.
36 if (ShapeUtil::IsEffectiveScalar(value->shape())) {
37 VLOG(4) << "Keeping value " << value->ToShortString()
38 << " in default mem because it is a scalar.";
39 return false;
40 }
41
42 // The semantics of TupleSelect are weird: TupleSelect doesn't define a
43 // buffer, but just forwards the buffers in the either left or right side.
44 // This means the two different inputs to TupleSelect must not alias, yet they
45 // should be allocated in the same memory space, and both buffers must be kept
46 // alive for the entire live range of TupleSelect. Instead, just don't
47 // allocate TupleSelect in the alternate memory space.
48 // TODO(berkin): Not allocating add-dependencies either since they need to be
49 // treated specially. We should revisit this later.
50 for (const HloPosition& position : value->positions()) {
51 if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
52 position.instruction->opcode() == HloOpcode::kAddDependency) {
53 VLOG(4) << "Keeping value " << value->ToShortString()
54 << " in default mem because it has a tuple-select or "
55 << "add-dependency position.";
56 return false;
57 }
58 }
59
60 // Send and Recv HLOs return a request identifier. These should not be
61 // allocated in the alternate memory.
62 for (const HloPosition& position : value->positions()) {
63 if ((position.instruction->opcode() == HloOpcode::kSend ||
64 position.instruction->opcode() == HloOpcode::kRecv)) {
65 // TODO(berkin): Send/recv buffers need a stable buffer allocation
66 // throughout sending/receiving. Disable memory space allocation for these
67 // for now.
68 if (position.index == ShapeIndex({0})) {
69 VLOG(4) << "Keeping value " << value->ToShortString()
70 << " in default mem because it is a send/recv buffer.";
71 return false;
72 } else if (position.index == ShapeIndex({1})) {
73 VLOG(4) << "Keeping value " << value->ToShortString()
74 << " in default mem because it is a request identifier for "
75 "send/recv.";
76 return false;
77 }
78 }
79
80 if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
81 position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
82 // Disable memory space allocation for these for now.
83 if (position.index == ShapeIndex({0})) {
84 VLOG(4) << "Keeping value " << value->ToShortString()
85 << " in default mem because it is a collective-permute buffer.";
86 return false;
87 } else if (position.index == ShapeIndex({1})) {
88 VLOG(4) << "Keeping value " << value->ToShortString()
89 << " in default mem because it is a collective-permute buffer.";
90 return false;
91 }
92 }
93 if (auto* custom_call =
94 DynCast<HloCustomCallInstruction>(position.instruction)) {
95 for (const auto& pair : custom_call->output_to_operand_aliasing()) {
96 if (position.index == pair.first) {
97 VLOG(4) << "Keeping value " << value->ToShortString()
98 << " in default mem because it is a custom-call output that "
99 "aliases an operand buffer.";
100 return false;
101 }
102 }
103 }
104 }
105
106 return true;
107 }
108
IsIntervalAllowedInAlternateMemory(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)109 bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
110 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
111 return IsValueAllowedInAlternateMemory(interval.buffer) &&
112 absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory);
113 }
114
115 } // namespace xla
116