1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17
18 #include <map>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/function.pb_text.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/lib/gtl/map_util.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 #include "tensorflow/core/util/equal_graph_def.h"
38
39 namespace tensorflow {
40
41 // Extracts the actual type from "attr_values" based on its definition
42 // "arg_def".
43 //
44 // If "arg_def" is a N*T type, *is_type_list is set to false, and
45 // *dtypes is set to be a vector of size N and each element is T.
46 //
47 // If "arg_def" is a list(type), *is_type_list is set to true, and
48 // *dtypes is set to be a vector of types specified in attrs for
49 // arg_def.
50 //
51 // Otherwise (arg_def is a simple type T), *is_type_list is set to
52 // false, and *dtypes is set to a single element vector, whose only
53 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)54 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
55 bool* is_type_list, DataTypeVector* dtypes) {
56 dtypes->clear();
57 if (!arg_def.type_list_attr().empty()) {
58 const AttrValue* v = attrs.Find(arg_def.type_list_attr());
59 if (v == nullptr) {
60 return errors::NotFound("type attr not found: ",
61 arg_def.type_list_attr());
62 }
63 *is_type_list = true;
64 for (int i = 0; i < v->list().type_size(); ++i) {
65 dtypes->push_back(v->list().type(i));
66 }
67 return Status::OK();
68 }
69
70 *is_type_list = false;
71 int num = 1;
72 if (!arg_def.number_attr().empty()) {
73 const AttrValue* v = attrs.Find(arg_def.number_attr());
74 if (v == nullptr) {
75 return errors::NotFound("type attr not found: ", arg_def.type_attr());
76 }
77 num = v->i();
78 }
79
80 DataType dtype;
81 if (arg_def.type() != DT_INVALID) {
82 dtype = arg_def.type();
83 } else if (arg_def.type_attr().empty()) {
84 dtype = DT_INVALID;
85 } else {
86 const AttrValue* v = attrs.Find(arg_def.type_attr());
87 if (v == nullptr) {
88 return errors::NotFound("type attr not found: ", arg_def.type_attr());
89 }
90 dtype = v->type();
91 }
92 dtypes->resize(num, dtype);
93 return Status::OK();
94 }
95
96 namespace {
97
98 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)99 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
100 SetAttrValue(val, &((*ndef->mutable_attr())[name]));
101 }
102
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)103 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
104 // attr_values should specify all attrs defined in fdef.
105 for (const auto& a : sig.attr()) {
106 const AttrValue* v = attr_values.Find(a.name());
107 if (!v) {
108 return errors::NotFound("Attr ", a.name(), " is not found from ",
109 SummarizeOpDef(sig));
110 }
111 Status status = AttrValueHasType(*v, a.type());
112 if (!status.ok()) {
113 errors::AppendToMessage(&status, "for attr '", a.name(), "'");
114 return status;
115 }
116 }
117
118 // TODO(josh11b): Enable this code once it works with function gradients.
119 // Right now the C++ function gradient code assumes it can pass
120 // all the attrs of the function to the gradient, and any attrs that
121 // the gradient doesn't care about will be ignored.
122 #if 0
123 if (attr_values.size() != sig.attr_size()) {
124 for (const auto& a : attr_values) {
125 // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
126 bool found = false;
127 for (const auto& s : sig.attr()) {
128 if (a.first == s.name()) {
129 found = true;
130 break;
131 }
132 }
133 if (!found) {
134 return errors::NotFound("Attr ", a.first, " is not found in ",
135 SummarizeOpDef(sig));
136 }
137 }
138 }
139 #endif
140
141 return Status::OK();
142 }
143
144 // A helper class for instantiating functions. This contains shared information
145 // like the resulting graph and node name index.
146 class FunctionInstantiationHelper {
147 public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)148 FunctionInstantiationHelper(GetFunctionSignature get_function,
149 InstantiationResult* result)
150 : get_function_(std ::move(get_function)), result_(*result) {
151 result_.nodes.clear();
152 }
153
154 // Builds index for nodes that can be used as node's input arguments.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,bool ints_on_device)155 Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
156 bool ints_on_device) {
157 bool is_type_list;
158 DataTypeVector dtypes;
159 TF_RETURN_IF_ERROR(
160 ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
161 CHECK_GE(dtypes.size(), size_t{1});
162 int arg_index = result_.nodes.size();
163 TF_RETURN_IF_ERROR(
164 AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
165 // Creates dtypes.size() nodes in the graph.
166 for (size_t i = 0; i < dtypes.size(); ++i) {
167 TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
168 {true, arg_index, 0, false, {dtypes[i]}}));
169 DCHECK_EQ(arg_index, result_.nodes.size());
170 string name = arg_def.name();
171 if (dtypes.size() > 1) {
172 strings::StrAppend(&name, "_", i);
173 }
174 NodeDef* gnode = AddNode(name);
175 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
176 gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
177 } else {
178 gnode->set_op(FunctionLibraryDefinition::kArgOp);
179 }
180 AddAttr("T", dtypes[i], gnode);
181 AddAttr("index", arg_index, gnode);
182 result_.arg_types.push_back(dtypes[i]);
183 ++arg_index;
184 }
185 return Status::OK();
186 }
187
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)188 Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
189 const int arg_index) {
190 const OpDef* node_sig = nullptr;
191 TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
192 if (node_sig->output_arg_size() == 0) {
193 return AddItem(node.name(), {false, arg_index, 0, false, {}});
194 }
195 const int num_retval = node_sig->output_arg_size();
196 int start = 0;
197 bool is_type_list;
198 DataTypeVector dtypes;
199 for (int i = 0; i < num_retval; ++i) {
200 TF_RETURN_IF_ERROR(
201 ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
202 // Note that we rely on the backwards-compatibility test enforcing
203 // that output_arg(*).name() doesn't change here.
204 const string base_name =
205 strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
206 TF_RETURN_IF_ERROR(
207 AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
208 for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
209 TF_RETURN_IF_ERROR(
210 AddItem(strings::StrCat(base_name, ":", j),
211 {false, arg_index, start + j, false, {dtypes[j]}}));
212 }
213 start += dtypes.size();
214 }
215 return Status::OK();
216 }
217
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)218 Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
219 const OpDef* fnode_sig = nullptr;
220 TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
221 NodeDef* gnode = AddNode(fnode.name());
222 gnode->set_op(fnode.op());
223 gnode->set_device(fnode.device());
224 int gnode_idx = nodes_.size() - 1;
225
226 // Input
227 const int num_args = fnode_sig->input_arg_size();
228 bool is_type_list; // ignored
229 DataTypeVector dtypes;
230 int fnode_arg_index = 0;
231 for (int i = 0; i < num_args; ++i) {
232 TF_RETURN_IF_ERROR(
233 ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
234 // Consume inputs (indexed by fnode_arg_index) until we have
235 // matched each element of dtypes (indexed by j).
236 for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
237 if (fnode_arg_index >= fnode.input_size()) {
238 // Should never happen if we computed dtypes correctly.
239 return errors::InvalidArgument(
240 "Attempt to access beyond input size: ", fnode_arg_index,
241 " >= ", fnode.input_size());
242 }
243 // Look up the next input.
244 const string& input_name = fnode.input(fnode_arg_index);
245 const auto* item = GetItemOrNull(input_name);
246 if (item == nullptr) {
247 return errors::InvalidArgument(
248 "input ", input_name,
249 " is not found: ", FormatNodeDefForError(fnode));
250 }
251 if (item->dtypes.size() > dtypes.size() - j) {
252 return errors::InvalidArgument("Input ", input_name, " too long for ",
253 fnode_sig->input_arg(i).name());
254 }
255 // Match up all the elements of this input (indexed by k) with
256 // elements of dtypes (advancing j).
257 for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
258 if (item->dtypes[k] != dtypes[j]) {
259 return errors::InvalidArgument(
260 "input ", fnode_sig->input_arg(i).name(), "[", j,
261 "] expected type ", DataTypeString(dtypes[j]),
262 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
263 input_name, "[", k, "]");
264 }
265 if (item->is_func_arg) {
266 AddInput(gnode_idx, item->nid + k, 0);
267 } else {
268 AddInput(gnode_idx, item->nid, item->idx + k);
269 }
270 }
271 }
272 }
273
274 // Control deps.
275 for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
276 const string& input = fnode.input(i);
277 if (input.empty() || input[0] != '^') {
278 return errors::InvalidArgument("Expected input[", i, "] == '", input,
279 "' to be a control input.");
280 }
281 int nid = -1;
282 const string node_name = input.substr(1);
283 const string node_colon = node_name + ":";
284 const string node_colon_bound = node_name + ";";
285 // index_ is a map sorted lexicographically, so the key we are looking for
286 // must lie in the range [node_name, node_colon_bound).
287 auto it = index_.lower_bound(node_name);
288 while (it != index_.end() && it->first <= node_colon_bound) {
289 if (it->first == node_name ||
290 tensorflow::str_util::StartsWith(it->first, node_colon)) {
291 nid = it->second.nid;
292 break;
293 }
294 ++it;
295 }
296 if (nid == -1) {
297 return errors::InvalidArgument("input[", i, "] == '", input,
298 "', is not found.");
299 }
300 AddDep(gnode_idx, nid);
301 }
302
303 // Attrs.
304 for (const auto& p : attrs) {
305 (*gnode->mutable_attr())[p.first] = p.second;
306 }
307
308 return Status::OK();
309 }
310
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)311 Status AddReturnNode(
312 const OpDef::ArgDef& ret_def, AttrSlice attrs,
313 const ::tensorflow::protobuf::Map<string, string>& ret_map,
314 bool ints_on_device, int* ret_index) {
315 auto ret_iter = ret_map.find(ret_def.name());
316 if (ret_iter == ret_map.end()) {
317 return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
318 }
319 bool is_type_list;
320 DataTypeVector dtypes;
321 TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
322 CHECK_GE(dtypes.size(), size_t{1});
323 const auto* item = GetItemOrNull(ret_iter->second);
324 if (item == nullptr) {
325 return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
326 ret_iter->second, " is not found.");
327 }
328 if (dtypes != item->dtypes) {
329 return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
330 " : ", DataTypeVectorString(dtypes),
331 " vs. ",
332 DataTypeVectorString(item->dtypes));
333 }
334 for (size_t i = 0; i < dtypes.size(); ++i) {
335 string name = strings::StrCat(ret_def.name(), "_RetVal");
336 if (dtypes.size() > 1) {
337 strings::StrAppend(&name, "_", i);
338 }
339 NodeDef* gnode = AddNode(name);
340 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
341 gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
342 } else {
343 gnode->set_op(FunctionLibraryDefinition::kRetOp);
344 }
345 AddInput(nodes_.size() - 1, item->nid, item->idx + i);
346 AddAttr("T", dtypes[i], gnode);
347 AddAttr("index", (*ret_index)++, gnode);
348 result_.ret_types.push_back(dtypes[i]);
349 }
350 return Status::OK();
351 }
352
353 // Adds the actual node inputs to the result graph by converting indexes to
354 // the node names.
AddNodeInputs()355 void AddNodeInputs() {
356 for (int i = 0; i < result_.nodes.size(); i++) {
357 NodeInfo& node_info = nodes_[i];
358 for (const auto& p : node_info.data_inputs) {
359 result_.nodes[i].add_input(Name(p.first, p.second));
360 }
361 for (int index : node_info.control_inputs) {
362 result_.nodes[i].add_input(Dep(index));
363 }
364 }
365 }
366
367 private:
368 // This is used to build a small index for all names that can be used as a
369 // node's input arguments.
370 //
371 // If is_func_arg is true, the name is a function's argument. In
372 // this case, the produced graph def has node[nid:nid + dtype.size()].
373 //
374 // Otherwise, the name is a function body's node return value. In
375 // this case, the produced graph def has one node node[nid] and
376 // the node's output index [idx ... idx + num) corresponds to the
377 // named outputs.
378 //
379 // In all cases, "dtype" specifies the data type.
380 struct NameInfoItem {
381 bool is_func_arg;
382 int nid;
383 int idx;
384 bool is_type_list;
385 DataTypeVector dtypes;
386 };
387
388 // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)389 Status AddItem(const string& name, const NameInfoItem& item) {
390 if (!index_.insert({name, item}).second) {
391 return errors::InvalidArgument(
392 strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
393 " name: "),
394 name);
395 }
396 return Status::OK();
397 }
398
GetItemOrNull(const string & name) const399 const NameInfoItem* GetItemOrNull(const string& name) const {
400 return gtl::FindOrNull(index_, name);
401 }
402
Dep(int node_index) const403 string Dep(int node_index) const {
404 return strings::StrCat("^", Name(node_index));
405 }
406
Name(int node_index) const407 string Name(int node_index) const {
408 CHECK_LT(node_index, nodes_.size());
409 return nodes_[node_index].name;
410 }
411
Name(int node_index,int output_index) const412 string Name(int node_index, int output_index) const {
413 if (output_index == 0) {
414 return Name(node_index);
415 } else {
416 return strings::StrCat(Name(node_index), ":", output_index);
417 }
418 }
419
AddNode(const string & name)420 NodeDef* AddNode(const string& name) {
421 result_.nodes.emplace_back();
422 NodeDef* gnode = &result_.nodes.back();
423 gnode->set_name(name);
424 nodes_.push_back({name, {}, {}});
425 CHECK_EQ(result_.nodes.size(), nodes_.size());
426 return gnode;
427 }
428
AddInput(int node_index,int output_node,int output_index)429 void AddInput(int node_index, int output_node, int output_index) {
430 CHECK_LT(node_index, nodes_.size());
431 nodes_[node_index].data_inputs.push_back(
432 std::make_pair(output_node, output_index));
433 }
434
AddDep(int node_index,int dep_index)435 void AddDep(int node_index, int dep_index) {
436 CHECK_LT(node_index, nodes_.size());
437 nodes_[node_index].control_inputs.push_back(dep_index);
438 }
439
440 GetFunctionSignature get_function_;
441 InstantiationResult& result_;
442 // A small index for all names that can be used as a node's input arguments.
443 std::map<string, NameInfoItem> index_;
444 // This contains information about a node in the new graph including the node
445 // names and input nodes' indexes.
446 struct NodeInfo {
447 string name;
448 // Data inputs where <n, k> means arg k of node n.
449 std::vector<std::pair<int, int>> data_inputs;
450 // Control inputs (dependencies).
451 std::vector<int> control_inputs;
452 };
453 // nodes_[i] is the information about result_.nodes[i].
454 std::vector<NodeInfo> nodes_;
455 };
456
457 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)458 string Print(const OpDef::ArgDef& arg) {
459 string out;
460 strings::StrAppend(&out, arg.name(), ":");
461 if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
462 if (!arg.number_attr().empty()) {
463 strings::StrAppend(&out, arg.number_attr(), "*");
464 }
465 if (arg.type() != DT_INVALID) {
466 strings::StrAppend(&out, DataTypeString(arg.type()));
467 } else {
468 strings::StrAppend(&out, arg.type_attr());
469 }
470 if (arg.is_ref()) strings::StrAppend(&out, ")");
471 return out;
472 }
473
474 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)475 string Print(const AttrValue& attr_value) {
476 if (attr_value.value_case() == AttrValue::kType) {
477 return DataTypeString(attr_value.type());
478 } else if ((attr_value.value_case() == AttrValue::kList) &&
479 (attr_value.list().type_size() > 0)) {
480 string ret = "{";
481 for (int i = 0; i < attr_value.list().type_size(); ++i) {
482 if (i > 0) strings::StrAppend(&ret, ", ");
483 strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
484 }
485 strings::StrAppend(&ret, "}");
486 return ret;
487 } else if (attr_value.value_case() == AttrValue::kFunc) {
488 if (attr_value.func().attr_size() == 0) {
489 return attr_value.func().name();
490 }
491 std::vector<string> entries;
492 for (auto p : attr_value.func().attr()) {
493 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
494 }
495 std::sort(entries.begin(), entries.end());
496 return strings::StrCat(attr_value.func().name(), "[",
497 str_util::Join(entries, ", "), "]");
498 }
499 return SummarizeAttrValue(attr_value);
500 }
501
502 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)503 string Print(const NodeDef& n) {
504 string out;
505 strings::StrAppend(&out, n.name(), " = ", n.op());
506 if (n.attr_size() > 0) {
507 std::vector<string> entries;
508 for (auto& a : n.attr()) {
509 entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
510 }
511 std::sort(entries.begin(), entries.end());
512 // Add a short device string at the end of all attributes.
513 if (!n.device().empty()) {
514 DeviceNameUtils::ParsedName parsed;
515 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
516 entries.push_back(
517 strings::StrCat("device=", parsed.type, ":", parsed.id));
518 } else {
519 entries.push_back("device=<FAILED_TO_PARSE>");
520 }
521 }
522 strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
523 }
524 strings::StrAppend(&out, "(");
525 std::vector<StringPiece> dat;
526 std::vector<string> dep;
527 for (StringPiece s : n.input()) {
528 if (str_util::ConsumePrefix(&s, "^")) {
529 dep.emplace_back(s);
530 } else {
531 dat.push_back(s);
532 }
533 }
534 strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
535 if (!dep.empty()) {
536 strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
537 }
538 return out;
539 }
540
Print(const FunctionDef & fdef)541 string Print(const FunctionDef& fdef) {
542 string out;
543 const OpDef& sig = fdef.signature();
544 strings::StrAppend(&out, "\n", sig.name());
545 if (sig.attr_size() > 0) {
546 strings::StrAppend(&out, "[");
547 for (int i = 0; i < sig.attr_size(); ++i) {
548 const auto& a = sig.attr(i);
549 if (i > 0) strings::StrAppend(&out, ", ");
550 if (a.type() == "type") {
551 strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
552 } else {
553 strings::StrAppend(&out, a.name(), ":", a.type());
554 }
555 }
556 strings::StrAppend(&out, "]");
557 }
558 strings::StrAppend(&out, "(");
559 for (int i = 0; i < sig.input_arg_size(); ++i) {
560 if (i > 0) strings::StrAppend(&out, ", ");
561 strings::StrAppend(&out, Print(sig.input_arg(i)));
562 }
563 strings::StrAppend(&out, ") -> (");
564 for (int i = 0; i < sig.output_arg_size(); ++i) {
565 if (i > 0) strings::StrAppend(&out, ", ");
566 strings::StrAppend(&out, Print(sig.output_arg(i)));
567 }
568 strings::StrAppend(&out, ") {\n");
569 for (const auto& n : fdef.node_def()) {
570 strings::StrAppend(&out, " ", Print(n), "\n");
571 }
572 for (const auto& cr : fdef.control_ret()) {
573 strings::StrAppend(&out, " @return ", cr.first, " = ", cr.second, "\n");
574 }
575 for (const auto& r : fdef.ret()) {
576 strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
577 }
578 strings::StrAppend(&out, "}\n");
579 return out;
580 }
581
Print(gtl::ArraySlice<const NodeDef * > nodes)582 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
583 std::vector<const NodeDef*> arg;
584 std::vector<const NodeDef*> ret;
585 std::vector<const NodeDef*> body;
586 for (const NodeDef* n : nodes) {
587 if (n->op() == FunctionLibraryDefinition::kArgOp ||
588 n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
589 arg.push_back(n);
590 } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
591 n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
592 ret.push_back(n);
593 } else {
594 body.push_back(n);
595 }
596 }
597 auto comp = [](const NodeDef* x, const NodeDef* y) {
598 int xi;
599 TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
600 int yi;
601 TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
602 return xi < yi;
603 };
604 std::sort(arg.begin(), arg.end(), comp);
605 std::sort(ret.begin(), ret.end(), comp);
606 string out;
607 strings::StrAppend(&out, "\n(");
608 auto get_type_and_device = [](const NodeDef& n) {
609 DataType dt;
610 if (!GetNodeAttr(n, "T", &dt).ok()) {
611 dt = DT_INVALID;
612 }
613 if (!n.device().empty()) {
614 DeviceNameUtils::ParsedName parsed;
615 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
616 return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
617 parsed.id);
618 } else {
619 LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
620 << n.op() << ":" << n.name();
621 return strings::StrCat(DataTypeString(dt), "@",
622 "<FAILED_TO_PARSE_DEVICE>");
623 }
624 }
625 return DataTypeString(dt);
626 };
627 for (size_t i = 0; i < arg.size(); ++i) {
628 const NodeDef* n = arg[i];
629 if (i > 0) strings::StrAppend(&out, ", ");
630 CHECK_GE(n->attr_size(), 2);
631 strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
632 }
633 strings::StrAppend(&out, ") -> (");
634 for (size_t i = 0; i < ret.size(); ++i) {
635 const NodeDef* n = ret[i];
636 if (i > 0) strings::StrAppend(&out, ", ");
637 CHECK_LE(2, n->attr_size());
638
639 // The _RetVal op should have a unique non-control input. We assert that
640 // here and add it to the output.
641 bool found_non_control_input = false;
642 for (const string& input : n->input()) {
643 if (!input.empty() && input[0] != '^') {
644 DCHECK_EQ(found_non_control_input, false)
645 << "RetVal node has more than one non-control input: "
646 << absl::StrJoin(n->input(), ", ");
647 strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
648 found_non_control_input = true;
649 }
650 }
651 DCHECK_EQ(found_non_control_input, true)
652 << "RetVal did not have any non-control inputs: "
653 << absl::StrJoin(n->input(), ", ");
654 }
655 strings::StrAppend(&out, ") {\n");
656 for (size_t i = 0; i < body.size(); ++i) {
657 strings::StrAppend(&out, " ", Print(*body[i]), "\n");
658 }
659 strings::StrAppend(&out, "}\n");
660 return out;
661 }
662
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)663 Status AddDefaultAttrs(const string& op,
664 const GetFunctionSignature& get_function,
665 AttrValueMap* attrs) {
666 const OpDef* op_def = nullptr;
667 TF_RETURN_IF_ERROR(get_function(op, &op_def));
668 AttrSlice attr_slice(attrs);
669 for (const auto& attr_def : op_def->attr()) {
670 if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
671 if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
672 return errors::Internal("Somehow duplicated: ", attr_def.name());
673 }
674 }
675 }
676 return Status::OK();
677 }
678
679 } // end namespace
680
681 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)682 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
683 GetFunctionSignature get_function,
684 InstantiationResult* result) {
685 VLOG(4) << "Instantiation Function: " << Print(fdef);
686
687 const OpDef& sig = fdef.signature();
688 TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
689
690 bool ints_on_device =
691 fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
692 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
693
694 FunctionInstantiationHelper helper(get_function, result);
695 Status s;
696 for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
697 s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device);
698 if (!s.ok()) {
699 errors::AppendToMessage(&s, "In ", Print(arg_def));
700 return s;
701 }
702 }
703
704 auto substitute = [attr_values](StringPiece name, AttrValue* val) {
705 if (const AttrValue* v = attr_values.Find(name)) {
706 *val = *v;
707 return true;
708 }
709 return false;
710 };
711
712 // Makes a copy of all attrs in fdef and substitutes placeholders.
713 // After this step, every attr is bound to a concrete value.
714 std::vector<AttrValueMap> node_attrs;
715 node_attrs.resize(fdef.node_def_size());
716 for (int i = 0; i < fdef.node_def_size(); ++i) {
717 for (auto attr : fdef.node_def(i).attr()) {
718 if (!SubstitutePlaceholders(substitute, &attr.second)) {
719 return errors::InvalidArgument("Failed to bind all placeholders in ",
720 SummarizeAttrValue(attr.second));
721 }
722 if (!node_attrs[i].insert(attr).second) {
723 return errors::Internal("Somehow duplicated: ", attr.first);
724 }
725 }
726 TF_RETURN_IF_ERROR(
727 AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
728 }
729
730 for (int i = 0; i < fdef.node_def_size(); ++i) {
731 s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
732 result->nodes.size() + i);
733 if (!s.ok()) {
734 errors::AppendToMessage(&s, "In ",
735 FormatNodeDefForError(fdef.node_def(i)));
736 return s;
737 }
738 }
739 // Emits one node for each fdef.node_def.
740 for (int i = 0; i < fdef.node_def_size(); ++i) {
741 s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
742 if (!s.ok()) {
743 errors::AppendToMessage(&s, "In ",
744 FormatNodeDefForError(fdef.node_def(i)));
745 return s;
746 }
747 }
748
749 // Emits nodes for the function's return values.
750 int ret_index = 0;
751 for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
752 s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
753 &ret_index);
754 if (!s.ok()) {
755 errors::AppendToMessage(&s, "In function output ", Print(ret_def));
756 return s;
757 }
758 }
759
760 // Adds the actual node inputs using the input indexes.
761 helper.AddNodeInputs();
762
763 return Status::OK();
764 }
765
DebugString(const FunctionDef & func_def)766 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
767
DebugString(const GraphDef & instantiated_func_def)768 string DebugString(const GraphDef& instantiated_func_def) {
769 std::vector<const NodeDef*> ptrs;
770 for (const NodeDef& n : instantiated_func_def.node()) {
771 ptrs.push_back(&n);
772 }
773 return Print(ptrs);
774 }
775
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)776 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
777 std::vector<const NodeDef*> ptrs;
778 for (const NodeDef& n : instantiated_func_nodes) {
779 ptrs.push_back(&n);
780 }
781 return Print(ptrs);
782 }
783
DebugStringWhole(const GraphDef & gdef)784 string DebugStringWhole(const GraphDef& gdef) {
785 string ret;
786 for (const auto& fdef : gdef.library().function()) {
787 strings::StrAppend(&ret, Print(fdef));
788 }
789 strings::StrAppend(&ret, "\n");
790 for (const auto& ndef : gdef.node()) {
791 strings::StrAppend(&ret, Print(ndef), "\n");
792 }
793 return ret;
794 }
795
796 namespace {
797
798 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
799 // Python, it's possible to access unset attrs, which returns a default value
800 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)801 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
802 std::map<string, AttrValue> set_attrs;
803 for (auto pair : fdef.attr()) {
804 if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
805 set_attrs[pair.first] = pair.second;
806 }
807 }
808 return set_attrs;
809 }
810
811 } // end namespace
812
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)813 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
814 if (!OpDefEqual(f1.signature(), f2.signature())) return false;
815
816 std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
817 std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
818 if (f1_attrs.size() != f2_attrs.size()) return false;
819 for (auto iter1 : f1_attrs) {
820 auto iter2 = f2_attrs.find(iter1.first);
821 if (iter2 == f2_attrs.end()) return false;
822 if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
823 }
824
825 if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
826 return false;
827 }
828
829 std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
830 std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
831 if (ret1 != ret2) return false;
832
833 std::map<string, string> control_ret1(f1.control_ret().begin(),
834 f1.control_ret().end());
835 std::map<string, string> control_ret2(f2.control_ret().begin(),
836 f2.control_ret().end());
837 if (control_ret1 != control_ret2) return false;
838
839 return true;
840 }
841
FunctionDefHash(const FunctionDef & fdef)842 uint64 FunctionDefHash(const FunctionDef& fdef) {
843 // signature
844 uint64 h = OpDefHash(fdef.signature());
845
846 // attrs
847 std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
848 for (const auto& p : attrs) {
849 h = Hash64(p.first.data(), p.first.size(), h);
850 h = Hash64Combine(AttrValueHash(p.second), h);
851 }
852
853 // node defs
854 h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
855
856 // output names
857 std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
858 for (const auto& p : ret) {
859 h = Hash64(p.first.data(), p.first.size(), h);
860 h = Hash64(p.second.data(), p.second.size(), h);
861 }
862
863 // control output names
864 std::map<string, string> control_ret(fdef.control_ret().begin(),
865 fdef.control_ret().end());
866 for (const auto& p : control_ret) {
867 h = Hash64(p.first.data(), p.first.size(), h);
868 h = Hash64(p.second.data(), p.second.size(), h);
869 }
870
871 return h;
872 }
873
874 static constexpr const char* const kExecutorAttr = "_executor";
875
876 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)877 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
878 AttrSlice attrs) {
879 if (!options.executor_type.empty()) {
880 return options.executor_type;
881 } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
882 return executor_attr->s();
883 } else {
884 return string();
885 }
886 }
887
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)888 string Canonicalize(const string& funcname, AttrSlice attrs,
889 const FunctionLibraryRuntime::InstantiateOptions& options) {
890 std::vector<string> entries;
891 entries.reserve(attrs.size() + static_cast<int>(options.target.empty()) +
892 options.input_devices.size());
893 for (auto p : attrs) {
894 if (p.first != kExecutorAttr) {
895 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
896 }
897 }
898 if (!options.target.empty()) {
899 entries.push_back(
900 strings::StrCat("_target", "=", str_util::CEscape(options.target)));
901 }
902 for (int i = 0; i < options.input_devices.size(); ++i) {
903 entries.push_back(strings::StrCat(
904 "_input_dev", i, "=", str_util::CEscape(options.input_devices[i])));
905 }
906 for (int i = 0; i < options.output_devices.size(); ++i) {
907 entries.push_back(strings::StrCat(
908 "_output_dev", i, "=", str_util::CEscape(options.output_devices[i])));
909 }
910 if (options.overlay_lib) {
911 entries.push_back(strings::StrCat(
912 "_overlay_lib", "=", reinterpret_cast<uintptr_t>(options.overlay_lib)));
913 }
914 if (!options.state_handle.empty()) {
915 entries.push_back(
916 strings::StrCat("_state_handle", "=", options.state_handle));
917 }
918 string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
919 if (!executor_type.empty()) {
920 entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
921 }
922 string config_proto_serialized;
923 options.config_proto.SerializeToString(&config_proto_serialized);
924 if (!config_proto_serialized.empty()) {
925 entries.push_back(strings::StrCat(
926 "_config_proto", "=", str_util::CEscape(config_proto_serialized)));
927 }
928 std::sort(entries.begin(), entries.end());
929 return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
930 }
931
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)932 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
933 DataTypeSlice ret_types)
934 : arg_types_(arg_types.begin(), arg_types.end()),
935 ret_types_(ret_types.begin(), ret_types.end()) {
936 args_.resize(arg_types_.size());
937 rets_.resize(ret_types_.size());
938 }
939
~FunctionCallFrame()940 FunctionCallFrame::~FunctionCallFrame() {}
941
SetArgs(gtl::ArraySlice<Tensor> args)942 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
943 // Input type checks.
944 if (args.size() != arg_types_.size()) {
945 return errors::InvalidArgument("Expects ", arg_types_.size(),
946 " arguments, but ", args.size(),
947 " is provided");
948 }
949 for (size_t i = 0; i < args.size(); ++i) {
950 if (arg_types_[i] != args[i].dtype()) {
951 return errors::InvalidArgument(
952 "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
953 DataTypeString(args[i].dtype()), " is provided");
954 }
955 args_[i] = args[i];
956 }
957 return Status::OK();
958 }
959
GetRetvals(std::vector<Tensor> * rets) const960 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
961 rets->clear();
962 rets->reserve(rets_.size());
963 for (size_t i = 0; i < rets_.size(); ++i) {
964 const auto& item = rets_[i];
965 if (item.has_val) {
966 rets->push_back(item.val);
967 } else {
968 return errors::Internal("Retval[", i, "] does not have value");
969 }
970 }
971 return Status::OK();
972 }
973
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)974 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
975 bool allow_dead_tensors) {
976 rets->clear();
977 rets->reserve(rets_.size());
978 for (size_t i = 0; i < rets_.size(); ++i) {
979 if (rets_[i].has_val) {
980 rets->emplace_back(std::move(rets_[i].val));
981 } else if (allow_dead_tensors) {
982 rets->emplace_back();
983 } else {
984 return errors::Internal("Retval[", i, "] does not have value");
985 }
986 }
987 return Status::OK();
988 }
989
GetArg(int index,Tensor * val) const990 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
991 if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
992 return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
993 args_.size(), ")");
994 }
995 *val = args_[index];
996 return Status::OK();
997 }
998
SetRetval(int index,const Tensor & val)999 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1000 if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1001 return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1002 rets_.size(), ")");
1003 }
1004 if (val.dtype() != ret_types_[index]) {
1005 return errors::InvalidArgument(
1006 "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1007 ", but ", DataTypeString(val.dtype()), " is provided.");
1008 }
1009 Retval* item = &rets_[index];
1010 if (!item->has_val) {
1011 item->has_val = true;
1012 item->val = val;
1013 } else {
1014 return errors::Internal("Retval[", index, "] has already been set.");
1015 }
1016 return Status::OK();
1017 }
1018
1019 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in)1020 FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
1021 : fdef(fdef_in),
1022 // Exact shape inference for functions is handled by ShapeRefiner.
1023 // Here we pass a dummy shape inference function for legacy code paths.
1024 op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1025 true /* is_function */) {}
1026
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1027 FunctionLibraryDefinition::FunctionLibraryDefinition(
1028 const FunctionLibraryDefinition& other)
1029 : default_registry_(other.default_registry_) {
1030 tf_shared_lock l(other.mu_);
1031 for (const auto& it : other.function_defs_) {
1032 TF_CHECK_OK(AddFunctionDef(it.second->fdef));
1033 }
1034 func_grad_ = other.func_grad_;
1035 }
1036
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1037 FunctionLibraryDefinition::FunctionLibraryDefinition(
1038 const OpRegistryInterface* default_registry,
1039 const FunctionDefLibrary& def_lib)
1040 : default_registry_(default_registry),
1041 function_defs_(def_lib.function_size()) {
1042 for (const auto& fdef : def_lib.function()) {
1043 // The latter function definition wins.
1044 auto& ptr = function_defs_[fdef.signature().name()];
1045 ptr.reset(new FunctionDefAndOpRegistration(fdef));
1046 }
1047 for (const auto& grad : def_lib.gradient()) {
1048 func_grad_[grad.function_name()] = grad.gradient_func();
1049 }
1050 }
1051
~FunctionLibraryDefinition()1052 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1053
Contains(const string & func) const1054 bool FunctionLibraryDefinition::Contains(const string& func) const {
1055 tf_shared_lock l(mu_);
1056 return function_defs_.find(func) != function_defs_.end();
1057 }
1058
Find(const string & func) const1059 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1060 tf_shared_lock l(mu_);
1061 return FindHelper(func);
1062 }
1063
FindHelper(const string & func) const1064 const FunctionDef* FunctionLibraryDefinition::FindHelper(
1065 const string& func) const {
1066 auto iter = function_defs_.find(func);
1067 if (iter == function_defs_.end()) {
1068 return nullptr;
1069 } else {
1070 return &iter->second->fdef;
1071 }
1072 }
1073
AddFunctionDef(const FunctionDef & fdef)1074 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
1075 mutex_lock l(mu_);
1076 bool added;
1077 return AddFunctionDefHelper(fdef, &added);
1078 }
1079
AddFunctionDefHelper(const FunctionDef & fdef,bool * added)1080 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
1081 bool* added) {
1082 *added = false;
1083 std::unique_ptr<FunctionDefAndOpRegistration>* entry =
1084 &function_defs_[fdef.signature().name()];
1085 if (*entry != nullptr) {
1086 if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
1087 return errors::InvalidArgument(
1088 "Cannot add function '", fdef.signature().name(),
1089 "' because a different function with the same name already "
1090 "exists.");
1091 }
1092 // Ignore duplicate FunctionDefs
1093 return Status::OK();
1094 }
1095 const OpDef* op_def;
1096 if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1097 return errors::InvalidArgument(
1098 "Cannot add function '", fdef.signature().name(),
1099 "' because an op with the same name already exists.");
1100 }
1101 entry->reset(new FunctionDefAndOpRegistration(fdef));
1102 *added = true;
1103 return Status::OK();
1104 }
1105
AddGradientDef(const GradientDef & grad)1106 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1107 mutex_lock l(mu_);
1108 bool added;
1109 return AddGradientDefHelper(grad, &added);
1110 }
1111
AddGradientDefHelper(const GradientDef & grad,bool * added)1112 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1113 bool* added) {
1114 *added = false;
1115 string* entry = &func_grad_[grad.function_name()];
1116 if (!entry->empty()) {
1117 if (*entry != grad.gradient_func()) {
1118 return errors::InvalidArgument(
1119 "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1120 grad.function_name(), "' because it already has gradient function ",
1121 "'", *entry, "'");
1122 }
1123 // Ignore duplicate GradientDefs
1124 return Status::OK();
1125 }
1126 *entry = grad.gradient_func();
1127 *added = true;
1128 return Status::OK();
1129 }
1130
AddLibrary(const FunctionLibraryDefinition & other)1131 Status FunctionLibraryDefinition::AddLibrary(
1132 const FunctionLibraryDefinition& other) {
1133 // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1134 // the duration of the function could lead to deadlock).
1135 FunctionLibraryDefinition clone(other);
1136 mutex_lock l(mu_);
1137 // Remember the funcs and grads that we added successfully so that
1138 // we can roll them back on error.
1139 std::vector<string> funcs;
1140 std::vector<string> funcs_with_grads;
1141 Status s;
1142 bool added;
1143 for (auto iter : clone.function_defs_) {
1144 s = AddFunctionDefHelper(iter.second->fdef, &added);
1145 if (!s.ok()) {
1146 Remove(funcs, funcs_with_grads);
1147 return s;
1148 }
1149 if (added) {
1150 funcs.push_back(iter.second->fdef.signature().name());
1151 }
1152 }
1153 for (auto iter : clone.func_grad_) {
1154 GradientDef grad;
1155 grad.set_function_name(iter.first);
1156 grad.set_gradient_func(iter.second);
1157 s = AddGradientDefHelper(grad, &added);
1158 if (!s.ok()) {
1159 Remove(funcs, funcs_with_grads);
1160 return s;
1161 }
1162 if (added) {
1163 funcs_with_grads.push_back(grad.function_name());
1164 }
1165 }
1166 return Status::OK();
1167 }
1168
AddLibrary(const FunctionDefLibrary & lib_def)1169 Status FunctionLibraryDefinition::AddLibrary(
1170 const FunctionDefLibrary& lib_def) {
1171 // Remember the funcs and grads that we added successfully so that
1172 // we can roll them back on error.
1173 mutex_lock l(mu_);
1174 std::vector<string> funcs;
1175 std::vector<string> funcs_with_grads;
1176 Status s;
1177 bool added;
1178 for (const FunctionDef& fdef : lib_def.function()) {
1179 s = AddFunctionDefHelper(fdef, &added);
1180 if (!s.ok()) {
1181 Remove(funcs, funcs_with_grads);
1182 return s;
1183 }
1184 if (added) {
1185 funcs.push_back(fdef.signature().name());
1186 }
1187 }
1188 for (const GradientDef& grad : lib_def.gradient()) {
1189 s = AddGradientDefHelper(grad, &added);
1190 if (!s.ok()) {
1191 Remove(funcs, funcs_with_grads);
1192 return s;
1193 }
1194 if (added) {
1195 funcs_with_grads.push_back(grad.function_name());
1196 }
1197 }
1198 return Status::OK();
1199 }
1200
ReplaceFunction(const string & func,const FunctionDef & fdef)1201 Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
1202 const FunctionDef& fdef) {
1203 mutex_lock l(mu_);
1204 bool added;
1205 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1206 TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
1207 return Status::OK();
1208 }
1209
ReplaceGradient(const GradientDef & grad)1210 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1211 mutex_lock l(mu_);
1212 bool added;
1213 TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1214 TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1215 return Status::OK();
1216 }
1217
RemoveFunction(const string & func)1218 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1219 mutex_lock l(mu_);
1220 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1221 return Status::OK();
1222 }
1223
RemoveFunctionHelper(const string & func)1224 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1225 const auto& i = function_defs_.find(func);
1226 if (i == function_defs_.end()) {
1227 return errors::InvalidArgument("Tried to remove non-existent function ",
1228 func);
1229 }
1230 function_defs_.erase(i);
1231 return Status::OK();
1232 }
1233
RemoveGradient(const string & func)1234 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1235 const auto& i = func_grad_.find(func);
1236 if (i == func_grad_.end()) {
1237 return errors::InvalidArgument("Tried to remove non-existent gradient ",
1238 func);
1239 }
1240 func_grad_.erase(i);
1241 return Status::OK();
1242 }
1243
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1244 void FunctionLibraryDefinition::Remove(
1245 const std::vector<string>& funcs,
1246 const std::vector<string>& funcs_with_grads) {
1247 for (const string& f : funcs) {
1248 Status s = RemoveFunctionHelper(f);
1249 DCHECK(s.ok());
1250 }
1251 for (const string& f : funcs_with_grads) {
1252 Status s = RemoveGradient(f);
1253 DCHECK(s.ok());
1254 }
1255 }
1256
FindGradient(const string & func) const1257 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1258 tf_shared_lock l(mu_);
1259 return gtl::FindWithDefault(func_grad_, func, "");
1260 }
1261
FindGradientHelper(const string & func) const1262 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1263 return gtl::FindWithDefault(func_grad_, func, "");
1264 }
1265
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1266 Status FunctionLibraryDefinition::LookUp(
1267 const string& op, const OpRegistrationData** op_reg_data) const {
1268 tf_shared_lock l(mu_);
1269 auto iter = function_defs_.find(op);
1270 if (iter != function_defs_.end()) {
1271 *op_reg_data = &iter->second->op_registration_data;
1272 return Status::OK();
1273 }
1274 return default_registry_->LookUp(op, op_reg_data);
1275 }
1276
UniqueFunctionName(StringPiece prefix) const1277 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1278 tf_shared_lock l(mu_);
1279 int index = 0;
1280 string name = strings::StrCat(prefix, index);
1281 while (function_defs_.find(name) != function_defs_.end()) {
1282 ++index;
1283 name = strings::StrCat(prefix, index);
1284 }
1285 return name;
1286 }
1287
GetAttrImpl(const NodeDef & ndef) const1288 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1289 const NodeDef& ndef) const {
1290 if (ndef.op() != kGradientOp) {
1291 // If 'ndef' calls a function and the function's def has the attr,
1292 // returns it.
1293 return Find(ndef.op());
1294 }
1295
1296 // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1297 // Foo's attributes.
1298 const NameAttrList* forward_func_attrs;
1299 if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
1300 return nullptr;
1301 }
1302 const string& func_name = forward_func_attrs->name();
1303 {
1304 tf_shared_lock l(mu_);
1305 const string& grad_name = FindGradientHelper(func_name);
1306 // If 'func' has a user-defined gradient function, uses the grad
1307 // function's attrs to see if noinline is specified. Otherwise,
1308 // uses func's attrs.
1309 if (!grad_name.empty()) {
1310 return FindHelper(grad_name);
1311 }
1312 return FindHelper(func_name);
1313 }
1314 }
1315
ListFunctionNames() const1316 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1317 std::vector<string> function_names;
1318 tf_shared_lock l(mu_);
1319 function_names.reserve(function_defs_.size());
1320 for (const auto& it : function_defs_) {
1321 function_names.emplace_back(it.first);
1322 }
1323 return function_names;
1324 }
1325
ToProto() const1326 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1327 FunctionDefLibrary lib;
1328 tf_shared_lock l(mu_);
1329 for (const auto& f : function_defs_) {
1330 *lib.add_function() = f.second->fdef;
1331 }
1332 for (const auto& g : func_grad_) {
1333 GradientDef* gd = lib.add_gradient();
1334 gd->set_function_name(g.first);
1335 gd->set_gradient_func(g.second);
1336 }
1337 return lib;
1338 }
1339
1340 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1341 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1342 const string& attr, T* value) const {
1343 const FunctionDef* fdef = GetAttrImpl(ndef);
1344 if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
1345 return Status::OK();
1346 }
1347 return errors::InvalidArgument("Attr ", attr, " is not defined.");
1348 }
1349
1350 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1351 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1352 T* value) const {
1353 return GetAttr(node.def(), attr, value);
1354 }
1355
1356 #define GET_ATTR(T) \
1357 template Status FunctionLibraryDefinition::GetAttr(const Node&, \
1358 const string&, T*) const; \
1359 template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
1360 const string&, T*) const;
1361 GET_ATTR(string)
1362 GET_ATTR(bool)
1363 #undef GET_ATTR
1364
1365 namespace {
1366
1367 constexpr char kApiImplements[] = "api_implements";
1368
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1369 absl::flat_hash_set<string> ReachableFunctions(
1370 const FunctionLibraryDefinition& flib,
1371 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1372 // Functions that are reachable from the graph.
1373 absl::flat_hash_set<string> reachable_funcs;
1374
1375 // For any functions, if it has attribute "api_implements" =
1376 // "some_interface" and it is reachable, then it means any other
1377 // function with same attribute name and value could also be potentially
1378 // reachable, eg via implementation_selector swapping the
1379 // nodedef.
1380 absl::flat_hash_set<string> reachable_api_interface;
1381
1382 // Functions might be reachable from the nested function calls, so we keep a
1383 // queue of functions that we have to check.
1384 gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1385
1386 // Add reachable and not already processed functions to the functions queue.
1387 const auto add_to_func_queue = [&](const string& func_name) {
1388 const FunctionDef* func = flib.Find(func_name);
1389 if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1390 func_queue.push_back(func);
1391 }
1392 };
1393
1394 // Add all the functions that are reachable from the given node to the queue.
1395 const auto process_node = [&](const NodeDef& node) {
1396 // Node itself can be a call to the function.
1397 add_to_func_queue(node.op());
1398
1399 // Or node can have an attribute referencing a function.
1400 for (const auto& attr : node.attr()) {
1401 const auto& attr_value = attr.second;
1402
1403 // 1. AttrValue.func
1404 if (attr_value.has_func()) {
1405 add_to_func_queue(attr_value.func().name());
1406 }
1407
1408 // 2. AttrValue.ListValue.func
1409 if (attr_value.has_list()) {
1410 for (const auto& func : attr_value.list().func()) {
1411 add_to_func_queue(func.name());
1412 }
1413 }
1414 }
1415 };
1416
1417 // Add all functions that are directly called from the optimized graph.
1418 std::for_each(nodes.begin(), nodes.end(), process_node);
1419
1420 // Process all reachable functions.
1421 while (!func_queue.empty()) {
1422 const FunctionDef* func = func_queue.back();
1423 func_queue.pop_back();
1424
1425 const string& func_name = func->signature().name();
1426 reachable_funcs.insert(func_name);
1427
1428 const auto attr_it = func->attr().find(kApiImplements);
1429 if (attr_it != func->attr().end()) {
1430 reachable_api_interface.insert(attr_it->second.s());
1431 }
1432
1433 // Find all the functions called from the function body.
1434 const auto& func_body = func->node_def();
1435 std::for_each(func_body.begin(), func_body.end(), process_node);
1436
1437 // Check if the function has a registered gradient.
1438 const string grad_func_name = flib.FindGradient(func_name);
1439 if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1440 }
1441
1442 for (const auto& func_name : flib.ListFunctionNames()) {
1443 const auto& func_def = flib.Find(func_name);
1444 const auto attr_it = func_def->attr().find(kApiImplements);
1445 if (attr_it != func_def->attr().end()) {
1446 if (reachable_api_interface.contains(attr_it->second.s())) {
1447 reachable_funcs.insert(func_name);
1448 }
1449 }
1450 }
1451
1452 return reachable_funcs;
1453 }
1454
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1455 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1456 const FunctionLibraryDefinition& flib,
1457 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1458 absl::flat_hash_set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1459
1460 FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1461 FunctionDefLibrary());
1462
1463 for (const string& func_name : reachable_funcs) {
1464 const FunctionDef* func = flib.Find(func_name);
1465 DCHECK_NE(func, nullptr);
1466 // That should never fail, because we copy functions from valid flib and use
1467 // the same default registry.
1468 const Status added = reachable_flib.AddFunctionDef(*func);
1469 DCHECK(added.ok());
1470
1471 const string grad_func_name = flib.FindGradient(func_name);
1472 if (!grad_func_name.empty()) {
1473 GradientDef grad;
1474 grad.set_function_name(func_name);
1475 grad.set_gradient_func(grad_func_name);
1476 // It can only fail if function already has a gradient function.
1477 const Status added_grad = reachable_flib.AddGradientDef(grad);
1478 DCHECK(added_grad.ok());
1479 }
1480 }
1481
1482 return reachable_flib;
1483 }
1484
1485 } // namespace
1486
ReachableDefinitions(const GraphDef & graph) const1487 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1488 const GraphDef& graph) const {
1489 return ReachableFunctionLibraryDefinition(*this, graph.node());
1490 }
1491
ReachableDefinitions(const FunctionDef & func) const1492 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1493 const FunctionDef& func) const {
1494 return ReachableFunctionLibraryDefinition(*this, func.node_def());
1495 }
1496
InitFromString(StringPiece val)1497 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1498 if (val.size() >= 2 && val[0] == '$') {
1499 proto.set_placeholder(val.data() + 1, val.size() - 1);
1500 } else {
1501 SetAttrValue(val, &proto);
1502 }
1503 }
1504
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1505 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1506 const string& name,
1507 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1508 AttrValueWrapper ret;
1509 ret.proto.mutable_func()->set_name(name);
1510 for (const auto& a : attrs) {
1511 ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1512 }
1513 return ret;
1514 }
1515
ToNodeDef() const1516 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1517 NodeDef n;
1518 n.set_op(this->op);
1519 n.set_name(this->ret[0]);
1520 for (const auto& a : this->attr) {
1521 n.mutable_attr()->insert({a.first, a.second.proto});
1522 }
1523 for (const string& a : this->arg) {
1524 n.add_input(a);
1525 }
1526 for (const string& d : this->dep) {
1527 n.add_input(strings::StrCat("^", d));
1528 }
1529 if (!this->device.empty()) {
1530 n.set_device(this->device);
1531 }
1532 return n;
1533 }
1534
1535 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1536 FunctionDef FunctionDefHelper::Create(
1537 const string& function_name, gtl::ArraySlice<string> in_def,
1538 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1539 gtl::ArraySlice<Node> node_def,
1540 gtl::ArraySlice<std::pair<string, string>> ret_def,
1541 gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1542 FunctionDef fdef;
1543
1544 // Signature
1545 OpDefBuilder b(function_name);
1546 for (const auto& i : in_def) b.Input(i);
1547 for (const auto& o : out_def) b.Output(o);
1548 for (const auto& a : attr_def) b.Attr(a);
1549 for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1550
1551 OpRegistrationData op_reg_data;
1552 TF_CHECK_OK(b.Finalize(&op_reg_data));
1553 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1554
1555 // Function body
1556 for (const auto& n : node_def) {
1557 *(fdef.add_node_def()) = n.ToNodeDef();
1558 }
1559
1560 // Returns
1561 for (const auto& r : ret_def) {
1562 fdef.mutable_ret()->insert({r.first, r.second});
1563 }
1564
1565 // Control returns
1566 for (const auto& cr : control_ret_def) {
1567 fdef.mutable_control_ret()->insert({cr.first, cr.second});
1568 }
1569
1570 auto* op_def_registry = OpRegistry::Global();
1571 // Check if any op is stateful.
1572 for (const auto& n : node_def) {
1573 const OpDef* op_def = nullptr;
1574 auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1575 // Lookup can fail if e.g. we are calling a function that was not yet
1576 // defined. If it happens, conservatively assume the op is stateful.
1577 if (!status.ok() || op_def->is_stateful()) {
1578 fdef.mutable_signature()->set_is_stateful(true);
1579 }
1580 }
1581
1582 return fdef;
1583 }
1584
1585 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1586 FunctionDef FunctionDefHelper::Create(
1587 const string& function_name, gtl::ArraySlice<string> in_def,
1588 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1589 gtl::ArraySlice<Node> node_def,
1590 gtl::ArraySlice<std::pair<string, string>> ret_def) {
1591 return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1592 /*control_ret_def=*/{});
1593 }
1594
1595 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1596 FunctionDef FunctionDefHelper::Define(const string& name,
1597 gtl::ArraySlice<string> arg_def,
1598 gtl::ArraySlice<string> ret_def,
1599 gtl::ArraySlice<string> attr_def,
1600 gtl::ArraySlice<Node> node_def) {
1601 FunctionDef fdef;
1602 OpDefBuilder b(name);
1603 for (const auto& a : arg_def) b.Input(a);
1604 for (const auto& r : ret_def) b.Output(r);
1605 for (const auto& a : attr_def) b.Attr(a);
1606
1607 OpRegistrationData op_reg_data;
1608 TF_CHECK_OK(b.Finalize(&op_reg_data));
1609 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1610
1611 // Mapping from legacy output names to NodeDef outputs.
1612 std::unordered_map<string, string> ret_index;
1613 for (const auto& a : fdef.signature().input_arg()) {
1614 ret_index[a.name()] = a.name();
1615 }
1616
1617 // For looking up OpDefs
1618 auto* op_def_registry = OpRegistry::Global();
1619
1620 // Function body
1621 for (const auto& src : node_def) {
1622 NodeDef* n = fdef.add_node_def();
1623 n->set_op(src.op);
1624 n->set_name(src.ret[0]);
1625 for (const auto& a : src.attr) {
1626 n->mutable_attr()->insert({a.first, a.second.proto});
1627 }
1628 for (const string& a : src.arg) {
1629 const auto iter = ret_index.find(a);
1630 CHECK(iter != ret_index.end())
1631 << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1632 n->add_input(iter->second);
1633 }
1634 for (const string& d : src.dep) {
1635 n->add_input(strings::StrCat("^", d));
1636 }
1637
1638 // Add the outputs of this node to ret_index.
1639 const OpDef* op_def = nullptr;
1640 TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1641 CHECK(op_def != nullptr) << n->op();
1642 NameRangeMap output_names;
1643 TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1644 for (const auto& o : output_names) {
1645 CHECK_LE(o.second.second, src.ret.size())
1646 << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1647 << "' of " << name;
1648 for (int i = o.second.first; i < o.second.second; ++i) {
1649 ret_index[src.ret[i]] =
1650 strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1651 }
1652 }
1653 if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1654 }
1655
1656 // Returns
1657 for (const auto& r : fdef.signature().output_arg()) {
1658 const auto iter = ret_index.find(r.name());
1659 CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1660 fdef.mutable_ret()->insert({r.name(), iter->second});
1661 }
1662 return fdef;
1663 }
1664
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1665 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1666 gtl::ArraySlice<string> ret_def,
1667 gtl::ArraySlice<string> attr_def,
1668 gtl::ArraySlice<Node> node_def) {
1669 return Define("_", arg_def, ret_def, attr_def, node_def);
1670 }
1671
1672 namespace gradient {
1673
1674 typedef std::unordered_map<string, Creator> OpGradFactory;
1675
GetOpGradFactory()1676 OpGradFactory* GetOpGradFactory() {
1677 static OpGradFactory* factory = new OpGradFactory;
1678 return factory;
1679 }
1680
RegisterOp(const string & op,Creator func)1681 bool RegisterOp(const string& op, Creator func) {
1682 CHECK(GetOpGradFactory()->insert({op, func}).second)
1683 << "Duplicated gradient for " << op;
1684 return true;
1685 }
1686
GetOpGradientCreator(const string & op,Creator * creator)1687 Status GetOpGradientCreator(const string& op, Creator* creator) {
1688 auto fac = GetOpGradFactory();
1689 auto iter = fac->find(op);
1690 if (iter == fac->end()) {
1691 return errors::NotFound("No gradient defined for op: ", op);
1692 }
1693 *creator = iter->second;
1694 return Status::OK();
1695 }
1696
1697 } // end namespace gradient
1698
1699 } // namespace tensorflow
1700