1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/service/backend.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/compiler/xrt/xrt_refptr.h"
26 #include "tensorflow/compiler/xrt/xrt_state.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/stream_executor/device_memory_allocator.h"
34 #include "tensorflow/stream_executor/stream_executor.h"
35
36 namespace tensorflow {
37
38 // The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase
39 // object which leaves within the ResourceMgr. This is only one XRT memory
40 // manager object within the ResourceMgr container.
41 class XRTMemoryManager : public ResourceBase {
42 // The DeviceContext class, defined and implemented locally inside the
43 // xrt_memory_manager.cc file, holds, for each device, all the information
44 // related to the XRT memory management for such device.
45 class DeviceContext;
46
47 public:
48 // A working set is a set of tuple allocations which are the input of a given
49 // operation, and as such they must be pinned on the device memory. The tuple
50 // allocations added to the WorkingSet will be unpinned at object destruction.
51 class WorkingSet {
52 public:
53 explicit WorkingSet(RefPtr<XRTMemoryManager> memory_manager);
54
55 ~WorkingSet();
56
57 // Looks up the tuple handle within the memory manager, and pins it to the
58 // device (if not already pinned).
59 Status LookupAndPin(xla::Backend* backend, int64 handle);
60
PinnedTuples()61 const std::vector<RefPtr<XRTTupleAllocation>>& PinnedTuples() const {
62 return pinned_tuples_;
63 }
64
MemoryManager()65 const RefPtr<XRTMemoryManager>& MemoryManager() const {
66 return memory_manager_;
67 }
68
69 private:
70 RefPtr<XRTMemoryManager> memory_manager_;
71 std::vector<RefPtr<XRTTupleAllocation>> pinned_tuples_;
72 };
73
74 // Retrieves the XRTMemoryManager singleton stored within the ResourceMgr.
75 static RefPtr<XRTMemoryManager> Get(ResourceMgr* rm);
76
77 // Registers an XRTTupleAllocation and returns the unique handle identifying
78 // it.
79 int64 Register(RefPtr<XRTTupleAllocation> tuple);
80
81 // Looks up an handle returned by the Register() API and returns the
82 // XRTTupleAllocation behind it.
83 xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(int64 handle);
84
Lookup(int64 handle,RefPtr<XRTTupleAllocation> * tuple)85 Status Lookup(int64 handle, RefPtr<XRTTupleAllocation>* tuple) {
86 TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle));
87 return Status::OK();
88 }
89
90 // Releases an handle by dropping the references count held on the
91 // XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation
92 // references will continue to be valid.
93 Status Release(int64 handle);
94
95 // Tries to compact all the memory allocations on a given device. This is
96 // currently done by swapping-out all the existing allocation, and swapping
97 // them back in.
98 Status CompactAllocations(xla::Backend* backend, int device_ordinal);
99
100 // Releases all the device memory allocated by XRT within the resource
101 // manager.
102 void ReleaseAllAllocations();
103
104 // Tries to allocate size bytes of device memory from the device_ordinal
105 // device. Might attempt to free some unpinned device memory, if the underline
106 // allocator call fails, and try the allocation again.
107 xla::StatusOr<se::OwningDeviceMemory> Allocate(xla::Backend* backend,
108 int device_ordinal,
109 size_t size);
110
111 // Runs the specified function and handling the error::RESOURCE_EXHAUSTED
112 // status code coming out of it. In such cases, we run different memory
113 // freeing operations trying to make runfn succeed. The requested_free_size
114 // argument represents an hint of the requested memory size which would make
115 // runfn succeed.
116 template <typename T>
117 xla::StatusOr<T> Run(const std::function<xla::StatusOr<T>()>& runfn,
118 xla::Backend* backend, int device_ordinal,
119 size_t requested_free_size);
120
121 string DebugString() const override;
122
123 // Returns the invalid key value, which will be never generated by the
124 // Intern() API.
InvalidKey()125 static int64 InvalidKey() { return 0; }
126
127 private:
128 // Structure used to track the progress of a try-to-free operation. It is
129 // initialized and the passed to the TryFreeMemoryStep() API.
130 struct MemoryReclaimContext {
MemoryReclaimContextMemoryReclaimContext131 MemoryReclaimContext(xla::Backend* backend, int device_ordinal,
132 size_t requested_free_size)
133 : backend(backend),
134 device_ordinal(device_ordinal),
135 requested_free_size(requested_free_size) {}
136
137 xla::Backend* const backend = nullptr;
138 const int device_ordinal = 0;
139 const size_t requested_free_size = 0;
140 size_t free_size = 0;
141 bool done_freeing = false;
142 bool done_compacting = false;
143 };
144
145 DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing);
146
147 // Called multiple times while trying to make a memory consuming function call
148 // to fit. Performs progressively more expensive memory reduction operations,
149 // until returning error::RESOURCE_EXHAUSTED when no further reductions are
150 // possible.
151 Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status);
152
153 mutex lock_;
154 std::vector<std::unique_ptr<DeviceContext>> device_contexts_;
155 };
156
157 template <typename T>
Run(const std::function<xla::StatusOr<T> ()> & runfn,xla::Backend * backend,int device_ordinal,size_t requested_free_size)158 xla::StatusOr<T> XRTMemoryManager::Run(
159 const std::function<xla::StatusOr<T>()>& runfn, xla::Backend* backend,
160 int device_ordinal, size_t requested_free_size) {
161 MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size);
162 while (true) {
163 // We assume that runfn is a relatively fast-fail function compared to the
164 // operations required to free up the required memory. Here we call into the
165 // TryFreeMemoryStep() API multiple times, which will run progressively more
166 // expensive operations.
167 auto result_or = runfn();
168 if (result_or.status().code() != error::RESOURCE_EXHAUSTED) {
169 return result_or;
170 }
171 TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status()));
172 }
173 }
174
175 } // namespace tensorflow
176
177 #endif // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
178