1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/function_api_info.h"
17
18 #include <string>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23
24 namespace tensorflow {
25 namespace grappler {
FunctionApiInfo()26 FunctionApiInfo::FunctionApiInfo() {}
~FunctionApiInfo()27 FunctionApiInfo::~FunctionApiInfo() {}
28
Init(const FunctionDef & function_def)29 Status FunctionApiInfo::Init(const FunctionDef& function_def) {
30 function_type_ = FunctionApiInfo::FunctionType::INFERENCE;
31 for (const auto& attr : function_def.attr()) {
32 if (attr.first == "api_preferred_device") {
33 preferred_device_ = attr.second.s();
34 }
35 if (attr.first == "api_implements") {
36 interface_name_ = attr.second.s();
37 }
38 if (attr.first == "forward_function_name") {
39 function_type_ = FunctionApiInfo::FunctionType::BACKWARD;
40 pairing_function_name_ = attr.second.s();
41 }
42 if (attr.first == "backward_function_name") {
43 function_type_ = FunctionApiInfo::FunctionType::FORWARD;
44 pairing_function_name_ = attr.second.s();
45 }
46 }
47
48 input_arg_dtypes_.reserve(function_def.signature().input_arg_size());
49 for (const auto& input_arg : function_def.signature().input_arg()) {
50 input_arg_dtypes_.emplace_back(input_arg.type());
51 }
52 output_arg_dtypes_.reserve(function_def.signature().output_arg_size());
53 for (const auto& output_arg : function_def.signature().output_arg()) {
54 output_arg_dtypes_.emplace_back(output_arg.type());
55 }
56
57 if (interface_name_.empty() && !preferred_device_.empty()) {
58 return errors::InvalidArgument(
59 "Function '", function_def.signature().name(),
60 "' has a preferred device, but does not implement an interface");
61 }
62 return Status::OK();
63 }
64
preferred_device() const65 const string& FunctionApiInfo::preferred_device() const {
66 return preferred_device_;
67 }
68
interface_name() const69 const string& FunctionApiInfo::interface_name() const {
70 return interface_name_;
71 }
72
function_type() const73 const FunctionApiInfo::FunctionType FunctionApiInfo::function_type() const {
74 return function_type_;
75 }
76
pairing_function_name() const77 const string& FunctionApiInfo::pairing_function_name() const {
78 return pairing_function_name_;
79 }
80
input_arg_dtypes() const81 const DataTypeVector& FunctionApiInfo::input_arg_dtypes() const {
82 return input_arg_dtypes_;
83 }
84
output_arg_dtypes() const85 const DataTypeVector& FunctionApiInfo::output_arg_dtypes() const {
86 return output_arg_dtypes_;
87 }
88
FunctionLibraryApiInfo()89 FunctionLibraryApiInfo::FunctionLibraryApiInfo() {}
~FunctionLibraryApiInfo()90 FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {}
91
92 namespace {
IsSameArgDef(const OpDef::ArgDef & arg1,const OpDef::ArgDef & arg2)93 bool IsSameArgDef(const OpDef::ArgDef& arg1, const OpDef::ArgDef& arg2) {
94 if (arg1.type() != arg2.type()) return false;
95 if (arg1.type_attr() != arg2.type_attr()) return false;
96 if (arg1.number_attr() != arg2.number_attr()) return false;
97 if (arg1.type_list_attr() != arg2.type_list_attr()) return false;
98 if (arg1.is_ref() != arg2.is_ref()) return false;
99 return true;
100 }
101
IsSameSignature(const FunctionDef & f1,const FunctionDef & f2,const bool check_inputs,const bool check_outputs)102 bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2,
103 const bool check_inputs, const bool check_outputs) {
104 const auto& sig1 = f1.signature();
105 const auto& sig2 = f2.signature();
106 // Functions have positional semantics, so we don't check for names.
107 if (check_inputs) {
108 if (sig1.input_arg_size() != sig2.input_arg_size()) return false;
109 for (int k = 0; k < sig1.input_arg_size(); ++k) {
110 if (!IsSameArgDef(sig1.input_arg(k), sig2.input_arg(k))) return false;
111 }
112 }
113 if (check_outputs) {
114 if (f1.ret().size() != f2.ret().size()) return false;
115 if (sig1.output_arg_size() != sig2.output_arg_size()) return false;
116 for (int k = 0; k < sig1.output_arg_size(); ++k) {
117 if (!IsSameArgDef(sig1.output_arg(k), sig2.output_arg(k))) return false;
118 }
119 }
120 return true;
121 }
122
ValidateSignature(const string & interface_name,const std::vector<const FunctionDef * > & equiv_funcs,const FunctionApiInfo::FunctionType function_type)123 Status ValidateSignature(const string& interface_name,
124 const std::vector<const FunctionDef*>& equiv_funcs,
125 const FunctionApiInfo::FunctionType function_type) {
126 if (equiv_funcs.size() < 2) return Status::OK();
127 for (size_t k = 1; k < equiv_funcs.size(); ++k) {
128 const bool check_input =
129 (function_type == FunctionApiInfo::FunctionType::INFERENCE ||
130 function_type == FunctionApiInfo::FunctionType::FORWARD);
131 const bool check_output =
132 (function_type == FunctionApiInfo::FunctionType::INFERENCE ||
133 function_type == FunctionApiInfo::FunctionType::BACKWARD);
134 if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k], check_input,
135 check_output)) {
136 return errors::InvalidArgument(
137 "Functions '", equiv_funcs[0]->signature().name(), "' and '",
138 equiv_funcs[k]->signature().name(), "' both implement '",
139 interface_name, "' but their signatures do not match.");
140 }
141 }
142 return Status::OK();
143 }
144
ValidateSignatures(const std::unordered_map<string,std::vector<const FunctionDef * >> & intf_to_func,const FunctionApiInfo::FunctionType function_type)145 Status ValidateSignatures(
146 const std::unordered_map<string, std::vector<const FunctionDef*>>&
147 intf_to_func,
148 const FunctionApiInfo::FunctionType function_type) {
149 for (const auto& item : intf_to_func)
150 TF_RETURN_IF_ERROR(
151 ValidateSignature(item.first, item.second, function_type));
152 return Status::OK();
153 }
154 } // namespace
155
Init(const FunctionDefLibrary & function_library)156 Status FunctionLibraryApiInfo::Init(
157 const FunctionDefLibrary& function_library) {
158 std::unordered_map<string, std::vector<const FunctionDef*>> infer_funcs;
159 std::unordered_map<string, std::vector<const FunctionDef*>> fwd_funcs;
160 std::unordered_map<string, std::vector<const FunctionDef*>> bwd_funcs;
161 for (const auto& function : function_library.function()) {
162 std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo);
163 TF_RETURN_IF_ERROR(func_info->Init(function));
164 // Ignore the function if it does not implement any interface.
165 if (func_info->interface_name().empty()) continue;
166
167 const string& function_name = function.signature().name();
168 const string& interface_name = func_info->interface_name();
169 VLOG(3) << "Got " << func_info->function_type()
170 << " function: " << function_name
171 << " with interface: " << interface_name;
172 switch (func_info->function_type()) {
173 case FunctionApiInfo::FunctionType::INFERENCE:
174 intf_to_inference_funcs_[interface_name].emplace_back(function_name);
175 infer_funcs[interface_name].emplace_back(&function);
176 break;
177 case FunctionApiInfo::FunctionType::FORWARD:
178 intf_to_forward_funcs_[interface_name].emplace_back(function_name);
179 fwd_funcs[interface_name].emplace_back(&function);
180 break;
181 case FunctionApiInfo::FunctionType::BACKWARD:
182 intf_to_backward_funcs_[interface_name].emplace_back(function_name);
183 bwd_funcs[interface_name].emplace_back(&function);
184 break;
185 default:
186 return errors::InvalidArgument("Unrecognized function type: ",
187 func_info->function_type());
188 }
189 func_info_[function_name] = std::move(func_info);
190 }
191 TF_RETURN_IF_ERROR(ValidateSignatures(
192 infer_funcs, FunctionApiInfo::FunctionType::INFERENCE));
193 TF_RETURN_IF_ERROR(
194 ValidateSignatures(fwd_funcs, FunctionApiInfo::FunctionType::FORWARD));
195 TF_RETURN_IF_ERROR(
196 ValidateSignatures(bwd_funcs, FunctionApiInfo::FunctionType::BACKWARD));
197 return Status::OK();
198 }
199
GetEquivalentImplementations(const string & function_name,std::vector<string> * other_functions) const200 Status FunctionLibraryApiInfo::GetEquivalentImplementations(
201 const string& function_name, std::vector<string>* other_functions) const {
202 const auto func_it = func_info_.find(function_name);
203 if (func_it == func_info_.end()) return Status::OK();
204 const FunctionApiInfo* func_info = func_it->second.get();
205
206 absl::flat_hash_map<string, std::vector<string>>::const_iterator it;
207 switch (func_info->function_type()) {
208 case FunctionApiInfo::FunctionType::INFERENCE:
209 it = intf_to_inference_funcs_.find(func_info->interface_name());
210 break;
211 case FunctionApiInfo::FunctionType::FORWARD:
212 it = intf_to_forward_funcs_.find(func_info->interface_name());
213 break;
214 case FunctionApiInfo::FunctionType::BACKWARD:
215 it = intf_to_backward_funcs_.find(func_info->interface_name());
216 break;
217 default:
218 return errors::InvalidArgument("Unrecognized function type: ",
219 func_info->function_type());
220 }
221
222 for (const auto& func_name : it->second) {
223 if (func_name == function_name) continue;
224 other_functions->emplace_back(func_name);
225 }
226 return Status::OK();
227 }
228
GetApiInfo(const string & function_name) const229 const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo(
230 const string& function_name) const {
231 const auto it = func_info_.find(function_name);
232 if (it == func_info_.end()) return nullptr;
233 return it->second.get();
234 }
235
236 } // end namespace grappler
237 } // end namespace tensorflow
238