1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
16
17 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
18 #include "tensorflow/core/platform/errors.h"
19 #include "tensorflow/core/platform/host_info.h"
20
21 namespace tensorflow {
22
Reset(const char * op,const char * raw_device_name,bool remote,EagerExecutor * executor,const absl::optional<EagerRemoteFunctionParams> remote_func_params)23 Status EagerOperation::Reset(
24 const char* op, const char* raw_device_name, bool remote,
25 EagerExecutor* executor,
26 const absl::optional<EagerRemoteFunctionParams> remote_func_params) {
27 DCHECK(inputs_.empty());
28 ClearInferenceState();
29 bool is_function = false;
30 TF_RETURN_IF_ERROR(AttrTypeMapForOp(op, &attr_types_, &is_function));
31
32 if (!is_function) {
33 TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_));
34 } else if (!remote && !ctx_.FindFunctionByName(op)) {
35 return errors::NotFound(
36 "'", op,
37 "' is neither a type of a primitive operation nor a name "
38 "of a function registered in binary running on ",
39 port::Hostname(),
40 ". Make sure the operation or function is "
41 "registered in the binary running in this process.");
42 }
43 attrs_.Reset(op);
44 device_ = nullptr;
45 use_xla_ = false;
46 is_function_ = is_function;
47 cancellation_manager_ = nullptr;
48 executor_ = executor ? executor : &ctx_.Executor();
49 remote_func_params_ = remote_func_params;
50 #ifdef TENSORFLOW_MEM_DEBUG
51 op_name_ = op;
52 #endif
53 return SetDeviceName(raw_device_name, true);
54 }
55
MaybeInferSingleInputAttrs(TensorHandle * handle)56 tensorflow::Status EagerOperation::MaybeInferSingleInputAttrs(
57 TensorHandle* handle) {
58 if (!op_def_) return Status::OK();
59
60 const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
61 if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
62 // Some clients that are still setting their input attributes manually are
63 // adding input list to their op by calling `TFE_OpAddInput` for each of
64 // its elements instead of calling `TFE_OpAddInputList`. When this happens,
65 // we cannot detect the end of such list, thus lose track of the input
66 // arguments in the op definition. To guarantee backward compatibility with
67 // those clients, disable automatic inference in this case.
68 ClearInferenceState();
69 return Status::OK();
70 }
71 const std::string& type_attr = input_def.type_attr();
72 if (!type_attr.empty() &&
73 inference_attrs_.find(type_attr) == inference_attrs_.end()) {
74 MutableAttrs()->Set(type_attr, handle->dtype);
75 inference_attrs_.insert(type_attr);
76 }
77 return Status::OK();
78 }
79
InferSingleTypeInputListAttrs(const tensorflow::OpDef::ArgDef & input_def,const tensorflow::DataType dtype,int num_inputs)80 void EagerOperation::InferSingleTypeInputListAttrs(
81 const tensorflow::OpDef::ArgDef& input_def,
82 const tensorflow::DataType dtype, int num_inputs) {
83 if (inference_attrs_.find(input_def.number_attr()) ==
84 inference_attrs_.end()) {
85 MutableAttrs()->Set(input_def.number_attr(), num_inputs);
86 inference_attrs_.insert(input_def.number_attr());
87 }
88 if (inference_attrs_.find(input_def.type_attr()) == inference_attrs_.end()) {
89 MutableAttrs()->Set(input_def.type_attr(), dtype);
90 inference_attrs_.insert(input_def.type_attr());
91 }
92 }
93
InferMixedTypeInputListAttrs(const tensorflow::OpDef::ArgDef & input_def,const std::vector<tensorflow::DataType> & dtypes)94 void EagerOperation::InferMixedTypeInputListAttrs(
95 const tensorflow::OpDef::ArgDef& input_def,
96 const std::vector<tensorflow::DataType>& dtypes) {
97 if (inference_attrs_.find(input_def.type_list_attr()) ==
98 inference_attrs_.end()) {
99 MutableAttrs()->Set(input_def.type_list_attr(),
100 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
101 dtypes.data(), dtypes.size()));
102 inference_attrs_.insert(input_def.type_list_attr());
103 }
104 }
105
InferInputListAttrs(int num_inputs)106 tensorflow::Status EagerOperation::InferInputListAttrs(int num_inputs) {
107 if (!op_def_) return Status::OK();
108
109 int start = inference_arg_idx_;
110 const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
111 if (!input_def.type_list_attr().empty()) {
112 std::vector<tensorflow::DataType> dtypes(num_inputs);
113 for (int i = 0; i < num_inputs; ++i) {
114 dtypes[i] = inputs_[start + i]->dtype;
115 }
116 InferMixedTypeInputListAttrs(input_def, dtypes);
117 } else if (!input_def.type_attr().empty() &&
118 !input_def.number_attr().empty()) {
119 InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
120 } else {
121 return tensorflow::errors::InvalidArgument("Invalid input list definition");
122 }
123 return tensorflow::Status::OK();
124 }
125
SetDeviceName(const char * device,const bool reset)126 tensorflow::Status EagerOperation::SetDeviceName(const char* device,
127 const bool reset) {
128 if (device != nullptr && strlen(device) > 0) {
129 if (device != raw_device_name_) {
130 if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
131 return errors::InvalidArgument("Malformed device specification '",
132 device,
133 "' in eager op: ", DebugString());
134 }
135 raw_device_name_ = device;
136 device_name_ =
137 DeviceNameUtils::HasSomeDetails(device_parsed_name_)
138 ? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
139 : "";
140 }
141 } else if (reset) {
142 raw_device_name_.clear();
143 device_name_.clear();
144 device_parsed_name_.Clear();
145 }
146 return Status::OK();
147 }
148
IsLocal() const149 bool EagerOperation::IsLocal() const {
150 if (ctx_.remote_device_mgr() == nullptr) return true;
151
152 if (!device_parsed_name_.has_job && !device_parsed_name_.has_replica &&
153 !device_parsed_name_.has_task)
154 return true;
155 auto& host_cpu_name = ctx_.HostCPU()->parsed_name();
156 return device_parsed_name_.job == host_cpu_name.job &&
157 device_parsed_name_.replica == host_cpu_name.replica &&
158 device_parsed_name_.task == host_cpu_name.task;
159 }
160
DebugString() const161 string EagerOperation::DebugString() const {
162 string out;
163 VLOG(1) << "EagerOperation::DebugString() over " << this;
164
165 strings::StrAppend(&out, "Name: ", Name(), "\n");
166 strings::StrAppend(&out, "Device Name: [", device_name_, "]\n");
167 strings::StrAppend(
168 &out, "Device: ", Device() ? Device()->DebugString() : "[]", "\n");
169 for (const auto& input : inputs_) {
170 VLOG(1) << "Input ptr: " << input;
171 strings::StrAppend(&out, "Input: ", input->DebugString(), "\n");
172 }
173
174 NodeDef ndef;
175 Attrs().FillAttrValueMap(ndef.mutable_attr());
176 strings::StrAppend(&out, "Attrs: ", ndef.DebugString(), "\n");
177 return out;
178 }
179
180 } // namespace tensorflow
181