1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/strings/string_view.h" 24 #include "absl/synchronization/mutex.h" 25 #include "absl/synchronization/notification.h" 26 #include "absl/types/span.h" 27 #include "tensorflow/compiler/xla/client/executable_build_options.h" 28 #include "tensorflow/compiler/xla/executable_run_options.h" 29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 30 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" 31 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" 32 #include "tensorflow/compiler/xla/shape.h" 33 #include "tensorflow/compiler/xla/status.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/util.h" 36 #include "tensorflow/core/platform/casts.h" 37 #include "tensorflow/core/platform/threadpool.h" 38 39 namespace xla { 40 41 constexpr char kTpuPlatform[] = "tpu"; 42 43 class TpuDevice : public PjRtDevice { 44 public: 45 TpuDevice(int id, int task_id, const std::array<int, 3>& coords, 46 int core_on_chip); 47 coords()48 const std::array<int, 3>& coords() const { return coords_; } core_on_chip()49 int core_on_chip() const { return core_on_chip_; } 50 51 std::string DebugString() const override; 52 53 static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>> 54 GetTpuDevices(const tpu_driver::SystemInfo& system_info); 55 client()56 PjRtClient* client() const override { return nullptr; } 57 IsAddressable()58 bool IsAddressable() const override { return false; } 59 id()60 int id() const override { return id_; } 61 task_id()62 int task_id() const override { return task_id_; } 63 local_hardware_id()64 int local_hardware_id() const override { return -1; } 65 device_kind()66 absl::string_view device_kind() const override { return device_kind_; } 67 TransferToInfeed(const LiteralSlice & literal)68 Status TransferToInfeed(const LiteralSlice& literal) override { 69 return Unimplemented("Infeed not yet implemented via this API"); 70 } 71 TransferFromOutfeed(MutableBorrowingLiteral literal)72 Status TransferFromOutfeed(MutableBorrowingLiteral literal) override { 73 return Unimplemented("Outfeed not yet implemented via this API"); 74 } 75 76 private: 77 const int id_; 78 const int task_id_; 79 const std::array<int, 3> coords_; 80 const std::string device_kind_ = "Cloud TPU"; 81 // Index of the core of the same chip. 82 int core_on_chip_; 83 }; 84 85 // Encapsulates the state of Python session with XLA. 86 class PyTpuClient { 87 public: 88 // Initializes a local XLA client for `platform_name`. Returns an error if no 89 // such platform exists, or if the platform has no visible devices. 90 static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker); 91 92 explicit PyTpuClient(std::string platform_name, 93 std::unique_ptr<tpu_driver::TpuDriver> driver, 94 std::vector<std::shared_ptr<PjRtDevice>> devices, 95 int task_id); 96 virtual ~PyTpuClient() = default; 97 98 PyTpuClient(const PyTpuClient&) = delete; 99 PyTpuClient(PyTpuClient&&) = delete; 100 PyTpuClient& operator=(const PyTpuClient&) = delete; 101 PyTpuClient& operator=(PyTpuClient&&) = delete; 102 103 Status TransferToInfeed(const LiteralSlice& literal, int device_id); 104 StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_id); 105 106 virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( 107 int num_replicas, int num_partitions) const; 108 device_count()109 int device_count() const { return devices_.size(); } local_device_count()110 int local_device_count() const { return local_devices_.size(); } devices()111 const std::vector<std::shared_ptr<PjRtDevice>>& devices() { return devices_; } local_devices()112 const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() { 113 return local_devices_; 114 } id_to_device()115 const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const { 116 return id_to_device_; 117 } task_id()118 int task_id() const { return task_id_; } platform_name()119 const std::string& platform_name() const { return platform_name_; } 120 ChooseCompactLayoutForShape(Shape subshape)121 StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) { 122 return Unimplemented("ChooseCompactLayoutForShape not implemented."); 123 } 124 125 // Returns a bad status containing `caller_name` if `device_id` doesn't 126 // correspond to a valid device at the POD-slice boundary. 127 Status CheckDeviceId(int device_id, absl::string_view caller_name); 128 driver()129 tpu_driver::TpuDriver* driver() { return driver_.get(); } 130 GetThreadPool()131 tensorflow::thread::ThreadPool* GetThreadPool() { return pool_.get(); } 132 133 protected: 134 std::string platform_name_; 135 std::unique_ptr<tpu_driver::TpuDriver> driver_; 136 137 // Includes all devices, including non-local devices on multi-host platforms. 138 std::vector<std::shared_ptr<PjRtDevice>> devices_; 139 // Maps Device::id() to the corresponding Device. Includes all devices. 140 std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_; 141 // Local devices indexed by local device ordinal. 142 std::vector<std::shared_ptr<PjRtDevice>> local_devices_; 143 int task_id_; 144 145 // A thread pool for scheduling core executions in parallel. 146 std::unique_ptr<tensorflow::thread::ThreadPool> pool_; 147 }; 148 149 // Manages a buffer shared amongst multiple users. Buffers are asynchronously 150 // deallocated after the last use. 151 struct TpuSharedBuffer final { 152 public: TpuSharedBufferfinal153 TpuSharedBuffer(tpu_driver::TpuDriver* driver, 154 std::unique_ptr<tpu_driver::BufferHandle> handle, 155 std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use, 156 std::shared_ptr<PjRtDevice> src_device) 157 : driver(driver), 158 device(std::move(src_device)), 159 handle(std::move(handle)), 160 wait_for_use(std::move(wait_for_use)) {} 161 ~TpuSharedBufferfinal162 ~TpuSharedBuffer() { 163 std::vector<tpu_driver::Event*> events; 164 for (const auto& e : wait_for_use) { 165 events.push_back(e.get()); 166 } 167 driver->Deallocate(std::move(handle), events); 168 } 169 170 tpu_driver::TpuDriver* const driver; 171 const std::shared_ptr<PjRtDevice> device; 172 173 std::unique_ptr<tpu_driver::BufferHandle> handle; 174 std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use; 175 }; 176 177 // Holds a reference from Python to one or more device buffers. 178 // A PyTpuBuffer can be either valid or invalid. An invalid buffer is one that 179 // has never been initialized, or a buffer that has been deleted (e.g., by 180 // calling Delete). We allow PyTpuBuffer objects to outlive the underlying 181 // device buffers so we can decouple buffer lifetimes from the corresponding 182 // Python references if needed. 183 // Thread-safe. 184 class PyTpuBuffer { 185 public: 186 // `tuple_shape` can be at most a one-level tuple combining non-tuple leaves. 187 static StatusOr<std::unique_ptr<PyTpuBuffer>> FromLiterals( 188 std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape, 189 std::shared_ptr<void> leaves_reference, 190 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device); 191 192 // Supports nested tuple creation. 193 static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple( 194 absl::Span<PyTpuBuffer* const> buffers, 195 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device); 196 197 PyTpuBuffer() = delete; 198 PyTpuBuffer(Shape on_host_shape, 199 std::shared_ptr<TpuSharedBuffer> device_buffer, 200 std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers, 201 std::shared_ptr<PyTpuClient> client); 202 203 PyTpuBuffer(const PyTpuBuffer&) = delete; 204 PyTpuBuffer(PyTpuBuffer&&) = delete; 205 PyTpuBuffer& operator=(const PyTpuBuffer&) = delete; 206 PyTpuBuffer& operator=(PyTpuBuffer&&) = delete; 207 on_host_shape()208 const Shape& on_host_shape() const { return on_host_shape_; } device()209 std::shared_ptr<PjRtDevice> device() const { return device_; } platform_name()210 const std::string& platform_name() const { return client_->platform_name(); } client()211 std::shared_ptr<PyTpuClient> client() const { return client_; } 212 213 // Returns the buffer's value as a tuple DAG of Python arrays. If the value 214 // has previously been prefetched to the host, then returns the prefetched 215 // version, otherwise copies the buffer to the host. Blocks until the 216 // value is ready. 217 StatusOr<std::shared_ptr<Literal>> ToLiteral(); 218 219 // Initiates a copy of the buffer to the host. Does not block waiting for 220 // the transfer to complete. The value can be retrieved by a later call to 221 // ToLiteral(). 222 Status CopyToHostAsync(); 223 224 // Returns the associated device buffer. Returns a nullptr if the buffer is 225 // invalid. 226 std::shared_ptr<TpuSharedBuffer> DeviceBuffer() const; 227 228 // Deletes the device memory associated with this buffer, leaving it in an 229 // invalid state. 230 void Delete(); 231 232 // Destructures a tuple-valued PyTpuBuffer into its constituent elements. 233 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> DestructureTuple(); 234 235 // Copies the buffer to target device `dst_device` and returns a PyTpuBuffer 236 // object holding the context to the target device buffer. 237 StatusOr<std::unique_ptr<PyTpuBuffer>> CopyToDevice( 238 std::shared_ptr<PjRtDevice> dst_device); 239 240 // Blocks the host until the buffer's value has been computed and is ready for 241 // immediate use on the device. Useful in particular for timing benchmarks. 242 Status BlockHostUntilReady(); 243 244 // Allocates uninitialized buffers on device `device_id`. If `shape` is a 245 // tuple, the returned buffer corresponds to the root tuple buffer. 246 static StatusOr<std::unique_ptr<PyTpuBuffer>> AllocateBuffer( 247 const Shape& shape, std::shared_ptr<PyTpuClient> client, 248 std::shared_ptr<PjRtDevice> device); 249 250 private: 251 // Initializes a just allocated device buffer. The returned event will be 252 // placed into the buffer's `wait_for_use` list. 253 using BufferInitializer = std::function<std::shared_ptr<tpu_driver::Event>( 254 tpu_driver::BufferHandle*)>; 255 // Allocates and optionally initializes a non-tuple buffer on the device. 256 static StatusOr<std::unique_ptr<PyTpuBuffer>> CreateBuffer( 257 const Shape& non_tuple_shape, 258 absl::optional<BufferInitializer> initializer, 259 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device); 260 261 const std::shared_ptr<PyTpuClient> client_; 262 const Shape on_host_shape_; 263 const std::shared_ptr<PjRtDevice> device_; 264 265 // If this is a tuple, `device_buffer_` stores the tuple buffer and 266 // `child_buffers_` stores the child buffers; else, `device_buffer_` stores 267 // the data content and `child_buffers_` is empty. 268 mutable absl::Mutex mu_; 269 std::shared_ptr<TpuSharedBuffer> device_buffer_ TF_GUARDED_BY(mu_); 270 std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers_ 271 TF_GUARDED_BY(mu_); 272 // The cached value of the buffer on the host, produced either from a call to 273 // CopyToHost or from a call to ToLiteral. Once a value has been fetched to 274 // the host, it persists Delete() is called or the PyTpuBuffer is destroyed. 275 struct HostValue { 276 absl::Mutex mutex; 277 absl::Notification ready; 278 int pending_ops; 279 // status and value are valid for reading only after `ready` has been 280 // notified. 281 Status status; 282 std::shared_ptr<Literal> value; 283 }; 284 std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_); 285 }; 286 287 // Represents a compiled computation that can be executed given handles to 288 // device-allocated literals. Wraps an XLA LocalExecutable. 289 class PyTpuExecutable { 290 public: 291 static StatusOr<std::unique_ptr<PyTpuExecutable>> Compile( 292 const XlaComputation& computation, 293 absl::optional<std::vector<Shape>> argument_layouts, 294 const ExecutableBuildOptions* build_options, 295 std::shared_ptr<PyTpuClient> client, bool tuple_arguments); 296 297 PyTpuExecutable( 298 std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program, 299 DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client, 300 xla::Shape result_shape, bool tuple_arguments); ~PyTpuExecutable()301 virtual ~PyTpuExecutable() { 302 for (auto it = executables_.begin(); it != executables_.end(); ++it) { 303 client_->driver()->UnloadProgram(std::move(it->second), {}); 304 } 305 } 306 307 PyTpuExecutable(const PyTpuExecutable&) = delete; 308 PyTpuExecutable(PyTpuExecutable&&) = delete; 309 PyTpuExecutable& operator=(const PyTpuExecutable&) = delete; 310 PyTpuExecutable& operator=(PyTpuExecutable&&) = delete; 311 client()312 std::shared_ptr<PyTpuClient> client() const { return client_; } 313 num_replicas()314 int num_replicas() const { return device_assignment_.replica_count(); } num_partitions()315 int num_partitions() const { return device_assignment_.computation_count(); } 316 SizeOfGeneratedCodeInBytes()317 int64 SizeOfGeneratedCodeInBytes() const { 318 CHECK_GE(executables_.size(), 1); 319 return executables_.begin()->second->size_in_bytes(); 320 } 321 device_assignment()322 const DeviceAssignment& device_assignment() const { 323 return device_assignment_; 324 } 325 local_logical_device_ids()326 const std::vector<std::pair<int, int>>& local_logical_device_ids() const { 327 return local_logical_device_ids_; 328 } 329 local_devices()330 const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const { 331 return local_devices_; 332 } 333 334 // TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait 335 // inside for computation to finish. Coordinate with JAX code change to see if 336 // we can make both Execute and ExecutePerReplica non-blocking. 337 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute( 338 absl::Span<PyTpuBuffer* const> argument_handles); 339 340 // Execute on local devices. Takes a sequence of argument lists (one argument 341 // list per local device) and returns a tuple of results (one result per local 342 // device). The number of argument lists must be equal to the local device 343 // count. 344 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>> 345 ExecuteOnLocalDevices( 346 absl::Span<const std::vector<PyTpuBuffer*>> argument_handles); 347 348 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>> 349 ExecuteShardedOnLocalDevices( 350 absl::Span<const std::vector<PyTpuBuffer*>> args); 351 Delete()352 void Delete() { executables_.clear(); } 353 354 private: 355 struct ExecuteResult { 356 std::unique_ptr<PyTpuBuffer> buffer; 357 std::shared_ptr<tpu_driver::Event> on_execute_finished; 358 }; 359 360 ExecuteResult ExecuteHelper( 361 absl::Span<const std::vector<PyTpuBuffer*>> all_core_arguments, 362 absl::Span<PyTpuBuffer* const> this_core_arguments, int replica, 363 int partition, const RunId& run_id); 364 365 std::shared_ptr<PyTpuClient> const client_; 366 std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_; 367 const DeviceAssignment device_assignment_; 368 const bool tuple_arguments_; 369 370 // The replica and partition indices of device_assignment_ to be run by this 371 // client. On single-host platforms without partitioning, this is all replicas 372 // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case 373 // on multi-host platforms. 374 // If there are 4 replicas and 2 partitions on a single host platform, size of 375 // local_logical_device_ids_ is 4*2 = 8. 376 std::vector<std::pair<int, int>> local_logical_device_ids_; 377 378 // local_devices_[i] is the Device to which local_logical_device_ids_[i] is 379 // assigned. 380 // shared_ptrs instead of unique_ptrs to play well with the Python bindings 381 // (see xla.cc). 382 std::vector<std::shared_ptr<PjRtDevice>> local_devices_; 383 384 xla::Shape result_shape_; 385 }; 386 387 } // namespace xla 388 389 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_ 390