1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "tensorflow/core/common_runtime/device/device_event_mgr.h" 25 #include "tensorflow/core/common_runtime/device/device_id.h" 26 #include "tensorflow/core/common_runtime/device/device_id_manager.h" 27 #include "tensorflow/core/common_runtime/device/device_id_utils.h" 28 #include "tensorflow/core/common_runtime/local_device.h" 29 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" 30 #include "tensorflow/core/common_runtime/shared_counter.h" 31 #include "tensorflow/core/framework/allocator.h" 32 #include "tensorflow/core/framework/device_base.h" 33 #include "tensorflow/core/framework/op_kernel.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/lib/core/status.h" 36 #include "tensorflow/core/lib/gtl/inlined_vector.h" 37 #include "tensorflow/core/platform/mutex.h" 38 #include "tensorflow/core/platform/stream_executor.h" 39 #include "tensorflow/core/platform/types.h" 40 #include "tensorflow/core/public/session_options.h" 41 42 namespace tensorflow { 43 44 class PluggableDevice : public LocalDevice { 45 public: 46 PluggableDevice(const SessionOptions& options, const std::string& name, 47 const string& device_type, const string& platform_name, 48 Bytes memory_limit, const DeviceLocality& locality, 49 TfDeviceId tf_device_id, 50 const std::string& physical_device_desc, 51 Allocator* device_allocator, Allocator* cpu_allocator, 52 bool sync_every_op); 53 54 ~PluggableDevice() override; 55 56 // Initialize the device and return the status of initialization. 57 Status Init(const SessionOptions& options); 58 59 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 60 AsyncOpKernel::DoneCallback done) override; 61 62 void Compute(OpKernel* op_kernel, OpKernelContext* context) override; 63 64 Status Sync() override; 65 66 Allocator* GetAllocator(AllocatorAttributes attr) override; 67 68 Status MakeTensorFromProto(const TensorProto& tensor_proto, 69 const AllocatorAttributes alloc_attrs, 70 Tensor* tensor) override; 71 72 void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, 73 const DeviceContext* device_context, 74 StatusCallback done) override; 75 76 // The executor that provides control for the pluggable device; executor()77 se::StreamExecutor* executor() const { return executor_; } 78 79 private: 80 Allocator* device_allocator_; 81 Allocator* cpu_allocator_; 82 83 se::StreamExecutor* executor_ = nullptr; 84 struct StreamGroup { 85 se::Stream* compute = nullptr; 86 se::Stream* host_to_device = nullptr; 87 se::Stream* device_to_host = nullptr; 88 gtl::InlinedVector<se::Stream*, 4> device_to_device; 89 }; 90 91 class StreamGroupFactory; 92 93 StreamGroup* stream_; 94 PluggableDeviceContext* device_context_; 95 // TODO(penpornk): Investigate renaming `GpuDeviceInfo` to `DeviceInfo`. 96 DeviceBase::AcceleratorDeviceInfo* pluggable_device_info_ = nullptr; 97 TfDeviceId tf_device_id_; 98 const string platform_name_; 99 const bool sync_every_op_ = false; 100 EventMgr* em_ = nullptr; 101 std::unique_ptr<thread::ThreadPool> thread_pool_; 102 bool force_gpu_compatible_ = false; 103 std::string ComputeOpKernelDebugString(const OpKernel& op_kernel, 104 const int stream_id); 105 106 // This method returns an initialization status, in addition to 107 // calling the "done" StatusCallback, if there is a failure to 108 // allocate memory or if the tensor "from" is not DMA-copyable. 109 // If there is no error prior to enqueueing the copy, an OK status 110 // is returned. 111 Status MaybeCopyTensorToPluggableDevice( 112 const AllocatorAttributes& alloc_attrs, const Tensor& from, Tensor* to, 113 StatusCallback done); 114 }; 115 116 } // namespace tensorflow 117 118 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_ 119