1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ 17 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ 18 19 #include <memory> 20 21 #include "absl/synchronization/mutex.h" 22 #include "tensorflow/compiler/jit/xla_tensor.h" 23 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 24 #include "tensorflow/compiler/xla/client/global_data.h" 25 #include "tensorflow/compiler/xla/client/local_client.h" 26 #include "tensorflow/core/framework/allocator.h" 27 #include "tensorflow/core/framework/device_base.h" 28 #include "tensorflow/core/lib/core/status.h" 29 30 namespace tensorflow { 31 32 // The allocator used for Tensors assigned to the XLA device. The allocator 33 // ignores the alignment and size of the request and always returns a new, 34 // empty, XlaTensor. 35 class XlaDeviceAllocator : public Allocator { 36 public: 37 XlaDeviceAllocator(se::StreamExecutor* stream_executor); 38 ~XlaDeviceAllocator() override; 39 40 string Name() override; 41 42 void* AllocateRaw(size_t alignment, size_t num_bytes) override; 43 void DeallocateRaw(void* ptr) override; 44 absl::optional<AllocatorStats> GetStats() override; 45 46 private: 47 // The stream executor of the device. 48 se::StreamExecutor* stream_executor_; 49 }; 50 51 // Helper class for managing data transfers between host and XLA devices. 52 class XlaDeviceContext : public DeviceContext { 53 public: 54 explicit XlaDeviceContext( 55 std::shared_ptr<se::Stream> compute_stream, 56 std::shared_ptr<se::Stream> host_to_device_stream, 57 std::shared_ptr<se::Stream> device_to_host_stream, 58 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams, 59 xla::LocalClient* client, 60 XlaCompiler::ShapeRepresentationFn shape_representation_fn, 61 thread::ThreadPool* thread_pool); 62 63 void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, 64 Tensor* device_tensor, 65 StatusCallback done) const override; 66 void CopyDeviceTensorToCPU(const Tensor* device_tensor, 67 absl::string_view tensor_name, Device* device, 68 Tensor* cpu_tensor, StatusCallback done) override; 69 void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, 70 Tensor* output_tensor, 71 StatusCallback done) const override; 72 client()73 xla::LocalClient* client() const { return client_; } stream()74 se::Stream* stream() const { return stream_.get(); } host_to_device_stream()75 se::Stream* host_to_device_stream() const { 76 return host_to_device_stream_.get(); 77 } device_to_device_stream(int index)78 se::Stream* device_to_device_stream(int index) const { 79 return device_to_device_streams_.at(index).get(); 80 } transfer_manager()81 xla::TransferManager* transfer_manager() const { return transfer_manager_; } shape_representation_fn()82 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { 83 return shape_representation_fn_; 84 } 85 86 // Returns a device-to-device stream, in round-robin fashion. 87 se::Stream* GetDeviceToDeviceStream(); 88 89 private: UseMultipleStreams()90 bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } 91 92 // The main compute stream of the device, used to synchronize the transfer 93 // streams if they are set. 94 std::shared_ptr<se::Stream> stream_; 95 // The stream to use for transferring data from host to device. Can be 96 // idential to stream_, but must not be nullptr. 97 std::shared_ptr<se::Stream> host_to_device_stream_; 98 // The stream to use for transferring data from device to host. Can be 99 // idential to stream_. If nullptr, borrow a stream from backend for each 100 // transfer request to support out-of-order requests. 101 std::shared_ptr<se::Stream> device_to_host_stream_; 102 // Streams to use for transferring data directly between different devices, 103 // e.g., over NVLINK. 104 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_; 105 106 // For the underlying memory allocator and XLA's TransferManager. 107 xla::LocalClient* client_; 108 // Transfer manager, for marshalling data to and from the device. 109 xla::TransferManager* transfer_manager_; 110 111 XlaCompiler::ShapeRepresentationFn shape_representation_fn_; 112 113 // Thread pool used for running closures 114 thread::ThreadPool* thread_pool_; 115 116 absl::Mutex mu_; 117 int next_stream_ GUARDED_BY(mu_) = 0; 118 }; 119 120 } // namespace tensorflow 121 122 #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ 123