1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/lib/core/threadpool_interface.h" 21 #include "tensorflow/core/util/device_name_utils.h" 22 23 namespace tensorflow { 24 25 // Wraps a device with a new name, delegating work to the wrapped device. 26 // 27 // This class is used to wrap local devices when using clusterspec propagation 28 // where the name of a particular device may change in the context of a given 29 // session. 30 class RenamedDevice : public Device { 31 public: 32 static std::unique_ptr<Device> NewRenamedDevice( 33 const string& new_base, Device* underlying, bool owns_underlying, 34 bool isolate_session_state, 35 thread::ThreadPoolInterface* underlying_threadpool = nullptr); 36 37 ~RenamedDevice() override; 38 UnderlyingDevice()39 const DeviceBase* UnderlyingDevice() const override { 40 return underlying_device_->UnderlyingDevice(); 41 } UnderlyingDevice()42 DeviceBase* UnderlyingDevice() override { 43 return underlying_device_->UnderlyingDevice(); 44 } 45 tensorflow_cpu_worker_threads()46 const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { 47 if (underlying_threadpool_) { 48 return Device::tensorflow_cpu_worker_threads(); 49 } 50 return underlying_device_->tensorflow_cpu_worker_threads(); 51 } 52 tensorflow_gpu_device_info()53 const GpuDeviceInfo* tensorflow_gpu_device_info() const override { 54 return underlying_device_->tensorflow_gpu_device_info(); 55 } 56 GetAllocator(AllocatorAttributes attr)57 Allocator* GetAllocator(AllocatorAttributes attr) override { 58 return underlying_device_->GetAllocator(attr); 59 } 60 GetScopedAllocator(AllocatorAttributes attr,int64 step_id)61 Allocator* GetScopedAllocator(AllocatorAttributes attr, 62 int64 step_id) override { 63 return underlying_device_->GetScopedAllocator(attr, step_id); 64 } 65 GetScopedAllocatorMgr()66 ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { 67 return underlying_device_->GetScopedAllocatorMgr(); 68 } 69 eigen_cpu_device()70 const Eigen::ThreadPoolDevice* eigen_cpu_device() override { 71 // Use the underlying threadpool only if the underlying device supports 72 // eigen_cpu_device. 73 if (underlying_threadpool_ && underlying_device_->has_eigen_cpu_device()) { 74 return Device::eigen_cpu_device(); 75 } 76 return underlying_device_->eigen_cpu_device(); 77 } 78 tensorflow_device_thread_pool()79 thread::ThreadPool* tensorflow_device_thread_pool() override { 80 // Use the underlying threadpool instead of tensorflow_device_thread_pool 81 // of the underlying device only if tensorflow_device_thread_pool is defined 82 // for the underlying device. 83 if (underlying_threadpool_ && 84 underlying_device_->tensorflow_device_thread_pool() != nullptr) { 85 return Device::tensorflow_device_thread_pool(); 86 } 87 return underlying_device_->tensorflow_device_thread_pool(); 88 } 89 has_eigen_cpu_device()90 bool has_eigen_cpu_device() const override { 91 return underlying_device_->has_eigen_cpu_device(); 92 } 93 94 MakeGpuDevice()95 PerOpGpuDevice* MakeGpuDevice() override { 96 return underlying_device_->MakeGpuDevice(); 97 } 98 ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)99 Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, 100 DeviceContext* dc, 101 Allocator* allocator) override { 102 return underlying_device_->ReinitializeGpuDevice(context, device, dc, 103 allocator); 104 } 105 MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)106 Status MakeTensorFromProto(const TensorProto& tensor_proto, 107 const AllocatorAttributes alloc_attrs, 108 Tensor* tensor) override { 109 return underlying_device_->MakeTensorFromProto(tensor_proto, alloc_attrs, 110 tensor); 111 } 112 CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)113 void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, 114 const DeviceContext* device_context, 115 StatusCallback done) override { 116 underlying_device_->CopyTensorInSameDevice(input_tensor, output_tensor, 117 device_context, std::move(done)); 118 } 119 120 // Below are virtual methods defined on Device 121 Compute(OpKernel * op_kernel,OpKernelContext * context)122 void Compute(OpKernel* op_kernel, OpKernelContext* context) override { 123 underlying_device_->Compute(op_kernel, context); 124 } 125 ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)126 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 127 AsyncOpKernel::DoneCallback done) override { 128 underlying_device_->ComputeAsync(op_kernel, context, std::move(done)); 129 } 130 Sync()131 Status Sync() override { return underlying_device_->Sync(); } 132 MaybeRewriteGraph(std::unique_ptr<Graph> * graph)133 Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override { 134 return underlying_device_->MaybeRewriteGraph(graph); 135 } 136 TryGetDeviceContext(DeviceContext ** out_context)137 Status TryGetDeviceContext(DeviceContext** out_context) override { 138 return underlying_device_->TryGetDeviceContext(out_context); 139 } 140 141 // Returns the resource manager associated w/ this device. resource_manager()142 ResourceMgr* resource_manager() override { 143 if (isolate_session_state_) { 144 return Device::resource_manager(); 145 } else { 146 return underlying_device_->resource_manager(); 147 } 148 } 149 IsLocal()150 bool IsLocal() const override { return underlying_device_->IsLocal(); } 151 IsRemoteCallAllowed()152 bool IsRemoteCallAllowed() const override { 153 return underlying_device_->IsRemoteCallAllowed(); 154 } 155 156 private: 157 RenamedDevice(Device* underlying, const DeviceAttributes& attributes, 158 bool owns_underlying, bool isolate_session_state, 159 thread::ThreadPoolInterface* underlying_threadpool); 160 Device* const underlying_device_; 161 const bool owns_underlying_device_; 162 const bool isolate_session_state_; 163 164 std::unique_ptr<thread::ThreadPool> underlying_threadpool_; 165 // eigen_worker_threads_ is stored here so that we can pass the pointer 166 // of eigen_worker_threads_.workers to the parent class. 167 DeviceBase::CpuWorkerThreads eigen_worker_threads_; 168 }; 169 170 } // namespace tensorflow 171 172 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 173