• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/eager/context.h"
29 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/framework/rendezvous.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/gtl/stl_util.h"
38 #include "tensorflow/core/platform/fingerprint.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41 #include "tensorflow/core/public/session_options.h"
42 #include "tensorflow/core/public/version.h"
43 
44 struct TF_Operation;
45 
46 namespace tensorflow {
47 
48 // This struct is isomorphic to TF_Output, but we cannot use the latter here due
49 // to layering concerns (TF_Output is defined at the C API layer).
50 struct OutputGraphNode {
51   TF_Operation* oper;
52   int index;  // The index of the output within oper.
53 };
54 
55 // Associates a Tensor and a Device, used in the eager runtime. Internal version
56 // of the TFE_TensorHandle struct and the python EagerTensor class
57 // (unrelated to python TensorHandle).
58 class TensorHandle : public core::RefCounted {
59  public:
60   TensorHandle(const Tensor& t, Device* d, Device* op_device,
61                EagerContext* ctx);
62   TensorHandle(uint64 node_id, Device* d, Device* op_device,
63                Device* resource_device, DataType dtype, EagerContext* ctx);
64 
65   // Remote tensor handle constructor.
66   TensorHandle(int64 op_id, int32 output_num, uint64 remote_shape_node_id,
67                DataType dtype, std::function<void()> call_on_destroy, Device* d,
68                Device* op_device, Device* resource_device, EagerContext* ctx);
69 
70   // Symbolic tensor constructor.
71   TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype);
72 
~TensorHandle()73   ~TensorHandle() override {
74     VLOG(1) << "Deleting internal TensorHandle " << this;
75     if (call_on_destroy_) {
76       call_on_destroy_();
77     }
78   }
79 
80   Status Tensor(const tensorflow::Tensor** t);
81 
82   Status TensorValue(tensorflow::TensorValue* t);
83 
device()84   tensorflow::Device* device() const { return device_; }
op_device()85   tensorflow::Device* op_device() const { return op_device_; }
resource_device()86   tensorflow::Device* resource_device() const { return resource_device_; }
87 
88   Status TensorAndDevice(const tensorflow::Tensor** tensor,
89                          tensorflow::Device** device,
90                          tensorflow::Device** op_device);
91 
92   Status Shape(tensorflow::TensorShape* shape);
93 
94   Status NumDims(int* num_dims);
95   Status Dim(int dim_index, int64* dim);
96   Status NumElements(int64* num_elements);
97 
98   // Return the op_id and output num if the handle refers to a remote tensor.
99   Status RemoteAddress(int64* op_id, int32* output_num);
100 
101   // Note that this can be called at most once, and only on non-ready handles,
102   // and makes them ready.
103   void SetTensor(const tensorflow::Tensor& tensor);
104 
105   Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
106                       TensorHandle** output);
107 
108   // Warning: can return nullptr for CPU tensors.
Context()109   EagerContext* Context() {
110     mutex_lock ml(ctx_mutex_);
111     return ctx_;
112   }
113 
114   // dtype for the handle. It must be the same as t.dtype() once the handle is
115   // ready.
116   const DataType dtype;
117 
SetRemoteShape(std::unique_ptr<TensorShape> remote_shape)118   void SetRemoteShape(std::unique_ptr<TensorShape> remote_shape) {
119     remote_shape_ = std::move(remote_shape);
120   }
121 
OnHostCPU()122   bool OnHostCPU() {
123     mutex_lock ml(ctx_mutex_);
124     return device_ == nullptr ||
125            (ctx_ == nullptr || ctx_->HostCPU() == device_);
126   }
127 
128   bool IsRemote();
129 
getSymbolicTensor()130   OutputGraphNode* getSymbolicTensor() const { return symbolic_tensor.get(); }
131 
132   string DebugString() const;
133 
134  private:
135   // If the contents of the Tensor pointed to by this handle is yet to be
136   // computed by a EagerNode, this function will block till that computation is
137   // done and the handle is "ready".
138   Status WaitReady();
139   Status WaitForNode(uint64 node_id, bool return_if_is_ready);
140 
141   bool IsReady();
142 
143   // Id for the EagerNode that will compute the value pointed to by this handle.
144   // If the value is 0, the handle is already ready, but not vice-versa.
145   const uint64 node_id_;
146 
147   tensorflow::Tensor tensor_;
148 
149   // TODO(ashankar): device_ == nullptr iff local CPU
150   // This was expedient, but perhaps worth revisiting ('device_' should always
151   // be a valid pointer?)
152   // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
153   // provided with the appropriate TFE_Context.
154   //
155   // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
156   // TFE_TensorHandle does not outlive the TFE_Context from which it came?
157   tensorflow::Device* const device_;
158 
159   // Device in which the op producing this tensor was executed. Equals to
160   // device_ for constant tensors.
161   // Can be nullptr if the op producing this tensor was a function executed
162   // with function library runtime or if this tensor represents a symbolic
163   // tensor.
164   tensorflow::Device* const op_device_;
165 
166   // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
167   // backing the resource. Else resource_device_ is nullptr.
168   tensorflow::Device* const resource_device_;
169 
170   // IDs required when this class is representing a remote tensor handle.
171   const int64 remote_op_id_;
172   const int32 remote_output_num_;
173   std::unique_ptr<TensorShape> remote_shape_;
174   const uint64 remote_shape_node_id_;
175 
176   // A callback that is executed when the class is destroyed.
177   //
178   // This is currently used for remote tensor handles.
179   const std::function<void()> call_on_destroy_;
180 
181   mutex ctx_mutex_;
182 
183   // `ctx` is only guaranteed to be set if the handle is not "ready". This is
184   // typically true when the handle was produced during async execution.
185   // `ctx` object is not owned and should outlive this handle.
186   EagerContext* ctx_ GUARDED_BY(ctx_mutex_);
187   bool is_ready_ GUARDED_BY(ctx_mutex_);
188 
189   // When non-NULL, this tensor handle instance represents a symbolic tensor
190   // (corresponding to a graph node), whose concrete value is to be produced by
191   // executing that graph node.
192   std::unique_ptr<OutputGraphNode> symbolic_tensor;
193 };
194 
195 // If tensor's dtype is DT_RESOURCE, returns the device backing the resource.
196 // Else, returns nullptr.
197 Device* GetResourceDevice(const Tensor& t, EagerContext* ctx);
198 
199 }  // namespace tensorflow
200 
201 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
202