1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
17
18 #include <utility>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
24 #include "tensorflow/compiler/xla/service/transfer_manager.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace xla {
32
Register(ScopedShapedBuffer shaped_buffer,const string & tag)33 StatusOr<GlobalDataHandle> AllocationTracker::Register(
34 ScopedShapedBuffer shaped_buffer, const string& tag) {
35 tensorflow::mutex_lock lock(mutex_);
36 VLOG(2) << "Register";
37 std::vector<ScopedShapedBuffer> replicated_buffers;
38 replicated_buffers.emplace_back(std::move(shaped_buffer));
39 return RegisterInternal(std::move(replicated_buffers), tag);
40 }
41
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)42 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
43 std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
44 tensorflow::mutex_lock lock(mutex_);
45 VLOG(2) << "RegisterReplicatedBuffers";
46 return RegisterInternal(std::move(replicated_buffers), tag);
47 }
48
49 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
50 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
51 // unmodified.
ReleaseIfScopedShapedBuffer(ShapedBuffer b)52 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b)53 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
54 return b.release();
55 }
56
57 template <typename ShapedBufferTy>
RegisterInternal(std::vector<ShapedBufferTy> replicated_buffers,const string & tag)58 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
59 std::vector<ShapedBufferTy> replicated_buffers, const string& tag) {
60 static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
61 std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
62 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
63 VLOG(2) << "RegisterInternal("
64 << "tag: \"" << tag << "\" with " << replicated_buffers.size()
65 << " shaped_buffers.";
66 for (const auto& shaped_buffer : replicated_buffers) {
67 VLOG(2) << "shaped_buffer:" << shaped_buffer;
68 if (shaped_buffer.platform() != backend_->platform()) {
69 return InvalidArgument(
70 "AllocationTracker for platform %s cannot register buffer from "
71 "platform %s",
72 backend_->platform()->Name(), shaped_buffer.platform()->Name());
73 }
74 }
75
76 int64 handle = next_handle_++;
77 for (auto& shaped_buffer : replicated_buffers) {
78 std::vector<ShapeIndex> shape_indices;
79 ShapeUtil::ForEachSubshape(
80 shaped_buffer.on_device_shape(),
81 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
82 shape_indices.push_back(index);
83 });
84 // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
85 // them.
86 for (const ShapeIndex& index : shape_indices) {
87 AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
88 shaped_buffer.device_ordinal());
89 }
90 // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
91 // into a regular ShapedBuffer, which is stored in
92 // handle_to_shaped_buffers_.
93 handle_to_shaped_buffers_[handle].emplace_back(
94 absl::make_unique<ShapedBuffer>(
95 ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
96 }
97
98 GlobalDataHandle result;
99 result.set_handle(handle);
100 VLOG(2) << "handle: " << handle;
101 return result;
102 }
103
Unregister(const GlobalDataHandle & data)104 Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
105 tensorflow::mutex_lock lock(mutex_);
106 VLOG(2) << "Unregister("
107 << "handle: " << data.handle() << ")";
108 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
109 ResolveInternal(data));
110 for (const auto& shaped_buffer : replicated_buffers) {
111 std::vector<ShapeIndex> shape_indices;
112 ShapeUtil::ForEachSubshape(
113 shaped_buffer->on_device_shape(),
114 [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
115 shape_indices.push_back(index);
116 });
117 for (const ShapeIndex& index : shape_indices) {
118 TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
119 shaped_buffer->device_ordinal()));
120 }
121 }
122 // Keep a nullptr as a tombstone for unregistered handles. This enables
123 // better error messages. That is, "handle has been deallocated" versus
124 // "handle does not exist".
125 auto it = handle_to_shaped_buffers_.find(data.handle());
126 if (it == handle_to_shaped_buffers_.end()) {
127 return NotFound("no allocation record for global data handle: %d",
128 data.handle());
129 }
130 for (auto& shaped_buffer : it->second) {
131 shaped_buffer.reset();
132 }
133 return Status::OK();
134 }
135
DeconstructTuple(const GlobalDataHandle & data)136 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
137 const GlobalDataHandle& data) {
138 tensorflow::mutex_lock lock(mutex_);
139
140 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
141 ResolveInternal(data));
142 // We only need to care about replica id 0 here, since the GlobalDataHandle is
143 // the same for all buffers across replicas.
144 const ShapedBuffer* shaped_buffer = replicated_buffers[0];
145 if (!shaped_buffer->on_host_shape().IsTuple()) {
146 return InvalidArgument("global data handle %d is not a tuple",
147 data.handle());
148 }
149 // If the on-host representation is a tuple, then the on-device one should be
150 // as well.
151 TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple());
152
153 if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
154 return Unimplemented("Deconstructing nested tuples is not implemented.");
155 }
156
157 std::vector<GlobalDataHandle> element_handles;
158 for (int i = 0;
159 i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
160 ++i) {
161 auto element_buffer = ShapedBuffer(
162 ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
163 ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
164 shaped_buffer->platform(), shaped_buffer->device_ordinal());
165 element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
166 /*index=*/{});
167 std::vector<ShapedBuffer> replicated_buffers;
168 replicated_buffers.push_back(std::move(element_buffer));
169 TF_ASSIGN_OR_RETURN(
170 GlobalDataHandle element_handle,
171 RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
172
173 element_handles.push_back(element_handle);
174 }
175 return std::move(element_handles);
176 }
177
Resolve(const GlobalDataHandle & data) const178 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
179 const GlobalDataHandle& data) const {
180 tensorflow::mutex_lock lock(mutex_);
181 return AllocationTracker::ResolveInternal(data);
182 }
183
ResolveForReplica(const GlobalDataHandle & data,int replica_id) const184 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
185 const GlobalDataHandle& data, int replica_id) const {
186 tensorflow::mutex_lock lock(mutex_);
187 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
188 ResolveInternal(data));
189 if (replica_id >= replicated_buffers.size()) {
190 return InvalidArgument(
191 "Requesting buffer for replica %d, but found buffers only for %lu "
192 "replicas.",
193 replica_id, replicated_buffers.size());
194 }
195 return replicated_buffers[replica_id];
196 }
197
ResolveInternal(const GlobalDataHandle & data) const198 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
199 const GlobalDataHandle& data) const {
200 VLOG(2) << "resolve:" << data.handle();
201 auto it = handle_to_shaped_buffers_.find(data.handle());
202 if (it == handle_to_shaped_buffers_.end()) {
203 return NotFound("no allocation record for global data handle: %d",
204 data.handle());
205 }
206 std::vector<const ShapedBuffer*> replicated_buffers;
207 for (const auto& shaped_buffer : it->second) {
208 if (shaped_buffer == nullptr) {
209 return InvalidArgument("global data handle %d was previously deallocated",
210 data.handle());
211 }
212 replicated_buffers.push_back(shaped_buffer.get());
213 }
214
215 return replicated_buffers;
216 }
217
AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)218 void AllocationTracker::AddAllocationOrIncrementRefCount(
219 se::DeviceMemoryBase device_memory, int device_ordinal) {
220 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
221 auto it = allocation_map.find(device_memory.opaque());
222 if (it == allocation_map.end()) {
223 allocation_map[device_memory.opaque()] = {
224 OwningDeviceMemory(device_memory, device_ordinal,
225 backend_->memory_allocator()),
226 /*ref_count=*/1};
227 } else {
228 it->second.ref_count++;
229 }
230 }
231
DecrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)232 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
233 int device_ordinal) {
234 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
235 auto it = allocation_map.find(device_memory.opaque());
236 TF_RET_CHECK(it != allocation_map.end());
237 Allocation& allocation = it->second;
238 TF_RET_CHECK(allocation.ref_count >= 1);
239 if (allocation.ref_count == 1) {
240 allocation.device_memory.Free();
241 allocation_map.erase(it);
242 } else {
243 allocation.ref_count--;
244 }
245 return Status::OK();
246 }
247
248 } // namespace xla
249