1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
17
18 #include "absl/container/inlined_vector.h"
19 #include "absl/types/optional.h"
20 #include "absl/types/span.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/c/eager/abstract_tensor_handle.h"
23 #include "tensorflow/c/eager/immediate_execution_operation.h"
24 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
27 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
28 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/op_def.pb.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 #include "tensorflow/core/util/managed_stack_trace.h"
34
35 namespace tensorflow {
36
37 class EagerOperation : public ImmediateExecutionOperation {
38 public:
EagerOperation(tensorflow::EagerContext * ctx)39 explicit EagerOperation(tensorflow::EagerContext* ctx)
40 : ImmediateExecutionOperation(kEager), ctx_(*ctx) {}
~EagerOperation()41 ~EagerOperation() override {
42 for (ImmediateExecutionTensorHandle* h : inputs_) {
43 h->Unref();
44 }
45 }
46
Release()47 void Release() override { delete this; }
48
49 void Clear() override;
Reset(const char * op,const char * raw_device_name)50 Status Reset(const char* op, const char* raw_device_name) override {
51 return Reset(op, raw_device_name, false, nullptr);
52 }
53
Name()54 const string& Name() const override { return attrs_.op_name(); }
55
DeviceName()56 const string& DeviceName() const override { return device_name_; }
57
GetContext()58 ImmediateExecutionContext* GetContext() const override { return &ctx_; }
59
GetDeviceParsedName()60 const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
61 return device_parsed_name_;
62 }
63
64 // Replaces the previous device name with the given one (see
65 // AbstractOperation::SetDeviceName for more details).
66 //
67 // This also resets the internal device pointer, unless the given name refers
68 // to a known custom device, in which case the internal device pointer is
69 // updated to that device.
70 Status SetDeviceName(const char* name) override;
71
SetDevice(VariantDevice device)72 void SetDevice(VariantDevice device) {
73 device_ = device;
74 device_name_ = absl::visit(
75 [](auto* device) { return device == nullptr ? "" : device->name(); },
76 device);
77 DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
78 // TODO(b/154133594): Due to intricacies of external logic, we can not
79 // set this do device_name_ as it would be natural, because we need the
80 // next call to SetDeviceName to reset the device pointer.
81 last_set_device_name_ = "\177"; // DEL (an invalid value)
82 }
83
84 Status SetAttrValue(const char* attr_name, const AttrValue& value);
85
86 Status AddInput(AbstractTensorHandle* input) override;
87 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
88 Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
89 absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
HasCustomDeviceInput()90 bool HasCustomDeviceInput() const override {
91 return custom_device_tensor_handles_count_ > 0;
92 }
93 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
94 int* num_retvals) override;
OpDef()95 const tensorflow::OpDef* OpDef() const override { return op_def_; };
96
97 Status SetAttrString(const char* attr_name, const char* data,
98 size_t length) override;
99 Status SetAttrInt(const char* attr_name, int64_t value) override;
100 Status SetAttrFloat(const char* attr_name, float value) override;
101 Status SetAttrBool(const char* attr_name, bool value) override;
102 Status SetAttrType(const char* attr_name, DataType value) override;
103 Status SetAttrShape(const char* attr_name, const int64_t* dims,
104 const int num_dims) override;
105 Status SetAttrFunction(const char* attr_name,
106 const AbstractOperation* value) override;
107 Status SetAttrFunctionName(const char* attr_name, const char* data,
108 size_t length) override;
109 Status SetAttrTensor(const char* attr_name,
110 AbstractTensorInterface* tensor) override;
111 Status SetAttrStringList(const char* attr_name, const void* const* values,
112 const size_t* lengths, int num_values) override;
113 Status SetAttrFloatList(const char* attr_name, const float* values,
114 int num_values) override;
115 Status SetAttrIntList(const char* attr_name, const int64_t* values,
116 int num_values) override;
117 Status SetAttrTypeList(const char* attr_name, const DataType* values,
118 int num_values) override;
119 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
120 int num_values) override;
121 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
122 const int* num_dims, int num_values) override;
123 Status SetAttrFunctionList(
124 const char* attr_name,
125 absl::Span<const AbstractOperation*> values) override;
126
127 Status InputLength(const char* input_name, int* length) override;
128 Status OutputLength(const char* output_name, int* length) override;
129
130 const AbstractOpAttrs* GetOpAttrs() const override;
131 void AddAttrs(const AbstractOpAttrs* op_attrs) override;
132
SetStackTrace(ManagedStackTrace stack_trace)133 void SetStackTrace(ManagedStackTrace stack_trace) override {
134 stack_trace_ = stack_trace;
135 }
136
GetStackTrace()137 absl::optional<ManagedStackTrace> GetStackTrace() override {
138 return stack_trace_;
139 }
140
141 Status Reset(const char* op, const char* device_name, bool remote,
142 EagerExecutor* executor,
143 const absl::optional<EagerRemoteFunctionParams>
144 remote_func_params = absl::nullopt);
145
is_function()146 bool is_function() const { return is_function_; }
colocation_exempt()147 bool colocation_exempt() const { return colocation_exempt_; }
148
EagerContext()149 tensorflow::EagerContext& EagerContext() const { return ctx_; }
150
MutableAttrs()151 AttrBuilder* MutableAttrs() { return &attrs_; }
Attrs()152 const AttrBuilder& Attrs() const { return attrs_; }
153
154 // TensorHandleInputs and MutableTensorHandleInputs first check that all
155 // inputs are TensorHandles, i.e. that there are no custom device inputs. They
156 // return a bad status otherwise.
157 Status TensorHandleInputs(
158 const absl::InlinedVector<TensorHandle*, 4>** inputs) const;
159 Status MutableTensorHandleInputs(
160 absl::InlinedVector<TensorHandle*, 4>** inputs);
161
Inputs()162 const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs()
163 const {
164 return inputs_;
165 }
166
167 void UpdateInput(int i, TensorHandle* h);
168
169 // Like TensorHandles, EagerOperations may be placed either on a virtual
170 // CustomDevice or on a physical Device.
Device()171 VariantDevice Device() const { return device_; }
172
173 // Indicates whether the op is assigned to a device that is local to the
174 // current host.
175 bool IsLocal() const;
176
GetCancellationManager()177 CancellationManager* GetCancellationManager() const {
178 return cancellation_manager_;
179 }
SetCancellationManager(CancellationManager * cancellation_manager)180 void SetCancellationManager(
181 CancellationManager* cancellation_manager) override {
182 cancellation_manager_ = cancellation_manager;
183 }
184
Executor()185 EagerExecutor& Executor() { return *executor_; }
186
187 string DebugString() const;
188
remote_func_params()189 const absl::optional<EagerRemoteFunctionParams>& remote_func_params() const {
190 return remote_func_params_;
191 }
192
193 // Op name recorded for memory debugging purpose.
op_name()194 const char* op_name() const { return op_name_; }
195
196 // For LLVM style RTTI.
classof(const AbstractOperation * ptr)197 static bool classof(const AbstractOperation* ptr) {
198 return ptr->getKind() == kEager;
199 }
200
201 private:
202 void AddTensorHandle(ImmediateExecutionTensorHandle* h);
203
204 const tensorflow::OpDef* GetOpDef(Status* status);
205
ClearInferenceState()206 void ClearInferenceState() {
207 op_def_ = nullptr;
208 inference_arg_idx_ = 0;
209 inference_attrs_.clear_no_resize();
210 }
211
212 Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle);
213 Status InferInputListAttrs(int num_inputs);
214
215 void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
216 const DataType dtype, int num_inputs);
217 void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
218 const std::vector<DataType>& dtypes);
219
220 tensorflow::EagerContext& ctx_;
221 const char* op_name_ = nullptr;
222 AttrBuilder attrs_;
223 const AttrTypeMap* attr_types_;
224
225 // The number of custom device TensorHandle inputs. These inputs need to be
226 // processed by CustomDeviceOpHandler first.
227 int custom_device_tensor_handles_count_ = 0;
228 absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
229
230 // The last device name given to SetDeviceName.
231 // This is used to avoid having to re-process the same device in repeated
232 // calls to SetDeviceName.
233 string last_set_device_name_;
234
235 // The operation's device name.
236 // This contains the named passed to SetDeviceName until device_ is set,
237 // at which point it contains the device_ name.
238 string device_name_;
239
240 // The parsed device name.
241 // This will always contain the result of
242 // DeviceNameUtils::ParseFullName(device_name_).
243 DeviceNameUtils::ParsedName device_parsed_name_;
244
245 // The operation's device.
246 // This is set by the execution device placement logic, and should conform
247 // with the contents of device_name_. Once it is set, the device_name_ is
248 // updated accordingly.
249 VariantDevice device_;
250
251 absl::optional<ManagedStackTrace> stack_trace_;
252 bool is_function_; // Conceptually const, but can't be because of Reset
253 bool colocation_exempt_;
254 CancellationManager* cancellation_manager_ = nullptr; // Not owned.
255 EagerExecutor* executor_; // Not owned.
256 absl::optional<EagerRemoteFunctionParams> remote_func_params_;
257
258 // Inference information
259 const tensorflow::OpDef* op_def_; // op definition from protobuf
260 int inference_arg_idx_; // arg definition index for the next input to be
261 // added
262 gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far
263 };
264
UpdateInput(int i,TensorHandle * h)265 inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
266 ImmediateExecutionTensorHandle** slot = &inputs_[i];
267 ImmediateExecutionTensorHandle* existing = *slot;
268 if (existing != h) {
269 h->Ref();
270 existing->Unref();
271 *slot = h; // Update inputs_[i] to h
272 }
273 }
274
OperationFromInterface(ImmediateExecutionOperation * operation)275 inline EagerOperation* OperationFromInterface(
276 ImmediateExecutionOperation* operation) {
277 return down_cast<EagerOperation*>(operation);
278 }
279
OperationFromInterface(const ImmediateExecutionOperation * operation)280 inline const EagerOperation* OperationFromInterface(
281 const ImmediateExecutionOperation* operation) {
282 return down_cast<const EagerOperation*>(operation);
283 }
284
285 } // namespace tensorflow
286
287 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
288