1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
18
19 #include <vector>
20
21 #include "absl/synchronization/mutex.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/stream_executor/device_memory.h"
27 #include "tensorflow/stream_executor/lib/statusor.h"
28 #include "tensorflow/stream_executor/platform.h"
29
30 namespace stream_executor {
31
32 class DeviceMemoryAllocator;
33
34 // Owning pointer for memory on a device.
35 //
36 // ScopedDeviceMemory is an owning pointer like std::unique_ptr, but it can
37 // point to memory that resides on a "device" (e.g. a GPU). When a
38 // ScopedDeviceMemory goes out of scope, it frees the memory it owns.
39 //
40 // We say that an instance of ScopedDeviceMemory is "active" if it currently
41 // owns a (possibly empty) slice of memory on the device. Moving,
42 // Release()'ing, Free()'ing, and other actions can deactive an active object.
43 template <typename ElemT>
44 class ScopedDeviceMemory {
45 public:
46 // Default construction initializes the internal state to nullptr. This
47 // mirrors the std::unique_ptr<> functionality, where default construction
48 // produces a nullptr unique_ptr, which can be assigned later.
ScopedDeviceMemory()49 ScopedDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
50
51 // Construct a ScopedDeviceMemory from a custom allocator.
52 //
53 // Parameters:
54 // mem: Already-allocated device memory value for this scoped mechanism to
55 // deallocate. This memory must have been allocated by parent.
56 // device_ordinal: Device on which the memory was allocated.
57 // allocator: Allocator used to deallocate memory when this instance goes
58 // out of scope.
ScopedDeviceMemory(DeviceMemoryBase mem,int device_ordinal,DeviceMemoryAllocator * allocator)59 ScopedDeviceMemory(DeviceMemoryBase mem, int device_ordinal,
60 DeviceMemoryAllocator *allocator)
61 : wrapped_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {}
62
63 // A helper constructor to generate a scoped device memory given an already
64 // allocated memory and a stream executor.
65 //
66 // Precondition: memory was allocated by the stream executor `parent`.
67 ScopedDeviceMemory(StreamExecutor *parent, DeviceMemoryBase value);
68
69 // Constructor overload that places a literal array into device memory.
70 //
71 // Relies on the allocation function exposed by the stream executor `parent`,
72 // which will be also used for deallocating the memory
73 ScopedDeviceMemory(StreamExecutor *parent,
74 std::initializer_list<ElemT> values);
75
76 // Moves ownership of the memory from other to the constructed
77 // object.
78 //
79 // Postcondition: other == nullptr.
ScopedDeviceMemory(ScopedDeviceMemory && other)80 ScopedDeviceMemory(ScopedDeviceMemory &&other)
81 : ScopedDeviceMemory(other.Release(), other.device_ordinal_,
82 other.allocator_) {}
83
84 // Releases the memory that was provided in the constructor, through the
85 // "parent" StreamExecutor.
~ScopedDeviceMemory()86 ~ScopedDeviceMemory() { TF_CHECK_OK(Free()); }
87
88 // Moves ownership of the memory from other to this object.
89 //
90 // Postcondition: other == nullptr.
91 ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) {
92 TF_CHECK_OK(Free());
93 wrapped_ = other.Release();
94 allocator_ = other.allocator_;
95 device_ordinal_ = other.device_ordinal_;
96 return *this;
97 }
98
99 // Returns the memory that backs this scoped allocation converted to
100 // DeviceMemory<T> apparent type. This is useful for cases where the
101 // DeviceMemory must be passed by const-ref, as the ScopedDeviceMemory doesn't
102 // allow copying, for scoped-object-lifetime reasons.
cref()103 const DeviceMemory<ElemT> &cref() const { return wrapped_; }
104
105 // Returns a pointer to the DeviceMemory<T> apparent type for use in mutable
106 // operations. The value returned should not be used outside the scope of this
107 // ScopedDeviceMemory object's lifetime.
ptr()108 DeviceMemory<ElemT> *ptr() { return &wrapped_; }
ptr()109 const DeviceMemory<ElemT> *ptr() const { return &wrapped_; }
110
111 // Smart-pointer-like operators for the wrapped DeviceMemory.
112 // This reference must not be used outside the lifetime of this
113 // ScopedDeviceMemory.
114 const DeviceMemory<ElemT> &operator*() const { return cref(); }
115 DeviceMemory<ElemT> *operator->() { return ptr(); }
116 const DeviceMemory<ElemT> *operator->() const { return ptr(); }
117
is_null()118 bool is_null() const { return wrapped_.is_null(); }
119 bool operator==(std::nullptr_t other) const { return is_null(); }
120 bool operator!=(std::nullptr_t other) const { return !is_null(); }
121
122 // Analogous to std::unique_ptr::release, releases ownership of the held
123 // memory and transfers it to the caller.
124 //
125 // Postcondition: *this == nullptr
Release()126 DeviceMemory<ElemT> Release() {
127 DeviceMemory<ElemT> tmp = wrapped_;
128 wrapped_ = DeviceMemory<ElemT>{};
129 return tmp;
130 }
131
132 // The returned allocator is nonnull iff this object is active.
allocator()133 DeviceMemoryAllocator *allocator() const { return allocator_; }
134
device_ordinal()135 int device_ordinal() const { return device_ordinal_; }
136
137 // Frees the existing memory, resets the wrapped memory to null.
138 port::Status Free();
139
140 private:
141 DeviceMemory<ElemT> wrapped_; // Value we wrap with scoped-release.
142 int device_ordinal_; // Negative one for inactive object.
143 DeviceMemoryAllocator *allocator_; // Null if this object is inactive.
144
145 SE_DISALLOW_COPY_AND_ASSIGN(ScopedDeviceMemory);
146 };
147
148 // Type alias for compatibility with the previous managed memory implementation.
149 using OwningDeviceMemory = ScopedDeviceMemory<uint8>;
150
151 // Memory allocator interface for the device.
152 //
153 // Intended usage is through Allocate() functions which return an owning smart
154 // pointer.
155 class DeviceMemoryAllocator {
156 public:
157 // Parameter platform indicates which platform the allocator allocates memory
158 // on. Must be non-null.
DeviceMemoryAllocator(const Platform * platform)159 explicit DeviceMemoryAllocator(const Platform* platform)
160 : platform_(platform) {}
~DeviceMemoryAllocator()161 virtual ~DeviceMemoryAllocator() {}
162
163 // Allocates memory on the device.
164 //
165 // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
166 // must not be null. If size == 0, must return a null OwningDeviceMemory.
167 //
168 // 'retry_on_failure': If false, and the first attempt to allocate the memory
169 // fails, the allocation should return immediately without retrying. An
170 // example use case is optional scratch spaces where a failure has only
171 // performance impact.
172 virtual port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
173 uint64 size,
174 bool retry_on_failure,
175 int64 memory_space) = 0;
176
177 // Two-arg version of Allocate(), which sets retry-on-failure to true and
178 // memory_space to default (0).
179 //
180 // (We don't simply use a default argument on the virtual Allocate function
181 // because default args on virtual functions are disallowed by the Google
182 // style guide.)
Allocate(int device_ordinal,uint64 size)183 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size) {
184 return Allocate(device_ordinal, size, /*retry_on_failure=*/true,
185 /*memory_space=*/0);
186 }
187
188 // Three-arg version of Allocate(), which sets memory_space to default (0).
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)189 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
190 bool retry_on_failure) {
191 return Allocate(device_ordinal, size, retry_on_failure,
192 /*memory_space=*/0);
193 }
194
195 // Typed version of the allocation, returning typed memory.
196 template <typename ElemT>
197 port::StatusOr<ScopedDeviceMemory<ElemT>> Allocate(
198 int device_ordinal, uint64 size, bool retry_on_failure = true,
199 int64 memory_space = 0) {
200 return Allocate(device_ordinal, size, retry_on_failure, memory_space);
201 }
202
203 // Must be a nop for null pointers. Should not be used.
204 //
205 // TODO(cheshire): Add deprecation notice.
206 virtual port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0;
207
208 // Return the platform that the allocator allocates memory on.
platform()209 const Platform* platform() const { return platform_; }
210
211 // Can we call Deallocate() as soon as a computation has been scheduled on
212 // a stream, or do we have to wait for the computation to complete first?
AllowsAsynchronousDeallocation()213 virtual bool AllowsAsynchronousDeallocation() const { return false; }
214
215 // Returns a stream pointer on which it is always safe to access memory
216 // allocated by this allocator. It is not necessary to use the returned stream
217 // though, as clients may have additional information letting them safely use
218 // a different stream.
219 virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
220
221 protected:
222 const Platform* platform_;
223 };
224
225 // Default memory allocator for a platform which uses
226 // StreamExecutor::Allocate/Deallocate.
227 class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
228 public:
229 // Create an allocator supporting a single device, corresponding to the passed
230 // executor.
231 explicit StreamExecutorMemoryAllocator(StreamExecutor *executor);
232
233 // Create an allocator supporting multiple stream executors.
234 //
235 // Precondition: all stream_executors have different device ordinals.
236 StreamExecutorMemoryAllocator(
237 const Platform *platform,
238 absl::Span<StreamExecutor *const> stream_executors);
239
240 port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
241 bool retry_on_failure,
242 int64 memory_space) override;
243
244 // Pull in two-arg overload that sets retry_on_failure to true.
245 using DeviceMemoryAllocator::Allocate;
246
247 port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override;
248
249 bool AllowsAsynchronousDeallocation() const override;
250
251 // Gets-or-creates a stream for a given `device_ordinal` from an appropriate
252 // stream executor.
253 port::StatusOr<Stream *> GetStream(int device_ordinal) override;
254
255 // Gets the stream executor for given device ordinal.
256 port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
257
258 private:
259 // Available stream executors. Each stream executor has a different device
260 // ordinal.
261 std::vector<StreamExecutor *> stream_executors_;
262
263 absl::Mutex mutex_;
264
265 // Cache of streams for GetStream.
266 std::map<int, Stream> streams_ GUARDED_BY(mutex_);
267 };
268
269 template <typename ElemT>
Free()270 port::Status ScopedDeviceMemory<ElemT>::Free() {
271 if (!wrapped_.is_null()) {
272 CHECK(allocator_ != nullptr) << "Owning pointer in inconsistent state";
273 TF_RETURN_IF_ERROR(allocator_->Deallocate(device_ordinal_, wrapped_));
274 }
275 wrapped_ = DeviceMemory<ElemT>{};
276 return port::Status::OK();
277 }
278
279 } // namespace stream_executor
280
281 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
282