• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
16 
17 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
18 #include "tensorflow/core/platform/errors.h"
19 #include "tensorflow/core/platform/host_info.h"
20 
21 namespace tensorflow {
22 
Reset(const char * op,const char * raw_device_name,bool remote,EagerExecutor * executor,const absl::optional<EagerRemoteFunctionParams> remote_func_params)23 Status EagerOperation::Reset(
24     const char* op, const char* raw_device_name, bool remote,
25     EagerExecutor* executor,
26     const absl::optional<EagerRemoteFunctionParams> remote_func_params) {
27   DCHECK(inputs_.empty());
28   ClearInferenceState();
29   bool is_function = false;
30   TF_RETURN_IF_ERROR(AttrTypeMapForOp(op, &attr_types_, &is_function));
31 
32   if (!is_function) {
33     TF_RETURN_IF_ERROR(OpDefForOp(op, &op_def_));
34   } else if (!remote && !ctx_.FindFunctionByName(op)) {
35     return errors::NotFound(
36         "'", op,
37         "' is neither a type of a primitive operation nor a name "
38         "of a function registered in binary running on ",
39         port::Hostname(),
40         ". Make sure the operation or function is "
41         "registered in the binary running in this process.");
42   }
43   attrs_.Reset(op);
44   device_ = nullptr;
45   use_xla_ = false;
46   is_function_ = is_function;
47   cancellation_manager_ = nullptr;
48   executor_ = executor ? executor : &ctx_.Executor();
49   remote_func_params_ = remote_func_params;
50 #ifdef TENSORFLOW_MEM_DEBUG
51   op_name_ = op;
52 #endif
53   return SetDeviceName(raw_device_name, true);
54 }
55 
MaybeInferSingleInputAttrs(TensorHandle * handle)56 tensorflow::Status EagerOperation::MaybeInferSingleInputAttrs(
57     TensorHandle* handle) {
58   if (!op_def_) return Status::OK();
59 
60   const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
61   if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
62     // Some clients that are still setting their input attributes manually are
63     // adding input list to their op by calling `TFE_OpAddInput` for each of
64     // its elements instead of calling `TFE_OpAddInputList`. When this happens,
65     // we cannot detect the end of such list, thus lose track of the input
66     // arguments in the op definition. To guarantee backward compatibility with
67     // those clients, disable automatic inference in this case.
68     ClearInferenceState();
69     return Status::OK();
70   }
71   const std::string& type_attr = input_def.type_attr();
72   if (!type_attr.empty() &&
73       inference_attrs_.find(type_attr) == inference_attrs_.end()) {
74     MutableAttrs()->Set(type_attr, handle->dtype);
75     inference_attrs_.insert(type_attr);
76   }
77   return Status::OK();
78 }
79 
InferSingleTypeInputListAttrs(const tensorflow::OpDef::ArgDef & input_def,const tensorflow::DataType dtype,int num_inputs)80 void EagerOperation::InferSingleTypeInputListAttrs(
81     const tensorflow::OpDef::ArgDef& input_def,
82     const tensorflow::DataType dtype, int num_inputs) {
83   if (inference_attrs_.find(input_def.number_attr()) ==
84       inference_attrs_.end()) {
85     MutableAttrs()->Set(input_def.number_attr(), num_inputs);
86     inference_attrs_.insert(input_def.number_attr());
87   }
88   if (inference_attrs_.find(input_def.type_attr()) == inference_attrs_.end()) {
89     MutableAttrs()->Set(input_def.type_attr(), dtype);
90     inference_attrs_.insert(input_def.type_attr());
91   }
92 }
93 
InferMixedTypeInputListAttrs(const tensorflow::OpDef::ArgDef & input_def,const std::vector<tensorflow::DataType> & dtypes)94 void EagerOperation::InferMixedTypeInputListAttrs(
95     const tensorflow::OpDef::ArgDef& input_def,
96     const std::vector<tensorflow::DataType>& dtypes) {
97   if (inference_attrs_.find(input_def.type_list_attr()) ==
98       inference_attrs_.end()) {
99     MutableAttrs()->Set(input_def.type_list_attr(),
100                         tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
101                             dtypes.data(), dtypes.size()));
102     inference_attrs_.insert(input_def.type_list_attr());
103   }
104 }
105 
InferInputListAttrs(int num_inputs)106 tensorflow::Status EagerOperation::InferInputListAttrs(int num_inputs) {
107   if (!op_def_) return Status::OK();
108 
109   int start = inference_arg_idx_;
110   const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
111   if (!input_def.type_list_attr().empty()) {
112     std::vector<tensorflow::DataType> dtypes(num_inputs);
113     for (int i = 0; i < num_inputs; ++i) {
114       dtypes[i] = inputs_[start + i]->dtype;
115     }
116     InferMixedTypeInputListAttrs(input_def, dtypes);
117   } else if (!input_def.type_attr().empty() &&
118              !input_def.number_attr().empty()) {
119     InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
120   } else {
121     return tensorflow::errors::InvalidArgument("Invalid input list definition");
122   }
123   return tensorflow::Status::OK();
124 }
125 
SetDeviceName(const char * device,const bool reset)126 tensorflow::Status EagerOperation::SetDeviceName(const char* device,
127                                                  const bool reset) {
128   if (device != nullptr && strlen(device) > 0) {
129     if (device != raw_device_name_) {
130       if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
131         return errors::InvalidArgument("Malformed device specification '",
132                                        device,
133                                        "' in eager op: ", DebugString());
134       }
135       raw_device_name_ = device;
136       device_name_ =
137           DeviceNameUtils::HasSomeDetails(device_parsed_name_)
138               ? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
139               : "";
140     }
141   } else if (reset) {
142     raw_device_name_.clear();
143     device_name_.clear();
144     device_parsed_name_.Clear();
145   }
146   return Status::OK();
147 }
148 
IsLocal() const149 bool EagerOperation::IsLocal() const {
150   if (ctx_.remote_device_mgr() == nullptr) return true;
151 
152   if (!device_parsed_name_.has_job && !device_parsed_name_.has_replica &&
153       !device_parsed_name_.has_task)
154     return true;
155   auto& host_cpu_name = ctx_.HostCPU()->parsed_name();
156   return device_parsed_name_.job == host_cpu_name.job &&
157          device_parsed_name_.replica == host_cpu_name.replica &&
158          device_parsed_name_.task == host_cpu_name.task;
159 }
160 
DebugString() const161 string EagerOperation::DebugString() const {
162   string out;
163   VLOG(1) << "EagerOperation::DebugString() over " << this;
164 
165   strings::StrAppend(&out, "Name: ", Name(), "\n");
166   strings::StrAppend(&out, "Device Name: [", device_name_, "]\n");
167   strings::StrAppend(
168       &out, "Device: ", Device() ? Device()->DebugString() : "[]", "\n");
169   for (const auto& input : inputs_) {
170     VLOG(1) << "Input ptr: " << input;
171     strings::StrAppend(&out, "Input: ", input->DebugString(), "\n");
172   }
173 
174   NodeDef ndef;
175   Attrs().FillAttrValueMap(ndef.mutable_attr());
176   strings::StrAppend(&out, "Attrs: ", ndef.DebugString(), "\n");
177   return out;
178 }
179 
180 }  // namespace tensorflow
181