1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
17
18 #include <utility>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/transfer_manager.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31
32 namespace xla {
33
Register(ScopedShapedBuffer shaped_buffer,const string & tag)34 StatusOr<GlobalDataHandle> AllocationTracker::Register(
35 ScopedShapedBuffer shaped_buffer, const string& tag) {
36 tensorflow::mutex_lock lock(mutex_);
37 VLOG(2) << "Register";
38 std::vector<ScopedShapedBuffer> replicated_buffers;
39 replicated_buffers.emplace_back(std::move(shaped_buffer));
40 return RegisterInternal(std::move(replicated_buffers), tag);
41 }
42
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)43 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
44 std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
45 tensorflow::mutex_lock lock(mutex_);
46 VLOG(2) << "RegisterReplicatedBuffers";
47 return RegisterInternal(std::move(replicated_buffers), tag);
48 }
49
50 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
51 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
52 // unmodified.
ReleaseIfScopedShapedBuffer(ShapedBuffer b)53 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b)54 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
55 return b.release();
56 }
57
58 template <typename ShapedBufferTy>
RegisterInternal(std::vector<ShapedBufferTy> replicated_buffers,const string & tag)59 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
60 std::vector<ShapedBufferTy> replicated_buffers, const string& tag) {
61 static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
62 std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
63 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
64 VLOG(2) << "RegisterInternal("
65 << "tag: \"" << tag << "\" with " << replicated_buffers.size()
66 << " shaped_buffers.";
67
68 int64 handle = next_handle_++;
69 for (auto& shaped_buffer : replicated_buffers) {
70 std::vector<ShapeIndex> shape_indices;
71 ShapeUtil::ForEachSubshape(
72 shaped_buffer.on_device_shape(),
73 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
74 shape_indices.push_back(index);
75 });
76 // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
77 // them.
78 for (const ShapeIndex& index : shape_indices) {
79 AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
80 shaped_buffer.device_ordinal());
81 }
82 // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
83 // into a regular ShapedBuffer, which is stored in
84 // handle_to_shaped_buffers_.
85 handle_to_shaped_buffers_[handle].emplace_back(
86 absl::make_unique<ShapedBuffer>(
87 ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
88 }
89
90 GlobalDataHandle result;
91 result.set_handle(handle);
92 VLOG(2) << "handle: " << handle;
93 return result;
94 }
95
Unregister(const GlobalDataHandle & data)96 Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
97 tensorflow::mutex_lock lock(mutex_);
98 VLOG(2) << "Unregister("
99 << "handle: " << data.handle() << ")";
100 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
101 ResolveInternal(data));
102 for (const auto& shaped_buffer : replicated_buffers) {
103 std::vector<ShapeIndex> shape_indices;
104 ShapeUtil::ForEachSubshape(
105 shaped_buffer->on_device_shape(),
106 [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
107 shape_indices.push_back(index);
108 });
109 for (const ShapeIndex& index : shape_indices) {
110 TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
111 shaped_buffer->device_ordinal()));
112 }
113 }
114 // Keep a nullptr as a tombstone for unregistered handles. This enables
115 // better error messages. That is, "handle has been deallocated" versus
116 // "handle does not exist".
117 auto it = handle_to_shaped_buffers_.find(data.handle());
118 if (it == handle_to_shaped_buffers_.end()) {
119 return NotFound("no allocation record for global data handle: %d",
120 data.handle());
121 }
122 for (auto& shaped_buffer : it->second) {
123 shaped_buffer.reset();
124 }
125 return Status::OK();
126 }
127
DeconstructTuple(const GlobalDataHandle & data)128 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
129 const GlobalDataHandle& data) {
130 tensorflow::mutex_lock lock(mutex_);
131
132 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
133 ResolveInternal(data));
134 // We only need to care about replica id 0 here, since the GlobalDataHandle is
135 // the same for all buffers across replicas.
136 const ShapedBuffer* shaped_buffer = replicated_buffers[0];
137 if (!shaped_buffer->on_device_shape().IsTuple()) {
138 return InvalidArgument("global data handle %d is not a tuple",
139 data.handle());
140 }
141
142 if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
143 return Unimplemented("Deconstructing nested tuples is not implemented.");
144 }
145
146 std::vector<GlobalDataHandle> element_handles;
147 for (int i = 0;
148 i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
149 ++i) {
150 auto element_buffer = ShapedBuffer(
151 ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
152 shaped_buffer->device_ordinal());
153 element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
154 /*index=*/{});
155 std::vector<ShapedBuffer> replicated_buffers;
156 replicated_buffers.push_back(std::move(element_buffer));
157 TF_ASSIGN_OR_RETURN(
158 GlobalDataHandle element_handle,
159 RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
160
161 element_handles.push_back(element_handle);
162 }
163 return std::move(element_handles);
164 }
165
Resolve(const GlobalDataHandle & data) const166 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
167 const GlobalDataHandle& data) const {
168 tensorflow::mutex_lock lock(mutex_);
169 return AllocationTracker::ResolveInternal(data);
170 }
171
ResolveForReplica(const GlobalDataHandle & data,int replica_id) const172 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
173 const GlobalDataHandle& data, int replica_id) const {
174 tensorflow::mutex_lock lock(mutex_);
175 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
176 ResolveInternal(data));
177 if (replica_id >= replicated_buffers.size()) {
178 return InvalidArgument(
179 "Requesting buffer for replica %d, but found buffers only for %lu "
180 "replicas.",
181 replica_id, replicated_buffers.size());
182 }
183 return replicated_buffers[replica_id];
184 }
185
ResolveInternal(const GlobalDataHandle & data) const186 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
187 const GlobalDataHandle& data) const {
188 VLOG(2) << "resolve:" << data.handle();
189 auto it = handle_to_shaped_buffers_.find(data.handle());
190 if (it == handle_to_shaped_buffers_.end()) {
191 return NotFound("no allocation record for global data handle: %d",
192 data.handle());
193 }
194 std::vector<const ShapedBuffer*> replicated_buffers;
195 for (const auto& shaped_buffer : it->second) {
196 if (shaped_buffer == nullptr) {
197 return InvalidArgument("global data handle %d was previously deallocated",
198 data.handle());
199 }
200 replicated_buffers.push_back(shaped_buffer.get());
201 }
202
203 return replicated_buffers;
204 }
205
AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)206 void AllocationTracker::AddAllocationOrIncrementRefCount(
207 se::DeviceMemoryBase device_memory, int device_ordinal) {
208 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
209 auto it = allocation_map.find(device_memory.opaque());
210 if (it == allocation_map.end()) {
211 allocation_map[device_memory.opaque()] = {
212 se::OwningDeviceMemory(device_memory, device_ordinal,
213 backend_->memory_allocator()),
214 /*ref_count=*/1};
215 } else {
216 it->second.ref_count++;
217 }
218 }
219
DecrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)220 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
221 int device_ordinal) {
222 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
223 auto it = allocation_map.find(device_memory.opaque());
224 TF_RET_CHECK(it != allocation_map.end());
225 Allocation& allocation = it->second;
226 TF_RET_CHECK(allocation.ref_count >= 1);
227 if (allocation.ref_count == 1) {
228 TF_RETURN_IF_ERROR(allocation.device_memory.Free());
229 allocation_map.erase(it);
230 } else {
231 allocation.ref_count--;
232 }
233 return Status::OK();
234 }
235
236 } // namespace xla
237