1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ 18 19 #include <memory> 20 #include <random> 21 #include <vector> 22 23 #include "absl/synchronization/mutex.h" 24 #include "tensorflow/compiler/xla/client/local_client.h" 25 #include "tensorflow/compiler/xla/pjrt/event_pool.h" 26 #include "tensorflow/compiler/xla/pjrt/semaphore.h" 27 #include "tensorflow/compiler/xla/pjrt/worker_thread.h" 28 #include "tensorflow/compiler/xla/status.h" 29 #include "tensorflow/core/platform/stream_executor.h" 30 31 namespace xla { 32 33 // Class that encapsulates state relating to a device (e.g., a GPU) on which we 34 // can perform computation and transfers. LocalDeviceState objects only exist 35 // for devices local to this host. 36 class LocalDeviceState { 37 public: 38 // There are three different semantics used by memory allocators on different 39 // devices. 40 enum AllocationModel { 41 // kSynchronous is used by CPU devices. 42 // 43 // A buffer returned from the allocator can be used immediately. 44 // 45 // A buffer cannot be freed until after the last stream operation 46 // referencing the buffer has completed, so the client is responsible for 47 // keeping buffers alive until all device-side activity that consumes those 48 // buffers has completed. 49 // 50 // The client's use of the device allocator corresponds to a view of the 51 // tail of the last stream using a buffer. 52 kSynchronous, 53 54 // kComputeSynchronous is used by GPU devices. 55 // 56 // A buffer returned from the allocator at time t can be used after the 57 // compute stream has finished executing the last computation enqueued 58 // before time t. 59 // 60 // A buffer b can be freed after: 61 // 1) The last use of b on the compute stream has been enqueued, and 62 // 2) For any non-compute stream s on which an operation o using b is 63 // enqueued, either: 64 // a) The host has been notified that o has completed, or 65 // b) The next operation to be enqueued on the compute stream is 66 // guaranteed to be started after o has completed. 67 // 68 // The client's use of the device allocator corresponds to a view of the 69 // tail of the compute stream. 70 kComputeSynchronized, 71 72 // kAsynchronous is used by TPU devices. 73 // 74 // A buffer returned from the allocator can be used immediately. 75 // 76 // A buffer b can be freed as soon as the last stream operation using b has 77 // been enqueued. 78 // 79 // The allocator and lower-level runtime are responsible for keeping buffers 80 // alive (if that is needed) from the perspective of the device until any 81 // device-side work actually completes. 82 // 83 // The only exception is when a buffer is transferred between devices since 84 // only one of the device executors knows about the transfer, so the buffer 85 // must be manually kept alive from the perspective of the other executor. 86 kAsynchronous 87 }; 88 89 // If asynchronous is false, the host will synchronize to the device after 90 // each execution or transfer. This is intended for debugging only. 91 LocalDeviceState(se::StreamExecutor* executor, LocalClient* client, 92 AllocationModel allocation_model, bool asynchronous, 93 bool allow_event_reuse); 94 virtual ~LocalDeviceState(); 95 executor()96 se::StreamExecutor* executor() const { return executor_; } 97 // StreamExecutor (local) device ordinal. device_ordinal()98 int device_ordinal() const { return executor_->device_ordinal(); } 99 client()100 LocalClient* client() const { return client_; } 101 allocation_model()102 AllocationModel allocation_model() const { return allocation_model_; } 103 event_pool()104 EventPool& event_pool() { return event_pool_; } 105 compute_stream()106 se::Stream* compute_stream() const { return compute_stream_.get(); } host_to_device_stream()107 se::Stream* host_to_device_stream() const { 108 return host_to_device_stream_.get(); 109 } callback_stream()110 se::Stream* callback_stream() const { return callback_stream_.get(); } 111 112 // Returns a device to host stream. Allocates streams in a round-robin fashion 113 // amongst the available streams. 114 se::Stream* GetDeviceToHostStream(); 115 116 // Returns a device to device stream. Allocates streams in a round-robin 117 // fashion amongst the available streams. 118 se::Stream* GetDeviceToDeviceStream(); 119 120 // Returns a stream from a pool. The stream is guaranteed not to have any 121 // currently outstanding work at its tail. 122 std::unique_ptr<se::Stream> BorrowStreamFromPool(); 123 // Returns a stream to the pool. The caller must ensure the stream does not 124 // have any outstanding work at its tail. 125 void ReturnStreamToPool(std::unique_ptr<se::Stream> stream); 126 127 // Enqueues a copy of `src_buffer` to `dst_buffer` onto `transfer_stream`. 128 virtual Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, 129 se::Stream* dst_stream, 130 se::DeviceMemoryBase src_buffer, 131 se::DeviceMemoryBase dst_buffer); 132 execute_thread()133 WorkerThread* execute_thread() const { return execute_thread_.get(); } 134 135 // Enqueues a host callback on 'stream', to be executed by callback_thread_. 136 // ThenDoHostCallback is often constrained in what it can do, in particular, 137 // on GPU the callback runs on a thread belonging to the GPU runtime and 138 // cannot perform GPU operations itself. 139 void ThenExecuteOnCallbackThread(se::Stream* stream, 140 std::function<void()> callback) const; 141 142 // Helpers for releasing values on a worker thread at the tail of a stream on 143 // a worker thread. Copies `object`, and destroys the copy when the tail of 144 // the stream is reached. The destruction happens either in the caller's 145 // thread or on the worker thread (depending on thread schedules), not a 146 // device callback, so it is safe if the destructor frees device resource 147 // (e.g., GPU objects). 148 // TODO(phawkins): use move-capture when we can use C++14 features. 149 template <typename T> ThenRelease(se::Stream * stream,T object)150 void ThenRelease(se::Stream* stream, T object) const { 151 if (callback_stream_.get() != stream) { 152 callback_stream_->ThenWaitFor(stream); 153 } 154 ThenExecuteOnCallbackThread(callback_stream_.get(), 155 [object]() { /* releases object */ }); 156 } 157 compute_semaphore()158 Semaphore& compute_semaphore() { return compute_semaphore_; } 159 160 // Returns a fresh, PRNG-generated random seed for an XLA computation. 161 int GetNewPrngSeed(); 162 163 private: 164 Status SynchronizeAllActivity(); 165 166 AllocationModel allocation_model_; 167 168 EventPool event_pool_; 169 170 // Semaphore used to limit how many programs can be enqueued on the compute 171 // stream by the host ahead of the device. 172 Semaphore compute_semaphore_; 173 174 se::StreamExecutor* const executor_; 175 LocalClient* const client_; 176 std::unique_ptr<se::Stream> compute_stream_; 177 std::unique_ptr<se::Stream> host_to_device_stream_; 178 std::vector<std::unique_ptr<se::Stream>> device_to_host_streams_; 179 std::vector<std::unique_ptr<se::Stream>> device_to_device_streams_; 180 181 // Number of device-to-host and device-to-device streams. 182 static constexpr int kNumDeviceToHostStreams = 4; 183 static constexpr int kNumDeviceToDeviceStreams = 4; 184 185 absl::Mutex mu_; 186 int next_device_to_host_stream_ TF_GUARDED_BY(mu_) = 0; 187 int next_device_to_device_stream_ TF_GUARDED_BY(mu_) = 0; 188 std::stack<std::unique_ptr<se::Stream>> usage_stream_pool_ TF_GUARDED_BY(mu_); 189 190 std::random_device prng_seed_device_ TF_GUARDED_BY(mu_); 191 std::mt19937 prng_seed_generator_ TF_GUARDED_BY(mu_); 192 std::uniform_int_distribution<> prng_seed_distribution_ TF_GUARDED_BY(mu_); 193 194 // Callback stream is used for running short host-side callbacks after device 195 // side events, without preventing the device-side stream from doing useful 196 // work. 197 std::unique_ptr<se::Stream> callback_stream_; 198 199 // A worker thread, used for replicated computation launches. 200 std::unique_ptr<WorkerThread> execute_thread_; 201 202 // A worker thread, used for callbacks. It is necessary that this be a 203 // different thread to the execute thread because we acquire the compute 204 // semaphore during calls to Execute but release it from a callback and if 205 // they are the same thread we might deadlock. 206 std::unique_ptr<WorkerThread> callback_thread_; 207 }; 208 209 } // namespace xla 210 211 #endif // TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ 212