• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
17 
18 #include "absl/container/inlined_vector.h"
19 #include "absl/types/optional.h"
20 #include "absl/types/span.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/c/eager/abstract_tensor_handle.h"
23 #include "tensorflow/c/eager/immediate_execution_operation.h"
24 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
27 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
28 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/op_def.pb.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 #include "tensorflow/core/util/managed_stack_trace.h"
34 
35 namespace tensorflow {
36 
37 class EagerOperation : public ImmediateExecutionOperation {
38  public:
EagerOperation(tensorflow::EagerContext * ctx)39   explicit EagerOperation(tensorflow::EagerContext* ctx)
40       : ImmediateExecutionOperation(kEager), ctx_(*ctx), is_function_(false) {}
~EagerOperation()41   ~EagerOperation() override {
42     for (ImmediateExecutionTensorHandle* h : inputs_) {
43       h->Unref();
44     }
45   }
46 
Release()47   void Release() override { delete this; }
48 
49   void Clear() override;
Reset(const char * op,const char * raw_device_name)50   Status Reset(const char* op, const char* raw_device_name) override {
51     return Reset(op, raw_device_name, false, nullptr);
52   }
53 
Name()54   const string& Name() const override { return attrs_.op_name(); }
55 
DeviceName()56   const string& DeviceName() const override { return device_name_; }
57 
GetContext()58   ImmediateExecutionContext* GetContext() const override { return &ctx_; }
59 
GetDeviceParsedName()60   const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
61     return device_parsed_name_;
62   }
63 
64   // Replaces the previous device name with the given one (see
65   // AbstractOperation::SetDeviceName for more details).
66   //
67   // This also resets the internal device pointer, unless the given name refers
68   // to a known custom device, in which case the internal device pointer is
69   // updated to that device.
70   Status SetDeviceName(const char* name) override;
71 
SetDevice(VariantDevice device)72   void SetDevice(VariantDevice device) {
73     device_ = device;
74     device_name_ = absl::visit(
75         [](auto* device) { return device == nullptr ? "" : device->name(); },
76         device);
77     DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
78     // TODO(b/154133594): Due to intricacies of external logic, we can not
79     // set this do device_name_ as it would be natural, because we need the
80     // next call to SetDeviceName to reset the device pointer.
81     last_set_device_name_ = "\177";  // DEL (an invalid value)
82   }
83 
84   Status SetAttrValue(const char* attr_name, const AttrValue& value);
85 
86   Status AddInput(AbstractTensorHandle* input) override;
87   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
88   Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
89   absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
HasCustomDeviceInput()90   bool HasCustomDeviceInput() const override {
91     return custom_device_tensor_handles_count_ > 0;
92   }
93   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
94                  int* num_retvals) override;
OpDef()95   const tensorflow::OpDef* OpDef() const override { return op_def_; };
96 
97   Status SetAttrString(const char* attr_name, const char* data,
98                        size_t length) override;
99   Status SetAttrInt(const char* attr_name, int64_t value) override;
100   Status SetAttrFloat(const char* attr_name, float value) override;
101   Status SetAttrBool(const char* attr_name, bool value) override;
102   Status SetAttrType(const char* attr_name, DataType value) override;
103   Status SetAttrShape(const char* attr_name, const int64_t* dims,
104                       const int num_dims) override;
105   Status SetAttrFunction(const char* attr_name,
106                          const AbstractOperation* value) override;
107   Status SetAttrFunctionName(const char* attr_name, const char* data,
108                              size_t length) override;
109   Status SetAttrTensor(const char* attr_name,
110                        AbstractTensorInterface* tensor) override;
111   Status SetAttrStringList(const char* attr_name, const void* const* values,
112                            const size_t* lengths, int num_values) override;
113   Status SetAttrFloatList(const char* attr_name, const float* values,
114                           int num_values) override;
115   Status SetAttrIntList(const char* attr_name, const int64_t* values,
116                         int num_values) override;
117   Status SetAttrTypeList(const char* attr_name, const DataType* values,
118                          int num_values) override;
119   Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
120                          int num_values) override;
121   Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
122                           const int* num_dims, int num_values) override;
123   Status SetAttrFunctionList(
124       const char* attr_name,
125       absl::Span<const AbstractOperation*> values) override;
126 
127   Status InputLength(const char* input_name, int* length) override;
128   Status OutputLength(const char* output_name, int* length) override;
129 
130   const AbstractOpAttrs* GetOpAttrs() const override;
131   void AddAttrs(const AbstractOpAttrs* op_attrs) override;
132 
SetStackTrace(ManagedStackTrace stack_trace)133   void SetStackTrace(ManagedStackTrace stack_trace) override {
134     stack_trace_ = stack_trace;
135   }
136 
GetStackTrace()137   absl::optional<ManagedStackTrace> GetStackTrace() override {
138     return stack_trace_;
139   }
140 
141   Status Reset(const char* op, const char* device_name, bool remote,
142                EagerExecutor* executor,
143                const absl::optional<EagerFunctionParams> remote_func_params =
144                    absl::nullopt);
145 
is_function()146   bool is_function() const { return is_function_; }
colocation_exempt()147   bool colocation_exempt() const { return colocation_exempt_; }
148 
EagerContext()149   tensorflow::EagerContext& EagerContext() const { return ctx_; }
150 
MutableAttrs()151   AttrBuilder* MutableAttrs() { return &attrs_; }
Attrs()152   const AttrBuilder& Attrs() const { return attrs_; }
153 
154   // TensorHandleInputs and MutableTensorHandleInputs first check that all
155   // inputs are TensorHandles, i.e. that there are no custom device inputs. They
156   // return a bad status otherwise.
157   Status TensorHandleInputs(
158       const absl::InlinedVector<TensorHandle*, 4>** inputs) const;
159   Status MutableTensorHandleInputs(
160       absl::InlinedVector<TensorHandle*, 4>** inputs);
161 
Inputs()162   const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs()
163       const {
164     return inputs_;
165   }
166 
167   void UpdateInput(int i, TensorHandle* h);
168 
169   // Like TensorHandles, EagerOperations may be placed either on a virtual
170   // CustomDevice or on a physical Device.
Device()171   VariantDevice Device() const { return device_; }
172 
173   // Indicates whether the op is assigned to a device that is local to the
174   // current host.
175   bool IsLocal() const;
176 
GetCancellationManager()177   CancellationManager* GetCancellationManager() const {
178     return cancellation_manager_;
179   }
SetCancellationManager(CancellationManager * cancellation_manager)180   void SetCancellationManager(
181       CancellationManager* cancellation_manager) override {
182     cancellation_manager_ = cancellation_manager;
183   }
184 
185   // Assign step_id value only if op has valid step id.
186   // When eager_func_params.has_value() returns true, we can directly overwrite
187   // its step id according to Op's step id (if not default value). However, when
188   // eager_func_params.has_value() returns false, we need to first create a new
189   // EagerFuncParams object for it before assigning step_id; otherwise,
190   // directly assigning step_id in this case leaves eager_func_params to be
191   // in a weird state where:
192   // (1) eager_func_params.has_value() returns false, but
193   // (2) eager_func_params->step_id.has_value() returns true.
SetStepId(int64_t step_id)194   void SetStepId(int64_t step_id) override {
195     assert(is_function());
196     if (step_id != EagerContext::kGlobalRendezvousId) {
197       if (eager_func_params_.has_value()) {
198         eager_func_params_->step_id = step_id;
199       } else {
200         eager_func_params_ = EagerFunctionParams{
201             kInvalidOpId, /*is_component_function=*/false, step_id};
202       }
203     } else {
204       LOG(WARNING) << "SetStepId() should not receive a gloabl rendezvous id.";
205     }
206   }
207 
Executor()208   EagerExecutor& Executor() { return *executor_; }
209 
210   string DebugString() const;
211 
eager_func_params()212   const absl::optional<EagerFunctionParams>& eager_func_params() const {
213     return eager_func_params_;
214   }
215 
216   // Op name recorded for memory debugging purpose.
op_name()217   const char* op_name() const { return op_name_; }
218 
219   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)220   static bool classof(const AbstractOperation* ptr) {
221     return ptr->getKind() == kEager;
222   }
223 
224  private:
225   void AddTensorHandle(ImmediateExecutionTensorHandle* h);
226 
227   const tensorflow::OpDef* GetOpDef(Status* status);
228 
ClearInferenceState()229   void ClearInferenceState() {
230     op_def_ = nullptr;
231     inference_arg_idx_ = 0;
232     inference_attrs_.clear_no_resize();
233   }
234 
235   Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle);
236   Status InferInputListAttrs(int num_inputs);
237 
238   void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
239                                      const DataType dtype, int num_inputs);
240   void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
241                                     const std::vector<DataType>& dtypes);
242 
243   tensorflow::EagerContext& ctx_;
244   const char* op_name_ = nullptr;
245   AttrBuilder attrs_;
246   const AttrTypeMap* attr_types_;
247 
248   // The number of custom device TensorHandle inputs. These inputs need to be
249   // processed by CustomDeviceOpHandler first.
250   int custom_device_tensor_handles_count_ = 0;
251   absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
252 
253   // The last device name given to SetDeviceName.
254   // This is used to avoid having to re-process the same device in repeated
255   // calls to SetDeviceName.
256   string last_set_device_name_;
257 
258   // The operation's device name.
259   // This contains the named passed to SetDeviceName until device_ is set,
260   // at which point it contains the device_ name.
261   string device_name_;
262 
263   // The parsed device name.
264   // This will always contain the result of
265   // DeviceNameUtils::ParseFullName(device_name_).
266   DeviceNameUtils::ParsedName device_parsed_name_;
267 
268   // The operation's device.
269   // This is set by the execution device placement logic, and should conform
270   // with the contents of device_name_. Once it is set, the device_name_ is
271   // updated accordingly.
272   VariantDevice device_;
273 
274   absl::optional<ManagedStackTrace> stack_trace_;
275   bool is_function_;  // Conceptually const, but can't be because of Reset
276   bool colocation_exempt_;
277   CancellationManager* cancellation_manager_ = nullptr;  // Not owned.
278   EagerExecutor* executor_;                              // Not owned.
279 
280   absl::optional<EagerFunctionParams> eager_func_params_;
281 
282   // Inference information
283   const tensorflow::OpDef* op_def_;  // op definition from protobuf
284   int inference_arg_idx_;  // arg definition index for the next input to be
285                            // added
286   gtl::FlatSet<std::string> inference_attrs_;  // attributes inferred so far
287 };
288 
UpdateInput(int i,TensorHandle * h)289 inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
290   ImmediateExecutionTensorHandle** slot = &inputs_[i];
291   ImmediateExecutionTensorHandle* existing = *slot;
292   if (existing != h) {
293     h->Ref();
294     existing->Unref();
295     *slot = h;  // Update inputs_[i] to h
296   }
297 }
298 
OperationFromInterface(ImmediateExecutionOperation * operation)299 inline EagerOperation* OperationFromInterface(
300     ImmediateExecutionOperation* operation) {
301   return down_cast<EagerOperation*>(operation);
302 }
303 
OperationFromInterface(const ImmediateExecutionOperation * operation)304 inline const EagerOperation* OperationFromInterface(
305     const ImmediateExecutionOperation* operation) {
306   return down_cast<const EagerOperation*>(operation);
307 }
308 
309 }  // namespace tensorflow
310 
311 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
312