1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
18
19 #include <vector>
20
21 #include "absl/synchronization/mutex.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/thread_annotations.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/stream_executor/device_memory.h"
28 #include "tensorflow/stream_executor/lib/statusor.h"
29 #include "tensorflow/stream_executor/platform.h"
30
31 namespace stream_executor {
32
33 class DeviceMemoryAllocator;
34
35 // Owning pointer for memory on a device.
36 //
37 // ScopedDeviceMemory is an owning pointer like std::unique_ptr, but it can
38 // point to memory that resides on a "device" (e.g. a GPU). When a
39 // ScopedDeviceMemory goes out of scope, it frees the memory it owns.
40 //
41 // We say that an instance of ScopedDeviceMemory is "active" if it currently
42 // owns a (possibly empty) slice of memory on the device. Moving,
43 // Release()'ing, Free()'ing, and other actions can deactive an active object.
44 template <typename ElemT>
45 class ScopedDeviceMemory {
46 public:
47 // Default construction initializes the internal state to nullptr. This
48 // mirrors the std::unique_ptr<> functionality, where default construction
49 // produces a nullptr unique_ptr, which can be assigned later.
ScopedDeviceMemory()50 ScopedDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
51
52 // Construct a ScopedDeviceMemory from a custom allocator.
53 //
54 // Parameters:
55 // mem: Already-allocated device memory value for this scoped mechanism to
56 // deallocate. This memory must have been allocated by parent.
57 // device_ordinal: Device on which the memory was allocated.
58 // allocator: Allocator used to deallocate memory when this instance goes
59 // out of scope.
ScopedDeviceMemory(DeviceMemoryBase mem,int device_ordinal,DeviceMemoryAllocator * allocator)60 ScopedDeviceMemory(DeviceMemoryBase mem, int device_ordinal,
61 DeviceMemoryAllocator *allocator)
62 : wrapped_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {
63 DCHECK_GE(device_ordinal_, 0);
64 }
65
66 // A helper constructor to generate a scoped device memory given an already
67 // allocated memory and a stream executor.
68 //
69 // Precondition: memory was allocated by the stream executor `parent`.
70 ScopedDeviceMemory(StreamExecutor *parent, DeviceMemoryBase value);
71
72 // Constructor overload that places a literal array into device memory.
73 //
74 // Relies on the allocation function exposed by the stream executor `parent`,
75 // which will be also used for deallocating the memory
76 ScopedDeviceMemory(StreamExecutor *parent,
77 std::initializer_list<ElemT> values);
78
79 // Moves ownership of the memory from other to the constructed
80 // object.
81 //
82 // Postcondition: other == nullptr.
ScopedDeviceMemory(ScopedDeviceMemory && other)83 ScopedDeviceMemory(ScopedDeviceMemory &&other)
84 : wrapped_(other.Release()),
85 device_ordinal_(other.device_ordinal_),
86 allocator_(other.allocator_) {}
87
88 // Releases the memory that was provided in the constructor, through the
89 // "parent" StreamExecutor.
~ScopedDeviceMemory()90 ~ScopedDeviceMemory() { TF_CHECK_OK(Free()); }
91
92 // Moves ownership of the memory from other to this object.
93 //
94 // Postcondition: other == nullptr.
95 ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) {
96 TF_CHECK_OK(Free());
97 wrapped_ = other.Release();
98 allocator_ = other.allocator_;
99 device_ordinal_ = other.device_ordinal_;
100 return *this;
101 }
102
103 // Returns the memory that backs this scoped allocation converted to
104 // DeviceMemory<T> apparent type. This is useful for cases where the
105 // DeviceMemory must be passed by const-ref, as the ScopedDeviceMemory doesn't
106 // allow copying, for scoped-object-lifetime reasons.
cref()107 const DeviceMemory<ElemT> &cref() const { return wrapped_; }
108
109 // Returns a pointer to the DeviceMemory<T> apparent type for use in mutable
110 // operations. The value returned should not be used outside the scope of this
111 // ScopedDeviceMemory object's lifetime.
ptr()112 DeviceMemory<ElemT> *ptr() { return &wrapped_; }
ptr()113 const DeviceMemory<ElemT> *ptr() const { return &wrapped_; }
114
115 // Smart-pointer-like operators for the wrapped DeviceMemory.
116 // This reference must not be used outside the lifetime of this
117 // ScopedDeviceMemory.
118 const DeviceMemory<ElemT> &operator*() const { return cref(); }
119 DeviceMemory<ElemT> *operator->() { return ptr(); }
120 const DeviceMemory<ElemT> *operator->() const { return ptr(); }
121
is_null()122 bool is_null() const { return wrapped_.is_null(); }
123 bool operator==(std::nullptr_t other) const { return is_null(); }
124 bool operator!=(std::nullptr_t other) const { return !is_null(); }
125
126 // Analogous to std::unique_ptr::release, releases ownership of the held
127 // memory and transfers it to the caller.
128 //
129 // Postcondition: *this == nullptr
Release()130 DeviceMemory<ElemT> Release() {
131 DeviceMemory<ElemT> tmp = wrapped_;
132 wrapped_ = DeviceMemory<ElemT>{};
133 return tmp;
134 }
135
136 // The returned allocator is nonnull iff this object is active.
allocator()137 DeviceMemoryAllocator *allocator() const { return allocator_; }
138
device_ordinal()139 int device_ordinal() const { return device_ordinal_; }
140
141 // Frees the existing memory, resets the wrapped memory to null.
142 port::Status Free();
143
144 private:
145 DeviceMemory<ElemT> wrapped_; // Value we wrap with scoped-release.
146 int device_ordinal_; // Negative one for inactive object.
147 DeviceMemoryAllocator *allocator_; // Null if this object is inactive.
148
149 SE_DISALLOW_COPY_AND_ASSIGN(ScopedDeviceMemory);
150 };
151
152 // Type alias for compatibility with the previous managed memory implementation.
153 using OwningDeviceMemory = ScopedDeviceMemory<uint8>;
154
155 // Memory allocator interface for the device.
156 //
157 // Intended usage is through Allocate() functions which return an owning smart
158 // pointer.
159 class DeviceMemoryAllocator {
160 public:
161 // Parameter platform indicates which platform the allocator allocates memory
162 // on. Must be non-null.
DeviceMemoryAllocator(const Platform * platform)163 explicit DeviceMemoryAllocator(const Platform* platform)
164 : platform_(platform) {}
~DeviceMemoryAllocator()165 virtual ~DeviceMemoryAllocator() {}
166
167 // Allocates memory on the device.
168 //
169 // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
170 // must not be null. If size == 0, must return a null OwningDeviceMemory.
171 //
172 // 'retry_on_failure': If false, and the first attempt to allocate the memory
173 // fails, the allocation should return immediately without retrying. An
174 // example use case is optional scratch spaces where a failure has only
175 // performance impact.
176 virtual port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
177 uint64 size,
178 bool retry_on_failure,
179 int64 memory_space) = 0;
180
181 // Two-arg version of Allocate(), which sets retry-on-failure to true and
182 // memory_space to default (0).
183 //
184 // (We don't simply use a default argument on the virtual Allocate function
185 // because default args on virtual functions are disallowed by the Google
186 // style guide.)
Allocate(int device_ordinal,uint64 size)187 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size) {
188 return Allocate(device_ordinal, size, /*retry_on_failure=*/true,
189 /*memory_space=*/0);
190 }
191
192 // Three-arg version of Allocate(), which sets memory_space to default (0).
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)193 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
194 bool retry_on_failure) {
195 return Allocate(device_ordinal, size, retry_on_failure,
196 /*memory_space=*/0);
197 }
198
199 // Typed version of the allocation, returning typed memory.
200 template <typename ElemT>
201 port::StatusOr<ScopedDeviceMemory<ElemT>> Allocate(
202 int device_ordinal, uint64 size, bool retry_on_failure = true,
203 int64 memory_space = 0) {
204 return Allocate(device_ordinal, size, retry_on_failure, memory_space);
205 }
206
207 // Must be a nop for null pointers. Should not be used.
208 //
209 // TODO(cheshire): Add deprecation notice.
210 virtual port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0;
211
212 // Return the platform that the allocator allocates memory on.
platform()213 const Platform* platform() const { return platform_; }
214
215 // Can we call Deallocate() as soon as a computation has been scheduled on
216 // a stream, or do we have to wait for the computation to complete first?
AllowsAsynchronousDeallocation()217 virtual bool AllowsAsynchronousDeallocation() const { return false; }
218
219 // Returns a stream pointer on which it is always safe to access memory
220 // allocated by this allocator. It is not necessary to use the returned stream
221 // though, as clients may have additional information letting them safely use
222 // a different stream.
223 virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
224
225 protected:
226 const Platform* platform_;
227 };
228
229 // Default memory allocator for a platform which uses
230 // StreamExecutor::Allocate/Deallocate.
231 class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
232 public:
233 // Create an allocator supporting a single device, corresponding to the passed
234 // executor.
235 explicit StreamExecutorMemoryAllocator(StreamExecutor *executor);
236
237 // Create an allocator supporting multiple stream executors.
238 //
239 // Precondition: all stream_executors have different device ordinals.
240 StreamExecutorMemoryAllocator(
241 const Platform *platform,
242 absl::Span<StreamExecutor *const> stream_executors);
243
244 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
245 bool retry_on_failure,
246 int64 memory_space) override;
247
248 // Pull in two-arg overload that sets retry_on_failure to true.
249 using DeviceMemoryAllocator::Allocate;
250
251 port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override;
252
253 bool AllowsAsynchronousDeallocation() const override;
254
255 // Gets-or-creates a stream for a given `device_ordinal` from an appropriate
256 // stream executor.
257 port::StatusOr<Stream *> GetStream(int device_ordinal) override;
258
259 // Gets the stream executor for given device ordinal.
260 port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
261
262 private:
263 // Available stream executors. Each stream executor has a different device
264 // ordinal.
265 std::vector<StreamExecutor *> stream_executors_;
266
267 absl::Mutex mutex_;
268
269 // Cache of streams for GetStream.
270 std::map<int, Stream> streams_ TF_GUARDED_BY(mutex_);
271 };
272
273 template <typename ElemT>
Free()274 port::Status ScopedDeviceMemory<ElemT>::Free() {
275 if (!wrapped_.is_null()) {
276 CHECK(allocator_ != nullptr) << "Owning pointer in inconsistent state";
277 TF_RETURN_IF_ERROR(allocator_->Deallocate(device_ordinal_, wrapped_));
278 }
279 wrapped_ = DeviceMemory<ElemT>{};
280 return port::Status::OK();
281 }
282
283 } // namespace stream_executor
284
285 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
286