• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/transfer_manager.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31 
32 namespace xla {
33 
Register(ScopedShapedBuffer shaped_buffer,const string & tag)34 StatusOr<GlobalDataHandle> AllocationTracker::Register(
35     ScopedShapedBuffer shaped_buffer, const string& tag) {
36   tensorflow::mutex_lock lock(mutex_);
37   VLOG(2) << "Register";
38   std::vector<ScopedShapedBuffer> replicated_buffers;
39   replicated_buffers.emplace_back(std::move(shaped_buffer));
40   return RegisterInternal(std::move(replicated_buffers), tag);
41 }
42 
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)43 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
44     std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
45   tensorflow::mutex_lock lock(mutex_);
46   VLOG(2) << "RegisterReplicatedBuffers";
47   return RegisterInternal(std::move(replicated_buffers), tag);
48 }
49 
50 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
51 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
52 // unmodified.
ReleaseIfScopedShapedBuffer(ShapedBuffer b)53 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b)54 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
55   return b.release();
56 }
57 
58 template <typename ShapedBufferTy>
RegisterInternal(std::vector<ShapedBufferTy> replicated_buffers,const string & tag)59 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
60     std::vector<ShapedBufferTy> replicated_buffers, const string& tag) {
61   static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
62                     std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
63                 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
64   VLOG(2) << "RegisterInternal("
65           << "tag: \"" << tag << "\" with " << replicated_buffers.size()
66           << " shaped_buffers.";
67 
68   int64 handle = next_handle_++;
69   for (auto& shaped_buffer : replicated_buffers) {
70     std::vector<ShapeIndex> shape_indices;
71     ShapeUtil::ForEachSubshape(
72         shaped_buffer.on_device_shape(),
73         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
74           shape_indices.push_back(index);
75         });
76     // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
77     // them.
78     for (const ShapeIndex& index : shape_indices) {
79       AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
80                                        shaped_buffer.device_ordinal());
81     }
82     // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
83     // into a regular ShapedBuffer, which is stored in
84     // handle_to_shaped_buffers_.
85     handle_to_shaped_buffers_[handle].emplace_back(
86         absl::make_unique<ShapedBuffer>(
87             ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
88   }
89 
90   GlobalDataHandle result;
91   result.set_handle(handle);
92   VLOG(2) << "handle: " << handle;
93   return result;
94 }
95 
Unregister(const GlobalDataHandle & data)96 Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
97   tensorflow::mutex_lock lock(mutex_);
98   VLOG(2) << "Unregister("
99           << "handle: " << data.handle() << ")";
100   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
101                       ResolveInternal(data));
102   for (const auto& shaped_buffer : replicated_buffers) {
103     std::vector<ShapeIndex> shape_indices;
104     ShapeUtil::ForEachSubshape(
105         shaped_buffer->on_device_shape(),
106         [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
107           shape_indices.push_back(index);
108         });
109     for (const ShapeIndex& index : shape_indices) {
110       TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
111                                            shaped_buffer->device_ordinal()));
112     }
113   }
114   // Keep a nullptr as a tombstone for unregistered handles. This enables
115   // better error messages. That is, "handle has been deallocated" versus
116   // "handle does not exist".
117   auto it = handle_to_shaped_buffers_.find(data.handle());
118   if (it == handle_to_shaped_buffers_.end()) {
119     return NotFound("no allocation record for global data handle: %d",
120                     data.handle());
121   }
122   for (auto& shaped_buffer : it->second) {
123     shaped_buffer.reset();
124   }
125   return Status::OK();
126 }
127 
DeconstructTuple(const GlobalDataHandle & data)128 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
129     const GlobalDataHandle& data) {
130   tensorflow::mutex_lock lock(mutex_);
131 
132   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
133                       ResolveInternal(data));
134   // We only need to care about replica id 0 here, since the GlobalDataHandle is
135   // the same for all buffers across replicas.
136   const ShapedBuffer* shaped_buffer = replicated_buffers[0];
137   if (!shaped_buffer->on_device_shape().IsTuple()) {
138     return InvalidArgument("global data handle %d is not a tuple",
139                            data.handle());
140   }
141 
142   if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
143     return Unimplemented("Deconstructing nested tuples is not implemented.");
144   }
145 
146   std::vector<GlobalDataHandle> element_handles;
147   for (int i = 0;
148        i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
149        ++i) {
150     auto element_buffer = ShapedBuffer(
151         ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
152         shaped_buffer->device_ordinal());
153     element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
154                               /*index=*/{});
155     std::vector<ShapedBuffer> replicated_buffers;
156     replicated_buffers.push_back(std::move(element_buffer));
157     TF_ASSIGN_OR_RETURN(
158         GlobalDataHandle element_handle,
159         RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
160 
161     element_handles.push_back(element_handle);
162   }
163   return std::move(element_handles);
164 }
165 
Resolve(const GlobalDataHandle & data) const166 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
167     const GlobalDataHandle& data) const {
168   tensorflow::mutex_lock lock(mutex_);
169   return AllocationTracker::ResolveInternal(data);
170 }
171 
ResolveForReplica(const GlobalDataHandle & data,int replica_id) const172 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
173     const GlobalDataHandle& data, int replica_id) const {
174   tensorflow::mutex_lock lock(mutex_);
175   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
176                       ResolveInternal(data));
177   if (replica_id >= replicated_buffers.size()) {
178     return InvalidArgument(
179         "Requesting buffer for replica %d, but found buffers only for %lu "
180         "replicas.",
181         replica_id, replicated_buffers.size());
182   }
183   return replicated_buffers[replica_id];
184 }
185 
ResolveInternal(const GlobalDataHandle & data) const186 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
187     const GlobalDataHandle& data) const {
188   VLOG(2) << "resolve:" << data.handle();
189   auto it = handle_to_shaped_buffers_.find(data.handle());
190   if (it == handle_to_shaped_buffers_.end()) {
191     return NotFound("no allocation record for global data handle: %d",
192                     data.handle());
193   }
194   std::vector<const ShapedBuffer*> replicated_buffers;
195   for (const auto& shaped_buffer : it->second) {
196     if (shaped_buffer == nullptr) {
197       return InvalidArgument("global data handle %d was previously deallocated",
198                              data.handle());
199     }
200     replicated_buffers.push_back(shaped_buffer.get());
201   }
202 
203   return replicated_buffers;
204 }
205 
AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)206 void AllocationTracker::AddAllocationOrIncrementRefCount(
207     se::DeviceMemoryBase device_memory, int device_ordinal) {
208   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
209   auto it = allocation_map.find(device_memory.opaque());
210   if (it == allocation_map.end()) {
211     allocation_map[device_memory.opaque()] = {
212         se::OwningDeviceMemory(device_memory, device_ordinal,
213                                backend_->memory_allocator()),
214         /*ref_count=*/1};
215   } else {
216     it->second.ref_count++;
217   }
218 }
219 
DecrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)220 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
221                                             int device_ordinal) {
222   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
223   auto it = allocation_map.find(device_memory.opaque());
224   TF_RET_CHECK(it != allocation_map.end());
225   Allocation& allocation = it->second;
226   TF_RET_CHECK(allocation.ref_count >= 1);
227   if (allocation.ref_count == 1) {
228     TF_RETURN_IF_ERROR(allocation.device_memory.Free());
229     allocation_map.erase(it);
230   } else {
231     allocation.ref_count--;
232   }
233   return Status::OK();
234 }
235 
236 }  // namespace xla
237