1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <unordered_map>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xrt/xrt_metrics.h"
24 #include "tensorflow/core/lib/monitoring/timed.h"
25 #include "tensorflow/core/lib/random/random.h"
26 #include "tensorflow/core/profiler/lib/traceme.h"
27
28 namespace tensorflow {
29 namespace {
30
31 // We use kDeviceBits to store the device ordinal in the handle. We store the
32 // device in the upper part of the int64 handle to make sure the random bits are
33 // in the lower part which is better when storing the handle as a key for
34 // unordered maps.
35 const int kDeviceBits = 12;
36
MakeDeviceHandle(int64 device_ordinal,int64 rnd_value)37 int64 MakeDeviceHandle(int64 device_ordinal, int64 rnd_value) {
38 const int64 kUidMask = (static_cast<int64>(1) << (64 - kDeviceBits)) - 1;
39 return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask);
40 }
41
GetDeviceFromHandle(int64 handle)42 int GetDeviceFromHandle(int64 handle) {
43 return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1);
44 }
45
46 } // namespace
47
48 class XRTMemoryManager::DeviceContext {
49 struct Alloc {
Alloctensorflow::XRTMemoryManager::DeviceContext::Alloc50 explicit Alloc(RefPtr<XRTTupleAllocation> tuple)
51 : tuple(std::move(tuple)) {}
52
53 RefPtr<XRTTupleAllocation> tuple;
54 };
55
56 using AllocList = std::list<Alloc>;
57
58 public:
Register(RefPtr<XRTTupleAllocation> tuple)59 int64 Register(RefPtr<XRTTupleAllocation> tuple) {
60 while (true) {
61 int64 handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid());
62 mutex_lock lock(lock_);
63 allocs_.emplace_front(tuple);
64 if (alloc_map_.emplace(handle, allocs_.begin()).second) {
65 return handle;
66 }
67 // The chances of hitting an existing handle are so remote, it is much
68 // more convenient to add to the list before, and eventually removing.
69 allocs_.erase(allocs_.begin());
70 }
71 }
72
Release(int64 handle)73 bool Release(int64 handle) {
74 mutex_lock lock(lock_);
75 auto it = alloc_map_.find(handle);
76 if (it == alloc_map_.end()) {
77 return false;
78 }
79 allocs_.erase(it->second);
80 alloc_map_.erase(it);
81 return true;
82 }
83
Lookup(int64 handle)84 RefPtr<XRTTupleAllocation> Lookup(int64 handle) {
85 mutex_lock lock(lock_);
86 auto it = alloc_map_.find(handle);
87 if (it == alloc_map_.end()) {
88 return nullptr;
89 }
90 // LRU
91 allocs_.splice(allocs_.begin(), allocs_, it->second);
92 return it->second->tuple;
93 }
94
Clear()95 void Clear() {
96 mutex_lock lock(lock_);
97 alloc_map_.clear();
98 allocs_.clear();
99 }
100
CompactAllocations(XRTMemoryManager * memory_manager,xla::Backend * backend)101 Status CompactAllocations(XRTMemoryManager* memory_manager,
102 xla::Backend* backend) {
103 profiler::TraceMe trace_me("XRTMemoryManager::CompactAllocations",
104 /*level=*/2);
105 auto timed = monitoring::MakeTimed(xrt_metrics::GetMemoryCompactCell());
106 VLOG(4) << "CompactAllocations started";
107 mutex_lock lock(lock_);
108 Status status;
109 std::vector<AllocList::iterator> swapped;
110 // We are swapping out from the most recently used allocations. This is
111 // desirable since the most recently used will be finding themselves at the
112 // bottom of the allocation space. Since these are more likely to be pinned
113 // allocations, a further trim done by following TryFreeMemory() call will
114 // eventually drop the higher located allocations, with better chance of
115 // reducing fragmentation.
116 // Also, by swapping out the pinned allocations first, those will also be
117 // the first to be restored, and hence if we will ever find OOM on the way
118 // out, we would more likely be swapping in not pinned ones.
119 for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
120 // We are compacting all the allocations, so we will temporarily swap out
121 // even pinned allocations.
122 auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true);
123 if (!swap_result_or.ok()) {
124 status = swap_result_or.status();
125 break;
126 }
127 if (swap_result_or.ValueOrDie()) {
128 swapped.push_back(it);
129 }
130 }
131 // At this point we have released all the device memory we could release.
132 // Load back the tuple allocations we have swapped out above.
133 for (auto& it : swapped) {
134 auto swap_result_or = it->tuple->SwapIn(memory_manager, backend);
135 if (!swap_result_or.ok()) {
136 // If we failed to restored a pinned allocation, better to CHECK here
137 // than wondering why XRTTupleAllocation calls fail with errors about
138 // missing buffers.
139 CHECK(!it->tuple->IsPinned()); // Crash OK
140 if (status.ok()) {
141 status = swap_result_or.status();
142 }
143 }
144 }
145 VLOG(4) << "CompactAllocations finished: " << status;
146 return status;
147 }
148
149 // Tries to free size bytes by freeing some unpinned device memory. Returns
150 // the amount of memory which was able to free.
TryFreeMemory(xla::Backend * backend,size_t size)151 xla::StatusOr<size_t> TryFreeMemory(xla::Backend* backend, size_t size) {
152 profiler::TraceMe trace_me("XRTMemoryManager::TryFreeMemory", /*level=*/2);
153 auto timed = monitoring::MakeTimed(xrt_metrics::GetTryFreeMemoryCell());
154 mutex_lock lock(lock_);
155 size_t swapped_size = 0;
156 for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
157 TF_ASSIGN_OR_RETURN(bool swap_result,
158 it->tuple->SwapOut(backend, /*swap_pinned=*/false));
159 if (swap_result) {
160 swapped_size += it->tuple->GetDeviceMemorySize();
161 if (swapped_size >= size) {
162 break;
163 }
164 }
165 }
166 VLOG(3) << "Swapped out " << swapped_size << " bytes";
167 return swapped_size;
168 }
169
170 private:
CreateUid()171 static int64 CreateUid() {
172 int64 uid;
173 do {
174 uid = random::New64() & INT64_MAX;
175 } while (uid == InvalidKey());
176 return uid;
177 }
178
179 // We store Alloc records inside an std::list<Alloc> so we can LRU it, and
180 // store the list iterators within the handle map, as list iterators don't get
181 // invalidated by (other elements) removals or position swaps.
182 mutex lock_;
183 AllocList allocs_;
184 std::unordered_map<int64, AllocList::iterator> alloc_map_;
185 };
186
WorkingSet(RefPtr<XRTMemoryManager> memory_manager)187 XRTMemoryManager::WorkingSet::WorkingSet(
188 RefPtr<XRTMemoryManager> memory_manager)
189 : memory_manager_(std::move(memory_manager)) {}
190
~WorkingSet()191 XRTMemoryManager::WorkingSet::~WorkingSet() {
192 for (auto& tuple : pinned_tuples_) {
193 tuple->Unpin();
194 }
195 }
196
LookupAndPin(xla::Backend * backend,int64 handle)197 Status XRTMemoryManager::WorkingSet::LookupAndPin(xla::Backend* backend,
198 int64 handle) {
199 TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle));
200 TF_RETURN_IF_ERROR(
201 tuple->PinAndSwapIn(memory_manager_.get(), backend).status());
202 pinned_tuples_.push_back(std::move(tuple));
203 return Status::OK();
204 }
205
Get(ResourceMgr * rm)206 /* static */ RefPtr<XRTMemoryManager> XRTMemoryManager::Get(ResourceMgr* rm) {
207 static string* container = new string("XrtState");
208 static string* name = new string("MemoryManager");
209 XRTMemoryManager* memory_manager = nullptr;
210 TF_CHECK_OK(rm->LookupOrCreate<XRTMemoryManager>(
211 *container, *name, &memory_manager, [](XRTMemoryManager** ret) {
212 *ret = new XRTMemoryManager();
213 return Status::OK();
214 }));
215 return memory_manager;
216 }
217
Register(RefPtr<XRTTupleAllocation> tuple)218 int64 XRTMemoryManager::Register(RefPtr<XRTTupleAllocation> tuple) {
219 DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(),
220 /*create_if_missing=*/true);
221 return device_context->Register(std::move(tuple));
222 }
223
Lookup(int64 handle)224 xla::StatusOr<RefPtr<XRTTupleAllocation>> XRTMemoryManager::Lookup(
225 int64 handle) {
226 int device_ordinal = GetDeviceFromHandle(handle);
227 DeviceContext* device_context = GetDeviceContext(device_ordinal,
228 /*create_if_missing=*/false);
229 if (device_context == nullptr) {
230 return errors::NotFound("XRT memory handle not found: ", handle);
231 }
232 RefPtr<XRTTupleAllocation> tuple = device_context->Lookup(handle);
233 if (tuple == nullptr) {
234 return errors::NotFound("XRT memory handle not found: ", handle);
235 }
236 return std::move(tuple);
237 }
238
Release(int64 handle)239 Status XRTMemoryManager::Release(int64 handle) {
240 int device_ordinal = GetDeviceFromHandle(handle);
241 DeviceContext* device_context = GetDeviceContext(device_ordinal,
242 /*create_if_missing=*/false);
243 if (device_context == nullptr || !device_context->Release(handle)) {
244 return errors::NotFound("XRT memory handle not found: ", handle);
245 }
246 return Status::OK();
247 }
248
CompactAllocations(xla::Backend * backend,int device_ordinal)249 Status XRTMemoryManager::CompactAllocations(xla::Backend* backend,
250 int device_ordinal) {
251 DeviceContext* device_context = GetDeviceContext(device_ordinal,
252 /*create_if_missing=*/false);
253 return device_context != nullptr
254 ? device_context->CompactAllocations(this, backend)
255 : Status::OK();
256 }
257
ReleaseAllAllocations()258 void XRTMemoryManager::ReleaseAllAllocations() {
259 mutex_lock lock(lock_);
260 for (auto& device_context : device_contexts_) {
261 if (device_context != nullptr) {
262 device_context->Clear();
263 }
264 }
265 }
266
Allocate(xla::Backend * backend,int device_ordinal,size_t size)267 xla::StatusOr<se::OwningDeviceMemory> XRTMemoryManager::Allocate(
268 xla::Backend* backend, int device_ordinal, size_t size) {
269 se::DeviceMemoryAllocator* allocator = backend->memory_allocator();
270 auto memory_or =
271 allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false);
272 if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) {
273 VLOG(4) << "Allocate of " << size << " bytes failed on device "
274 << device_ordinal;
275
276 DeviceContext* device_context =
277 GetDeviceContext(device_ordinal,
278 /*create_if_missing=*/false);
279 if (device_context != nullptr) {
280 Status status = device_context->TryFreeMemory(backend, size).status();
281 if (status.ok()) {
282 // As long as there is no error, we still try again the allocation, even
283 // if the TryFreeMemory() call ended up freeing less memory than the
284 // required size. Fragmentation could make the memory allocation succeed
285 // even if the freed memory is indeed lower.
286 memory_or = allocator->Allocate(device_ordinal, size,
287 /*retry_on_failure=*/false);
288 } else if (status.code() != error::RESOURCE_EXHAUSTED) {
289 VLOG(4) << "Allocate of " << size << " bytes on device "
290 << device_ordinal << ": " << status;
291 return status;
292 }
293 }
294 }
295 return memory_or;
296 }
297
DebugString() const298 string XRTMemoryManager::DebugString() const {
299 // We might want to emit more detailed information here, like per device
300 // memory allocations.
301 return "XRTMemoryManager";
302 }
303
GetDeviceContext(int device_ordinal,bool create_if_missing)304 XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext(
305 int device_ordinal, bool create_if_missing) {
306 mutex_lock lock(lock_);
307 if (device_ordinal >= device_contexts_.size()) {
308 if (!create_if_missing) {
309 return nullptr;
310 }
311 device_contexts_.resize(device_ordinal + 1);
312 }
313 DeviceContext* device_context = device_contexts_[device_ordinal].get();
314 if (device_context == nullptr && create_if_missing) {
315 device_contexts_[device_ordinal] = absl::make_unique<DeviceContext>();
316 device_context = device_contexts_[device_ordinal].get();
317 }
318 return device_context;
319 }
320
TryFreeMemoryStep(MemoryReclaimContext * mrctx,const Status & status)321 Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx,
322 const Status& status) {
323 DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal,
324 /*create_if_missing=*/false);
325 if (device_context == nullptr) {
326 return status;
327 }
328 if (!mrctx->done_freeing) {
329 // If the caller passed us a zero requested_free_size, we try to free chunks
330 // of kMaxFreeSize memory, until either the run function succeeds, or we run
331 // out of freeable memory.
332 const size_t kMaxFreeSize = 1000000000;
333 size_t free_size =
334 (mrctx->requested_free_size > 0)
335 ? std::min<size_t>(mrctx->requested_free_size - mrctx->free_size,
336 kMaxFreeSize)
337 : kMaxFreeSize;
338 if (free_size > 0) {
339 auto free_size_or =
340 device_context->TryFreeMemory(mrctx->backend, free_size);
341 if (!free_size_or.ok()) {
342 return status;
343 }
344 size_t size = free_size_or.ValueOrDie();
345 mrctx->free_size += size;
346 if (size > 0) {
347 return Status::OK();
348 }
349 }
350 mrctx->done_freeing = true;
351 }
352 if (!mrctx->done_compacting) {
353 mrctx->done_compacting = true;
354 if (device_context->CompactAllocations(this, mrctx->backend).ok()) {
355 return Status::OK();
356 }
357 }
358 return status;
359 }
360
361 } // namespace tensorflow
362