• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
22 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
23 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/util/device_name_utils.h"
26 
27 namespace tensorflow {
28 
29 class EagerOperation {
30  public:
EagerOperation(tensorflow::EagerContext * ctx)31   explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {}
~EagerOperation()32   ~EagerOperation() {
33     for (tensorflow::TensorHandle* h : inputs_) {
34       h->Unref();
35     }
36   }
37 
38   // An EagerOperation object can be reused for a different op by calling
39   // Clear(), and then Reset(...) with the same arguments that would have
40   // been provided to the constructor.
Clear()41   void Clear() {
42     for (tensorflow::TensorHandle* h : inputs_) {
43       h->Unref();
44     }
45     inputs_.clear();
46     ClearInferenceState();
47   }
48 
49   tensorflow::Status Reset(const char* op, const char* raw_device_name,
50                            bool remote, EagerExecutor* executor,
51                            const absl::optional<EagerRemoteFunctionParams>
52                                remote_func_params = absl::nullopt);
53 
is_function()54   bool is_function() const { return is_function_; }
55 
EagerContext()56   tensorflow::EagerContext& EagerContext() { return ctx_; }
57 
MutableAttrs()58   tensorflow::AttrBuilder* MutableAttrs() { return &attrs_; }
Attrs()59   const tensorflow::AttrBuilder& Attrs() const { return attrs_; }
OpDef()60   const tensorflow::OpDef* OpDef() const { return op_def_; }
61 
Inputs()62   const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& Inputs()
63       const {
64     return inputs_;
65   }
66   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>*
MutableInputs()67   MutableInputs() {
68     return &inputs_;
69   }
70 
71   void AddInput(tensorflow::TensorHandle* h);
72   void UpdateInput(int i, tensorflow::TensorHandle* h);
73   void ConsumeInput(tensorflow::TensorHandle* h);
74 
Name()75   const tensorflow::string& Name() const { return attrs_.op_name(); }
AttrTypes()76   const tensorflow::AttrTypeMap* AttrTypes() const { return attr_types_; }
77 
Device()78   tensorflow::Device* Device() const { return device_; }
SetDevice(tensorflow::Device * device)79   void SetDevice(tensorflow::Device* device) {
80     device_ = device;
81     raw_device_name_.clear();
82     device_name_ = device->name();
83     device_parsed_name_ = device->parsed_name();
84   }
85 
GetDeviceName()86   const string& GetDeviceName() const { return device_name_; }
GetDeviceParsedName()87   const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
88     return device_parsed_name_;
89   }
90   tensorflow::Status SetDeviceName(const char* device,
91                                    const bool reset = false);
92 
93   // Indicates whether the op is assigned to a device that is local to the
94   // current host.
95   bool IsLocal() const;
96 
SetUseXla(bool use_xla)97   void SetUseXla(bool use_xla) { use_xla_ = use_xla; }
98 
GetCancellationManager()99   CancellationManager* GetCancellationManager() const {
100     return cancellation_manager_;
101   }
SetCancellationManager(CancellationManager * cancellation_manager)102   void SetCancellationManager(CancellationManager* cancellation_manager) {
103     cancellation_manager_ = cancellation_manager;
104   }
105 
Executor()106   EagerExecutor& Executor() { return *executor_; }
107 
108   string DebugString() const;
109 
remote_func_params()110   const absl::optional<EagerRemoteFunctionParams>& remote_func_params() const {
111     return remote_func_params_;
112   }
113 
114 #ifdef TENSORFLOW_MEM_DEBUG
op_name()115   const char* op_name() const { return op_name_; }
116   const char* op_name_ = nullptr;
117 #endif
118 
119   Status MaybeInferSingleInputAttrs(tensorflow::TensorHandle* handle);
120   Status InferInputListAttrs(int num_inputs);
121 
122  private:
ClearInferenceState()123   void ClearInferenceState() {
124     op_def_ = nullptr;
125     inference_arg_idx_ = 0;
126     inference_attrs_.clear_no_resize();
127   }
128   void InferSingleTypeInputListAttrs(const tensorflow::OpDef::ArgDef& input_def,
129                                      const tensorflow::DataType dtype,
130                                      int num_inputs);
131   void InferMixedTypeInputListAttrs(
132       const tensorflow::OpDef::ArgDef& input_def,
133       const std::vector<tensorflow::DataType>& dtypes);
134 
135   tensorflow::EagerContext& ctx_;
136   tensorflow::AttrBuilder attrs_;
137   const tensorflow::AttrTypeMap* attr_types_;
138   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
139   tensorflow::Device* device_;
140   string raw_device_name_;
141   string device_name_;
142   DeviceNameUtils::ParsedName device_parsed_name_;
143   bool use_xla_ = false;
144   bool is_function_;  // Conceptually const, but can't be because of Reset
145   CancellationManager* cancellation_manager_ = nullptr;  // Not owned.
146   EagerExecutor* executor_;                              // Not owned.
147   absl::optional<EagerRemoteFunctionParams> remote_func_params_;
148 
149   // Inference information
150   const tensorflow::OpDef* op_def_;  // op definition from protobuf
151   int inference_arg_idx_;  // arg definition index for the next input to be
152                            // added
153   tensorflow::gtl::FlatSet<std::string>
154       inference_attrs_;  // attributes inferred so far
155 };
156 
AddInput(tensorflow::TensorHandle * h)157 inline void EagerOperation::AddInput(tensorflow::TensorHandle* h) {
158   h->Ref();
159   inputs_.push_back(h);
160   attrs_.NumInputs(static_cast<int>(inputs_.size()));
161 }
162 
UpdateInput(int i,tensorflow::TensorHandle * h)163 inline void EagerOperation::UpdateInput(int i, tensorflow::TensorHandle* h) {
164   tensorflow::TensorHandle** slot = &inputs_[i];
165   tensorflow::TensorHandle* existing = *slot;
166   if (existing != h) {
167     h->Ref();
168     existing->Unref();
169     *slot = h;  // Update inputs_[i] to h
170   }
171 }
172 
ConsumeInput(tensorflow::TensorHandle * h)173 inline void EagerOperation::ConsumeInput(tensorflow::TensorHandle* h) {
174   inputs_.push_back(h);
175   attrs_.NumInputs(static_cast<int>(inputs_.size()));
176 }
177 
178 }  // namespace tensorflow
179 
180 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
181