• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Device is a something that can perform computations as part of a
17 // model.  Devices can be local (runs computation on this machine), or
18 // remote (contacts a device local to another machine using an RPC to
19 // do the work).  Devices are registered in a DeviceSet, which is also
20 // responsible for the Device <-> id mapping.
21 //
22 // Device names
23 // * Every Device should have a unique name with the format:
24 //     /job:___/replica:___/task:___/(gpu|cpu):___
25 //   An example name would be "/job:train/replica:0/task:3/device:GPU:2".
26 // * Task numbers are within the specified replica, so there are as
27 //   many "task zeros" as replicas.
28 
29 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
30 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
31 
32 #include <memory>
33 #include <string>
34 
35 #include "tensorflow/core/framework/allocator.h"
36 #include "tensorflow/core/framework/control_flow.h"
37 #include "tensorflow/core/framework/device_attributes.pb.h"
38 #include "tensorflow/core/framework/device_base.h"
39 #include "tensorflow/core/framework/graph.pb.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/op_segment.h"
42 #include "tensorflow/core/framework/resource_mgr.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/types.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/macros.h"
49 #include "tensorflow/core/platform/types.h"
50 #include "tensorflow/core/util/device_name_utils.h"
51 
52 namespace tensorflow {
53 
54 class Device : public DeviceBase {
55  public:
56   // Callback type that takes a Status and returns void.
57   typedef std::function<void(const Status&)> DoneCallback;
58 
59   Device(Env* env, const DeviceAttributes& device_attributes);
60   ~Device() override;
61 
62   // Full name of this device (see top comment).
name()63   const string& name() const override { return device_attributes_.name(); }
64 
65   // Parsed name of this device
parsed_name()66   const DeviceNameUtils::ParsedName& parsed_name() const {
67     return parsed_name_;
68   }
69 
70   // Describes what kind of device this is.  This is intended to be
71   // human-readable and not computer-parsed, except that two devices
72   // with the same device_type() are expected to perform similarly
73   // (both from a computation and communication perspective).
device_type()74   const string& device_type() const { return device_attributes_.device_type(); }
75 
76   // Returns an aggregation of device attributes.
attributes()77   const DeviceAttributes& attributes() const override {
78     return device_attributes_;
79   }
80 
81   // Performs the actual compute function.
82   //
83   // Subclasses may override this function if they wish to perform
84   // some initialization before each compute.
Compute(OpKernel * op_kernel,OpKernelContext * context)85   virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
86     op_kernel->Compute(context);
87   }
88 
89   // Asynchronous kernel's compute.
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)90   virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
91                             AsyncOpKernel::DoneCallback done) {
92     op_kernel->ComputeAsync(context, std::move(done));
93   }
94 
95   // Takes ownership of the references in tensors. If necessary, a
96   // device may override this method to keep a reference to the
97   // accessed tensors until the async computation has completed.
ConsumeListOfAccessedTensors(DeviceContext * context,const TensorReferenceVector & tensors)98   virtual void ConsumeListOfAccessedTensors(
99       DeviceContext* context, const TensorReferenceVector& tensors) {
100     for (const auto& ref : tensors) {
101       ref.Unref();
102     }
103   }
104 
105   // Blocks until all operations queued on the device at the time of
106   // the call have completed.  Returns any error pending on the device
107   // at completion.
108   virtual Status Sync() = 0;
109 
110   // Calls the given callback when all operations queued on the device at the
111   // time of the call have completed. The callback is passed any error pending
112   // on the device at completion.
113   // TODO(b/112409994): Consolidate these two APIs, removing the synchronous
114   // version.
115   virtual void Sync(const DoneCallback& done);
116 
117   // On session completion, the executor may call Device::Sync() depending on
118   // flag settings. Override this to return false for devices that don't allow
119   // such calls. Instead, these devices must use other mechanisms (such as
120   // num_deferred_ops) to ensure the device has finished processing necessary
121   // work at session completion. In addition, for these devices, RefreshStatus
122   // must be called at session completion to retrieve execution result status.
123   //
124   // Devices that override this function must also implement RefreshStatus.
AllowsSyncOnCompletion()125   virtual bool AllowsSyncOnCompletion() const { return true; }
126 
127   // This is used in conjunction with AllowsSyncOnCompletion to allow the
128   // executor to get execution result status at session completion.
129   //
130   // For supported devices, this call returns the underlying device stream's
131   // current status in a non-blocking way, without using blocking calls such as
132   // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
133   // status is also updated with the retrieved stream status.
RefreshStatus()134   virtual Status RefreshStatus() {
135     return errors::Unimplemented(
136         "RefreshStatus is not supported on this device.");
137   }
138 
139   // Optionally modify the device's GraphDef before execution.
140   //
141   // This method should be considered experimental and is supplied to enable
142   // prototyping of TensorFlow device implementations that need to modify
143   // the GraphDef before execution.
144   //
145   // 'graph' supplies the partition of the graph assigned to this
146   // device.
MaybeRewriteGraph(std::unique_ptr<Graph> *)147   virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) {
148     return Status::OK();
149   }
150 
151   // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr
152   // if the device does not support contexts. Returns an error status if any
153   // error occurred while trying to create a context, otherwise OK.
154   //
155   // The caller takes ownership of one reference on the output DeviceContext*,
156   // and should call Unref().
TryGetDeviceContext(DeviceContext ** out_context)157   virtual Status TryGetDeviceContext(DeviceContext** out_context) {
158     *out_context = nullptr;
159     return Status::OK();
160   }
161 
162   // Returns the op segment of this device.  The caller can reuse op
163   // kernels registered for the same session running on this device.
op_segment()164   OpSegment* op_segment() { return &op_seg_; }
165 
166   // Returns the resource manager associated w/ this device.
resource_manager()167   virtual ResourceMgr* resource_manager() { return rmgr_; }
168 
169   // Summarizes the status of this Device, for debugging.
DebugString()170   string DebugString() const { return device_attributes_.DebugString(); }
171 
172   // Assembles the parameter components into a complete DeviceAttributes value.
173   static DeviceAttributes BuildDeviceAttributes(
174       const string& name, DeviceType device, Bytes memory_limit,
175       const DeviceLocality& locality, const string& physical_device_desc);
176 
BuildDeviceAttributes(const string & name,DeviceType device,Bytes memory_limit,const DeviceLocality & locality)177   static DeviceAttributes BuildDeviceAttributes(
178       const string& name, DeviceType device, Bytes memory_limit,
179       const DeviceLocality& locality) {
180     // Pass in an empty string as physical device name.
181     return BuildDeviceAttributes(name, device, memory_limit, locality, "");
182   }
183 
184   // Clears the resource manager associated with this device.
ClearResourceMgr()185   void ClearResourceMgr() { rmgr_->Clear(); }
186 
IsLocal()187   virtual bool IsLocal() const { return true; }
188 
189  protected:
DeleteResourceMgr()190   void DeleteResourceMgr() {
191     delete rmgr_;
192     rmgr_ = nullptr;
193   }
194 
195  private:
196   const DeviceAttributes device_attributes_;
197   DeviceNameUtils::ParsedName parsed_name_;
198 
199   // op_seg_ maps session handle and op name to OpKernel objects.
200   OpSegment op_seg_;
201 
202   // Resources associated w/ this device. E.g., shared variables, etc.
203   ResourceMgr* rmgr_ = nullptr;
204 
205   TF_DISALLOW_COPY_AND_ASSIGN(Device);
206 };
207 
208 }  // namespace tensorflow
209 
210 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
211