• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/synchronization/mutex.h"
25 #include "absl/synchronization/notification.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/client/executable_build_options.h"
28 #include "tensorflow/compiler/xla/executable_run_options.h"
29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
30 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
31 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/platform/casts.h"
37 #include "tensorflow/core/platform/threadpool.h"
38 
39 namespace xla {
40 
41 constexpr char kTpuPlatform[] = "tpu";
42 
43 class TpuDevice : public PjRtDevice {
44  public:
45   TpuDevice(int id, int task_id, const std::array<int, 3>& coords,
46             int core_on_chip);
47 
coords()48   const std::array<int, 3>& coords() const { return coords_; }
core_on_chip()49   int core_on_chip() const { return core_on_chip_; }
50 
51   std::string DebugString() const override;
52 
53   static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
54   GetTpuDevices(const tpu_driver::SystemInfo& system_info);
55 
client()56   PjRtClient* client() const override { return nullptr; }
57 
IsAddressable()58   bool IsAddressable() const override { return false; }
59 
id()60   int id() const override { return id_; }
61 
task_id()62   int task_id() const override { return task_id_; }
63 
local_hardware_id()64   int local_hardware_id() const override { return -1; }
65 
device_kind()66   absl::string_view device_kind() const override { return device_kind_; }
67 
TransferToInfeed(const LiteralSlice & literal)68   Status TransferToInfeed(const LiteralSlice& literal) override {
69     return Unimplemented("Infeed not yet implemented via this API");
70   }
71 
TransferFromOutfeed(MutableBorrowingLiteral literal)72   Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
73     return Unimplemented("Outfeed not yet implemented via this API");
74   }
75 
76  private:
77   const int id_;
78   const int task_id_;
79   const std::array<int, 3> coords_;
80   const std::string device_kind_ = "Cloud TPU";
81   // Index of the core of the same chip.
82   int core_on_chip_;
83 };
84 
85 // Encapsulates the state of Python session with XLA.
86 class PyTpuClient {
87  public:
88   // Initializes a local XLA client for `platform_name`. Returns an error if no
89   // such platform exists, or if the platform has no visible devices.
90   static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
91 
92   explicit PyTpuClient(std::string platform_name,
93                        std::unique_ptr<tpu_driver::TpuDriver> driver,
94                        std::vector<std::shared_ptr<PjRtDevice>> devices,
95                        int task_id);
96   virtual ~PyTpuClient() = default;
97 
98   PyTpuClient(const PyTpuClient&) = delete;
99   PyTpuClient(PyTpuClient&&) = delete;
100   PyTpuClient& operator=(const PyTpuClient&) = delete;
101   PyTpuClient& operator=(PyTpuClient&&) = delete;
102 
103   Status TransferToInfeed(const LiteralSlice& literal, int device_id);
104   StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_id);
105 
106   virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
107       int num_replicas, int num_partitions) const;
108 
device_count()109   int device_count() const { return devices_.size(); }
local_device_count()110   int local_device_count() const { return local_devices_.size(); }
devices()111   const std::vector<std::shared_ptr<PjRtDevice>>& devices() { return devices_; }
local_devices()112   const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() {
113     return local_devices_;
114   }
id_to_device()115   const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const {
116     return id_to_device_;
117   }
task_id()118   int task_id() const { return task_id_; }
platform_name()119   const std::string& platform_name() const { return platform_name_; }
120 
ChooseCompactLayoutForShape(Shape subshape)121   StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) {
122     return Unimplemented("ChooseCompactLayoutForShape not implemented.");
123   }
124 
125   // Returns a bad status containing `caller_name` if `device_id` doesn't
126   // correspond to a valid device at the POD-slice boundary.
127   Status CheckDeviceId(int device_id, absl::string_view caller_name);
128 
driver()129   tpu_driver::TpuDriver* driver() { return driver_.get(); }
130 
GetThreadPool()131   tensorflow::thread::ThreadPool* GetThreadPool() { return pool_.get(); }
132 
133  protected:
134   std::string platform_name_;
135   std::unique_ptr<tpu_driver::TpuDriver> driver_;
136 
137   // Includes all devices, including non-local devices on multi-host platforms.
138   std::vector<std::shared_ptr<PjRtDevice>> devices_;
139   // Maps Device::id() to the corresponding Device. Includes all devices.
140   std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_;
141   // Local devices indexed by local device ordinal.
142   std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
143   int task_id_;
144 
145   // A thread pool for scheduling core executions in parallel.
146   std::unique_ptr<tensorflow::thread::ThreadPool> pool_;
147 };
148 
149 // Manages a buffer shared amongst multiple users. Buffers are asynchronously
150 // deallocated after the last use.
151 struct TpuSharedBuffer final {
152  public:
TpuSharedBufferfinal153   TpuSharedBuffer(tpu_driver::TpuDriver* driver,
154                   std::unique_ptr<tpu_driver::BufferHandle> handle,
155                   std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use,
156                   std::shared_ptr<PjRtDevice> src_device)
157       : driver(driver),
158         device(std::move(src_device)),
159         handle(std::move(handle)),
160         wait_for_use(std::move(wait_for_use)) {}
161 
~TpuSharedBufferfinal162   ~TpuSharedBuffer() {
163     std::vector<tpu_driver::Event*> events;
164     for (const auto& e : wait_for_use) {
165       events.push_back(e.get());
166     }
167     driver->Deallocate(std::move(handle), events);
168   }
169 
170   tpu_driver::TpuDriver* const driver;
171   const std::shared_ptr<PjRtDevice> device;
172 
173   std::unique_ptr<tpu_driver::BufferHandle> handle;
174   std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
175 };
176 
177 // Holds a reference from Python to one or more device buffers.
178 // A PyTpuBuffer can be either valid or invalid. An invalid buffer is one that
179 // has never been initialized, or a buffer that has been deleted (e.g., by
180 // calling Delete). We allow PyTpuBuffer objects to outlive the underlying
181 // device buffers so we can decouple buffer lifetimes from the corresponding
182 // Python references if needed.
183 // Thread-safe.
184 class PyTpuBuffer {
185  public:
186   // `tuple_shape` can be at most a one-level tuple combining non-tuple leaves.
187   static StatusOr<std::unique_ptr<PyTpuBuffer>> FromLiterals(
188       std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
189       std::shared_ptr<void> leaves_reference,
190       std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
191 
192   // Supports nested tuple creation.
193   static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple(
194       absl::Span<PyTpuBuffer* const> buffers,
195       std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
196 
197   PyTpuBuffer() = delete;
198   PyTpuBuffer(Shape on_host_shape,
199               std::shared_ptr<TpuSharedBuffer> device_buffer,
200               std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
201               std::shared_ptr<PyTpuClient> client);
202 
203   PyTpuBuffer(const PyTpuBuffer&) = delete;
204   PyTpuBuffer(PyTpuBuffer&&) = delete;
205   PyTpuBuffer& operator=(const PyTpuBuffer&) = delete;
206   PyTpuBuffer& operator=(PyTpuBuffer&&) = delete;
207 
on_host_shape()208   const Shape& on_host_shape() const { return on_host_shape_; }
device()209   std::shared_ptr<PjRtDevice> device() const { return device_; }
platform_name()210   const std::string& platform_name() const { return client_->platform_name(); }
client()211   std::shared_ptr<PyTpuClient> client() const { return client_; }
212 
213   // Returns the buffer's value as a tuple DAG of Python arrays. If the value
214   // has previously been prefetched to the host, then returns the prefetched
215   // version, otherwise copies the buffer to the host. Blocks until the
216   // value is ready.
217   StatusOr<std::shared_ptr<Literal>> ToLiteral();
218 
219   // Initiates a copy of the buffer to the host. Does not block waiting for
220   // the transfer to complete. The value can be retrieved by a later call to
221   // ToLiteral().
222   Status CopyToHostAsync();
223 
224   // Returns the associated device buffer. Returns a nullptr if the buffer is
225   // invalid.
226   std::shared_ptr<TpuSharedBuffer> DeviceBuffer() const;
227 
228   // Deletes the device memory associated with this buffer, leaving it in an
229   // invalid state.
230   void Delete();
231 
232   // Destructures a tuple-valued PyTpuBuffer into its constituent elements.
233   StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> DestructureTuple();
234 
235   // Copies the buffer to target device `dst_device` and returns a PyTpuBuffer
236   // object holding the context to the target device buffer.
237   StatusOr<std::unique_ptr<PyTpuBuffer>> CopyToDevice(
238       std::shared_ptr<PjRtDevice> dst_device);
239 
240   // Blocks the host until the buffer's value has been computed and is ready for
241   // immediate use on the device. Useful in particular for timing benchmarks.
242   Status BlockHostUntilReady();
243 
244   // Allocates uninitialized buffers on device `device_id`. If `shape` is a
245   // tuple, the returned buffer corresponds to the root tuple buffer.
246   static StatusOr<std::unique_ptr<PyTpuBuffer>> AllocateBuffer(
247       const Shape& shape, std::shared_ptr<PyTpuClient> client,
248       std::shared_ptr<PjRtDevice> device);
249 
250  private:
251   // Initializes a just allocated device buffer. The returned event will be
252   // placed into the buffer's `wait_for_use` list.
253   using BufferInitializer = std::function<std::shared_ptr<tpu_driver::Event>(
254       tpu_driver::BufferHandle*)>;
255   // Allocates and optionally initializes a non-tuple buffer on the device.
256   static StatusOr<std::unique_ptr<PyTpuBuffer>> CreateBuffer(
257       const Shape& non_tuple_shape,
258       absl::optional<BufferInitializer> initializer,
259       std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
260 
261   const std::shared_ptr<PyTpuClient> client_;
262   const Shape on_host_shape_;
263   const std::shared_ptr<PjRtDevice> device_;
264 
265   // If this is a tuple, `device_buffer_` stores the tuple buffer and
266   // `child_buffers_` stores the child buffers; else, `device_buffer_` stores
267   // the data content and `child_buffers_` is empty.
268   mutable absl::Mutex mu_;
269   std::shared_ptr<TpuSharedBuffer> device_buffer_ TF_GUARDED_BY(mu_);
270   std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers_
271       TF_GUARDED_BY(mu_);
272   // The cached value of the buffer on the host, produced either from a call to
273   // CopyToHost or from a call to ToLiteral. Once a value has been fetched to
274   // the host, it persists Delete() is called or the PyTpuBuffer is destroyed.
275   struct HostValue {
276     absl::Mutex mutex;
277     absl::Notification ready;
278     int pending_ops;
279     // status and value are valid for reading only after `ready` has been
280     // notified.
281     Status status;
282     std::shared_ptr<Literal> value;
283   };
284   std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
285 };
286 
287 // Represents a compiled computation that can be executed given handles to
288 // device-allocated literals. Wraps an XLA LocalExecutable.
289 class PyTpuExecutable {
290  public:
291   static StatusOr<std::unique_ptr<PyTpuExecutable>> Compile(
292       const XlaComputation& computation,
293       absl::optional<std::vector<Shape>> argument_layouts,
294       const ExecutableBuildOptions* build_options,
295       std::shared_ptr<PyTpuClient> client, bool tuple_arguments);
296 
297   PyTpuExecutable(
298       std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
299       DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
300       xla::Shape result_shape, bool tuple_arguments);
~PyTpuExecutable()301   virtual ~PyTpuExecutable() {
302     for (auto it = executables_.begin(); it != executables_.end(); ++it) {
303       client_->driver()->UnloadProgram(std::move(it->second), {});
304     }
305   }
306 
307   PyTpuExecutable(const PyTpuExecutable&) = delete;
308   PyTpuExecutable(PyTpuExecutable&&) = delete;
309   PyTpuExecutable& operator=(const PyTpuExecutable&) = delete;
310   PyTpuExecutable& operator=(PyTpuExecutable&&) = delete;
311 
client()312   std::shared_ptr<PyTpuClient> client() const { return client_; }
313 
num_replicas()314   int num_replicas() const { return device_assignment_.replica_count(); }
num_partitions()315   int num_partitions() const { return device_assignment_.computation_count(); }
316 
SizeOfGeneratedCodeInBytes()317   int64 SizeOfGeneratedCodeInBytes() const {
318     CHECK_GE(executables_.size(), 1);
319     return executables_.begin()->second->size_in_bytes();
320   }
321 
device_assignment()322   const DeviceAssignment& device_assignment() const {
323     return device_assignment_;
324   }
325 
local_logical_device_ids()326   const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
327     return local_logical_device_ids_;
328   }
329 
local_devices()330   const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
331     return local_devices_;
332   }
333 
334   // TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait
335   // inside for computation to finish. Coordinate with JAX code change to see if
336   // we can make both Execute and ExecutePerReplica non-blocking.
337   StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
338       absl::Span<PyTpuBuffer* const> argument_handles);
339 
340   // Execute on local devices. Takes a sequence of argument lists (one argument
341   // list per local device) and returns a tuple of results (one result per local
342   // device). The number of argument lists must be equal to the local device
343   // count.
344   StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
345   ExecuteOnLocalDevices(
346       absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
347 
348   StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
349   ExecuteShardedOnLocalDevices(
350       absl::Span<const std::vector<PyTpuBuffer*>> args);
351 
Delete()352   void Delete() { executables_.clear(); }
353 
354  private:
355   struct ExecuteResult {
356     std::unique_ptr<PyTpuBuffer> buffer;
357     std::shared_ptr<tpu_driver::Event> on_execute_finished;
358   };
359 
360   ExecuteResult ExecuteHelper(
361       absl::Span<const std::vector<PyTpuBuffer*>> all_core_arguments,
362       absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
363       int partition, const RunId& run_id);
364 
365   std::shared_ptr<PyTpuClient> const client_;
366   std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_;
367   const DeviceAssignment device_assignment_;
368   const bool tuple_arguments_;
369 
370   // The replica and partition indices of device_assignment_ to be run by this
371   // client. On single-host platforms without partitioning, this is all replicas
372   // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
373   // on multi-host platforms.
374   // If there are 4 replicas and 2 partitions on a single host platform, size of
375   // local_logical_device_ids_ is 4*2 = 8.
376   std::vector<std::pair<int, int>> local_logical_device_ids_;
377 
378   // local_devices_[i] is the Device to which local_logical_device_ids_[i] is
379   // assigned.
380   // shared_ptrs instead of unique_ptrs to play well with the Python bindings
381   // (see xla.cc).
382   std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
383 
384   xla::Shape result_shape_;
385 };
386 
387 }  // namespace xla
388 
389 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
390