• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/service/backend.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/compiler/xrt/xrt_refptr.h"
26 #include "tensorflow/compiler/xrt/xrt_state.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/stream_executor/device_memory_allocator.h"
34 #include "tensorflow/stream_executor/stream_executor.h"
35 
36 namespace tensorflow {
37 
38 // The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase
39 // object which leaves within the ResourceMgr. This is only one XRT memory
40 // manager object within the ResourceMgr container.
41 class XRTMemoryManager : public ResourceBase {
42   // The DeviceContext class, defined and implemented locally inside the
43   // xrt_memory_manager.cc file, holds, for each device, all the information
44   // related to the XRT memory management for such device.
45   class DeviceContext;
46 
47  public:
48   // A working set is a set of tuple allocations which are the input of a given
49   // operation, and as such they must be pinned on the device memory. The tuple
50   // allocations added to the WorkingSet will be unpinned at object destruction.
51   class WorkingSet {
52    public:
53     explicit WorkingSet(RefPtr<XRTMemoryManager> memory_manager);
54 
55     ~WorkingSet();
56 
57     // Looks up the tuple handle within the memory manager, and pins it to the
58     // device (if not already pinned).
59     Status LookupAndPin(xla::Backend* backend, int64 handle);
60 
PinnedTuples()61     const std::vector<RefPtr<XRTTupleAllocation>>& PinnedTuples() const {
62       return pinned_tuples_;
63     }
64 
MemoryManager()65     const RefPtr<XRTMemoryManager>& MemoryManager() const {
66       return memory_manager_;
67     }
68 
69    private:
70     RefPtr<XRTMemoryManager> memory_manager_;
71     std::vector<RefPtr<XRTTupleAllocation>> pinned_tuples_;
72   };
73 
74   // Retrieves the XRTMemoryManager singleton stored within the ResourceMgr.
75   static RefPtr<XRTMemoryManager> Get(ResourceMgr* rm);
76 
77   // Registers an XRTTupleAllocation and returns the unique handle identifying
78   // it.
79   int64 Register(RefPtr<XRTTupleAllocation> tuple);
80 
81   // Looks up an handle returned by the Register() API and returns the
82   // XRTTupleAllocation behind it.
83   xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(int64 handle);
84 
Lookup(int64 handle,RefPtr<XRTTupleAllocation> * tuple)85   Status Lookup(int64 handle, RefPtr<XRTTupleAllocation>* tuple) {
86     TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle));
87     return Status::OK();
88   }
89 
90   // Releases an handle by dropping the references count held on the
91   // XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation
92   // references will continue to be valid.
93   Status Release(int64 handle);
94 
95   // Tries to compact all the memory allocations on a given device. This is
96   // currently done by swapping-out all the existing allocation, and swapping
97   // them back in.
98   Status CompactAllocations(xla::Backend* backend, int device_ordinal);
99 
100   // Releases all the device memory allocated by XRT within the resource
101   // manager.
102   void ReleaseAllAllocations();
103 
104   // Tries to allocate size bytes of device memory from the device_ordinal
105   // device. Might attempt to free some unpinned device memory, if the underline
106   // allocator call fails, and try the allocation again.
107   xla::StatusOr<se::OwningDeviceMemory> Allocate(xla::Backend* backend,
108                                                  int device_ordinal,
109                                                  size_t size);
110 
111   // Runs the specified function and handling the error::RESOURCE_EXHAUSTED
112   // status code coming out of it. In such cases, we run different memory
113   // freeing operations trying to make runfn succeed. The requested_free_size
114   // argument represents an hint of the requested memory size which would make
115   // runfn succeed.
116   template <typename T>
117   xla::StatusOr<T> Run(const std::function<xla::StatusOr<T>()>& runfn,
118                        xla::Backend* backend, int device_ordinal,
119                        size_t requested_free_size);
120 
121   string DebugString() const override;
122 
123   // Returns the invalid key value, which will be never generated by the
124   // Intern() API.
InvalidKey()125   static int64 InvalidKey() { return 0; }
126 
127  private:
128   // Structure used to track the progress of a try-to-free operation. It is
129   // initialized and the passed to the TryFreeMemoryStep() API.
130   struct MemoryReclaimContext {
MemoryReclaimContextMemoryReclaimContext131     MemoryReclaimContext(xla::Backend* backend, int device_ordinal,
132                          size_t requested_free_size)
133         : backend(backend),
134           device_ordinal(device_ordinal),
135           requested_free_size(requested_free_size) {}
136 
137     xla::Backend* const backend = nullptr;
138     const int device_ordinal = 0;
139     const size_t requested_free_size = 0;
140     size_t free_size = 0;
141     bool done_freeing = false;
142     bool done_compacting = false;
143   };
144 
145   DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing);
146 
147   // Called multiple times while trying to make a memory consuming function call
148   // to fit. Performs progressively more expensive memory reduction operations,
149   // until returning error::RESOURCE_EXHAUSTED when no further reductions are
150   // possible.
151   Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status);
152 
153   mutex lock_;
154   std::vector<std::unique_ptr<DeviceContext>> device_contexts_;
155 };
156 
157 template <typename T>
Run(const std::function<xla::StatusOr<T> ()> & runfn,xla::Backend * backend,int device_ordinal,size_t requested_free_size)158 xla::StatusOr<T> XRTMemoryManager::Run(
159     const std::function<xla::StatusOr<T>()>& runfn, xla::Backend* backend,
160     int device_ordinal, size_t requested_free_size) {
161   MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size);
162   while (true) {
163     // We assume that runfn is a relatively fast-fail function compared to the
164     // operations required to free up the required memory. Here we call into the
165     // TryFreeMemoryStep() API multiple times, which will run progressively more
166     // expensive operations.
167     auto result_or = runfn();
168     if (result_or.status().code() != error::RESOURCE_EXHAUSTED) {
169       return result_or;
170     }
171     TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status()));
172   }
173 }
174 
175 }  // namespace tensorflow
176 
177 #endif  // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
178