1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/util/device_name_utils.h" 21 22 namespace tensorflow { 23 24 // Wraps a device with a new name, delegating work to the wrapped device. 25 // 26 // This class is used to wrap local devices when using clusterspec propagation 27 // where the name of a particular device may change in the context of a given 28 // session. 29 class RenamedDevice : public Device { 30 public: 31 static std::unique_ptr<Device> NewRenamedDevice(const string& new_base, 32 Device* underlying, 33 bool owns_underlying, 34 bool isolate_session_state); 35 36 ~RenamedDevice() override; 37 38 // Below are virtual methods defined on DeviceBase RequiresRecordingAccessedTensors()39 bool RequiresRecordingAccessedTensors() const override { 40 return underlying_->RequiresRecordingAccessedTensors(); 41 } 42 UnderlyingDevice()43 const DeviceBase* UnderlyingDevice() const override { 44 return underlying_->UnderlyingDevice(); 45 } UnderlyingDevice()46 DeviceBase* UnderlyingDevice() override { 47 return underlying_->UnderlyingDevice(); 48 } 49 tensorflow_cpu_worker_threads()50 const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { 51 return underlying_->tensorflow_cpu_worker_threads(); 52 } 53 tensorflow_gpu_device_info()54 const GpuDeviceInfo* tensorflow_gpu_device_info() const override { 55 return underlying_->tensorflow_gpu_device_info(); 56 } 57 GetAllocator(AllocatorAttributes attr)58 Allocator* GetAllocator(AllocatorAttributes attr) override { 59 return underlying_->GetAllocator(attr); 60 } 61 GetScopedAllocator(AllocatorAttributes attr,int64 step_id)62 Allocator* GetScopedAllocator(AllocatorAttributes attr, 63 int64 step_id) override { 64 return underlying_->GetScopedAllocator(attr, step_id); 65 } 66 GetScopedAllocatorMgr()67 ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { 68 return underlying_->GetScopedAllocatorMgr(); 69 } 70 eigen_cpu_device()71 const Eigen::ThreadPoolDevice* eigen_cpu_device() override { 72 return underlying_->eigen_cpu_device(); 73 } 74 75 #ifdef TENSORFLOW_USE_SYCL eigen_sycl_device()76 const Eigen::SyclDevice* eigen_sycl_device() const override { 77 return underlying_->eigen_sycl_device(); 78 } 79 #endif 80 MakeGpuDevice()81 PerOpGpuDevice* MakeGpuDevice() override { 82 return underlying_->MakeGpuDevice(); 83 } 84 ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)85 Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, 86 DeviceContext* dc, 87 Allocator* allocator) override { 88 return underlying_->ReinitializeGpuDevice(context, device, dc, allocator); 89 } 90 MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)91 Status MakeTensorFromProto(const TensorProto& tensor_proto, 92 const AllocatorAttributes alloc_attrs, 93 Tensor* tensor) override { 94 return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor); 95 } 96 97 // Below are virtual methods defined on Device 98 Compute(OpKernel * op_kernel,OpKernelContext * context)99 void Compute(OpKernel* op_kernel, OpKernelContext* context) override { 100 underlying_->Compute(op_kernel, context); 101 } 102 ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)103 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 104 AsyncOpKernel::DoneCallback done) override { 105 underlying_->ComputeAsync(op_kernel, context, std::move(done)); 106 } 107 ConsumeListOfAccessedTensors(DeviceContext * context,const TensorReferenceVector & tensors)108 void ConsumeListOfAccessedTensors( 109 DeviceContext* context, const TensorReferenceVector& tensors) override { 110 underlying_->ConsumeListOfAccessedTensors(context, tensors); 111 } 112 Sync()113 Status Sync() override { return underlying_->Sync(); } 114 MaybeRewriteGraph(std::unique_ptr<Graph> * graph)115 Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override { 116 return underlying_->MaybeRewriteGraph(graph); 117 } 118 FillContextMap(const Graph * graph,DeviceContextMap * device_context_map)119 Status FillContextMap(const Graph* graph, 120 DeviceContextMap* device_context_map) override { 121 return underlying_->FillContextMap(graph, device_context_map); 122 } 123 124 // Returns the resource manager associated w/ this device. resource_manager()125 ResourceMgr* resource_manager() override { 126 if (isolate_session_state_) { 127 return Device::resource_manager(); 128 } else { 129 return underlying_->resource_manager(); 130 } 131 } 132 133 private: 134 RenamedDevice(Device* underlying, const DeviceAttributes& attributes, 135 bool owns_underlying, bool isolate_session_state); 136 Device* const underlying_; 137 const bool owns_underlying_; 138 const bool isolate_session_state_; 139 }; 140 141 } // namespace tensorflow 142 143 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 144