• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
17 
18 #include "absl/container/inlined_vector.h"
19 #include "absl/types/optional.h"
20 #include "absl/types/span.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/c/eager/abstract_tensor_handle.h"
23 #include "tensorflow/c/eager/immediate_execution_operation.h"
24 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
27 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
28 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/op_def.pb.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 #include "tensorflow/core/util/managed_stack_trace.h"
34 
35 namespace tensorflow {
36 
37 class EagerOperation : public ImmediateExecutionOperation {
38  public:
EagerOperation(tensorflow::EagerContext * ctx)39   explicit EagerOperation(tensorflow::EagerContext* ctx)
40       : ImmediateExecutionOperation(kEager), ctx_(*ctx) {}
~EagerOperation()41   ~EagerOperation() override {
42     for (ImmediateExecutionTensorHandle* h : inputs_) {
43       h->Unref();
44     }
45   }
46 
Release()47   void Release() override { delete this; }
48 
49   void Clear() override;
Reset(const char * op,const char * raw_device_name)50   Status Reset(const char* op, const char* raw_device_name) override {
51     return Reset(op, raw_device_name, false, nullptr);
52   }
53 
Name()54   const string& Name() const override { return attrs_.op_name(); }
55 
DeviceName()56   const string& DeviceName() const override { return device_name_; }
57 
GetContext()58   ImmediateExecutionContext* GetContext() const override { return &ctx_; }
59 
GetDeviceParsedName()60   const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
61     return device_parsed_name_;
62   }
63 
64   // Replaces the previous device name with the given one (see
65   // AbstractOperation::SetDeviceName for more details).
66   //
67   // This also resets the internal device pointer, unless the given name refers
68   // to a known custom device, in which case the internal device pointer is
69   // updated to that device.
70   Status SetDeviceName(const char* name) override;
71 
SetDevice(VariantDevice device)72   void SetDevice(VariantDevice device) {
73     device_ = device;
74     device_name_ = absl::visit(
75         [](auto* device) { return device == nullptr ? "" : device->name(); },
76         device);
77     DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
78     // TODO(b/154133594): Due to intricacies of external logic, we can not
79     // set this do device_name_ as it would be natural, because we need the
80     // next call to SetDeviceName to reset the device pointer.
81     last_set_device_name_ = "\177";  // DEL (an invalid value)
82   }
83 
84   Status SetAttrValue(const char* attr_name, const AttrValue& value);
85 
86   Status AddInput(AbstractTensorHandle* input) override;
87   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
88   Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
89   absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
HasCustomDeviceInput()90   bool HasCustomDeviceInput() const override {
91     return custom_device_tensor_handles_count_ > 0;
92   }
93   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
94                  int* num_retvals) override;
OpDef()95   const tensorflow::OpDef* OpDef() const override { return op_def_; };
96 
97   Status SetAttrString(const char* attr_name, const char* data,
98                        size_t length) override;
99   Status SetAttrInt(const char* attr_name, int64_t value) override;
100   Status SetAttrFloat(const char* attr_name, float value) override;
101   Status SetAttrBool(const char* attr_name, bool value) override;
102   Status SetAttrType(const char* attr_name, DataType value) override;
103   Status SetAttrShape(const char* attr_name, const int64_t* dims,
104                       const int num_dims) override;
105   Status SetAttrFunction(const char* attr_name,
106                          const AbstractOperation* value) override;
107   Status SetAttrFunctionName(const char* attr_name, const char* data,
108                              size_t length) override;
109   Status SetAttrTensor(const char* attr_name,
110                        AbstractTensorInterface* tensor) override;
111   Status SetAttrStringList(const char* attr_name, const void* const* values,
112                            const size_t* lengths, int num_values) override;
113   Status SetAttrFloatList(const char* attr_name, const float* values,
114                           int num_values) override;
115   Status SetAttrIntList(const char* attr_name, const int64_t* values,
116                         int num_values) override;
117   Status SetAttrTypeList(const char* attr_name, const DataType* values,
118                          int num_values) override;
119   Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
120                          int num_values) override;
121   Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
122                           const int* num_dims, int num_values) override;
123   Status SetAttrFunctionList(
124       const char* attr_name,
125       absl::Span<const AbstractOperation*> values) override;
126 
127   Status InputLength(const char* input_name, int* length) override;
128   Status OutputLength(const char* output_name, int* length) override;
129 
130   const AbstractOpAttrs* GetOpAttrs() const override;
131   void AddAttrs(const AbstractOpAttrs* op_attrs) override;
132 
SetStackTrace(ManagedStackTrace stack_trace)133   void SetStackTrace(ManagedStackTrace stack_trace) override {
134     stack_trace_ = stack_trace;
135   }
136 
GetStackTrace()137   absl::optional<ManagedStackTrace> GetStackTrace() override {
138     return stack_trace_;
139   }
140 
141   Status Reset(const char* op, const char* device_name, bool remote,
142                EagerExecutor* executor,
143                const absl::optional<EagerRemoteFunctionParams>
144                    remote_func_params = absl::nullopt);
145 
is_function()146   bool is_function() const { return is_function_; }
colocation_exempt()147   bool colocation_exempt() const { return colocation_exempt_; }
148 
EagerContext()149   tensorflow::EagerContext& EagerContext() const { return ctx_; }
150 
MutableAttrs()151   AttrBuilder* MutableAttrs() { return &attrs_; }
Attrs()152   const AttrBuilder& Attrs() const { return attrs_; }
153 
154   // TensorHandleInputs and MutableTensorHandleInputs first check that all
155   // inputs are TensorHandles, i.e. that there are no custom device inputs. They
156   // return a bad status otherwise.
157   Status TensorHandleInputs(
158       const absl::InlinedVector<TensorHandle*, 4>** inputs) const;
159   Status MutableTensorHandleInputs(
160       absl::InlinedVector<TensorHandle*, 4>** inputs);
161 
Inputs()162   const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs()
163       const {
164     return inputs_;
165   }
166 
167   void UpdateInput(int i, TensorHandle* h);
168 
169   // Like TensorHandles, EagerOperations may be placed either on a virtual
170   // CustomDevice or on a physical Device.
Device()171   VariantDevice Device() const { return device_; }
172 
173   // Indicates whether the op is assigned to a device that is local to the
174   // current host.
175   bool IsLocal() const;
176 
GetCancellationManager()177   CancellationManager* GetCancellationManager() const {
178     return cancellation_manager_;
179   }
SetCancellationManager(CancellationManager * cancellation_manager)180   void SetCancellationManager(CancellationManager* cancellation_manager) {
181     cancellation_manager_ = cancellation_manager;
182   }
183 
Executor()184   EagerExecutor& Executor() { return *executor_; }
185 
186   string DebugString() const;
187 
remote_func_params()188   const absl::optional<EagerRemoteFunctionParams>& remote_func_params() const {
189     return remote_func_params_;
190   }
191 
192   // Op name recorded for memory debugging purpose.
op_name()193   const char* op_name() const { return op_name_; }
194 
195   // For LLVM style RTTI.
classof(const AbstractOperation * ptr)196   static bool classof(const AbstractOperation* ptr) {
197     return ptr->getKind() == kEager;
198   }
199 
200  private:
201   void AddTensorHandle(ImmediateExecutionTensorHandle* h);
202 
203   const tensorflow::OpDef* GetOpDef(Status* status);
204 
ClearInferenceState()205   void ClearInferenceState() {
206     op_def_ = nullptr;
207     inference_arg_idx_ = 0;
208     inference_attrs_.clear_no_resize();
209   }
210 
211   Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle);
212   Status InferInputListAttrs(int num_inputs);
213 
214   void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
215                                      const DataType dtype, int num_inputs);
216   void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
217                                     const std::vector<DataType>& dtypes);
218 
219   tensorflow::EagerContext& ctx_;
220   const char* op_name_ = nullptr;
221   AttrBuilder attrs_;
222   const AttrTypeMap* attr_types_;
223 
224   // The number of custom device TensorHandle inputs. These inputs need to be
225   // processed by CustomDeviceOpHandler first.
226   int custom_device_tensor_handles_count_ = 0;
227   absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
228 
229   // The last device name given to SetDeviceName.
230   // This is used to avoid having to re-process the same device in repeated
231   // calls to SetDeviceName.
232   string last_set_device_name_;
233 
234   // The operation's device name.
235   // This contains the named passed to SetDeviceName until device_ is set,
236   // at which point it contains the device_ name.
237   string device_name_;
238 
239   // The parsed device name.
240   // This will always contain the result of
241   // DeviceNameUtils::ParseFullName(device_name_).
242   DeviceNameUtils::ParsedName device_parsed_name_;
243 
244   // The operation's device.
245   // This is set by the execution device placement logic, and should conform
246   // with the contents of device_name_. Once it is set, the device_name_ is
247   // updated accordingly.
248   VariantDevice device_;
249 
250   absl::optional<ManagedStackTrace> stack_trace_;
251   bool is_function_;  // Conceptually const, but can't be because of Reset
252   bool colocation_exempt_;
253   CancellationManager* cancellation_manager_ = nullptr;  // Not owned.
254   EagerExecutor* executor_;                              // Not owned.
255   absl::optional<EagerRemoteFunctionParams> remote_func_params_;
256 
257   // Inference information
258   const tensorflow::OpDef* op_def_;  // op definition from protobuf
259   int inference_arg_idx_;  // arg definition index for the next input to be
260                            // added
261   gtl::FlatSet<std::string> inference_attrs_;  // attributes inferred so far
262 };
263 
UpdateInput(int i,TensorHandle * h)264 inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
265   ImmediateExecutionTensorHandle** slot = &inputs_[i];
266   ImmediateExecutionTensorHandle* existing = *slot;
267   if (existing != h) {
268     h->Ref();
269     existing->Unref();
270     *slot = h;  // Update inputs_[i] to h
271   }
272 }
273 
OperationFromInterface(ImmediateExecutionOperation * operation)274 inline EagerOperation* OperationFromInterface(
275     ImmediateExecutionOperation* operation) {
276   return down_cast<EagerOperation*>(operation);
277 }
278 
OperationFromInterface(const ImmediateExecutionOperation * operation)279 inline const EagerOperation* OperationFromInterface(
280     const ImmediateExecutionOperation* operation) {
281   return down_cast<const EagerOperation*>(operation);
282 }
283 
284 }  // namespace tensorflow
285 
286 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
287