• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 
18 #include <ctype.h>
19 
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/lib/strings/proto_serialization.h"
41 #include "tensorflow/core/util/device_name_utils.h"
42 #include "tensorflow/core/util/equal_graph_def.h"
43 
44 namespace tensorflow {
45 
46 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
47 /* static */ constexpr const char* const
48     FunctionLibraryDefinition::kDeviceArgOp;
49 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
50 /* static */ constexpr const char* const
51     FunctionLibraryDefinition::kDeviceRetOp;
52 /* static */ constexpr const char* const
53     FunctionLibraryDefinition::kIntsOnDeviceAttr;
54 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
55 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
56 
57 // Extracts the actual type from "attr_values" based on its definition
58 // "arg_def".
59 //
60 // If "arg_def" is a N*T type, *is_type_list is set to false, and
61 // *dtypes is set to be a vector of size N and each element is T.
62 //
63 // If "arg_def" is a list(type), *is_type_list is set to true, and
64 // *dtypes is set to be a vector of types specified in attrs for
65 // arg_def.
66 //
67 // Otherwise (arg_def is a simple type T), *is_type_list is set to
68 // false, and *dtypes is set to a single element vector, whose only
69 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)70 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
71                   bool* is_type_list, DataTypeVector* dtypes) {
72   dtypes->clear();
73   if (!arg_def.type_list_attr().empty()) {
74     const AttrValue* v = attrs.Find(arg_def.type_list_attr());
75     if (v == nullptr) {
76       return errors::NotFound("type attr not found: ",
77                               arg_def.type_list_attr());
78     }
79     *is_type_list = true;
80     for (int i = 0; i < v->list().type_size(); ++i) {
81       dtypes->push_back(v->list().type(i));
82     }
83     return Status::OK();
84   }
85 
86   *is_type_list = false;
87   int num = 1;
88   if (!arg_def.number_attr().empty()) {
89     const AttrValue* v = attrs.Find(arg_def.number_attr());
90     if (v == nullptr) {
91       return errors::NotFound("type attr not found: ", arg_def.type_attr());
92     }
93     num = v->i();
94   }
95 
96   DataType dtype;
97   if (arg_def.type() != DT_INVALID) {
98     dtype = arg_def.type();
99   } else if (arg_def.type_attr().empty()) {
100     dtype = DT_INVALID;
101   } else {
102     const AttrValue* v = attrs.Find(arg_def.type_attr());
103     if (v == nullptr) {
104       return errors::NotFound("type attr not found: ", arg_def.type_attr());
105     }
106     dtype = v->type();
107   }
108   dtypes->resize(num, dtype);
109   return Status::OK();
110 }
111 
112 namespace {
113 
114 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)115 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
116   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
117 }
118 
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)119 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
120   // attr_values should specify all attrs defined in fdef.
121   for (const auto& a : sig.attr()) {
122     const AttrValue* v = attr_values.Find(a.name());
123     if (!v) {
124       return errors::NotFound("Attr ", a.name(), " is not found from ",
125                               SummarizeOpDef(sig));
126     }
127     Status status = AttrValueHasType(*v, a.type());
128     if (!status.ok()) {
129       errors::AppendToMessage(&status, "for attr '", a.name(), "'");
130       return status;
131     }
132   }
133 
134 // TODO(josh11b): Enable this code once it works with function gradients.
135 // Right now the C++ function gradient code assumes it can pass
136 // all the attrs of the function to the gradient, and any attrs that
137 // the gradient doesn't care about will be ignored.
138 #if 0
139   if (attr_values.size() != sig.attr_size()) {
140     for (const auto& a : attr_values) {
141       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
142       bool found = false;
143       for (const auto& s : sig.attr()) {
144         if (a.first == s.name()) {
145           found = true;
146           break;
147         }
148       }
149       if (!found) {
150         return errors::NotFound("Attr ", a.first, " is not found in ",
151                                 SummarizeOpDef(sig));
152       }
153     }
154   }
155 #endif
156 
157   return Status::OK();
158 }
159 
160 // A helper class for instantiating functions. This contains shared information
161 // like the resulting graph and node name index.
162 class FunctionInstantiationHelper {
163  public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)164   FunctionInstantiationHelper(GetFunctionSignature get_function,
165                               InstantiationResult* result)
166       : get_function_(std ::move(get_function)), result_(*result) {
167     result_.nodes.clear();
168   }
169 
170   // Builds index for nodes that can be used as node's input arguments.
171   // `resource_arg_unique_id`: if non-negative, will be populated to the
172   // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64_t resource_arg_unique_id)173   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
174                             const FunctionDef::ArgAttrs* arg_attrs,
175                             bool ints_on_device,
176                             int64_t resource_arg_unique_id) {
177     bool is_type_list;
178     DataTypeVector dtypes;
179     TF_RETURN_IF_ERROR(
180         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
181     CHECK_GE(dtypes.size(), size_t{1});
182     int arg_index = result_.nodes.size();
183     TF_RETURN_IF_ERROR(
184         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
185     // Creates dtypes.size() nodes in the graph.
186     for (size_t i = 0; i < dtypes.size(); ++i) {
187       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
188                                  {true, arg_index, 0, false, {dtypes[i]}}));
189       DCHECK_EQ(arg_index, result_.nodes.size());
190       string name = arg_def.name();
191       if (dtypes.size() > 1) {
192         strings::StrAppend(&name, "_", i);
193       }
194       NodeDef* gnode = AddNode(name);
195       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
196         gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
197       } else {
198         gnode->set_op(FunctionLibraryDefinition::kArgOp);
199       }
200       DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
201       AddAttr("T", dtype, gnode);
202       AddAttr("index", arg_index, gnode);
203       if (resource_arg_unique_id >= 0) {
204         AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
205       }
206       if (arg_attrs) {
207         for (const auto& arg_attr : arg_attrs->attr()) {
208           AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
209         }
210       }
211       result_.arg_types.push_back(dtypes[i]);
212       ++arg_index;
213     }
214     return Status::OK();
215   }
216 
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)217   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
218                               const int arg_index) {
219     const OpDef* node_sig = nullptr;
220     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
221     if (node_sig->output_arg_size() == 0) {
222       return AddItem(node.name(), {false, arg_index, 0, false, {}});
223     }
224     const int num_retval = node_sig->output_arg_size();
225     int start = 0;
226     bool is_type_list;
227     DataTypeVector dtypes;
228     for (int i = 0; i < num_retval; ++i) {
229       TF_RETURN_IF_ERROR(
230           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
231       // Note that we rely on the backwards-compatibility test enforcing
232       // that output_arg(*).name() doesn't change here.
233       const string base_name =
234           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
235       TF_RETURN_IF_ERROR(
236           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
237       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
238         TF_RETURN_IF_ERROR(
239             AddItem(strings::StrCat(base_name, ":", j),
240                     {false, arg_index, start + j, false, {dtypes[j]}}));
241       }
242       start += dtypes.size();
243     }
244     return Status::OK();
245   }
246 
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)247   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
248     const OpDef* fnode_sig = nullptr;
249     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
250     NodeDef* gnode = AddNode(fnode.name());
251     gnode->set_op(fnode.op());
252     gnode->set_device(fnode.device());
253     int gnode_idx = nodes_.size() - 1;
254 
255     // Input
256     const int num_args = fnode_sig->input_arg_size();
257     bool is_type_list;  // ignored
258     DataTypeVector dtypes;
259     int fnode_arg_index = 0;
260     for (int i = 0; i < num_args; ++i) {
261       TF_RETURN_IF_ERROR(
262           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
263       // Consume inputs (indexed by fnode_arg_index) until we have
264       // matched each element of dtypes (indexed by j).
265       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
266         if (fnode_arg_index >= fnode.input_size()) {
267           // Should never happen if we computed dtypes correctly.
268           return errors::InvalidArgument(
269               "Attempt to access beyond input size: ", fnode_arg_index,
270               " >= ", fnode.input_size());
271         }
272         // Look up the next input.
273         const string& input_name = fnode.input(fnode_arg_index);
274         const auto* item = GetItemOrNull(input_name);
275         if (item == nullptr) {
276           return errors::InvalidArgument(
277               "input ", input_name,
278               " is not found: ", FormatNodeDefForError(fnode));
279         }
280         if (item->dtypes.size() > dtypes.size() - j) {
281           return errors::InvalidArgument("Input ", input_name, " too long for ",
282                                          fnode_sig->input_arg(i).name());
283         }
284         // Match up all the elements of this input (indexed by k) with
285         // elements of dtypes (advancing j).
286         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
287           if (item->dtypes[k] != dtypes[j]) {
288             return errors::InvalidArgument(
289                 "input ", fnode_sig->input_arg(i).name(), "[", j,
290                 "] expected type ", DataTypeString(dtypes[j]),
291                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
292                 input_name, "[", k, "]");
293           }
294           if (item->is_func_arg) {
295             AddInput(gnode_idx, item->nid + k, 0);
296           } else {
297             AddInput(gnode_idx, item->nid, item->idx + k);
298           }
299         }
300       }
301     }
302 
303     // Control deps.
304     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
305       const string& input = fnode.input(i);
306       if (input.empty() || input[0] != '^') {
307         return errors::InvalidArgument("Expected input[", i, "] == '", input,
308                                        "' to be a control input.");
309       }
310       int nid = -1;
311       const string node_name = input.substr(1);
312       const string node_colon = node_name + ":";
313       const string node_colon_bound = node_name + ";";
314       // index_ is a map sorted lexicographically, so the key we are looking for
315       // must lie in the range [node_name, node_colon_bound).
316       auto it = index_.lower_bound(node_name);
317       while (it != index_.end() && it->first <= node_colon_bound) {
318         if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
319           nid = it->second.nid;
320           break;
321         }
322         ++it;
323       }
324       if (nid == -1) {
325         return errors::InvalidArgument("input[", i, "] == '", input,
326                                        "', is not found.");
327       }
328       AddDep(gnode_idx, nid);
329     }
330 
331     // Attrs.
332     for (const auto& p : attrs) {
333       (*gnode->mutable_attr())[p.first] = p.second;
334     }
335 
336     return Status::OK();
337   }
338 
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)339   Status AddReturnNode(
340       const OpDef::ArgDef& ret_def, AttrSlice attrs,
341       const ::tensorflow::protobuf::Map<string, string>& ret_map,
342       bool ints_on_device, int* ret_index) {
343     auto ret_iter = ret_map.find(ret_def.name());
344     if (ret_iter == ret_map.end()) {
345       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
346     }
347     bool is_type_list;
348     DataTypeVector dtypes;
349     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
350     CHECK_GE(dtypes.size(), size_t{1});
351     const auto* item = GetItemOrNull(ret_iter->second);
352     if (item == nullptr) {
353       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
354                                      ret_iter->second, " is not found.");
355     }
356     if (dtypes != item->dtypes) {
357       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
358                                      " : ", DataTypeVectorString(dtypes),
359                                      " vs. ",
360                                      DataTypeVectorString(item->dtypes));
361     }
362     for (size_t i = 0; i < dtypes.size(); ++i) {
363       string name = strings::StrCat(ret_def.name(), "_RetVal");
364       if (dtypes.size() > 1) {
365         strings::StrAppend(&name, "_", i);
366       }
367       NodeDef* gnode = AddNode(name);
368       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
369         gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
370       } else {
371         gnode->set_op(FunctionLibraryDefinition::kRetOp);
372       }
373       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
374       DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
375       AddAttr("T", dtype, gnode);
376       AddAttr("index", (*ret_index)++, gnode);
377       result_.ret_types.push_back(dtypes[i]);
378     }
379     return Status::OK();
380   }
381 
382   // Adds the actual node inputs to the result graph by converting indexes to
383   // the node names.
AddNodeInputs()384   void AddNodeInputs() {
385     for (int i = 0; i < result_.nodes.size(); i++) {
386       NodeInfo& node_info = nodes_[i];
387       for (const auto& p : node_info.data_inputs) {
388         result_.nodes[i].add_input(Name(p.first, p.second));
389       }
390       for (int index : node_info.control_inputs) {
391         result_.nodes[i].add_input(Dep(index));
392       }
393     }
394   }
395 
396  private:
397   // This is used to build a small index for all names that can be used as a
398   // node's input arguments.
399   //
400   // If is_func_arg is true, the name is a function's argument.  In
401   // this case, the produced graph def has node[nid:nid + dtype.size()].
402   //
403   // Otherwise, the name is a function body's node return value.  In
404   // this case, the produced graph def has one node node[nid] and
405   // the node's output index [idx ... idx + num) corresponds to the
406   // named outputs.
407   //
408   // In all cases, "dtype" specifies the data type.
409   struct NameInfoItem {
410     bool is_func_arg;
411     int nid;
412     int idx;
413     bool is_type_list;
414     DataTypeVector dtypes;
415   };
416 
417   // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)418   Status AddItem(const string& name, const NameInfoItem& item) {
419     if (!index_.insert({name, item}).second) {
420       return errors::InvalidArgument(
421           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
422                           " name: "),
423           name);
424     }
425     return Status::OK();
426   }
427 
GetItemOrNull(const string & name) const428   const NameInfoItem* GetItemOrNull(const string& name) const {
429     return gtl::FindOrNull(index_, name);
430   }
431 
Dep(int node_index) const432   string Dep(int node_index) const {
433     return strings::StrCat("^", Name(node_index));
434   }
435 
Name(int node_index) const436   string Name(int node_index) const {
437     CHECK_LT(node_index, nodes_.size());
438     return nodes_[node_index].name;
439   }
440 
Name(int node_index,int output_index) const441   string Name(int node_index, int output_index) const {
442     if (output_index == 0) {
443       return Name(node_index);
444     } else {
445       return strings::StrCat(Name(node_index), ":", output_index);
446     }
447   }
448 
AddNode(const string & name)449   NodeDef* AddNode(const string& name) {
450     result_.nodes.emplace_back();
451     NodeDef* gnode = &result_.nodes.back();
452     gnode->set_name(name);
453     nodes_.push_back({name, {}, {}});
454     CHECK_EQ(result_.nodes.size(), nodes_.size());
455     return gnode;
456   }
457 
AddInput(int node_index,int output_node,int output_index)458   void AddInput(int node_index, int output_node, int output_index) {
459     CHECK_LT(node_index, nodes_.size());
460     nodes_[node_index].data_inputs.push_back(
461         std::make_pair(output_node, output_index));
462   }
463 
AddDep(int node_index,int dep_index)464   void AddDep(int node_index, int dep_index) {
465     CHECK_LT(node_index, nodes_.size());
466     nodes_[node_index].control_inputs.push_back(dep_index);
467   }
468 
469   GetFunctionSignature get_function_;
470   InstantiationResult& result_;
471   // A small index for all names that can be used as a node's input arguments.
472   std::map<string, NameInfoItem> index_;
473   // This contains information about a node in the new graph including the node
474   // names and input nodes' indexes.
475   struct NodeInfo {
476     string name;
477     // Data inputs where <n, k> means arg k of node n.
478     std::vector<std::pair<int, int>> data_inputs;
479     // Control inputs (dependencies).
480     std::vector<int> control_inputs;
481   };
482   // nodes_[i] is the information about result_.nodes[i].
483   std::vector<NodeInfo> nodes_;
484 };
485 
486 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)487 string Print(const OpDef::ArgDef& arg) {
488   string out;
489   strings::StrAppend(&out, arg.name(), ":");
490   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
491   if (!arg.number_attr().empty()) {
492     strings::StrAppend(&out, arg.number_attr(), "*");
493   }
494   if (arg.type() != DT_INVALID) {
495     strings::StrAppend(&out, DataTypeString(arg.type()));
496   } else {
497     strings::StrAppend(&out, arg.type_attr());
498   }
499   if (arg.is_ref()) strings::StrAppend(&out, ")");
500   return out;
501 }
502 
503 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)504 string Print(const AttrValue& attr_value) {
505   if (attr_value.value_case() == AttrValue::kType) {
506     return DataTypeString(attr_value.type());
507   } else if ((attr_value.value_case() == AttrValue::kList) &&
508              (attr_value.list().type_size() > 0)) {
509     string ret = "{";
510     for (int i = 0; i < attr_value.list().type_size(); ++i) {
511       if (i > 0) strings::StrAppend(&ret, ", ");
512       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
513     }
514     strings::StrAppend(&ret, "}");
515     return ret;
516   } else if (attr_value.value_case() == AttrValue::kFunc) {
517     if (attr_value.func().attr_size() == 0) {
518       return attr_value.func().name();
519     }
520     std::vector<string> entries;
521     for (const auto& p : attr_value.func().attr()) {
522       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
523     }
524     std::sort(entries.begin(), entries.end());
525     return strings::StrCat(attr_value.func().name(), "[",
526                            absl::StrJoin(entries, ", "), "]");
527   }
528   return SummarizeAttrValue(attr_value);
529 }
530 
531 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)532 string Print(const NodeDef& n) {
533   string out;
534   strings::StrAppend(&out, n.name(), " = ", n.op());
535   if (n.attr_size() > 0) {
536     std::vector<string> entries;
537     for (auto& a : n.attr()) {
538       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
539     }
540     std::sort(entries.begin(), entries.end());
541     // Add a short device string at the end of all attributes.
542     if (!n.device().empty()) {
543       DeviceNameUtils::ParsedName parsed;
544       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
545         entries.push_back(
546             strings::StrCat("device=", parsed.type, ":", parsed.id));
547       } else {
548         entries.push_back("device=<FAILED_TO_PARSE>");
549       }
550     }
551     strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
552   }
553   strings::StrAppend(&out, "(");
554   std::vector<StringPiece> dat;
555   std::vector<string> dep;
556   for (StringPiece s : n.input()) {
557     if (absl::ConsumePrefix(&s, "^")) {
558       dep.emplace_back(s);
559     } else {
560       dat.push_back(s);
561     }
562   }
563   strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
564   if (!dep.empty()) {
565     strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
566   }
567   return out;
568 }
569 
Print(const FunctionDef & fdef)570 string Print(const FunctionDef& fdef) {
571   string out;
572   const OpDef& sig = fdef.signature();
573   strings::StrAppend(&out, "\n", sig.name());
574   if (sig.attr_size() > 0) {
575     strings::StrAppend(&out, "[");
576     for (int i = 0; i < sig.attr_size(); ++i) {
577       const auto& a = sig.attr(i);
578       if (i > 0) strings::StrAppend(&out, ", ");
579       if (a.type() == "type") {
580         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
581       } else {
582         strings::StrAppend(&out, a.name(), ":", a.type());
583       }
584     }
585     strings::StrAppend(&out, "]");
586   }
587   strings::StrAppend(&out, "(");
588   for (int i = 0; i < sig.input_arg_size(); ++i) {
589     if (i > 0) strings::StrAppend(&out, ", ");
590     strings::StrAppend(&out, Print(sig.input_arg(i)));
591   }
592   strings::StrAppend(&out, ") -> (");
593   for (int i = 0; i < sig.output_arg_size(); ++i) {
594     if (i > 0) strings::StrAppend(&out, ", ");
595     strings::StrAppend(&out, Print(sig.output_arg(i)));
596   }
597   strings::StrAppend(&out, ") {\n");
598   for (const auto& n : fdef.node_def()) {
599     strings::StrAppend(&out, "  ", Print(n), "\n");
600   }
601   for (const auto& cr : fdef.control_ret()) {
602     strings::StrAppend(&out, "  @return ", cr.first, " = ", cr.second, "\n");
603   }
604   for (const auto& r : fdef.ret()) {
605     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
606   }
607   strings::StrAppend(&out, "}\n");
608   return out;
609 }
610 
Print(gtl::ArraySlice<const NodeDef * > nodes)611 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
612   std::vector<const NodeDef*> arg;
613   std::vector<const NodeDef*> ret;
614   std::vector<const NodeDef*> body;
615   for (const NodeDef* n : nodes) {
616     if (n->op() == FunctionLibraryDefinition::kArgOp ||
617         n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
618       arg.push_back(n);
619     } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
620                n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
621       ret.push_back(n);
622     } else {
623       body.push_back(n);
624     }
625   }
626   auto comp = [](const NodeDef* x, const NodeDef* y) {
627     int xi;
628     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
629     int yi;
630     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
631     return xi < yi;
632   };
633   std::sort(arg.begin(), arg.end(), comp);
634   std::sort(ret.begin(), ret.end(), comp);
635   string out;
636   strings::StrAppend(&out, "\n(");
637   auto get_type_and_device = [](const NodeDef& n) {
638     DataType dt;
639     if (!TryGetNodeAttr(n, "T", &dt)) {
640       dt = DT_INVALID;
641     }
642     if (!n.device().empty()) {
643       DeviceNameUtils::ParsedName parsed;
644       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
645         return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
646                                parsed.id);
647       } else {
648         LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
649                      << n.op() << ":" << n.name();
650         return strings::StrCat(DataTypeString(dt), "@",
651                                "<FAILED_TO_PARSE_DEVICE>");
652       }
653     }
654     return DataTypeString(dt);
655   };
656   for (size_t i = 0; i < arg.size(); ++i) {
657     const NodeDef* n = arg[i];
658     if (i > 0) strings::StrAppend(&out, ", ");
659     CHECK_GE(n->attr_size(), 2);
660     strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
661   }
662   strings::StrAppend(&out, ") -> (");
663   for (size_t i = 0; i < ret.size(); ++i) {
664     const NodeDef* n = ret[i];
665     if (i > 0) strings::StrAppend(&out, ", ");
666     CHECK_LE(2, n->attr_size());
667 
668     // The _RetVal op should have a unique non-control input. We assert that
669     // here and add it to the output.
670     bool found_non_control_input = false;
671     for (const string& input : n->input()) {
672       if (!input.empty() && input[0] != '^') {
673         DCHECK_EQ(found_non_control_input, false)
674             << "RetVal node has more than one non-control input: "
675             << absl::StrJoin(n->input(), ", ");
676         strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
677         found_non_control_input = true;
678       }
679     }
680     DCHECK_EQ(found_non_control_input, true)
681         << "RetVal did not have any non-control inputs: "
682         << absl::StrJoin(n->input(), ", ");
683   }
684   strings::StrAppend(&out, ") {\n");
685   for (size_t i = 0; i < body.size(); ++i) {
686     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
687   }
688   strings::StrAppend(&out, "}\n");
689   return out;
690 }
691 
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)692 Status AddDefaultAttrs(const string& op,
693                        const GetFunctionSignature& get_function,
694                        AttrValueMap* attrs) {
695   const OpDef* op_def = nullptr;
696   TF_RETURN_IF_ERROR(get_function(op, &op_def));
697   AttrSlice attr_slice(attrs);
698   for (const auto& attr_def : op_def->attr()) {
699     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
700       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
701         return errors::Internal("Somehow duplicated: ", attr_def.name());
702       }
703     }
704   }
705   return Status::OK();
706 }
707 
708 }  // end namespace
709 
710 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)711 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
712                            GetFunctionSignature get_function,
713                            InstantiationResult* result) {
714   if (VLOG_IS_ON(5)) {
715     const auto& signature = fdef.signature();
716     VLOG(5) << "Instantiate function definition: name=" << signature.name()
717             << " #input_args=" << signature.input_arg_size()
718             << " #output_args=" << signature.output_arg_size()
719             << " #control_output=" << signature.control_output_size();
720     for (const auto& line : str_util::Split(Print(fdef), '\n')) {
721       VLOG(5) << "|| " << line;
722     }
723   }
724 
725   const OpDef& sig = fdef.signature();
726   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
727 
728   bool ints_on_device =
729       fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
730       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
731 
732   FunctionInstantiationHelper helper(get_function, result);
733   Status s;
734   for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
735     const OpDef::ArgDef& arg_def = sig.input_arg(i);
736     auto it = fdef.arg_attr().find(i);
737     const FunctionDef::ArgAttrs* arg_attrs =
738         it != fdef.arg_attr().end() ? &it->second : nullptr;
739     auto resource_id_it = fdef.resource_arg_unique_id().find(i);
740     int64_t resource_arg_unique_id =
741         resource_id_it != fdef.resource_arg_unique_id().end()
742             ? resource_id_it->second
743             : -1LL;
744     s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
745                                   ints_on_device, resource_arg_unique_id);
746 
747     if (!s.ok()) {
748       errors::AppendToMessage(&s, "In ", Print(arg_def));
749       return s;
750     }
751   }
752 
753   auto substitute = [attr_values](StringPiece name, AttrValue* val) {
754     if (const AttrValue* v = attr_values.Find(name)) {
755       *val = *v;
756       return true;
757     }
758     return false;
759   };
760 
761   // Makes a copy of all attrs in fdef and substitutes placeholders.
762   // After this step, every attr is bound to a concrete value.
763   std::vector<AttrValueMap> node_attrs;
764   node_attrs.resize(fdef.node_def_size());
765   for (int i = 0; i < fdef.node_def_size(); ++i) {
766     for (auto attr : fdef.node_def(i).attr()) {
767       if (!SubstitutePlaceholders(substitute, &attr.second)) {
768         return errors::InvalidArgument("Failed to bind all placeholders in ",
769                                        SummarizeAttrValue(attr.second));
770       }
771       if (!node_attrs[i].insert(attr).second) {
772         return errors::Internal("Somehow duplicated: ", attr.first);
773       }
774     }
775     TF_RETURN_IF_ERROR(
776         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
777   }
778 
779   for (int i = 0; i < fdef.node_def_size(); ++i) {
780     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
781                                     result->nodes.size() + i);
782     if (!s.ok()) {
783       errors::AppendToMessage(&s, "In ",
784                               FormatNodeDefForError(fdef.node_def(i)));
785       return s;
786     }
787   }
788   // Emits one node for each fdef.node_def.
789   for (int i = 0; i < fdef.node_def_size(); ++i) {
790     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
791     if (!s.ok()) {
792       errors::AppendToMessage(&s, "In ",
793                               FormatNodeDefForError(fdef.node_def(i)));
794       return s;
795     }
796   }
797 
798   // Emits nodes for the function's return values.
799   int ret_index = 0;
800   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
801     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
802                              &ret_index);
803     if (!s.ok()) {
804       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
805       return s;
806     }
807   }
808 
809   // Adds the actual node inputs using the input indexes.
810   helper.AddNodeInputs();
811 
812   return Status::OK();
813 }
814 
DebugString(const FunctionDef & func_def)815 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
816 
DebugString(const GraphDef & instantiated_func_def)817 string DebugString(const GraphDef& instantiated_func_def) {
818   std::vector<const NodeDef*> ptrs;
819   for (const NodeDef& n : instantiated_func_def.node()) {
820     ptrs.push_back(&n);
821   }
822   return Print(ptrs);
823 }
824 
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)825 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
826   std::vector<const NodeDef*> ptrs;
827   for (const NodeDef& n : instantiated_func_nodes) {
828     ptrs.push_back(&n);
829   }
830   return Print(ptrs);
831 }
832 
DebugStringWhole(const GraphDef & gdef)833 string DebugStringWhole(const GraphDef& gdef) {
834   string ret;
835   for (const auto& fdef : gdef.library().function()) {
836     strings::StrAppend(&ret, Print(fdef));
837   }
838   strings::StrAppend(&ret, "\n");
839   for (const auto& ndef : gdef.node()) {
840     strings::StrAppend(&ret, Print(ndef), "\n");
841   }
842   return ret;
843 }
844 
845 namespace {
846 
847 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
848 // Python, it's possible to access unset attrs, which returns a default value
849 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)850 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
851   std::map<string, AttrValue> set_attrs;
852   for (const auto& pair : fdef.attr()) {
853     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
854       set_attrs[pair.first] = pair.second;
855     }
856   }
857   return set_attrs;
858 }
859 
860 }  // end namespace
861 
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)862 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
863   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
864 
865   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
866   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
867   if (f1_attrs.size() != f2_attrs.size()) return false;
868   for (const auto& iter1 : f1_attrs) {
869     auto iter2 = f2_attrs.find(iter1.first);
870     if (iter2 == f2_attrs.end()) return false;
871     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
872   }
873 
874   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
875     return false;
876   }
877 
878   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
879   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
880   if (ret1 != ret2) return false;
881 
882   std::map<string, string> control_ret1(f1.control_ret().begin(),
883                                         f1.control_ret().end());
884   std::map<string, string> control_ret2(f2.control_ret().begin(),
885                                         f2.control_ret().end());
886   if (control_ret1 != control_ret2) return false;
887 
888   return true;
889 }
890 
FunctionDefHash(const FunctionDef & fdef)891 uint64 FunctionDefHash(const FunctionDef& fdef) {
892   // signature
893   uint64 h = OpDefHash(fdef.signature());
894 
895   // attrs
896   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
897   for (const auto& p : attrs) {
898     h = Hash64(p.first.data(), p.first.size(), h);
899     h = Hash64Combine(AttrValueHash(p.second), h);
900   }
901 
902   // node defs
903   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
904 
905   // output names
906   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
907   for (const auto& p : ret) {
908     h = Hash64(p.first.data(), p.first.size(), h);
909     h = Hash64(p.second.data(), p.second.size(), h);
910   }
911 
912   // control output names
913   std::map<string, string> control_ret(fdef.control_ret().begin(),
914                                        fdef.control_ret().end());
915   for (const auto& p : control_ret) {
916     h = Hash64(p.first.data(), p.first.size(), h);
917     h = Hash64(p.second.data(), p.second.size(), h);
918   }
919 
920   return h;
921 }
922 
923 static constexpr const char* const kExecutorAttr = "_executor";
924 
925 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)926 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
927                                             AttrSlice attrs) {
928   if (!options.executor_type.empty()) {
929     return options.executor_type;
930   } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
931     return executor_attr->s();
932   } else {
933     return string();
934   }
935 }
936 
937 namespace {
938 class AttrKeyAndValue {
939  public:
940   enum ValueRepresentationOp {
941     kRaw,
942     kCEscape,
943   };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)944   AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
945                   ValueRepresentationOp value_op = kRaw)
946       : key_name_(key_name),
947         key_suffix_(key_suffix),
948         value_op_(value_op),
949         value_(std::move(value)) {}
950 
operator <(const AttrKeyAndValue & b) const951   bool operator<(const AttrKeyAndValue& b) const {
952     if (key_name_ != b.key_name_) {
953       return key_name_ < b.key_name_;
954     } else if (key_suffix_ != b.key_suffix_) {
955       return key_suffix_ < b.key_suffix_;
956     } else {
957       return value_ < b.value_;
958     }
959   }
960 
AppendTo(bool first,string * s) const961   void AppendTo(bool first, string* s) const {
962     absl::string_view v;
963     bool add_escaped = false;
964     if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
965       // Use CEscape call below
966       add_escaped = true;
967     } else {
968       // Add raw value contents directly
969       v = value_;
970     }
971     if (key_suffix_ >= 0) {
972       strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
973     } else {
974       strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
975     }
976     if (add_escaped) {
977       strings::StrAppend(s, absl::CEscape(value_));
978     }
979   }
980 
981  private:
NeedsEscaping(const string & s)982   static bool NeedsEscaping(const string& s) {
983     for (auto c : s) {
984       if (!isalnum(c) && (c != ' ')) {
985         return true;
986       }
987     }
988     return false;
989   }
990 
991   absl::string_view key_name_;
992   int key_suffix_;  // -1 if missing
993   ValueRepresentationOp value_op_;
994   string value_;
995 };
996 }  // namespace
997 
GetFunctionResourceInputDevice(const Tensor & input,const int arg_index,const FunctionDef & function_def,absl::flat_hash_map<string,std::vector<string>> * composite_devices)998 string GetFunctionResourceInputDevice(
999     const Tensor& input, const int arg_index, const FunctionDef& function_def,
1000     absl::flat_hash_map<string, std::vector<string>>* composite_devices) {
1001   const auto& handles = input.flat<ResourceHandle>();
1002   const ResourceHandle& handle0 = handles(0);
1003   string composite_device;
1004   auto iter = function_def.arg_attr().find(arg_index);
1005   if (iter != function_def.arg_attr().end()) {
1006     auto arg_attr = iter->second.attr().find("_composite_device");
1007     if (arg_attr != iter->second.attr().end()) {
1008       composite_device = arg_attr->second.s();
1009     }
1010   }
1011   if (!composite_device.empty()) {
1012     if (composite_devices->find(composite_device) == composite_devices->end()) {
1013       for (int i = 0; i < handles.size(); ++i) {
1014         (*composite_devices)[composite_device].push_back(handles(i).device());
1015       }
1016     }
1017     return composite_device;
1018   } else {
1019     return handle0.device();
1020   }
1021 }
1022 
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)1023 string Canonicalize(const string& funcname, AttrSlice attrs,
1024                     const FunctionLibraryRuntime::InstantiateOptions& options) {
1025   absl::InlinedVector<AttrKeyAndValue, 8> entries;
1026   entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1027                   options.input_devices.size());
1028   for (const auto& p : attrs) {
1029     if (p.first != kExecutorAttr) {
1030       entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
1031     }
1032   }
1033   if (!options.target.empty()) {
1034     entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1035                                       AttrKeyAndValue::kCEscape));
1036   }
1037   for (int i = 0; i < options.input_devices.size(); ++i) {
1038     entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1039                                       AttrKeyAndValue::kCEscape));
1040   }
1041   for (int i = 0; i < options.output_devices.size(); ++i) {
1042     entries.push_back(AttrKeyAndValue("_output_dev", i,
1043                                       options.output_devices[i],
1044                                       AttrKeyAndValue::kCEscape));
1045   }
1046   for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1047     entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1048                                       DataTypeString(iter.second.dtype)));
1049     entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1050                                       iter.second.shape.DebugString(),
1051                                       AttrKeyAndValue::kCEscape));
1052   }
1053   if (options.lib_def) {
1054     entries.push_back(AttrKeyAndValue(
1055         "_lib_def", -1,
1056         absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1057   }
1058   if (!options.state_handle.empty()) {
1059     entries.push_back(
1060         AttrKeyAndValue("_state_handle", -1, options.state_handle));
1061   }
1062   string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1063   if (!executor_type.empty()) {
1064     entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1065   }
1066   if (options.config_proto.ByteSize() > 0) {
1067     string config_proto_serialized;
1068     SerializeToStringDeterministic(options.config_proto,
1069                                    &config_proto_serialized);
1070     entries.push_back(AttrKeyAndValue("_config_proto", -1,
1071                                       config_proto_serialized,
1072                                       AttrKeyAndValue::kCEscape));
1073   }
1074   std::sort(entries.begin(), entries.end());
1075   string result = strings::StrCat(funcname, "[");
1076   bool first = true;
1077   for (const auto& entry : entries) {
1078     entry.AppendTo(first, &result);
1079     first = false;
1080   }
1081   result += "]";
1082   return result;
1083 }
1084 
Canonicalize(const string & funcname,AttrSlice attrs)1085 string Canonicalize(const string& funcname, AttrSlice attrs) {
1086   static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1087       new FunctionLibraryRuntime::InstantiateOptions;
1088   return Canonicalize(funcname, attrs, *kEmptyOptions);
1089 }
1090 
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1091 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1092                                      DataTypeSlice ret_types)
1093     : arg_types_(arg_types.begin(), arg_types.end()),
1094       ret_types_(ret_types.begin(), ret_types.end()) {
1095   args_.resize(arg_types_.size());
1096   rets_.resize(ret_types_.size());
1097 }
1098 
~FunctionCallFrame()1099 FunctionCallFrame::~FunctionCallFrame() {}
1100 
SetArgs(gtl::ArraySlice<Tensor> args)1101 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1102   // Input type checks.
1103   if (args.size() != arg_types_.size()) {
1104     return errors::InvalidArgument("Expects ", arg_types_.size(),
1105                                    " arguments, but ", args.size(),
1106                                    " is provided");
1107   }
1108   for (size_t i = 0; i < args.size(); ++i) {
1109     if (arg_types_[i] != args[i].dtype()) {
1110       return errors::InvalidArgument(
1111           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1112           DataTypeString(args[i].dtype()), " is provided");
1113     }
1114     args_[i] = args[i];
1115   }
1116   return Status::OK();
1117 }
1118 
GetRetvals(std::vector<Tensor> * rets) const1119 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1120   rets->clear();
1121   rets->reserve(rets_.size());
1122   for (size_t i = 0; i < rets_.size(); ++i) {
1123     const auto& item = rets_[i];
1124     if (item.has_val) {
1125       rets->push_back(item.val);
1126     } else {
1127       return errors::Internal("Retval[", i, "] does not have value");
1128     }
1129   }
1130   return Status::OK();
1131 }
1132 
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1133 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1134                                          bool allow_dead_tensors) {
1135   rets->clear();
1136   rets->reserve(rets_.size());
1137   for (size_t i = 0; i < rets_.size(); ++i) {
1138     if (rets_[i].has_val) {
1139       rets->emplace_back(std::move(rets_[i].val));
1140     } else if (allow_dead_tensors) {
1141       rets->emplace_back();
1142     } else {
1143       return errors::Internal("Retval[", i, "] does not have value");
1144     }
1145   }
1146   return Status::OK();
1147 }
1148 
GetArg(int index,const Tensor ** val)1149 Status FunctionCallFrame::GetArg(int index, const Tensor** val) {
1150   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1151     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1152                                    args_.size(), ")");
1153   }
1154   *val = &args_[index];
1155   return Status::OK();
1156 }
1157 
SetRetval(int index,const Tensor & val)1158 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1159   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1160     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1161                                    rets_.size(), ")");
1162   }
1163   if (val.dtype() != ret_types_[index]) {
1164     return errors::InvalidArgument(
1165         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1166         ", but ", DataTypeString(val.dtype()), " is provided.");
1167   }
1168   Retval* item = &rets_[index];
1169   if (!item->has_val) {
1170     item->has_val = true;
1171     item->val = val;
1172   } else {
1173     return errors::Internal("Retval[", index, "] has already been set.");
1174   }
1175   return Status::OK();
1176 }
1177 
1178 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in,const StackTracesMap & stack_traces)1179     FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
1180                                  const StackTracesMap& stack_traces)
1181     : fdef(fdef_in),
1182       // Exact shape inference for functions is handled by ShapeRefiner.
1183       // Here we pass a dummy shape inference function for legacy code paths.
1184       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1185                            true /* is_function */),
1186       stack_traces(stack_traces) {}
1187 
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1188 FunctionLibraryDefinition::FunctionLibraryDefinition(
1189     const FunctionLibraryDefinition& other)
1190     : default_registry_(other.default_registry_) {
1191   tf_shared_lock l(other.mu_);
1192   function_defs_ = other.function_defs_;
1193   func_grad_ = other.func_grad_;
1194 }
1195 
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1196 FunctionLibraryDefinition::FunctionLibraryDefinition(
1197     const OpRegistryInterface* default_registry,
1198     const FunctionDefLibrary& def_lib)
1199     : default_registry_(default_registry),
1200       function_defs_(def_lib.function_size()) {
1201   for (const auto& fdef : def_lib.function()) {
1202     // The latter function definition wins.
1203     auto& ptr = function_defs_[fdef.signature().name()];
1204     ptr.reset(new FunctionDefAndOpRegistration(fdef));
1205   }
1206   for (const auto& grad : def_lib.gradient()) {
1207     func_grad_[grad.function_name()] = grad.gradient_func();
1208   }
1209 }
1210 
~FunctionLibraryDefinition()1211 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1212 
Contains(const string & func) const1213 bool FunctionLibraryDefinition::Contains(const string& func) const {
1214   tf_shared_lock l(mu_);
1215   return function_defs_.find(func) != function_defs_.end();
1216 }
1217 
Find(const string & func) const1218 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1219   tf_shared_lock l(mu_);
1220   auto result = FindHelper(func);
1221   if (result) {
1222     return &result->fdef;
1223   } else {
1224     return nullptr;
1225   }
1226 }
1227 
1228 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1229 FunctionLibraryDefinition::FindHelper(const string& func) const {
1230   auto iter = function_defs_.find(func);
1231   if (iter == function_defs_.end()) {
1232     return nullptr;
1233   } else {
1234     return iter->second;
1235   }
1236 }
1237 
AddFunctionDef(const FunctionDef & fdef,const StackTracesMap & stack_traces)1238 Status FunctionLibraryDefinition::AddFunctionDef(
1239     const FunctionDef& fdef, const StackTracesMap& stack_traces) {
1240   mutex_lock l(mu_);
1241   bool added;
1242   return AddFunctionDefHelper(fdef, stack_traces, &added);
1243 }
1244 
AddFunctionDefHelper(const FunctionDef & fdef,const StackTracesMap & stack_traces,bool * added)1245 Status FunctionLibraryDefinition::AddFunctionDefHelper(
1246     const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) {
1247   *added = false;
1248   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1249       function_defs_[fdef.signature().name()];
1250   if (entry) {
1251     if (!FunctionDefsEqual(entry->fdef, fdef)) {
1252       return errors::InvalidArgument(
1253           "Cannot add function '", fdef.signature().name(),
1254           "' because a different function with the same name already "
1255           "exists.");
1256     }
1257     // Ignore duplicate FunctionDefs.
1258     return Status::OK();
1259   }
1260   const OpDef* op_def;
1261   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1262     return errors::InvalidArgument(
1263         "Cannot add function '", fdef.signature().name(),
1264         "' because an op with the same name already exists.");
1265   }
1266   entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces);
1267   *added = true;
1268   return Status::OK();
1269 }
1270 
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1271 Status FunctionLibraryDefinition::AddHelper(
1272     std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1273   *added = false;
1274   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1275       function_defs_[registration->fdef.signature().name()];
1276   if (entry) {
1277     if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1278       return errors::InvalidArgument(
1279           "Cannot add function '", registration->fdef.signature().name(),
1280           "' because a different function with the same name already "
1281           "exists.");
1282     }
1283     // Ignore duplicate FunctionDefs.
1284     return Status::OK();
1285   }
1286   const OpDef* op_def;
1287   if (default_registry_
1288           ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1289           .ok()) {
1290     return errors::InvalidArgument(
1291         "Cannot add function '", registration->fdef.signature().name(),
1292         "' because an op with the same name already exists.");
1293   }
1294   entry = std::move(registration);
1295   *added = true;
1296   return Status::OK();
1297 }
1298 
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1299 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1300     const string& func, const FunctionLibraryDefinition& other) {
1301   if (default_registry_ != other.default_registry_) {
1302     return errors::InvalidArgument(
1303         "Cannot copy function '", func,
1304         "' because CopyFunctionDefFrom() requires that both libraries have the "
1305         "same default registry.");
1306   }
1307   std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1308   {
1309     tf_shared_lock l(other.mu_);
1310     function_def = other.FindHelper(func);
1311   }
1312   if (!function_def) {
1313     return errors::InvalidArgument(
1314         "Cannot copy function '", func,
1315         "' because no function with that name exists in the other library.");
1316   }
1317   {
1318     mutex_lock l(mu_);
1319     std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1320     if (entry) {
1321       if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1322         return errors::InvalidArgument(
1323             "Cannot copy function '", func,
1324             "' because a different function with the same name already "
1325             "exists.");
1326       }
1327     } else {
1328       entry = std::move(function_def);
1329     }
1330   }
1331   return Status::OK();
1332 }
1333 
AddGradientDef(const GradientDef & grad)1334 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1335   mutex_lock l(mu_);
1336   bool added;
1337   return AddGradientDefHelper(grad, &added);
1338 }
1339 
AddGradientDefHelper(const GradientDef & grad,bool * added)1340 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1341                                                        bool* added) {
1342   *added = false;
1343   string* entry = &func_grad_[grad.function_name()];
1344   if (!entry->empty()) {
1345     if (*entry != grad.gradient_func()) {
1346       return errors::InvalidArgument(
1347           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1348           grad.function_name(), "' because it already has gradient function ",
1349           "'", *entry, "'");
1350     }
1351     // Ignore duplicate GradientDefs
1352     return Status::OK();
1353   }
1354   *entry = grad.gradient_func();
1355   *added = true;
1356   return Status::OK();
1357 }
1358 
AddLibrary(const FunctionLibraryDefinition & other)1359 Status FunctionLibraryDefinition::AddLibrary(
1360     const FunctionLibraryDefinition& other) {
1361   // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1362   // the duration of the function could lead to deadlock).
1363   FunctionLibraryDefinition clone(other);
1364   mutex_lock l(mu_);
1365   mutex_lock l2(clone.mu_);
1366   // Remember the funcs and grads that we added successfully so that
1367   // we can roll them back on error.
1368   std::vector<string> funcs;
1369   std::vector<string> funcs_with_grads;
1370   Status s;
1371   bool added;
1372   for (auto iter : clone.function_defs_) {
1373     s = AddHelper(iter.second, &added);
1374     if (!s.ok()) {
1375       Status remove_status = Remove(funcs, funcs_with_grads);
1376       if (!remove_status.ok()) {
1377         return remove_status;
1378       }
1379       return s;
1380     }
1381     if (added) {
1382       funcs.push_back(iter.second->fdef.signature().name());
1383     }
1384   }
1385   for (auto iter : clone.func_grad_) {
1386     GradientDef grad;
1387     grad.set_function_name(iter.first);
1388     grad.set_gradient_func(iter.second);
1389     s = AddGradientDefHelper(grad, &added);
1390     if (!s.ok()) {
1391       Status remove_status = Remove(funcs, funcs_with_grads);
1392       if (!remove_status.ok()) {
1393         return remove_status;
1394       }
1395       return s;
1396     }
1397     if (added) {
1398       funcs_with_grads.push_back(grad.function_name());
1399     }
1400   }
1401   return Status::OK();
1402 }
1403 
AddLibrary(const FunctionDefLibrary & lib_def)1404 Status FunctionLibraryDefinition::AddLibrary(
1405     const FunctionDefLibrary& lib_def) {
1406   // Remember the funcs and grads that we added successfully so that
1407   // we can roll them back on error.
1408   mutex_lock l(mu_);
1409   std::vector<string> funcs;
1410   std::vector<string> funcs_with_grads;
1411   Status s;
1412   bool added;
1413   for (const FunctionDef& fdef : lib_def.function()) {
1414     s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added);
1415     if (!s.ok()) {
1416       Status remove_status = Remove(funcs, funcs_with_grads);
1417       if (!remove_status.ok()) {
1418         return remove_status;
1419       }
1420       return s;
1421     }
1422     if (added) {
1423       funcs.push_back(fdef.signature().name());
1424     }
1425   }
1426   for (const GradientDef& grad : lib_def.gradient()) {
1427     s = AddGradientDefHelper(grad, &added);
1428     if (!s.ok()) {
1429       Status remove_status = Remove(funcs, funcs_with_grads);
1430       if (!remove_status.ok()) {
1431         return remove_status;
1432       }
1433       return s;
1434     }
1435     if (added) {
1436       funcs_with_grads.push_back(grad.function_name());
1437     }
1438   }
1439   return Status::OK();
1440 }
1441 
ReplaceFunction(const string & func,const FunctionDef & fdef,const StackTracesMap & stack_traces)1442 Status FunctionLibraryDefinition::ReplaceFunction(
1443     const string& func, const FunctionDef& fdef,
1444     const StackTracesMap& stack_traces) {
1445   mutex_lock l(mu_);
1446   bool added;
1447   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1448   TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, stack_traces, &added));
1449   return Status::OK();
1450 }
1451 
ReplaceGradient(const GradientDef & grad)1452 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1453   mutex_lock l(mu_);
1454   bool added;
1455   TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1456   TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1457   return Status::OK();
1458 }
1459 
RemoveFunction(const string & func)1460 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1461   mutex_lock l(mu_);
1462   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1463   return Status::OK();
1464 }
1465 
RemoveFunctionHelper(const string & func)1466 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1467   const auto& i = function_defs_.find(func);
1468   if (i == function_defs_.end()) {
1469     return errors::InvalidArgument("Tried to remove non-existent function '",
1470                                    func, "'.");
1471   }
1472   function_defs_.erase(i);
1473   return Status::OK();
1474 }
1475 
Clear()1476 void FunctionLibraryDefinition::Clear() {
1477   mutex_lock l(mu_);
1478   function_defs_.clear();
1479   func_grad_.clear();
1480 }
1481 
RemoveGradient(const string & func)1482 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1483   const auto& i = func_grad_.find(func);
1484   if (i == func_grad_.end()) {
1485     return errors::InvalidArgument("Tried to remove non-existent gradient '",
1486                                    func, "'.");
1487   }
1488   func_grad_.erase(i);
1489   return Status::OK();
1490 }
1491 
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1492 Status FunctionLibraryDefinition::Remove(
1493     const std::vector<string>& funcs,
1494     const std::vector<string>& funcs_with_grads) {
1495   Status s;
1496   for (const string& f : funcs) {
1497     s = RemoveFunctionHelper(f);
1498     if (!s.ok()) {
1499       return s;
1500     }
1501   }
1502   for (const string& f : funcs_with_grads) {
1503     s = RemoveGradient(f);
1504     if (!s.ok()) {
1505       return s;
1506     }
1507   }
1508   return Status::OK();
1509 }
1510 
FindGradient(const string & func) const1511 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1512   tf_shared_lock l(mu_);
1513   return gtl::FindWithDefault(func_grad_, func, "");
1514 }
1515 
FindGradientHelper(const string & func) const1516 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1517   return gtl::FindWithDefault(func_grad_, func, "");
1518 }
1519 
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1520 Status FunctionLibraryDefinition::LookUp(
1521     const string& op, const OpRegistrationData** op_reg_data) const {
1522   tf_shared_lock l(mu_);
1523   auto iter = function_defs_.find(op);
1524   if (iter != function_defs_.end()) {
1525     *op_reg_data = &iter->second->op_registration_data;
1526     return Status::OK();
1527   }
1528   return default_registry_->LookUp(op, op_reg_data);
1529 }
1530 
UniqueFunctionName(StringPiece prefix) const1531 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1532   tf_shared_lock l(mu_);
1533   int index = 0;
1534   string name = strings::StrCat(prefix, index);
1535   while (function_defs_.find(name) != function_defs_.end()) {
1536     ++index;
1537     name = strings::StrCat(prefix, index);
1538   }
1539   return name;
1540 }
1541 
GetAttrImpl(const NodeDef & ndef) const1542 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1543     const NodeDef& ndef) const {
1544   if (ndef.op() != kGradientOp) {
1545     // If 'ndef' calls a function and the function's def has the attr,
1546     // returns it.
1547     return Find(ndef.op());
1548   }
1549 
1550   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1551   // Foo's attributes.
1552   const NameAttrList* forward_func_attrs;
1553   if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1554     return nullptr;
1555   }
1556   const string& func_name = forward_func_attrs->name();
1557   {
1558     tf_shared_lock l(mu_);
1559     const string& grad_name = FindGradientHelper(func_name);
1560     // If 'func' has a user-defined gradient function, uses the grad
1561     // function's attrs to see if noinline is specified. Otherwise,
1562     // uses func's attrs.
1563     if (!grad_name.empty()) {
1564       if (const auto helper = FindHelper(grad_name)) {
1565         return &(helper->fdef);
1566       } else {
1567         return nullptr;
1568       }
1569     }
1570     if (const auto helper = FindHelper(func_name)) {
1571       return &(helper->fdef);
1572     } else {
1573       return nullptr;
1574     }
1575   }
1576 }
1577 
ListFunctionNames() const1578 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1579   std::vector<string> function_names;
1580   tf_shared_lock l(mu_);
1581   function_names.reserve(function_defs_.size());
1582   for (const auto& it : function_defs_) {
1583     function_names.emplace_back(it.first);
1584   }
1585   return function_names;
1586 }
1587 
ToProto() const1588 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1589   FunctionDefLibrary lib;
1590   tf_shared_lock l(mu_);
1591   for (const auto& f : function_defs_) {
1592     *lib.add_function() = f.second->fdef;
1593   }
1594   for (const auto& g : func_grad_) {
1595     GradientDef* gd = lib.add_gradient();
1596     gd->set_function_name(g.first);
1597     gd->set_gradient_func(g.second);
1598   }
1599   return lib;
1600 }
1601 
1602 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1603 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1604                                           const string& attr, T* value) const {
1605   const FunctionDef* fdef = GetAttrImpl(ndef);
1606   if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1607     return Status::OK();
1608   }
1609   return errors::InvalidArgument("Attr ", attr, " is not defined.");
1610 }
1611 
1612 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1613 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1614                                           T* value) const {
1615   return GetAttr(node.def(), attr, value);
1616 }
1617 
1618 #define GET_ATTR(T)                                                            \
1619   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
1620                                                      const string&, T*) const; \
1621   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
1622                                                      const string&, T*) const;
1623 GET_ATTR(string)
1624 GET_ATTR(bool)
1625 #undef GET_ATTR
1626 
1627 namespace {
1628 
1629 constexpr char kApiImplements[] = "api_implements";
1630 
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1631 std::set<string> ReachableFunctions(
1632     const FunctionLibraryDefinition& flib,
1633     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1634   // Functions that are reachable from the graph.
1635   std::set<string> reachable_funcs;
1636 
1637   // For any functions, if it has attribute "api_implements" =
1638   // "some_interface" and it is reachable, then it means any other
1639   // function with same attribute name and value could also be potentially
1640   // reachable, eg via implementation_selector swapping the nodedef.
1641   absl::flat_hash_set<string> reachable_api_interface;
1642 
1643   // Functions might be reachable from the nested function calls, so we keep a
1644   // queue of functions that we have to check.
1645   gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1646 
1647   // Add reachable and not already processed functions to the functions queue.
1648   const auto add_to_func_queue = [&](const string& func_name) {
1649     const FunctionDef* func = flib.Find(func_name);
1650     if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1651       func_queue.push_back(func);
1652     }
1653   };
1654 
1655   // If any function with certain API name is reachable, all the other functions
1656   // with same API name should also be checked.
1657   const auto add_function_with_api_interface = [&](const string& api_name) {
1658     if (!reachable_api_interface.contains(api_name)) {
1659       reachable_api_interface.insert(api_name);
1660       for (const auto& func_name : flib.ListFunctionNames()) {
1661         const auto& func_def = flib.Find(func_name);
1662         const auto attr_it = func_def->attr().find(kApiImplements);
1663         if (attr_it != func_def->attr().end() &&
1664             attr_it->second.s() == api_name) {
1665           add_to_func_queue(func_name);
1666         }
1667       }
1668     }
1669   };
1670 
1671   // Add all the functions that are reachable from the given node to the queue.
1672   const auto process_node = [&](const NodeDef& node) {
1673     // Node itself can be a call to the function.
1674     add_to_func_queue(node.op());
1675 
1676     // Or node can have an attribute referencing a function.
1677     for (const auto& attr : node.attr()) {
1678       const auto& attr_value = attr.second;
1679 
1680       // 1. AttrValue.func
1681       if (attr_value.has_func()) {
1682         add_to_func_queue(attr_value.func().name());
1683       }
1684 
1685       // 2. AttrValue.ListValue.func
1686       if (attr_value.has_list()) {
1687         for (const auto& func : attr_value.list().func()) {
1688           add_to_func_queue(func.name());
1689         }
1690       }
1691     }
1692   };
1693 
1694   // Add all functions that are directly called from the optimized graph.
1695   std::for_each(nodes.begin(), nodes.end(), process_node);
1696 
1697   // Process all reachable functions.
1698   while (!func_queue.empty()) {
1699     const FunctionDef* func = func_queue.back();
1700     func_queue.pop_back();
1701 
1702     const string& func_name = func->signature().name();
1703     reachable_funcs.insert(func_name);
1704 
1705     const auto attr_it = func->attr().find(kApiImplements);
1706     if (attr_it != func->attr().end()) {
1707       add_function_with_api_interface(attr_it->second.s());
1708     }
1709 
1710     // Find all the functions called from the function body.
1711     const auto& func_body = func->node_def();
1712     std::for_each(func_body.begin(), func_body.end(), process_node);
1713 
1714     // Check if the function has a registered gradient.
1715     const string grad_func_name = flib.FindGradient(func_name);
1716     if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1717   }
1718 
1719   return reachable_funcs;
1720 }
1721 
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1722 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1723     const FunctionLibraryDefinition& flib,
1724     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1725   std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1726 
1727   FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1728                                            FunctionDefLibrary());
1729 
1730   for (const string& func_name : reachable_funcs) {
1731     // This should never fail, because we copy functions from a valid flib and
1732     // use the same default registry.
1733     Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1734     TF_DCHECK_OK(added);
1735 
1736     const string grad_func_name = flib.FindGradient(func_name);
1737     if (!grad_func_name.empty()) {
1738       GradientDef grad;
1739       grad.set_function_name(func_name);
1740       grad.set_gradient_func(grad_func_name);
1741       // It can only fail if function already has a gradient function.
1742       const Status added_grad = reachable_flib.AddGradientDef(grad);
1743       TF_DCHECK_OK(added_grad);
1744     }
1745   }
1746 
1747   return reachable_flib;
1748 }
1749 
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1750 string AllocatorAttributesToString(
1751     const std::vector<AllocatorAttributes>& attrs) {
1752   string result("[");
1753   // AllocatorAttribute::DebugString produces around 85 bytes now.
1754   result.reserve(100 * attrs.size());
1755   for (const AllocatorAttributes& attr : attrs) {
1756     result.append(attr.DebugString());
1757     result.append(", ");
1758   }
1759   if (!attrs.empty()) {
1760     result.resize(result.size() - 2);
1761   }
1762   result.append("]");
1763   return result;
1764 }
1765 
IsSet(void * ptr)1766 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1767 
1768 }  // namespace
1769 
ReachableDefinitions(const GraphDef & graph) const1770 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1771     const GraphDef& graph) const {
1772   return ReachableFunctionLibraryDefinition(*this, graph.node());
1773 }
1774 
ReachableDefinitions(const FunctionDef & func) const1775 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1776     const FunctionDef& func) const {
1777   return ReachableFunctionLibraryDefinition(*this, func.node_def());
1778 }
1779 
DebugString() const1780 string FunctionLibraryRuntime::Options::DebugString() const {
1781   return absl::StrCat(
1782       "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1783       " cancellation_manager=", IsSet(cancellation_manager),
1784       " collective_executor=", IsSet(collective_executor),
1785       " step_container=", IsSet(step_container),
1786       " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1787       " remote_execution=", remote_execution, " source_device=", source_device,
1788       " create_rendezvous=", create_rendezvous,
1789       " allow_dead_tensors=", allow_dead_tensors,
1790       " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1791       " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1792 }
1793 
InitFromString(StringPiece val)1794 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1795   if (val.size() >= 2 && val[0] == '$') {
1796     proto.set_placeholder(val.data() + 1, val.size() - 1);
1797   } else {
1798     SetAttrValue(val, &proto);
1799   }
1800 }
1801 
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1802 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1803     const string& name,
1804     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1805   AttrValueWrapper ret;
1806   ret.proto.mutable_func()->set_name(name);
1807   for (const auto& a : attrs) {
1808     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1809   }
1810   return ret;
1811 }
1812 
ToNodeDef() const1813 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1814   NodeDef n;
1815   n.set_op(this->op);
1816   n.set_name(GetName());
1817   for (const auto& a : this->attr) {
1818     n.mutable_attr()->insert({a.first, a.second.proto});
1819   }
1820   for (const string& a : this->arg) {
1821     n.add_input(a);
1822   }
1823   for (const string& d : this->dep) {
1824     n.add_input(strings::StrCat("^", d));
1825   }
1826   if (!this->device.empty()) {
1827     n.set_device(this->device);
1828   }
1829   return n;
1830 }
1831 
1832 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1833 FunctionDef FunctionDefHelper::Create(
1834     const string& function_name, gtl::ArraySlice<string> in_def,
1835     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1836     gtl::ArraySlice<Node> node_def,
1837     gtl::ArraySlice<std::pair<string, string>> ret_def,
1838     gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1839   FunctionDef fdef;
1840 
1841   // Signature
1842   OpDefBuilder b(function_name);
1843   for (const auto& i : in_def) b.Input(i);
1844   for (const auto& o : out_def) b.Output(o);
1845   for (const auto& a : attr_def) b.Attr(a);
1846   for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1847 
1848   OpRegistrationData op_reg_data;
1849   TF_CHECK_OK(b.Finalize(&op_reg_data));
1850   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1851 
1852   // Function body
1853   for (const auto& n : node_def) {
1854     *(fdef.add_node_def()) = n.ToNodeDef();
1855   }
1856 
1857   // Returns
1858   for (const auto& r : ret_def) {
1859     fdef.mutable_ret()->insert({r.first, r.second});
1860   }
1861 
1862   // Control returns
1863   for (const auto& cr : control_ret_def) {
1864     fdef.mutable_control_ret()->insert({cr.first, cr.second});
1865   }
1866 
1867   auto* op_def_registry = OpRegistry::Global();
1868   // Check if any op is stateful.
1869   for (const auto& n : node_def) {
1870     const OpDef* op_def = nullptr;
1871     auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1872     // Lookup can fail if e.g. we are calling a function that was not yet
1873     // defined.  If it happens, conservatively assume the op is stateful.
1874     if (!status.ok() || op_def->is_stateful()) {
1875       fdef.mutable_signature()->set_is_stateful(true);
1876     }
1877   }
1878 
1879   return fdef;
1880 }
1881 
1882 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1883 FunctionDef FunctionDefHelper::Create(
1884     const string& function_name, gtl::ArraySlice<string> in_def,
1885     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1886     gtl::ArraySlice<Node> node_def,
1887     gtl::ArraySlice<std::pair<string, string>> ret_def) {
1888   return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1889                 /*control_ret_def=*/{});
1890 }
1891 
1892 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1893 FunctionDef FunctionDefHelper::Define(const string& name,
1894                                       gtl::ArraySlice<string> arg_def,
1895                                       gtl::ArraySlice<string> ret_def,
1896                                       gtl::ArraySlice<string> attr_def,
1897                                       gtl::ArraySlice<Node> node_def) {
1898   FunctionDef fdef;
1899   OpDefBuilder b(name);
1900   for (const auto& a : arg_def) b.Input(a);
1901   for (const auto& r : ret_def) b.Output(r);
1902   for (const auto& a : attr_def) b.Attr(a);
1903 
1904   OpRegistrationData op_reg_data;
1905   TF_CHECK_OK(b.Finalize(&op_reg_data));
1906   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1907 
1908   // Mapping from legacy output names to NodeDef outputs.
1909   std::unordered_map<string, string> ret_index;
1910   for (const auto& a : fdef.signature().input_arg()) {
1911     ret_index[a.name()] = a.name();
1912   }
1913 
1914   // For looking up OpDefs
1915   auto* op_def_registry = OpRegistry::Global();
1916 
1917   // Function body
1918   for (const auto& src : node_def) {
1919     NodeDef* n = fdef.add_node_def();
1920     n->set_op(src.op);
1921     n->set_name(src.GetName());
1922     for (const auto& a : src.attr) {
1923       n->mutable_attr()->insert({a.first, a.second.proto});
1924     }
1925     for (const string& a : src.arg) {
1926       const auto iter = ret_index.find(a);
1927       CHECK(iter != ret_index.end())
1928           << "Node input '" << a << "' in '" << n->name() << "' of " << name;
1929       n->add_input(iter->second);
1930     }
1931     for (const string& d : src.dep) {
1932       n->add_input(strings::StrCat("^", d));
1933     }
1934 
1935     // Add the outputs of this node to ret_index.
1936     const OpDef* op_def = nullptr;
1937     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1938     CHECK(op_def != nullptr) << n->op();
1939     NameRangeMap output_names;
1940     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1941     for (const auto& o : output_names) {
1942       CHECK_LE(o.second.second, src.ret.size())
1943           << "Missing ret for output '" << o.first << "' in '" << n->name()
1944           << "' of " << name;
1945       for (int i = o.second.first; i < o.second.second; ++i) {
1946         ret_index[src.ret[i]] =
1947             strings::StrCat(n->name(), ":", o.first, ":", i - o.second.first);
1948       }
1949     }
1950     if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1951   }
1952 
1953   // Returns
1954   for (const auto& r : fdef.signature().output_arg()) {
1955     const auto iter = ret_index.find(r.name());
1956     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1957     fdef.mutable_ret()->insert({r.name(), iter->second});
1958   }
1959   return fdef;
1960 }
1961 
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1962 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1963                                       gtl::ArraySlice<string> ret_def,
1964                                       gtl::ArraySlice<string> attr_def,
1965                                       gtl::ArraySlice<Node> node_def) {
1966   return Define("_", arg_def, ret_def, attr_def, node_def);
1967 }
1968 
1969 namespace gradient {
1970 
1971 typedef std::unordered_map<string, Creator> OpGradFactory;
1972 
GetOpGradFactory()1973 OpGradFactory* GetOpGradFactory() {
1974   static OpGradFactory* factory = new OpGradFactory;
1975   return factory;
1976 }
1977 
RegisterOp(const string & op,Creator func)1978 bool RegisterOp(const string& op, Creator func) {
1979   CHECK(GetOpGradFactory()->insert({op, func}).second)
1980       << "Duplicated gradient for " << op;
1981   return true;
1982 }
1983 
GetOpGradientCreator(const string & op,Creator * creator)1984 Status GetOpGradientCreator(const string& op, Creator* creator) {
1985   auto fac = GetOpGradFactory();
1986   auto iter = fac->find(op);
1987   if (iter == fac->end()) {
1988     return errors::NotFound("No gradient defined for op: ", op);
1989   }
1990   *creator = iter->second;
1991   return Status::OK();
1992 }
1993 
1994 }  // end namespace gradient
1995 
1996 }  // namespace tensorflow
1997