1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
17
18 #include "absl/container/inlined_vector.h"
19 #include "absl/types/optional.h"
20 #include "absl/types/span.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/c/eager/abstract_tensor_handle.h"
23 #include "tensorflow/c/eager/immediate_execution_operation.h"
24 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
27 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
28 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/op_def.pb.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 #include "tensorflow/core/util/managed_stack_trace.h"
34
35 namespace tensorflow {
36
37 class EagerOperation : public ImmediateExecutionOperation {
38 public:
EagerOperation(tensorflow::EagerContext * ctx)39 explicit EagerOperation(tensorflow::EagerContext* ctx)
40 : ImmediateExecutionOperation(kEager), ctx_(*ctx) {}
~EagerOperation()41 ~EagerOperation() override {
42 for (ImmediateExecutionTensorHandle* h : inputs_) {
43 h->Unref();
44 }
45 }
46
Release()47 void Release() override { delete this; }
48
49 void Clear() override;
Reset(const char * op,const char * raw_device_name)50 Status Reset(const char* op, const char* raw_device_name) override {
51 return Reset(op, raw_device_name, false, nullptr);
52 }
53
Name()54 const string& Name() const override { return attrs_.op_name(); }
55
DeviceName()56 const string& DeviceName() const override { return device_name_; }
57
GetContext()58 ImmediateExecutionContext* GetContext() const override { return &ctx_; }
59
GetDeviceParsedName()60 const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
61 return device_parsed_name_;
62 }
63
64 // Replaces the previous device name with the given one (see
65 // AbstractOperation::SetDeviceName for more details).
66 //
67 // This also resets the internal device pointer, unless the given name refers
68 // to a known custom device, in which case the internal device pointer is
69 // updated to that device.
70 Status SetDeviceName(const char* name) override;
71
SetDevice(VariantDevice device)72 void SetDevice(VariantDevice device) {
73 device_ = device;
74 device_name_ = absl::visit(
75 [](auto* device) { return device == nullptr ? "" : device->name(); },
76 device);
77 DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
78 // TODO(b/154133594): Due to intricacies of external logic, we can not
79 // set this do device_name_ as it would be natural, because we need the
80 // next call to SetDeviceName to reset the device pointer.
81 last_set_device_name_ = "\177"; // DEL (an invalid value)
82 }
83
84 Status SetAttrValue(const char* attr_name, const AttrValue& value);
85
86 Status AddInput(AbstractTensorHandle* input) override;
87 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
88 Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
89 absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
HasCustomDeviceInput()90 bool HasCustomDeviceInput() const override {
91 return custom_device_tensor_handles_count_ > 0;
92 }
93 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
94 int* num_retvals) override;
OpDef()95 const tensorflow::OpDef* OpDef() const override { return op_def_; };
96
97 Status SetAttrString(const char* attr_name, const char* data,
98 size_t length) override;
99 Status SetAttrInt(const char* attr_name, int64_t value) override;
100 Status SetAttrFloat(const char* attr_name, float value) override;
101 Status SetAttrBool(const char* attr_name, bool value) override;
102 Status SetAttrType(const char* attr_name, DataType value) override;
103 Status SetAttrShape(const char* attr_name, const int64_t* dims,
104 const int num_dims) override;
105 Status SetAttrFunction(const char* attr_name,
106 const AbstractOperation* value) override;
107 Status SetAttrFunctionName(const char* attr_name, const char* data,
108 size_t length) override;
109 Status SetAttrTensor(const char* attr_name,
110 AbstractTensorInterface* tensor) override;
111 Status SetAttrStringList(const char* attr_name, const void* const* values,
112 const size_t* lengths, int num_values) override;
113 Status SetAttrFloatList(const char* attr_name, const float* values,
114 int num_values) override;
115 Status SetAttrIntList(const char* attr_name, const int64_t* values,
116 int num_values) override;
117 Status SetAttrTypeList(const char* attr_name, const DataType* values,
118 int num_values) override;
119 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
120 int num_values) override;
121 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
122 const int* num_dims, int num_values) override;
123 Status SetAttrFunctionList(
124 const char* attr_name,
125 absl::Span<const AbstractOperation*> values) override;
126
127 Status InputLength(const char* input_name, int* length) override;
128 Status OutputLength(const char* output_name, int* length) override;
129
130 const AbstractOpAttrs* GetOpAttrs() const override;
131 void AddAttrs(const AbstractOpAttrs* op_attrs) override;
132
SetStackTrace(ManagedStackTrace stack_trace)133 void SetStackTrace(ManagedStackTrace stack_trace) override {
134 stack_trace_ = stack_trace;
135 }
136
GetStackTrace()137 absl::optional<ManagedStackTrace> GetStackTrace() override {
138 return stack_trace_;
139 }
140
141 Status Reset(const char* op, const char* device_name, bool remote,
142 EagerExecutor* executor,
143 const absl::optional<EagerRemoteFunctionParams>
144 remote_func_params = absl::nullopt);
145
is_function()146 bool is_function() const { return is_function_; }
colocation_exempt()147 bool colocation_exempt() const { return colocation_exempt_; }
148
EagerContext()149 tensorflow::EagerContext& EagerContext() const { return ctx_; }
150
MutableAttrs()151 AttrBuilder* MutableAttrs() { return &attrs_; }
Attrs()152 const AttrBuilder& Attrs() const { return attrs_; }
153
154 // TensorHandleInputs and MutableTensorHandleInputs first check that all
155 // inputs are TensorHandles, i.e. that there are no custom device inputs. They
156 // return a bad status otherwise.
157 Status TensorHandleInputs(
158 const absl::InlinedVector<TensorHandle*, 4>** inputs) const;
159 Status MutableTensorHandleInputs(
160 absl::InlinedVector<TensorHandle*, 4>** inputs);
161
Inputs()162 const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs()
163 const {
164 return inputs_;
165 }
166
167 void UpdateInput(int i, TensorHandle* h);
168
169 // Like TensorHandles, EagerOperations may be placed either on a virtual
170 // CustomDevice or on a physical Device.
Device()171 VariantDevice Device() const { return device_; }
172
173 // Indicates whether the op is assigned to a device that is local to the
174 // current host.
175 bool IsLocal() const;
176
GetCancellationManager()177 CancellationManager* GetCancellationManager() const {
178 return cancellation_manager_;
179 }
SetCancellationManager(CancellationManager * cancellation_manager)180 void SetCancellationManager(CancellationManager* cancellation_manager) {
181 cancellation_manager_ = cancellation_manager;
182 }
183
Executor()184 EagerExecutor& Executor() { return *executor_; }
185
186 string DebugString() const;
187
remote_func_params()188 const absl::optional<EagerRemoteFunctionParams>& remote_func_params() const {
189 return remote_func_params_;
190 }
191
192 // Op name recorded for memory debugging purpose.
op_name()193 const char* op_name() const { return op_name_; }
194
195 // For LLVM style RTTI.
classof(const AbstractOperation * ptr)196 static bool classof(const AbstractOperation* ptr) {
197 return ptr->getKind() == kEager;
198 }
199
200 private:
201 void AddTensorHandle(ImmediateExecutionTensorHandle* h);
202
203 const tensorflow::OpDef* GetOpDef(Status* status);
204
ClearInferenceState()205 void ClearInferenceState() {
206 op_def_ = nullptr;
207 inference_arg_idx_ = 0;
208 inference_attrs_.clear_no_resize();
209 }
210
211 Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle);
212 Status InferInputListAttrs(int num_inputs);
213
214 void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
215 const DataType dtype, int num_inputs);
216 void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
217 const std::vector<DataType>& dtypes);
218
219 tensorflow::EagerContext& ctx_;
220 const char* op_name_ = nullptr;
221 AttrBuilder attrs_;
222 const AttrTypeMap* attr_types_;
223
224 // The number of custom device TensorHandle inputs. These inputs need to be
225 // processed by CustomDeviceOpHandler first.
226 int custom_device_tensor_handles_count_ = 0;
227 absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
228
229 // The last device name given to SetDeviceName.
230 // This is used to avoid having to re-process the same device in repeated
231 // calls to SetDeviceName.
232 string last_set_device_name_;
233
234 // The operation's device name.
235 // This contains the named passed to SetDeviceName until device_ is set,
236 // at which point it contains the device_ name.
237 string device_name_;
238
239 // The parsed device name.
240 // This will always contain the result of
241 // DeviceNameUtils::ParseFullName(device_name_).
242 DeviceNameUtils::ParsedName device_parsed_name_;
243
244 // The operation's device.
245 // This is set by the execution device placement logic, and should conform
246 // with the contents of device_name_. Once it is set, the device_name_ is
247 // updated accordingly.
248 VariantDevice device_;
249
250 absl::optional<ManagedStackTrace> stack_trace_;
251 bool is_function_; // Conceptually const, but can't be because of Reset
252 bool colocation_exempt_;
253 CancellationManager* cancellation_manager_ = nullptr; // Not owned.
254 EagerExecutor* executor_; // Not owned.
255 absl::optional<EagerRemoteFunctionParams> remote_func_params_;
256
257 // Inference information
258 const tensorflow::OpDef* op_def_; // op definition from protobuf
259 int inference_arg_idx_; // arg definition index for the next input to be
260 // added
261 gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
262 };
263
UpdateInput(int i,TensorHandle * h)264 inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
265 ImmediateExecutionTensorHandle** slot = &inputs_[i];
266 ImmediateExecutionTensorHandle* existing = *slot;
267 if (existing != h) {
268 h->Ref();
269 existing->Unref();
270 *slot = h; // Update inputs_[i] to h
271 }
272 }
273
OperationFromInterface(ImmediateExecutionOperation * operation)274 inline EagerOperation* OperationFromInterface(
275 ImmediateExecutionOperation* operation) {
276 return down_cast<EagerOperation*>(operation);
277 }
278
OperationFromInterface(const ImmediateExecutionOperation * operation)279 inline const EagerOperation* OperationFromInterface(
280 const ImmediateExecutionOperation* operation) {
281 return down_cast<const EagerOperation*>(operation);
282 }
283
284 } // namespace tensorflow
285
286 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
287