1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
22 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
23 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/util/device_name_utils.h"
26
27 namespace tensorflow {
28
29 class EagerOperation {
30 public:
EagerOperation(tensorflow::EagerContext * ctx)31 explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {}
~EagerOperation()32 ~EagerOperation() {
33 for (tensorflow::TensorHandle* h : inputs_) {
34 h->Unref();
35 }
36 }
37
38 // An EagerOperation object can be reused for a different op by calling
39 // Clear(), and then Reset(...) with the same arguments that would have
40 // been provided to the constructor.
Clear()41 void Clear() {
42 for (tensorflow::TensorHandle* h : inputs_) {
43 h->Unref();
44 }
45 inputs_.clear();
46 ClearInferenceState();
47 }
48
49 tensorflow::Status Reset(const char* op, const char* raw_device_name,
50 bool remote, EagerExecutor* executor,
51 const absl::optional<EagerRemoteFunctionParams>
52 remote_func_params = absl::nullopt);
53
is_function()54 bool is_function() const { return is_function_; }
55
EagerContext()56 tensorflow::EagerContext& EagerContext() { return ctx_; }
57
MutableAttrs()58 tensorflow::AttrBuilder* MutableAttrs() { return &attrs_; }
Attrs()59 const tensorflow::AttrBuilder& Attrs() const { return attrs_; }
OpDef()60 const tensorflow::OpDef* OpDef() const { return op_def_; }
61
Inputs()62 const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& Inputs()
63 const {
64 return inputs_;
65 }
66 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>*
MutableInputs()67 MutableInputs() {
68 return &inputs_;
69 }
70
71 void AddInput(tensorflow::TensorHandle* h);
72 void UpdateInput(int i, tensorflow::TensorHandle* h);
73 void ConsumeInput(tensorflow::TensorHandle* h);
74
Name()75 const tensorflow::string& Name() const { return attrs_.op_name(); }
AttrTypes()76 const tensorflow::AttrTypeMap* AttrTypes() const { return attr_types_; }
77
Device()78 tensorflow::Device* Device() const { return device_; }
SetDevice(tensorflow::Device * device)79 void SetDevice(tensorflow::Device* device) {
80 device_ = device;
81 raw_device_name_.clear();
82 device_name_ = device->name();
83 device_parsed_name_ = device->parsed_name();
84 }
85
GetDeviceName()86 const string& GetDeviceName() const { return device_name_; }
GetDeviceParsedName()87 const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
88 return device_parsed_name_;
89 }
90 tensorflow::Status SetDeviceName(const char* device,
91 const bool reset = false);
92
93 // Indicates whether the op is assigned to a device that is local to the
94 // current host.
95 bool IsLocal() const;
96
SetUseXla(bool use_xla)97 void SetUseXla(bool use_xla) { use_xla_ = use_xla; }
98
GetCancellationManager()99 CancellationManager* GetCancellationManager() const {
100 return cancellation_manager_;
101 }
SetCancellationManager(CancellationManager * cancellation_manager)102 void SetCancellationManager(CancellationManager* cancellation_manager) {
103 cancellation_manager_ = cancellation_manager;
104 }
105
Executor()106 EagerExecutor& Executor() { return *executor_; }
107
108 string DebugString() const;
109
remote_func_params()110 const absl::optional<EagerRemoteFunctionParams>& remote_func_params() const {
111 return remote_func_params_;
112 }
113
114 #ifdef TENSORFLOW_MEM_DEBUG
op_name()115 const char* op_name() const { return op_name_; }
116 const char* op_name_ = nullptr;
117 #endif
118
119 Status MaybeInferSingleInputAttrs(tensorflow::TensorHandle* handle);
120 Status InferInputListAttrs(int num_inputs);
121
122 private:
ClearInferenceState()123 void ClearInferenceState() {
124 op_def_ = nullptr;
125 inference_arg_idx_ = 0;
126 inference_attrs_.clear_no_resize();
127 }
128 void InferSingleTypeInputListAttrs(const tensorflow::OpDef::ArgDef& input_def,
129 const tensorflow::DataType dtype,
130 int num_inputs);
131 void InferMixedTypeInputListAttrs(
132 const tensorflow::OpDef::ArgDef& input_def,
133 const std::vector<tensorflow::DataType>& dtypes);
134
135 tensorflow::EagerContext& ctx_;
136 tensorflow::AttrBuilder attrs_;
137 const tensorflow::AttrTypeMap* attr_types_;
138 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
139 tensorflow::Device* device_;
140 string raw_device_name_;
141 string device_name_;
142 DeviceNameUtils::ParsedName device_parsed_name_;
143 bool use_xla_ = false;
144 bool is_function_; // Conceptually const, but can't be because of Reset
145 CancellationManager* cancellation_manager_ = nullptr; // Not owned.
146 EagerExecutor* executor_; // Not owned.
147 absl::optional<EagerRemoteFunctionParams> remote_func_params_;
148
149 // Inference information
150 const tensorflow::OpDef* op_def_; // op definition from protobuf
151 int inference_arg_idx_; // arg definition index for the next input to be
152 // added
153 tensorflow::gtl::FlatSet<std::string>
154 inference_attrs_; // attributes inferred so far
155 };
156
AddInput(tensorflow::TensorHandle * h)157 inline void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
158 h->Ref();
159 inputs_.push_back(h);
160 attrs_.NumInputs(static_cast<int>(inputs_.size()));
161 }
162
UpdateInput(int i,tensorflow::TensorHandle * h)163 inline void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) {
164 tensorflow::TensorHandle** slot = &inputs_[i];
165 tensorflow::TensorHandle* existing = *slot;
166 if (existing != h) {
167 h->Ref();
168 existing->Unref();
169 *slot = h; // Update inputs_[i] to h
170 }
171 }
172
ConsumeInput(tensorflow::TensorHandle * h)173 inline void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
174 inputs_.push_back(h);
175 attrs_.NumInputs(static_cast<int>(inputs_.size()));
176 }
177
178 } // namespace tensorflow
179
180 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
181