• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/lib/core/threadpool_interface.h"
21 #include "tensorflow/core/util/device_name_utils.h"
22 
23 namespace tensorflow {
24 
25 // Wraps a device with a new name, delegating work to the wrapped device.
26 //
27 // This class is used to wrap local devices when using clusterspec propagation
28 // where the name of a particular device may change in the context of a given
29 // session.
30 class RenamedDevice : public Device {
31  public:
32   static std::unique_ptr<Device> NewRenamedDevice(
33       const string& new_base, Device* underlying, bool owns_underlying,
34       bool isolate_session_state,
35       thread::ThreadPoolInterface* underlying_threadpool = nullptr);
36 
37   ~RenamedDevice() override;
38 
UnderlyingDevice()39   const DeviceBase* UnderlyingDevice() const override {
40     return underlying_device_->UnderlyingDevice();
41   }
UnderlyingDevice()42   DeviceBase* UnderlyingDevice() override {
43     return underlying_device_->UnderlyingDevice();
44   }
45 
tensorflow_cpu_worker_threads()46   const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
47     if (underlying_threadpool_) {
48       return Device::tensorflow_cpu_worker_threads();
49     }
50     return underlying_device_->tensorflow_cpu_worker_threads();
51   }
52 
tensorflow_gpu_device_info()53   const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
54     return underlying_device_->tensorflow_gpu_device_info();
55   }
56 
GetAllocator(AllocatorAttributes attr)57   Allocator* GetAllocator(AllocatorAttributes attr) override {
58     return underlying_device_->GetAllocator(attr);
59   }
60 
GetScopedAllocator(AllocatorAttributes attr,int64 step_id)61   Allocator* GetScopedAllocator(AllocatorAttributes attr,
62                                 int64 step_id) override {
63     return underlying_device_->GetScopedAllocator(attr, step_id);
64   }
65 
GetScopedAllocatorMgr()66   ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
67     return underlying_device_->GetScopedAllocatorMgr();
68   }
69 
eigen_cpu_device()70   const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
71     // Use the underlying threadpool only if the underlying device supports
72     // eigen_cpu_device.
73     if (underlying_threadpool_ && underlying_device_->has_eigen_cpu_device()) {
74       return Device::eigen_cpu_device();
75     }
76     return underlying_device_->eigen_cpu_device();
77   }
78 
tensorflow_device_thread_pool()79   thread::ThreadPool* tensorflow_device_thread_pool() override {
80     // Use the underlying threadpool instead of tensorflow_device_thread_pool
81     // of the underlying device only if tensorflow_device_thread_pool is defined
82     // for the underlying device.
83     if (underlying_threadpool_ &&
84         underlying_device_->tensorflow_device_thread_pool() != nullptr) {
85       return Device::tensorflow_device_thread_pool();
86     }
87     return underlying_device_->tensorflow_device_thread_pool();
88   }
89 
has_eigen_cpu_device()90   bool has_eigen_cpu_device() const override {
91     return underlying_device_->has_eigen_cpu_device();
92   }
93 
94 
MakeGpuDevice()95   PerOpGpuDevice* MakeGpuDevice() override {
96     return underlying_device_->MakeGpuDevice();
97   }
98 
ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)99   Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
100                                DeviceContext* dc,
101                                Allocator* allocator) override {
102     return underlying_device_->ReinitializeGpuDevice(context, device, dc,
103                                                      allocator);
104   }
105 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)106   Status MakeTensorFromProto(const TensorProto& tensor_proto,
107                              const AllocatorAttributes alloc_attrs,
108                              Tensor* tensor) override {
109     return underlying_device_->MakeTensorFromProto(tensor_proto, alloc_attrs,
110                                                    tensor);
111   }
112 
CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)113   void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
114                               const DeviceContext* device_context,
115                               StatusCallback done) override {
116     underlying_device_->CopyTensorInSameDevice(input_tensor, output_tensor,
117                                                device_context, std::move(done));
118   }
119 
120   // Below are virtual methods defined on Device
121 
Compute(OpKernel * op_kernel,OpKernelContext * context)122   void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
123     underlying_device_->Compute(op_kernel, context);
124   }
125 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)126   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
127                     AsyncOpKernel::DoneCallback done) override {
128     underlying_device_->ComputeAsync(op_kernel, context, std::move(done));
129   }
130 
Sync()131   Status Sync() override { return underlying_device_->Sync(); }
132 
MaybeRewriteGraph(std::unique_ptr<Graph> * graph)133   Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override {
134     return underlying_device_->MaybeRewriteGraph(graph);
135   }
136 
TryGetDeviceContext(DeviceContext ** out_context)137   Status TryGetDeviceContext(DeviceContext** out_context) override {
138     return underlying_device_->TryGetDeviceContext(out_context);
139   }
140 
141   // Returns the resource manager associated w/ this device.
resource_manager()142   ResourceMgr* resource_manager() override {
143     if (isolate_session_state_) {
144       return Device::resource_manager();
145     } else {
146       return underlying_device_->resource_manager();
147     }
148   }
149 
IsLocal()150   bool IsLocal() const override { return underlying_device_->IsLocal(); }
151 
IsRemoteCallAllowed()152   bool IsRemoteCallAllowed() const override {
153     return underlying_device_->IsRemoteCallAllowed();
154   }
155 
156  private:
157   RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
158                 bool owns_underlying, bool isolate_session_state,
159                 thread::ThreadPoolInterface* underlying_threadpool);
160   Device* const underlying_device_;
161   const bool owns_underlying_device_;
162   const bool isolate_session_state_;
163 
164   std::unique_ptr<thread::ThreadPool> underlying_threadpool_;
165   // eigen_worker_threads_ is stored here so that we can pass the pointer
166   // of eigen_worker_threads_.workers to the parent class.
167   DeviceBase::CpuWorkerThreads eigen_worker_threads_;
168 };
169 
170 }  // namespace tensorflow
171 
172 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
173