• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
18 
19 #include <vector>
20 
21 #include "absl/synchronization/mutex.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/stream_executor/device_memory.h"
27 #include "tensorflow/stream_executor/lib/statusor.h"
28 #include "tensorflow/stream_executor/platform.h"
29 
30 namespace stream_executor {
31 
32 class DeviceMemoryAllocator;
33 
34 // Owning pointer for memory on a device.
35 //
36 // ScopedDeviceMemory is an owning pointer like std::unique_ptr, but it can
37 // point to memory that resides on a "device" (e.g. a GPU).  When a
38 // ScopedDeviceMemory goes out of scope, it frees the memory it owns.
39 //
40 // We say that an instance of ScopedDeviceMemory is "active" if it currently
41 // owns a (possibly empty) slice of memory on the device.  Moving,
42 // Release()'ing, Free()'ing, and other actions can deactive an active object.
43 template <typename ElemT>
44 class ScopedDeviceMemory {
45  public:
46   // Default construction initializes the internal state to nullptr.  This
47   // mirrors the std::unique_ptr<> functionality, where default construction
48   // produces a nullptr unique_ptr, which can be assigned later.
ScopedDeviceMemory()49   ScopedDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
50 
51   // Construct a ScopedDeviceMemory from a custom allocator.
52   //
53   // Parameters:
54   //  mem: Already-allocated device memory value for this scoped mechanism to
55   //       deallocate. This memory must have been allocated by parent.
56   //  device_ordinal: Device on which the memory was allocated.
57   //  allocator: Allocator used to deallocate memory when this instance goes
58   //             out of scope.
ScopedDeviceMemory(DeviceMemoryBase mem,int device_ordinal,DeviceMemoryAllocator * allocator)59   ScopedDeviceMemory(DeviceMemoryBase mem, int device_ordinal,
60                      DeviceMemoryAllocator *allocator)
61       : wrapped_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {}
62 
63   // A helper constructor to generate a scoped device memory given an already
64   // allocated memory and a stream executor.
65   //
66   // Precondition: memory was allocated by the stream executor `parent`.
67   ScopedDeviceMemory(StreamExecutor *parent, DeviceMemoryBase value);
68 
69   // Constructor overload that places a literal array into device memory.
70   //
71   // Relies on the allocation function exposed by the stream executor `parent`,
72   // which will be also used for deallocating the memory
73   ScopedDeviceMemory(StreamExecutor *parent,
74                      std::initializer_list<ElemT> values);
75 
76   // Moves ownership of the memory from other to the constructed
77   // object.
78   //
79   // Postcondition: other == nullptr.
ScopedDeviceMemory(ScopedDeviceMemory && other)80   ScopedDeviceMemory(ScopedDeviceMemory &&other)
81       : ScopedDeviceMemory(other.Release(), other.device_ordinal_,
82                            other.allocator_) {}
83 
84   // Releases the memory that was provided in the constructor, through the
85   // "parent" StreamExecutor.
~ScopedDeviceMemory()86   ~ScopedDeviceMemory() { TF_CHECK_OK(Free()); }
87 
88   // Moves ownership of the memory from other to this object.
89   //
90   // Postcondition: other == nullptr.
91   ScopedDeviceMemory &operator=(ScopedDeviceMemory &&other) {
92     TF_CHECK_OK(Free());
93     wrapped_ = other.Release();
94     allocator_ = other.allocator_;
95     device_ordinal_ = other.device_ordinal_;
96     return *this;
97   }
98 
99   // Returns the memory that backs this scoped allocation converted to
100   // DeviceMemory<T> apparent type. This is useful for cases where the
101   // DeviceMemory must be passed by const-ref, as the ScopedDeviceMemory doesn't
102   // allow copying, for scoped-object-lifetime reasons.
cref()103   const DeviceMemory<ElemT> &cref() const { return wrapped_; }
104 
105   // Returns a pointer to the DeviceMemory<T> apparent type for use in mutable
106   // operations. The value returned should not be used outside the scope of this
107   // ScopedDeviceMemory object's lifetime.
ptr()108   DeviceMemory<ElemT> *ptr() { return &wrapped_; }
ptr()109   const DeviceMemory<ElemT> *ptr() const { return &wrapped_; }
110 
111   // Smart-pointer-like operators for the wrapped DeviceMemory.
112   // This reference must not be used outside the lifetime of this
113   // ScopedDeviceMemory.
114   const DeviceMemory<ElemT> &operator*() const { return cref(); }
115   DeviceMemory<ElemT> *operator->() { return ptr(); }
116   const DeviceMemory<ElemT> *operator->() const { return ptr(); }
117 
is_null()118   bool is_null() const { return wrapped_.is_null(); }
119   bool operator==(std::nullptr_t other) const { return is_null(); }
120   bool operator!=(std::nullptr_t other) const { return !is_null(); }
121 
122   // Analogous to std::unique_ptr::release, releases ownership of the held
123   // memory and transfers it to the caller.
124   //
125   // Postcondition: *this == nullptr
Release()126   DeviceMemory<ElemT> Release() {
127     DeviceMemory<ElemT> tmp = wrapped_;
128     wrapped_ = DeviceMemory<ElemT>{};
129     return tmp;
130   }
131 
132   // The returned allocator is nonnull iff this object is active.
allocator()133   DeviceMemoryAllocator *allocator() const { return allocator_; }
134 
device_ordinal()135   int device_ordinal() const { return device_ordinal_; }
136 
137   // Frees the existing memory, resets the wrapped memory to null.
138   port::Status Free();
139 
140  private:
141   DeviceMemory<ElemT> wrapped_;       // Value we wrap with scoped-release.
142   int device_ordinal_;                // Negative one for inactive object.
143   DeviceMemoryAllocator *allocator_;  // Null if this object is inactive.
144 
145   SE_DISALLOW_COPY_AND_ASSIGN(ScopedDeviceMemory);
146 };
147 
148 // Type alias for compatibility with the previous managed memory implementation.
149 using OwningDeviceMemory = ScopedDeviceMemory<uint8>;
150 
151 // Memory allocator interface for the device.
152 //
153 // Intended usage is through Allocate() functions which return an owning smart
154 // pointer.
155 class DeviceMemoryAllocator {
156  public:
157   // Parameter platform indicates which platform the allocator allocates memory
158   // on. Must be non-null.
DeviceMemoryAllocator(const Platform * platform)159   explicit DeviceMemoryAllocator(const Platform* platform)
160       : platform_(platform) {}
~DeviceMemoryAllocator()161   virtual ~DeviceMemoryAllocator() {}
162 
163   // Allocates memory on the device.
164   //
165   // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
166   // must not be null.  If size == 0, must return a null OwningDeviceMemory.
167   //
168   // 'retry_on_failure': If false, and the first attempt to allocate the memory
169   // fails, the allocation should return immediately without retrying.  An
170   // example use case is optional scratch spaces where a failure has only
171   // performance impact.
172   virtual port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal,
173                                                       uint64 size,
174                                                       bool retry_on_failure,
175                                                       int64 memory_space) = 0;
176 
177   // Two-arg version of Allocate(), which sets retry-on-failure to true and
178   // memory_space to default (0).
179   //
180   // (We don't simply use a default argument on the virtual Allocate function
181   // because default args on virtual functions are disallowed by the Google
182   // style guide.)
Allocate(int device_ordinal,uint64 size)183   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size) {
184     return Allocate(device_ordinal, size, /*retry_on_failure=*/true,
185                     /*memory_space=*/0);
186   }
187 
188   // Three-arg version of Allocate(), which sets memory_space to default (0).
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)189   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
190                                               bool retry_on_failure) {
191     return Allocate(device_ordinal, size, retry_on_failure,
192                     /*memory_space=*/0);
193   }
194 
195   // Typed version of the allocation, returning typed memory.
196   template <typename ElemT>
197   port::StatusOr<ScopedDeviceMemory<ElemT>> Allocate(
198       int device_ordinal, uint64 size, bool retry_on_failure = true,
199       int64 memory_space = 0) {
200     return Allocate(device_ordinal, size, retry_on_failure, memory_space);
201   }
202 
203   // Must be a nop for null pointers. Should not be used.
204   //
205   // TODO(cheshire): Add deprecation notice.
206   virtual port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) = 0;
207 
208   // Return the platform that the allocator allocates memory on.
platform()209   const Platform* platform() const { return platform_; }
210 
211   // Can we call Deallocate() as soon as a computation has been scheduled on
212   // a stream, or do we have to wait for the computation to complete first?
AllowsAsynchronousDeallocation()213   virtual bool AllowsAsynchronousDeallocation() const { return false; }
214 
215   // Returns a stream pointer on which it is always safe to access memory
216   // allocated by this allocator. It is not necessary to use the returned stream
217   // though, as clients may have additional information letting them safely use
218   // a different stream.
219   virtual port::StatusOr<Stream *> GetStream(int device_ordinal) = 0;
220 
221  protected:
222   const Platform* platform_;
223 };
224 
225 // Default memory allocator for a platform which uses
226 // StreamExecutor::Allocate/Deallocate.
227 class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
228  public:
229   // Create an allocator supporting a single device, corresponding to the passed
230   // executor.
231   explicit StreamExecutorMemoryAllocator(StreamExecutor *executor);
232 
233   // Create an allocator supporting multiple stream executors.
234   //
235   // Precondition: all stream_executors have different device ordinals.
236   StreamExecutorMemoryAllocator(
237       const Platform *platform,
238       absl::Span<StreamExecutor *const> stream_executors);
239 
240   port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
241                                               bool retry_on_failure,
242                                               int64 memory_space) override;
243 
244   // Pull in two-arg overload that sets retry_on_failure to true.
245   using DeviceMemoryAllocator::Allocate;
246 
247   port::Status Deallocate(int device_ordinal, DeviceMemoryBase mem) override;
248 
249   bool AllowsAsynchronousDeallocation() const override;
250 
251   // Gets-or-creates a stream for a given `device_ordinal` from an appropriate
252   // stream executor.
253   port::StatusOr<Stream *> GetStream(int device_ordinal) override;
254 
255   // Gets the stream executor for given device ordinal.
256   port::StatusOr<StreamExecutor *> GetStreamExecutor(int device_ordinal) const;
257 
258  private:
259   // Available stream executors. Each stream executor has a different device
260   // ordinal.
261   std::vector<StreamExecutor *> stream_executors_;
262 
263   absl::Mutex mutex_;
264 
265   // Cache of streams for GetStream.
266   std::map<int, Stream> streams_ GUARDED_BY(mutex_);
267 };
268 
269 template <typename ElemT>
Free()270 port::Status ScopedDeviceMemory<ElemT>::Free() {
271   if (!wrapped_.is_null()) {
272     CHECK(allocator_ != nullptr) << "Owning pointer in inconsistent state";
273     TF_RETURN_IF_ERROR(allocator_->Deallocate(device_ordinal_, wrapped_));
274   }
275   wrapped_ = DeviceMemory<ElemT>{};
276   return port::Status::OK();
277 }
278 
279 }  // namespace stream_executor
280 
281 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
282