• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
18 
19 #include <memory>
20 #include <random>
21 #include <vector>
22 
23 #include "absl/synchronization/mutex.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
26 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
27 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/core/platform/stream_executor.h"
30 
31 namespace xla {
32 
33 // Class that encapsulates state relating to a device (e.g., a GPU) on which we
34 // can perform computation and transfers. LocalDeviceState objects only exist
35 // for devices local to this host.
36 class LocalDeviceState {
37  public:
38   // There are three different semantics used by memory allocators on different
39   // devices.
40   enum AllocationModel {
41     // kSynchronous is used by CPU devices.
42     //
43     // A buffer returned from the allocator can be used immediately.
44     //
45     // A buffer cannot be freed until after the last stream operation
46     // referencing the buffer has completed, so the client is responsible for
47     // keeping buffers alive until all device-side activity that consumes those
48     // buffers has completed.
49     //
50     // The client's use of the device allocator corresponds to a view of the
51     // tail of the last stream using a buffer.
52     kSynchronous,
53 
54     // kComputeSynchronous is used by GPU devices.
55     //
56     // A buffer returned from the allocator at time t can be used after the
57     // compute stream has finished executing the last computation enqueued
58     // before time t.
59     //
60     // A buffer b can be freed after:
61     //   1) The last use of b on the compute stream has been enqueued, and
62     //   2) For any non-compute stream s on which an operation o using b is
63     //      enqueued, either:
64     //     a) The host has been notified that o has completed, or
65     //     b) The next operation to be enqueued on the compute stream is
66     //        guaranteed to be started after o has completed.
67     //
68     // The client's use of the device allocator corresponds to a view of the
69     // tail of the compute stream.
70     kComputeSynchronized,
71 
72     // kAsynchronous is used by TPU devices.
73     //
74     // A buffer returned from the allocator can be used immediately.
75     //
76     // A buffer b can be freed as soon as the last stream operation using b has
77     // been enqueued.
78     //
79     // The allocator and lower-level runtime are responsible for keeping buffers
80     // alive (if that is needed) from the perspective of the device until any
81     // device-side work actually completes.
82     //
83     // The only exception is when a buffer is transferred between devices since
84     // only one of the device executors knows about the transfer, so the buffer
85     // must be manually kept alive from the perspective of the other executor.
86     kAsynchronous
87   };
88 
89   // If asynchronous is false, the host will synchronize to the device after
90   // each execution or transfer. This is intended for debugging only.
91   LocalDeviceState(se::StreamExecutor* executor, LocalClient* client,
92                    AllocationModel allocation_model, bool asynchronous,
93                    bool allow_event_reuse);
94   virtual ~LocalDeviceState();
95 
executor()96   se::StreamExecutor* executor() const { return executor_; }
97   // StreamExecutor (local) device ordinal.
device_ordinal()98   int device_ordinal() const { return executor_->device_ordinal(); }
99 
client()100   LocalClient* client() const { return client_; }
101 
allocation_model()102   AllocationModel allocation_model() const { return allocation_model_; }
103 
event_pool()104   EventPool& event_pool() { return event_pool_; }
105 
compute_stream()106   se::Stream* compute_stream() const { return compute_stream_.get(); }
host_to_device_stream()107   se::Stream* host_to_device_stream() const {
108     return host_to_device_stream_.get();
109   }
callback_stream()110   se::Stream* callback_stream() const { return callback_stream_.get(); }
111 
112   // Returns a device to host stream. Allocates streams in a round-robin fashion
113   // amongst the available streams.
114   se::Stream* GetDeviceToHostStream();
115 
116   // Returns a device to device stream. Allocates streams in a round-robin
117   // fashion amongst the available streams.
118   se::Stream* GetDeviceToDeviceStream();
119 
120   // Returns a stream from a pool. The stream is guaranteed not to have any
121   // currently outstanding work at its tail.
122   std::unique_ptr<se::Stream> BorrowStreamFromPool();
123   // Returns a stream to the pool. The caller must ensure the stream does not
124   // have any outstanding work at its tail.
125   void ReturnStreamToPool(std::unique_ptr<se::Stream> stream);
126 
127   // Enqueues a copy of `src_buffer` to `dst_buffer` onto `transfer_stream`.
128   virtual Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
129                                           se::Stream* dst_stream,
130                                           se::DeviceMemoryBase src_buffer,
131                                           se::DeviceMemoryBase dst_buffer);
132 
execute_thread()133   WorkerThread* execute_thread() const { return execute_thread_.get(); }
134 
135   // Enqueues a host callback on 'stream', to be executed by callback_thread_.
136   // ThenDoHostCallback is often constrained in what it can do, in particular,
137   // on GPU the callback runs on a thread belonging to the GPU runtime and
138   // cannot perform GPU operations itself.
139   void ThenExecuteOnCallbackThread(se::Stream* stream,
140                                    std::function<void()> callback) const;
141 
142   // Helpers for releasing values on a worker thread at the tail of a stream on
143   // a worker thread. Copies `object`, and destroys the copy when the tail of
144   // the stream is reached. The destruction happens either in the caller's
145   // thread or on the worker thread (depending on thread schedules), not a
146   // device callback, so it is safe if the destructor frees device resource
147   // (e.g., GPU objects).
148   // TODO(phawkins): use move-capture when we can use C++14 features.
149   template <typename T>
ThenRelease(se::Stream * stream,T object)150   void ThenRelease(se::Stream* stream, T object) const {
151     if (callback_stream_.get() != stream) {
152       callback_stream_->ThenWaitFor(stream);
153     }
154     ThenExecuteOnCallbackThread(callback_stream_.get(),
155                                 [object]() { /* releases object */ });
156   }
157 
compute_semaphore()158   Semaphore& compute_semaphore() { return compute_semaphore_; }
159 
160   // Returns a fresh, PRNG-generated random seed for an XLA computation.
161   int GetNewPrngSeed();
162 
163  private:
164   Status SynchronizeAllActivity();
165 
166   AllocationModel allocation_model_;
167 
168   EventPool event_pool_;
169 
170   // Semaphore used to limit how many programs can be enqueued on the compute
171   // stream by the host ahead of the device.
172   Semaphore compute_semaphore_;
173 
174   se::StreamExecutor* const executor_;
175   LocalClient* const client_;
176   std::unique_ptr<se::Stream> compute_stream_;
177   std::unique_ptr<se::Stream> host_to_device_stream_;
178   std::vector<std::unique_ptr<se::Stream>> device_to_host_streams_;
179   std::vector<std::unique_ptr<se::Stream>> device_to_device_streams_;
180 
181   // Number of device-to-host and device-to-device streams.
182   static constexpr int kNumDeviceToHostStreams = 4;
183   static constexpr int kNumDeviceToDeviceStreams = 4;
184 
185   absl::Mutex mu_;
186   int next_device_to_host_stream_ TF_GUARDED_BY(mu_) = 0;
187   int next_device_to_device_stream_ TF_GUARDED_BY(mu_) = 0;
188   std::stack<std::unique_ptr<se::Stream>> usage_stream_pool_ TF_GUARDED_BY(mu_);
189 
190   std::random_device prng_seed_device_ TF_GUARDED_BY(mu_);
191   std::mt19937 prng_seed_generator_ TF_GUARDED_BY(mu_);
192   std::uniform_int_distribution<> prng_seed_distribution_ TF_GUARDED_BY(mu_);
193 
194   // Callback stream is used for running short host-side callbacks after device
195   // side events, without preventing the device-side stream from doing useful
196   // work.
197   std::unique_ptr<se::Stream> callback_stream_;
198 
199   // A worker thread, used for replicated computation launches.
200   std::unique_ptr<WorkerThread> execute_thread_;
201 
202   // A worker thread, used for callbacks. It is necessary that this be a
203   // different thread to the execute thread because we acquire the compute
204   // semaphore during calls to Execute but release it from a callback and if
205   // they are the same thread we might deadlock.
206   std::unique_ptr<WorkerThread> callback_thread_;
207 };
208 
209 }  // namespace xla
210 
211 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
212