• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 
18 #include <ctype.h>
19 
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/util/device_name_utils.h"
41 #include "tensorflow/core/util/equal_graph_def.h"
42 
43 namespace tensorflow {
44 
45 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
46 /* static */ constexpr const char* const
47     FunctionLibraryDefinition::kDeviceArgOp;
48 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
49 /* static */ constexpr const char* const
50     FunctionLibraryDefinition::kDeviceRetOp;
51 /* static */ constexpr const char* const
52     FunctionLibraryDefinition::kIntsOnDeviceAttr;
53 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
54 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
55 
56 // Extracts the actual type from "attr_values" based on its definition
57 // "arg_def".
58 //
59 // If "arg_def" is a N*T type, *is_type_list is set to false, and
60 // *dtypes is set to be a vector of size N and each element is T.
61 //
62 // If "arg_def" is a list(type), *is_type_list is set to true, and
63 // *dtypes is set to be a vector of types specified in attrs for
64 // arg_def.
65 //
66 // Otherwise (arg_def is a simple type T), *is_type_list is set to
67 // false, and *dtypes is set to a single element vector, whose only
68 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)69 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
70                   bool* is_type_list, DataTypeVector* dtypes) {
71   dtypes->clear();
72   if (!arg_def.type_list_attr().empty()) {
73     const AttrValue* v = attrs.Find(arg_def.type_list_attr());
74     if (v == nullptr) {
75       return errors::NotFound("type attr not found: ",
76                               arg_def.type_list_attr());
77     }
78     *is_type_list = true;
79     for (int i = 0; i < v->list().type_size(); ++i) {
80       dtypes->push_back(v->list().type(i));
81     }
82     return Status::OK();
83   }
84 
85   *is_type_list = false;
86   int num = 1;
87   if (!arg_def.number_attr().empty()) {
88     const AttrValue* v = attrs.Find(arg_def.number_attr());
89     if (v == nullptr) {
90       return errors::NotFound("type attr not found: ", arg_def.type_attr());
91     }
92     num = v->i();
93   }
94 
95   DataType dtype;
96   if (arg_def.type() != DT_INVALID) {
97     dtype = arg_def.type();
98   } else if (arg_def.type_attr().empty()) {
99     dtype = DT_INVALID;
100   } else {
101     const AttrValue* v = attrs.Find(arg_def.type_attr());
102     if (v == nullptr) {
103       return errors::NotFound("type attr not found: ", arg_def.type_attr());
104     }
105     dtype = v->type();
106   }
107   dtypes->resize(num, dtype);
108   return Status::OK();
109 }
110 
111 namespace {
112 
113 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)114 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
115   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
116 }
117 
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)118 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
119   // attr_values should specify all attrs defined in fdef.
120   for (const auto& a : sig.attr()) {
121     const AttrValue* v = attr_values.Find(a.name());
122     if (!v) {
123       return errors::NotFound("Attr ", a.name(), " is not found from ",
124                               SummarizeOpDef(sig));
125     }
126     Status status = AttrValueHasType(*v, a.type());
127     if (!status.ok()) {
128       errors::AppendToMessage(&status, "for attr '", a.name(), "'");
129       return status;
130     }
131   }
132 
133 // TODO(josh11b): Enable this code once it works with function gradients.
134 // Right now the C++ function gradient code assumes it can pass
135 // all the attrs of the function to the gradient, and any attrs that
136 // the gradient doesn't care about will be ignored.
137 #if 0
138   if (attr_values.size() != sig.attr_size()) {
139     for (const auto& a : attr_values) {
140       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
141       bool found = false;
142       for (const auto& s : sig.attr()) {
143         if (a.first == s.name()) {
144           found = true;
145           break;
146         }
147       }
148       if (!found) {
149         return errors::NotFound("Attr ", a.first, " is not found in ",
150                                 SummarizeOpDef(sig));
151       }
152     }
153   }
154 #endif
155 
156   return Status::OK();
157 }
158 
159 // A helper class for instantiating functions. This contains shared information
160 // like the resulting graph and node name index.
161 class FunctionInstantiationHelper {
162  public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)163   FunctionInstantiationHelper(GetFunctionSignature get_function,
164                               InstantiationResult* result)
165       : get_function_(std ::move(get_function)), result_(*result) {
166     result_.nodes.clear();
167   }
168 
169   // Builds index for nodes that can be used as node's input arguments.
170   // `resource_arg_unique_id`: if non-negative, will be populated to the
171   // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64 resource_arg_unique_id)172   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
173                             const FunctionDef::ArgAttrs* arg_attrs,
174                             bool ints_on_device, int64 resource_arg_unique_id) {
175     bool is_type_list;
176     DataTypeVector dtypes;
177     TF_RETURN_IF_ERROR(
178         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
179     CHECK_GE(dtypes.size(), size_t{1});
180     int arg_index = result_.nodes.size();
181     TF_RETURN_IF_ERROR(
182         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
183     // Creates dtypes.size() nodes in the graph.
184     for (size_t i = 0; i < dtypes.size(); ++i) {
185       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
186                                  {true, arg_index, 0, false, {dtypes[i]}}));
187       DCHECK_EQ(arg_index, result_.nodes.size());
188       string name = arg_def.name();
189       if (dtypes.size() > 1) {
190         strings::StrAppend(&name, "_", i);
191       }
192       NodeDef* gnode = AddNode(name);
193       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
194         gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
195       } else {
196         gnode->set_op(FunctionLibraryDefinition::kArgOp);
197       }
198       DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
199       AddAttr("T", dtype, gnode);
200       AddAttr("index", arg_index, gnode);
201       if (resource_arg_unique_id >= 0) {
202         AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
203       }
204       if (arg_attrs) {
205         for (const auto& arg_attr : arg_attrs->attr()) {
206           AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
207         }
208       }
209       result_.arg_types.push_back(dtypes[i]);
210       ++arg_index;
211     }
212     return Status::OK();
213   }
214 
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)215   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
216                               const int arg_index) {
217     const OpDef* node_sig = nullptr;
218     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
219     if (node_sig->output_arg_size() == 0) {
220       return AddItem(node.name(), {false, arg_index, 0, false, {}});
221     }
222     const int num_retval = node_sig->output_arg_size();
223     int start = 0;
224     bool is_type_list;
225     DataTypeVector dtypes;
226     for (int i = 0; i < num_retval; ++i) {
227       TF_RETURN_IF_ERROR(
228           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
229       // Note that we rely on the backwards-compatibility test enforcing
230       // that output_arg(*).name() doesn't change here.
231       const string base_name =
232           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
233       TF_RETURN_IF_ERROR(
234           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
235       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
236         TF_RETURN_IF_ERROR(
237             AddItem(strings::StrCat(base_name, ":", j),
238                     {false, arg_index, start + j, false, {dtypes[j]}}));
239       }
240       start += dtypes.size();
241     }
242     return Status::OK();
243   }
244 
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)245   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
246     const OpDef* fnode_sig = nullptr;
247     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
248     NodeDef* gnode = AddNode(fnode.name());
249     gnode->set_op(fnode.op());
250     gnode->set_device(fnode.device());
251     int gnode_idx = nodes_.size() - 1;
252 
253     // Input
254     const int num_args = fnode_sig->input_arg_size();
255     bool is_type_list;  // ignored
256     DataTypeVector dtypes;
257     int fnode_arg_index = 0;
258     for (int i = 0; i < num_args; ++i) {
259       TF_RETURN_IF_ERROR(
260           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
261       // Consume inputs (indexed by fnode_arg_index) until we have
262       // matched each element of dtypes (indexed by j).
263       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
264         if (fnode_arg_index >= fnode.input_size()) {
265           // Should never happen if we computed dtypes correctly.
266           return errors::InvalidArgument(
267               "Attempt to access beyond input size: ", fnode_arg_index,
268               " >= ", fnode.input_size());
269         }
270         // Look up the next input.
271         const string& input_name = fnode.input(fnode_arg_index);
272         const auto* item = GetItemOrNull(input_name);
273         if (item == nullptr) {
274           return errors::InvalidArgument(
275               "input ", input_name,
276               " is not found: ", FormatNodeDefForError(fnode));
277         }
278         if (item->dtypes.size() > dtypes.size() - j) {
279           return errors::InvalidArgument("Input ", input_name, " too long for ",
280                                          fnode_sig->input_arg(i).name());
281         }
282         // Match up all the elements of this input (indexed by k) with
283         // elements of dtypes (advancing j).
284         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
285           if (item->dtypes[k] != dtypes[j]) {
286             return errors::InvalidArgument(
287                 "input ", fnode_sig->input_arg(i).name(), "[", j,
288                 "] expected type ", DataTypeString(dtypes[j]),
289                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
290                 input_name, "[", k, "]");
291           }
292           if (item->is_func_arg) {
293             AddInput(gnode_idx, item->nid + k, 0);
294           } else {
295             AddInput(gnode_idx, item->nid, item->idx + k);
296           }
297         }
298       }
299     }
300 
301     // Control deps.
302     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
303       const string& input = fnode.input(i);
304       if (input.empty() || input[0] != '^') {
305         return errors::InvalidArgument("Expected input[", i, "] == '", input,
306                                        "' to be a control input.");
307       }
308       int nid = -1;
309       const string node_name = input.substr(1);
310       const string node_colon = node_name + ":";
311       const string node_colon_bound = node_name + ";";
312       // index_ is a map sorted lexicographically, so the key we are looking for
313       // must lie in the range [node_name, node_colon_bound).
314       auto it = index_.lower_bound(node_name);
315       while (it != index_.end() && it->first <= node_colon_bound) {
316         if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
317           nid = it->second.nid;
318           break;
319         }
320         ++it;
321       }
322       if (nid == -1) {
323         return errors::InvalidArgument("input[", i, "] == '", input,
324                                        "', is not found.");
325       }
326       AddDep(gnode_idx, nid);
327     }
328 
329     // Attrs.
330     for (const auto& p : attrs) {
331       (*gnode->mutable_attr())[p.first] = p.second;
332     }
333 
334     return Status::OK();
335   }
336 
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)337   Status AddReturnNode(
338       const OpDef::ArgDef& ret_def, AttrSlice attrs,
339       const ::tensorflow::protobuf::Map<string, string>& ret_map,
340       bool ints_on_device, int* ret_index) {
341     auto ret_iter = ret_map.find(ret_def.name());
342     if (ret_iter == ret_map.end()) {
343       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
344     }
345     bool is_type_list;
346     DataTypeVector dtypes;
347     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
348     CHECK_GE(dtypes.size(), size_t{1});
349     const auto* item = GetItemOrNull(ret_iter->second);
350     if (item == nullptr) {
351       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
352                                      ret_iter->second, " is not found.");
353     }
354     if (dtypes != item->dtypes) {
355       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
356                                      " : ", DataTypeVectorString(dtypes),
357                                      " vs. ",
358                                      DataTypeVectorString(item->dtypes));
359     }
360     for (size_t i = 0; i < dtypes.size(); ++i) {
361       string name = strings::StrCat(ret_def.name(), "_RetVal");
362       if (dtypes.size() > 1) {
363         strings::StrAppend(&name, "_", i);
364       }
365       NodeDef* gnode = AddNode(name);
366       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
367         gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
368       } else {
369         gnode->set_op(FunctionLibraryDefinition::kRetOp);
370       }
371       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
372       DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
373       AddAttr("T", dtype, gnode);
374       AddAttr("index", (*ret_index)++, gnode);
375       result_.ret_types.push_back(dtypes[i]);
376     }
377     return Status::OK();
378   }
379 
380   // Adds the actual node inputs to the result graph by converting indexes to
381   // the node names.
AddNodeInputs()382   void AddNodeInputs() {
383     for (int i = 0; i < result_.nodes.size(); i++) {
384       NodeInfo& node_info = nodes_[i];
385       for (const auto& p : node_info.data_inputs) {
386         result_.nodes[i].add_input(Name(p.first, p.second));
387       }
388       for (int index : node_info.control_inputs) {
389         result_.nodes[i].add_input(Dep(index));
390       }
391     }
392   }
393 
394  private:
395   // This is used to build a small index for all names that can be used as a
396   // node's input arguments.
397   //
398   // If is_func_arg is true, the name is a function's argument.  In
399   // this case, the produced graph def has node[nid:nid + dtype.size()].
400   //
401   // Otherwise, the name is a function body's node return value.  In
402   // this case, the produced graph def has one node node[nid] and
403   // the node's output index [idx ... idx + num) corresponds to the
404   // named outputs.
405   //
406   // In all cases, "dtype" specifies the data type.
407   struct NameInfoItem {
408     bool is_func_arg;
409     int nid;
410     int idx;
411     bool is_type_list;
412     DataTypeVector dtypes;
413   };
414 
415   // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)416   Status AddItem(const string& name, const NameInfoItem& item) {
417     if (!index_.insert({name, item}).second) {
418       return errors::InvalidArgument(
419           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
420                           " name: "),
421           name);
422     }
423     return Status::OK();
424   }
425 
GetItemOrNull(const string & name) const426   const NameInfoItem* GetItemOrNull(const string& name) const {
427     return gtl::FindOrNull(index_, name);
428   }
429 
Dep(int node_index) const430   string Dep(int node_index) const {
431     return strings::StrCat("^", Name(node_index));
432   }
433 
Name(int node_index) const434   string Name(int node_index) const {
435     CHECK_LT(node_index, nodes_.size());
436     return nodes_[node_index].name;
437   }
438 
Name(int node_index,int output_index) const439   string Name(int node_index, int output_index) const {
440     if (output_index == 0) {
441       return Name(node_index);
442     } else {
443       return strings::StrCat(Name(node_index), ":", output_index);
444     }
445   }
446 
AddNode(const string & name)447   NodeDef* AddNode(const string& name) {
448     result_.nodes.emplace_back();
449     NodeDef* gnode = &result_.nodes.back();
450     gnode->set_name(name);
451     nodes_.push_back({name, {}, {}});
452     CHECK_EQ(result_.nodes.size(), nodes_.size());
453     return gnode;
454   }
455 
AddInput(int node_index,int output_node,int output_index)456   void AddInput(int node_index, int output_node, int output_index) {
457     CHECK_LT(node_index, nodes_.size());
458     nodes_[node_index].data_inputs.push_back(
459         std::make_pair(output_node, output_index));
460   }
461 
AddDep(int node_index,int dep_index)462   void AddDep(int node_index, int dep_index) {
463     CHECK_LT(node_index, nodes_.size());
464     nodes_[node_index].control_inputs.push_back(dep_index);
465   }
466 
467   GetFunctionSignature get_function_;
468   InstantiationResult& result_;
469   // A small index for all names that can be used as a node's input arguments.
470   std::map<string, NameInfoItem> index_;
471   // This contains information about a node in the new graph including the node
472   // names and input nodes' indexes.
473   struct NodeInfo {
474     string name;
475     // Data inputs where <n, k> means arg k of node n.
476     std::vector<std::pair<int, int>> data_inputs;
477     // Control inputs (dependencies).
478     std::vector<int> control_inputs;
479   };
480   // nodes_[i] is the information about result_.nodes[i].
481   std::vector<NodeInfo> nodes_;
482 };
483 
484 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)485 string Print(const OpDef::ArgDef& arg) {
486   string out;
487   strings::StrAppend(&out, arg.name(), ":");
488   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
489   if (!arg.number_attr().empty()) {
490     strings::StrAppend(&out, arg.number_attr(), "*");
491   }
492   if (arg.type() != DT_INVALID) {
493     strings::StrAppend(&out, DataTypeString(arg.type()));
494   } else {
495     strings::StrAppend(&out, arg.type_attr());
496   }
497   if (arg.is_ref()) strings::StrAppend(&out, ")");
498   return out;
499 }
500 
501 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)502 string Print(const AttrValue& attr_value) {
503   if (attr_value.value_case() == AttrValue::kType) {
504     return DataTypeString(attr_value.type());
505   } else if ((attr_value.value_case() == AttrValue::kList) &&
506              (attr_value.list().type_size() > 0)) {
507     string ret = "{";
508     for (int i = 0; i < attr_value.list().type_size(); ++i) {
509       if (i > 0) strings::StrAppend(&ret, ", ");
510       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
511     }
512     strings::StrAppend(&ret, "}");
513     return ret;
514   } else if (attr_value.value_case() == AttrValue::kFunc) {
515     if (attr_value.func().attr_size() == 0) {
516       return attr_value.func().name();
517     }
518     std::vector<string> entries;
519     for (const auto& p : attr_value.func().attr()) {
520       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
521     }
522     std::sort(entries.begin(), entries.end());
523     return strings::StrCat(attr_value.func().name(), "[",
524                            absl::StrJoin(entries, ", "), "]");
525   }
526   return SummarizeAttrValue(attr_value);
527 }
528 
529 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)530 string Print(const NodeDef& n) {
531   string out;
532   strings::StrAppend(&out, n.name(), " = ", n.op());
533   if (n.attr_size() > 0) {
534     std::vector<string> entries;
535     for (auto& a : n.attr()) {
536       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
537     }
538     std::sort(entries.begin(), entries.end());
539     // Add a short device string at the end of all attributes.
540     if (!n.device().empty()) {
541       DeviceNameUtils::ParsedName parsed;
542       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
543         entries.push_back(
544             strings::StrCat("device=", parsed.type, ":", parsed.id));
545       } else {
546         entries.push_back("device=<FAILED_TO_PARSE>");
547       }
548     }
549     strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
550   }
551   strings::StrAppend(&out, "(");
552   std::vector<StringPiece> dat;
553   std::vector<string> dep;
554   for (StringPiece s : n.input()) {
555     if (absl::ConsumePrefix(&s, "^")) {
556       dep.emplace_back(s);
557     } else {
558       dat.push_back(s);
559     }
560   }
561   strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
562   if (!dep.empty()) {
563     strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
564   }
565   return out;
566 }
567 
Print(const FunctionDef & fdef)568 string Print(const FunctionDef& fdef) {
569   string out;
570   const OpDef& sig = fdef.signature();
571   strings::StrAppend(&out, "\n", sig.name());
572   if (sig.attr_size() > 0) {
573     strings::StrAppend(&out, "[");
574     for (int i = 0; i < sig.attr_size(); ++i) {
575       const auto& a = sig.attr(i);
576       if (i > 0) strings::StrAppend(&out, ", ");
577       if (a.type() == "type") {
578         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
579       } else {
580         strings::StrAppend(&out, a.name(), ":", a.type());
581       }
582     }
583     strings::StrAppend(&out, "]");
584   }
585   strings::StrAppend(&out, "(");
586   for (int i = 0; i < sig.input_arg_size(); ++i) {
587     if (i > 0) strings::StrAppend(&out, ", ");
588     strings::StrAppend(&out, Print(sig.input_arg(i)));
589   }
590   strings::StrAppend(&out, ") -> (");
591   for (int i = 0; i < sig.output_arg_size(); ++i) {
592     if (i > 0) strings::StrAppend(&out, ", ");
593     strings::StrAppend(&out, Print(sig.output_arg(i)));
594   }
595   strings::StrAppend(&out, ") {\n");
596   for (const auto& n : fdef.node_def()) {
597     strings::StrAppend(&out, "  ", Print(n), "\n");
598   }
599   for (const auto& cr : fdef.control_ret()) {
600     strings::StrAppend(&out, "  @return ", cr.first, " = ", cr.second, "\n");
601   }
602   for (const auto& r : fdef.ret()) {
603     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
604   }
605   strings::StrAppend(&out, "}\n");
606   return out;
607 }
608 
Print(gtl::ArraySlice<const NodeDef * > nodes)609 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
610   std::vector<const NodeDef*> arg;
611   std::vector<const NodeDef*> ret;
612   std::vector<const NodeDef*> body;
613   for (const NodeDef* n : nodes) {
614     if (n->op() == FunctionLibraryDefinition::kArgOp ||
615         n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
616       arg.push_back(n);
617     } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
618                n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
619       ret.push_back(n);
620     } else {
621       body.push_back(n);
622     }
623   }
624   auto comp = [](const NodeDef* x, const NodeDef* y) {
625     int xi;
626     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
627     int yi;
628     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
629     return xi < yi;
630   };
631   std::sort(arg.begin(), arg.end(), comp);
632   std::sort(ret.begin(), ret.end(), comp);
633   string out;
634   strings::StrAppend(&out, "\n(");
635   auto get_type_and_device = [](const NodeDef& n) {
636     DataType dt;
637     if (!TryGetNodeAttr(n, "T", &dt)) {
638       dt = DT_INVALID;
639     }
640     if (!n.device().empty()) {
641       DeviceNameUtils::ParsedName parsed;
642       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
643         return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
644                                parsed.id);
645       } else {
646         LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
647                      << n.op() << ":" << n.name();
648         return strings::StrCat(DataTypeString(dt), "@",
649                                "<FAILED_TO_PARSE_DEVICE>");
650       }
651     }
652     return DataTypeString(dt);
653   };
654   for (size_t i = 0; i < arg.size(); ++i) {
655     const NodeDef* n = arg[i];
656     if (i > 0) strings::StrAppend(&out, ", ");
657     CHECK_GE(n->attr_size(), 2);
658     strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
659   }
660   strings::StrAppend(&out, ") -> (");
661   for (size_t i = 0; i < ret.size(); ++i) {
662     const NodeDef* n = ret[i];
663     if (i > 0) strings::StrAppend(&out, ", ");
664     CHECK_LE(2, n->attr_size());
665 
666     // The _RetVal op should have a unique non-control input. We assert that
667     // here and add it to the output.
668     bool found_non_control_input = false;
669     for (const string& input : n->input()) {
670       if (!input.empty() && input[0] != '^') {
671         DCHECK_EQ(found_non_control_input, false)
672             << "RetVal node has more than one non-control input: "
673             << absl::StrJoin(n->input(), ", ");
674         strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
675         found_non_control_input = true;
676       }
677     }
678     DCHECK_EQ(found_non_control_input, true)
679         << "RetVal did not have any non-control inputs: "
680         << absl::StrJoin(n->input(), ", ");
681   }
682   strings::StrAppend(&out, ") {\n");
683   for (size_t i = 0; i < body.size(); ++i) {
684     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
685   }
686   strings::StrAppend(&out, "}\n");
687   return out;
688 }
689 
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)690 Status AddDefaultAttrs(const string& op,
691                        const GetFunctionSignature& get_function,
692                        AttrValueMap* attrs) {
693   const OpDef* op_def = nullptr;
694   TF_RETURN_IF_ERROR(get_function(op, &op_def));
695   AttrSlice attr_slice(attrs);
696   for (const auto& attr_def : op_def->attr()) {
697     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
698       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
699         return errors::Internal("Somehow duplicated: ", attr_def.name());
700       }
701     }
702   }
703   return Status::OK();
704 }
705 
706 }  // end namespace
707 
708 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)709 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
710                            GetFunctionSignature get_function,
711                            InstantiationResult* result) {
712   if (VLOG_IS_ON(5)) {
713     const auto& signature = fdef.signature();
714     VLOG(5) << "Instantiate function definition: name=" << signature.name()
715             << " #input_args=" << signature.input_arg_size()
716             << " #output_args=" << signature.output_arg_size()
717             << " #control_output=" << signature.control_output_size();
718     for (const auto& line : str_util::Split(Print(fdef), '\n')) {
719       VLOG(5) << "|| " << line;
720     }
721   }
722 
723   const OpDef& sig = fdef.signature();
724   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
725 
726   bool ints_on_device =
727       fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
728       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
729 
730   FunctionInstantiationHelper helper(get_function, result);
731   Status s;
732   for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
733     const OpDef::ArgDef& arg_def = sig.input_arg(i);
734     auto it = fdef.arg_attr().find(i);
735     const FunctionDef::ArgAttrs* arg_attrs =
736         it != fdef.arg_attr().end() ? &it->second : nullptr;
737     auto resource_id_it = fdef.resource_arg_unique_id().find(i);
738     int64 resource_arg_unique_id =
739         resource_id_it != fdef.resource_arg_unique_id().end()
740             ? resource_id_it->second
741             : -1LL;
742     s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
743                                   ints_on_device, resource_arg_unique_id);
744 
745     if (!s.ok()) {
746       errors::AppendToMessage(&s, "In ", Print(arg_def));
747       return s;
748     }
749   }
750 
751   auto substitute = [attr_values](StringPiece name, AttrValue* val) {
752     if (const AttrValue* v = attr_values.Find(name)) {
753       *val = *v;
754       return true;
755     }
756     return false;
757   };
758 
759   // Makes a copy of all attrs in fdef and substitutes placeholders.
760   // After this step, every attr is bound to a concrete value.
761   std::vector<AttrValueMap> node_attrs;
762   node_attrs.resize(fdef.node_def_size());
763   for (int i = 0; i < fdef.node_def_size(); ++i) {
764     for (auto attr : fdef.node_def(i).attr()) {
765       if (!SubstitutePlaceholders(substitute, &attr.second)) {
766         return errors::InvalidArgument("Failed to bind all placeholders in ",
767                                        SummarizeAttrValue(attr.second));
768       }
769       if (!node_attrs[i].insert(attr).second) {
770         return errors::Internal("Somehow duplicated: ", attr.first);
771       }
772     }
773     TF_RETURN_IF_ERROR(
774         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
775   }
776 
777   for (int i = 0; i < fdef.node_def_size(); ++i) {
778     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
779                                     result->nodes.size() + i);
780     if (!s.ok()) {
781       errors::AppendToMessage(&s, "In ",
782                               FormatNodeDefForError(fdef.node_def(i)));
783       return s;
784     }
785   }
786   // Emits one node for each fdef.node_def.
787   for (int i = 0; i < fdef.node_def_size(); ++i) {
788     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
789     if (!s.ok()) {
790       errors::AppendToMessage(&s, "In ",
791                               FormatNodeDefForError(fdef.node_def(i)));
792       return s;
793     }
794   }
795 
796   // Emits nodes for the function's return values.
797   int ret_index = 0;
798   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
799     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
800                              &ret_index);
801     if (!s.ok()) {
802       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
803       return s;
804     }
805   }
806 
807   // Adds the actual node inputs using the input indexes.
808   helper.AddNodeInputs();
809 
810   return Status::OK();
811 }
812 
DebugString(const FunctionDef & func_def)813 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
814 
DebugString(const GraphDef & instantiated_func_def)815 string DebugString(const GraphDef& instantiated_func_def) {
816   std::vector<const NodeDef*> ptrs;
817   for (const NodeDef& n : instantiated_func_def.node()) {
818     ptrs.push_back(&n);
819   }
820   return Print(ptrs);
821 }
822 
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)823 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
824   std::vector<const NodeDef*> ptrs;
825   for (const NodeDef& n : instantiated_func_nodes) {
826     ptrs.push_back(&n);
827   }
828   return Print(ptrs);
829 }
830 
DebugStringWhole(const GraphDef & gdef)831 string DebugStringWhole(const GraphDef& gdef) {
832   string ret;
833   for (const auto& fdef : gdef.library().function()) {
834     strings::StrAppend(&ret, Print(fdef));
835   }
836   strings::StrAppend(&ret, "\n");
837   for (const auto& ndef : gdef.node()) {
838     strings::StrAppend(&ret, Print(ndef), "\n");
839   }
840   return ret;
841 }
842 
843 namespace {
844 
845 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
846 // Python, it's possible to access unset attrs, which returns a default value
847 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)848 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
849   std::map<string, AttrValue> set_attrs;
850   for (const auto& pair : fdef.attr()) {
851     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
852       set_attrs[pair.first] = pair.second;
853     }
854   }
855   return set_attrs;
856 }
857 
858 }  // end namespace
859 
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)860 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
861   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
862 
863   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
864   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
865   if (f1_attrs.size() != f2_attrs.size()) return false;
866   for (const auto& iter1 : f1_attrs) {
867     auto iter2 = f2_attrs.find(iter1.first);
868     if (iter2 == f2_attrs.end()) return false;
869     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
870   }
871 
872   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
873     return false;
874   }
875 
876   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
877   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
878   if (ret1 != ret2) return false;
879 
880   std::map<string, string> control_ret1(f1.control_ret().begin(),
881                                         f1.control_ret().end());
882   std::map<string, string> control_ret2(f2.control_ret().begin(),
883                                         f2.control_ret().end());
884   if (control_ret1 != control_ret2) return false;
885 
886   return true;
887 }
888 
FunctionDefHash(const FunctionDef & fdef)889 uint64 FunctionDefHash(const FunctionDef& fdef) {
890   // signature
891   uint64 h = OpDefHash(fdef.signature());
892 
893   // attrs
894   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
895   for (const auto& p : attrs) {
896     h = Hash64(p.first.data(), p.first.size(), h);
897     h = Hash64Combine(AttrValueHash(p.second), h);
898   }
899 
900   // node defs
901   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
902 
903   // output names
904   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
905   for (const auto& p : ret) {
906     h = Hash64(p.first.data(), p.first.size(), h);
907     h = Hash64(p.second.data(), p.second.size(), h);
908   }
909 
910   // control output names
911   std::map<string, string> control_ret(fdef.control_ret().begin(),
912                                        fdef.control_ret().end());
913   for (const auto& p : control_ret) {
914     h = Hash64(p.first.data(), p.first.size(), h);
915     h = Hash64(p.second.data(), p.second.size(), h);
916   }
917 
918   return h;
919 }
920 
921 static constexpr const char* const kExecutorAttr = "_executor";
922 
923 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)924 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
925                                             AttrSlice attrs) {
926   if (!options.executor_type.empty()) {
927     return options.executor_type;
928   } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
929     return executor_attr->s();
930   } else {
931     return string();
932   }
933 }
934 
935 namespace {
936 class AttrKeyAndValue {
937  public:
938   enum ValueRepresentationOp {
939     kRaw,
940     kCEscape,
941   };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)942   AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
943                   ValueRepresentationOp value_op = kRaw)
944       : key_name_(key_name),
945         key_suffix_(key_suffix),
946         value_op_(value_op),
947         value_(std::move(value)) {}
948 
operator <(const AttrKeyAndValue & b) const949   bool operator<(const AttrKeyAndValue& b) const {
950     if (key_name_ != b.key_name_) {
951       return key_name_ < b.key_name_;
952     } else if (key_suffix_ != b.key_suffix_) {
953       return key_suffix_ < b.key_suffix_;
954     } else {
955       return value_ < b.value_;
956     }
957   }
958 
AppendTo(bool first,string * s) const959   void AppendTo(bool first, string* s) const {
960     absl::string_view v;
961     bool add_escaped = false;
962     if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
963       // Use CEscape call below
964       add_escaped = true;
965     } else {
966       // Add raw value contents directly
967       v = value_;
968     }
969     if (key_suffix_ >= 0) {
970       strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
971     } else {
972       strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
973     }
974     if (add_escaped) {
975       strings::StrAppend(s, absl::CEscape(value_));
976     }
977   }
978 
979  private:
NeedsEscaping(const string & s)980   static bool NeedsEscaping(const string& s) {
981     for (auto c : s) {
982       if (!isalnum(c) && (c != ' ')) {
983         return true;
984       }
985     }
986     return false;
987   }
988 
989   absl::string_view key_name_;
990   int key_suffix_;  // -1 if missing
991   ValueRepresentationOp value_op_;
992   string value_;
993 };
994 }  // namespace
995 
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)996 string Canonicalize(const string& funcname, AttrSlice attrs,
997                     const FunctionLibraryRuntime::InstantiateOptions& options) {
998   absl::InlinedVector<AttrKeyAndValue, 8> entries;
999   entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1000                   options.input_devices.size());
1001   for (const auto& p : attrs) {
1002     if (p.first != kExecutorAttr) {
1003       entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
1004     }
1005   }
1006   if (!options.target.empty()) {
1007     entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1008                                       AttrKeyAndValue::kCEscape));
1009   }
1010   for (int i = 0; i < options.input_devices.size(); ++i) {
1011     entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1012                                       AttrKeyAndValue::kCEscape));
1013   }
1014   for (int i = 0; i < options.output_devices.size(); ++i) {
1015     entries.push_back(AttrKeyAndValue("_output_dev", i,
1016                                       options.output_devices[i],
1017                                       AttrKeyAndValue::kCEscape));
1018   }
1019   for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1020     entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1021                                       DataTypeString(iter.second.dtype)));
1022     entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1023                                       iter.second.shape.DebugString(),
1024                                       AttrKeyAndValue::kCEscape));
1025   }
1026   if (options.lib_def) {
1027     entries.push_back(AttrKeyAndValue(
1028         "_lib_def", -1,
1029         absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1030   }
1031   if (!options.state_handle.empty()) {
1032     entries.push_back(
1033         AttrKeyAndValue("_state_handle", -1, options.state_handle));
1034   }
1035   string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1036   if (!executor_type.empty()) {
1037     entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1038   }
1039   if (options.config_proto.ByteSize() > 0) {
1040     string config_proto_serialized;
1041     options.config_proto.SerializeToString(&config_proto_serialized);
1042     entries.push_back(AttrKeyAndValue("_config_proto", -1,
1043                                       config_proto_serialized,
1044                                       AttrKeyAndValue::kCEscape));
1045   }
1046   std::sort(entries.begin(), entries.end());
1047   string result = strings::StrCat(funcname, "[");
1048   bool first = true;
1049   for (const auto& entry : entries) {
1050     entry.AppendTo(first, &result);
1051     first = false;
1052   }
1053   result += "]";
1054   return result;
1055 }
1056 
Canonicalize(const string & funcname,AttrSlice attrs)1057 string Canonicalize(const string& funcname, AttrSlice attrs) {
1058   static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1059       new FunctionLibraryRuntime::InstantiateOptions;
1060   return Canonicalize(funcname, attrs, *kEmptyOptions);
1061 }
1062 
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1063 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1064                                      DataTypeSlice ret_types)
1065     : arg_types_(arg_types.begin(), arg_types.end()),
1066       ret_types_(ret_types.begin(), ret_types.end()) {
1067   args_.resize(arg_types_.size());
1068   rets_.resize(ret_types_.size());
1069 }
1070 
~FunctionCallFrame()1071 FunctionCallFrame::~FunctionCallFrame() {}
1072 
SetArgs(gtl::ArraySlice<Tensor> args)1073 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1074   // Input type checks.
1075   if (args.size() != arg_types_.size()) {
1076     return errors::InvalidArgument("Expects ", arg_types_.size(),
1077                                    " arguments, but ", args.size(),
1078                                    " is provided");
1079   }
1080   for (size_t i = 0; i < args.size(); ++i) {
1081     if (arg_types_[i] != args[i].dtype()) {
1082       return errors::InvalidArgument(
1083           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1084           DataTypeString(args[i].dtype()), " is provided");
1085     }
1086     args_[i] = args[i];
1087   }
1088   return Status::OK();
1089 }
1090 
GetRetvals(std::vector<Tensor> * rets) const1091 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1092   rets->clear();
1093   rets->reserve(rets_.size());
1094   for (size_t i = 0; i < rets_.size(); ++i) {
1095     const auto& item = rets_[i];
1096     if (item.has_val) {
1097       rets->push_back(item.val);
1098     } else {
1099       return errors::Internal("Retval[", i, "] does not have value");
1100     }
1101   }
1102   return Status::OK();
1103 }
1104 
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1105 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1106                                          bool allow_dead_tensors) {
1107   rets->clear();
1108   rets->reserve(rets_.size());
1109   for (size_t i = 0; i < rets_.size(); ++i) {
1110     if (rets_[i].has_val) {
1111       rets->emplace_back(std::move(rets_[i].val));
1112     } else if (allow_dead_tensors) {
1113       rets->emplace_back();
1114     } else {
1115       return errors::Internal("Retval[", i, "] does not have value");
1116     }
1117   }
1118   return Status::OK();
1119 }
1120 
GetArg(int index,Tensor * val) const1121 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
1122   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1123     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1124                                    args_.size(), ")");
1125   }
1126   *val = args_[index];
1127   return Status::OK();
1128 }
1129 
SetRetval(int index,const Tensor & val)1130 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1131   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1132     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1133                                    rets_.size(), ")");
1134   }
1135   if (val.dtype() != ret_types_[index]) {
1136     return errors::InvalidArgument(
1137         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1138         ", but ", DataTypeString(val.dtype()), " is provided.");
1139   }
1140   Retval* item = &rets_[index];
1141   if (!item->has_val) {
1142     item->has_val = true;
1143     item->val = val;
1144   } else {
1145     return errors::Internal("Retval[", index, "] has already been set.");
1146   }
1147   return Status::OK();
1148 }
1149 
1150 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in)1151     FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
1152     : fdef(fdef_in),
1153       // Exact shape inference for functions is handled by ShapeRefiner.
1154       // Here we pass a dummy shape inference function for legacy code paths.
1155       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1156                            true /* is_function */) {}
1157 
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1158 FunctionLibraryDefinition::FunctionLibraryDefinition(
1159     const FunctionLibraryDefinition& other)
1160     : default_registry_(other.default_registry_) {
1161   tf_shared_lock l(other.mu_);
1162   function_defs_ = other.function_defs_;
1163   func_grad_ = other.func_grad_;
1164 }
1165 
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1166 FunctionLibraryDefinition::FunctionLibraryDefinition(
1167     const OpRegistryInterface* default_registry,
1168     const FunctionDefLibrary& def_lib)
1169     : default_registry_(default_registry),
1170       function_defs_(def_lib.function_size()) {
1171   for (const auto& fdef : def_lib.function()) {
1172     // The latter function definition wins.
1173     auto& ptr = function_defs_[fdef.signature().name()];
1174     ptr.reset(new FunctionDefAndOpRegistration(fdef));
1175   }
1176   for (const auto& grad : def_lib.gradient()) {
1177     func_grad_[grad.function_name()] = grad.gradient_func();
1178   }
1179 }
1180 
~FunctionLibraryDefinition()1181 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1182 
Contains(const string & func) const1183 bool FunctionLibraryDefinition::Contains(const string& func) const {
1184   tf_shared_lock l(mu_);
1185   return function_defs_.find(func) != function_defs_.end();
1186 }
1187 
Find(const string & func) const1188 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1189   tf_shared_lock l(mu_);
1190   auto result = FindHelper(func);
1191   if (result) {
1192     return &result->fdef;
1193   } else {
1194     return nullptr;
1195   }
1196 }
1197 
1198 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1199 FunctionLibraryDefinition::FindHelper(const string& func) const {
1200   auto iter = function_defs_.find(func);
1201   if (iter == function_defs_.end()) {
1202     return nullptr;
1203   } else {
1204     return iter->second;
1205   }
1206 }
1207 
AddFunctionDef(const FunctionDef & fdef)1208 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
1209   mutex_lock l(mu_);
1210   bool added;
1211   return AddFunctionDefHelper(fdef, &added);
1212 }
1213 
AddFunctionDefHelper(const FunctionDef & fdef,bool * added)1214 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
1215                                                        bool* added) {
1216   *added = false;
1217   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1218       function_defs_[fdef.signature().name()];
1219   if (entry) {
1220     if (!FunctionDefsEqual(entry->fdef, fdef)) {
1221       return errors::InvalidArgument(
1222           "Cannot add function '", fdef.signature().name(),
1223           "' because a different function with the same name already "
1224           "exists.");
1225     }
1226     // Ignore duplicate FunctionDefs.
1227     return Status::OK();
1228   }
1229   const OpDef* op_def;
1230   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1231     return errors::InvalidArgument(
1232         "Cannot add function '", fdef.signature().name(),
1233         "' because an op with the same name already exists.");
1234   }
1235   entry = std::make_shared<FunctionDefAndOpRegistration>(fdef);
1236   *added = true;
1237   return Status::OK();
1238 }
1239 
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1240 Status FunctionLibraryDefinition::AddHelper(
1241     std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1242   *added = false;
1243   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1244       function_defs_[registration->fdef.signature().name()];
1245   if (entry) {
1246     if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1247       return errors::InvalidArgument(
1248           "Cannot add function '", registration->fdef.signature().name(),
1249           "' because a different function with the same name already "
1250           "exists.");
1251     }
1252     // Ignore duplicate FunctionDefs.
1253     return Status::OK();
1254   }
1255   const OpDef* op_def;
1256   if (default_registry_
1257           ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1258           .ok()) {
1259     return errors::InvalidArgument(
1260         "Cannot add function '", registration->fdef.signature().name(),
1261         "' because an op with the same name already exists.");
1262   }
1263   entry = std::move(registration);
1264   *added = true;
1265   return Status::OK();
1266 }
1267 
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1268 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1269     const string& func, const FunctionLibraryDefinition& other) {
1270   if (default_registry_ != other.default_registry_) {
1271     return errors::InvalidArgument(
1272         "Cannot copy function '", func,
1273         "' because CopyFunctionDefFrom() requires that both libraries have the "
1274         "same default registry.");
1275   }
1276   std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1277   {
1278     tf_shared_lock l(other.mu_);
1279     function_def = other.FindHelper(func);
1280   }
1281   if (!function_def) {
1282     return errors::InvalidArgument(
1283         "Cannot copy function '", func,
1284         "' because no function with that name exists in the other library.");
1285   }
1286   {
1287     mutex_lock l(mu_);
1288     std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1289     if (entry) {
1290       if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1291         return errors::InvalidArgument(
1292             "Cannot copy function '", func,
1293             "' because a different function with the same name already "
1294             "exists.");
1295       }
1296     } else {
1297       entry = std::move(function_def);
1298     }
1299   }
1300   return Status::OK();
1301 }
1302 
AddGradientDef(const GradientDef & grad)1303 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1304   mutex_lock l(mu_);
1305   bool added;
1306   return AddGradientDefHelper(grad, &added);
1307 }
1308 
AddGradientDefHelper(const GradientDef & grad,bool * added)1309 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1310                                                        bool* added) {
1311   *added = false;
1312   string* entry = &func_grad_[grad.function_name()];
1313   if (!entry->empty()) {
1314     if (*entry != grad.gradient_func()) {
1315       return errors::InvalidArgument(
1316           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1317           grad.function_name(), "' because it already has gradient function ",
1318           "'", *entry, "'");
1319     }
1320     // Ignore duplicate GradientDefs
1321     return Status::OK();
1322   }
1323   *entry = grad.gradient_func();
1324   *added = true;
1325   return Status::OK();
1326 }
1327 
AddLibrary(const FunctionLibraryDefinition & other)1328 Status FunctionLibraryDefinition::AddLibrary(
1329     const FunctionLibraryDefinition& other) {
1330   // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1331   // the duration of the function could lead to deadlock).
1332   FunctionLibraryDefinition clone(other);
1333   mutex_lock l(mu_);
1334   mutex_lock l2(clone.mu_);
1335   // Remember the funcs and grads that we added successfully so that
1336   // we can roll them back on error.
1337   std::vector<string> funcs;
1338   std::vector<string> funcs_with_grads;
1339   Status s;
1340   bool added;
1341   for (auto iter : clone.function_defs_) {
1342     s = AddHelper(iter.second, &added);
1343     if (!s.ok()) {
1344       Remove(funcs, funcs_with_grads);
1345       return s;
1346     }
1347     if (added) {
1348       funcs.push_back(iter.second->fdef.signature().name());
1349     }
1350   }
1351   for (auto iter : clone.func_grad_) {
1352     GradientDef grad;
1353     grad.set_function_name(iter.first);
1354     grad.set_gradient_func(iter.second);
1355     s = AddGradientDefHelper(grad, &added);
1356     if (!s.ok()) {
1357       Remove(funcs, funcs_with_grads);
1358       return s;
1359     }
1360     if (added) {
1361       funcs_with_grads.push_back(grad.function_name());
1362     }
1363   }
1364   return Status::OK();
1365 }
1366 
AddLibrary(const FunctionDefLibrary & lib_def)1367 Status FunctionLibraryDefinition::AddLibrary(
1368     const FunctionDefLibrary& lib_def) {
1369   // Remember the funcs and grads that we added successfully so that
1370   // we can roll them back on error.
1371   mutex_lock l(mu_);
1372   std::vector<string> funcs;
1373   std::vector<string> funcs_with_grads;
1374   Status s;
1375   bool added;
1376   for (const FunctionDef& fdef : lib_def.function()) {
1377     s = AddFunctionDefHelper(fdef, &added);
1378     if (!s.ok()) {
1379       Remove(funcs, funcs_with_grads);
1380       return s;
1381     }
1382     if (added) {
1383       funcs.push_back(fdef.signature().name());
1384     }
1385   }
1386   for (const GradientDef& grad : lib_def.gradient()) {
1387     s = AddGradientDefHelper(grad, &added);
1388     if (!s.ok()) {
1389       Remove(funcs, funcs_with_grads);
1390       return s;
1391     }
1392     if (added) {
1393       funcs_with_grads.push_back(grad.function_name());
1394     }
1395   }
1396   return Status::OK();
1397 }
1398 
ReplaceFunction(const string & func,const FunctionDef & fdef)1399 Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
1400                                                   const FunctionDef& fdef) {
1401   mutex_lock l(mu_);
1402   bool added;
1403   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1404   TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
1405   return Status::OK();
1406 }
1407 
ReplaceGradient(const GradientDef & grad)1408 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1409   mutex_lock l(mu_);
1410   bool added;
1411   TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1412   TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1413   return Status::OK();
1414 }
1415 
RemoveFunction(const string & func)1416 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1417   mutex_lock l(mu_);
1418   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1419   return Status::OK();
1420 }
1421 
RemoveFunctionHelper(const string & func)1422 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1423   const auto& i = function_defs_.find(func);
1424   if (i == function_defs_.end()) {
1425     return errors::InvalidArgument("Tried to remove non-existent function '",
1426                                    func, "'.");
1427   }
1428   function_defs_.erase(i);
1429   return Status::OK();
1430 }
1431 
RemoveGradient(const string & func)1432 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1433   const auto& i = func_grad_.find(func);
1434   if (i == func_grad_.end()) {
1435     return errors::InvalidArgument("Tried to remove non-existent gradient '",
1436                                    func, "'.");
1437   }
1438   func_grad_.erase(i);
1439   return Status::OK();
1440 }
1441 
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1442 void FunctionLibraryDefinition::Remove(
1443     const std::vector<string>& funcs,
1444     const std::vector<string>& funcs_with_grads) {
1445   for (const string& f : funcs) {
1446     Status s = RemoveFunctionHelper(f);
1447     DCHECK(s.ok());
1448   }
1449   for (const string& f : funcs_with_grads) {
1450     Status s = RemoveGradient(f);
1451     DCHECK(s.ok());
1452   }
1453 }
1454 
FindGradient(const string & func) const1455 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1456   tf_shared_lock l(mu_);
1457   return gtl::FindWithDefault(func_grad_, func, "");
1458 }
1459 
FindGradientHelper(const string & func) const1460 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1461   return gtl::FindWithDefault(func_grad_, func, "");
1462 }
1463 
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1464 Status FunctionLibraryDefinition::LookUp(
1465     const string& op, const OpRegistrationData** op_reg_data) const {
1466   tf_shared_lock l(mu_);
1467   auto iter = function_defs_.find(op);
1468   if (iter != function_defs_.end()) {
1469     *op_reg_data = &iter->second->op_registration_data;
1470     return Status::OK();
1471   }
1472   return default_registry_->LookUp(op, op_reg_data);
1473 }
1474 
UniqueFunctionName(StringPiece prefix) const1475 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1476   tf_shared_lock l(mu_);
1477   int index = 0;
1478   string name = strings::StrCat(prefix, index);
1479   while (function_defs_.find(name) != function_defs_.end()) {
1480     ++index;
1481     name = strings::StrCat(prefix, index);
1482   }
1483   return name;
1484 }
1485 
GetAttrImpl(const NodeDef & ndef) const1486 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1487     const NodeDef& ndef) const {
1488   if (ndef.op() != kGradientOp) {
1489     // If 'ndef' calls a function and the function's def has the attr,
1490     // returns it.
1491     return Find(ndef.op());
1492   }
1493 
1494   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1495   // Foo's attributes.
1496   const NameAttrList* forward_func_attrs;
1497   if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1498     return nullptr;
1499   }
1500   const string& func_name = forward_func_attrs->name();
1501   {
1502     tf_shared_lock l(mu_);
1503     const string& grad_name = FindGradientHelper(func_name);
1504     // If 'func' has a user-defined gradient function, uses the grad
1505     // function's attrs to see if noinline is specified. Otherwise,
1506     // uses func's attrs.
1507     if (!grad_name.empty()) {
1508       return &(FindHelper(grad_name)->fdef);
1509     }
1510     return &(FindHelper(func_name)->fdef);
1511   }
1512 }
1513 
ListFunctionNames() const1514 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1515   std::vector<string> function_names;
1516   tf_shared_lock l(mu_);
1517   function_names.reserve(function_defs_.size());
1518   for (const auto& it : function_defs_) {
1519     function_names.emplace_back(it.first);
1520   }
1521   return function_names;
1522 }
1523 
ToProto() const1524 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1525   FunctionDefLibrary lib;
1526   tf_shared_lock l(mu_);
1527   for (const auto& f : function_defs_) {
1528     *lib.add_function() = f.second->fdef;
1529   }
1530   for (const auto& g : func_grad_) {
1531     GradientDef* gd = lib.add_gradient();
1532     gd->set_function_name(g.first);
1533     gd->set_gradient_func(g.second);
1534   }
1535   return lib;
1536 }
1537 
1538 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1539 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1540                                           const string& attr, T* value) const {
1541   const FunctionDef* fdef = GetAttrImpl(ndef);
1542   if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1543     return Status::OK();
1544   }
1545   return errors::InvalidArgument("Attr ", attr, " is not defined.");
1546 }
1547 
1548 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1549 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1550                                           T* value) const {
1551   return GetAttr(node.def(), attr, value);
1552 }
1553 
1554 #define GET_ATTR(T)                                                            \
1555   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
1556                                                      const string&, T*) const; \
1557   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
1558                                                      const string&, T*) const;
1559 GET_ATTR(string)
1560 GET_ATTR(bool)
1561 #undef GET_ATTR
1562 
1563 namespace {
1564 
1565 constexpr char kApiImplements[] = "api_implements";
1566 
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1567 std::set<string> ReachableFunctions(
1568     const FunctionLibraryDefinition& flib,
1569     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1570   // Functions that are reachable from the graph.
1571   std::set<string> reachable_funcs;
1572 
1573   // For any functions, if it has attribute "api_implements" =
1574   // "some_interface" and it is reachable, then it means any other
1575   // function with same attribute name and value could also be potentially
1576   // reachable, eg via implementation_selector swapping the nodedef.
1577   absl::flat_hash_set<string> reachable_api_interface;
1578 
1579   // Functions might be reachable from the nested function calls, so we keep a
1580   // queue of functions that we have to check.
1581   gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1582 
1583   // Add reachable and not already processed functions to the functions queue.
1584   const auto add_to_func_queue = [&](const string& func_name) {
1585     const FunctionDef* func = flib.Find(func_name);
1586     if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1587       func_queue.push_back(func);
1588     }
1589   };
1590 
1591   // If any function with certain API name is reachable, all the other functions
1592   // with same API name should also be checked.
1593   const auto add_function_with_api_interface = [&](const string& api_name) {
1594     if (!reachable_api_interface.contains(api_name)) {
1595       reachable_api_interface.insert(api_name);
1596       for (const auto& func_name : flib.ListFunctionNames()) {
1597         const auto& func_def = flib.Find(func_name);
1598         const auto attr_it = func_def->attr().find(kApiImplements);
1599         if (attr_it != func_def->attr().end() &&
1600             attr_it->second.s() == api_name) {
1601           add_to_func_queue(func_name);
1602         }
1603       }
1604     }
1605   };
1606 
1607   // Add all the functions that are reachable from the given node to the queue.
1608   const auto process_node = [&](const NodeDef& node) {
1609     // Node itself can be a call to the function.
1610     add_to_func_queue(node.op());
1611 
1612     // Or node can have an attribute referencing a function.
1613     for (const auto& attr : node.attr()) {
1614       const auto& attr_value = attr.second;
1615 
1616       // 1. AttrValue.func
1617       if (attr_value.has_func()) {
1618         add_to_func_queue(attr_value.func().name());
1619       }
1620 
1621       // 2. AttrValue.ListValue.func
1622       if (attr_value.has_list()) {
1623         for (const auto& func : attr_value.list().func()) {
1624           add_to_func_queue(func.name());
1625         }
1626       }
1627     }
1628   };
1629 
1630   // Add all functions that are directly called from the optimized graph.
1631   std::for_each(nodes.begin(), nodes.end(), process_node);
1632 
1633   // Process all reachable functions.
1634   while (!func_queue.empty()) {
1635     const FunctionDef* func = func_queue.back();
1636     func_queue.pop_back();
1637 
1638     const string& func_name = func->signature().name();
1639     reachable_funcs.insert(func_name);
1640 
1641     const auto attr_it = func->attr().find(kApiImplements);
1642     if (attr_it != func->attr().end()) {
1643       add_function_with_api_interface(attr_it->second.s());
1644     }
1645 
1646     // Find all the functions called from the function body.
1647     const auto& func_body = func->node_def();
1648     std::for_each(func_body.begin(), func_body.end(), process_node);
1649 
1650     // Check if the function has a registered gradient.
1651     const string grad_func_name = flib.FindGradient(func_name);
1652     if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1653   }
1654 
1655   return reachable_funcs;
1656 }
1657 
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1658 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1659     const FunctionLibraryDefinition& flib,
1660     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1661   std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1662 
1663   FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1664                                            FunctionDefLibrary());
1665 
1666   for (const string& func_name : reachable_funcs) {
1667     // This should never fail, because we copy functions from a valid flib and
1668     // use the same default registry.
1669     Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1670     TF_DCHECK_OK(added);
1671 
1672     const string grad_func_name = flib.FindGradient(func_name);
1673     if (!grad_func_name.empty()) {
1674       GradientDef grad;
1675       grad.set_function_name(func_name);
1676       grad.set_gradient_func(grad_func_name);
1677       // It can only fail if function already has a gradient function.
1678       const Status added_grad = reachable_flib.AddGradientDef(grad);
1679       TF_DCHECK_OK(added_grad);
1680     }
1681   }
1682 
1683   return reachable_flib;
1684 }
1685 
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1686 string AllocatorAttributesToString(
1687     const std::vector<AllocatorAttributes>& attrs) {
1688   string result("[");
1689   // AllocatorAttribute::DebugString produces around 85 bytes now.
1690   result.reserve(100 * attrs.size());
1691   for (const AllocatorAttributes& attr : attrs) {
1692     result.append(attr.DebugString());
1693     result.append(", ");
1694   }
1695   if (!attrs.empty()) {
1696     result.resize(result.size() - 2);
1697   }
1698   result.append("]");
1699   return result;
1700 }
1701 
IsSet(void * ptr)1702 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1703 
1704 }  // namespace
1705 
ReachableDefinitions(const GraphDef & graph) const1706 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1707     const GraphDef& graph) const {
1708   return ReachableFunctionLibraryDefinition(*this, graph.node());
1709 }
1710 
ReachableDefinitions(const FunctionDef & func) const1711 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1712     const FunctionDef& func) const {
1713   return ReachableFunctionLibraryDefinition(*this, func.node_def());
1714 }
1715 
DebugString() const1716 string FunctionLibraryRuntime::Options::DebugString() const {
1717   return absl::StrCat(
1718       "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1719       " cancellation_manager=", IsSet(cancellation_manager),
1720       " collective_executor=", IsSet(collective_executor),
1721       " step_container=", IsSet(step_container),
1722       " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1723       " remote_execution=", remote_execution, " source_device=", source_device,
1724       " create_rendezvous=", create_rendezvous,
1725       " allow_dead_tensors=", allow_dead_tensors,
1726       " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1727       " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1728 }
1729 
InitFromString(StringPiece val)1730 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1731   if (val.size() >= 2 && val[0] == '$') {
1732     proto.set_placeholder(val.data() + 1, val.size() - 1);
1733   } else {
1734     SetAttrValue(val, &proto);
1735   }
1736 }
1737 
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1738 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1739     const string& name,
1740     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1741   AttrValueWrapper ret;
1742   ret.proto.mutable_func()->set_name(name);
1743   for (const auto& a : attrs) {
1744     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1745   }
1746   return ret;
1747 }
1748 
ToNodeDef() const1749 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1750   NodeDef n;
1751   n.set_op(this->op);
1752   n.set_name(this->ret[0]);
1753   for (const auto& a : this->attr) {
1754     n.mutable_attr()->insert({a.first, a.second.proto});
1755   }
1756   for (const string& a : this->arg) {
1757     n.add_input(a);
1758   }
1759   for (const string& d : this->dep) {
1760     n.add_input(strings::StrCat("^", d));
1761   }
1762   if (!this->device.empty()) {
1763     n.set_device(this->device);
1764   }
1765   return n;
1766 }
1767 
1768 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1769 FunctionDef FunctionDefHelper::Create(
1770     const string& function_name, gtl::ArraySlice<string> in_def,
1771     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1772     gtl::ArraySlice<Node> node_def,
1773     gtl::ArraySlice<std::pair<string, string>> ret_def,
1774     gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1775   FunctionDef fdef;
1776 
1777   // Signature
1778   OpDefBuilder b(function_name);
1779   for (const auto& i : in_def) b.Input(i);
1780   for (const auto& o : out_def) b.Output(o);
1781   for (const auto& a : attr_def) b.Attr(a);
1782   for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1783 
1784   OpRegistrationData op_reg_data;
1785   TF_CHECK_OK(b.Finalize(&op_reg_data));
1786   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1787 
1788   // Function body
1789   for (const auto& n : node_def) {
1790     *(fdef.add_node_def()) = n.ToNodeDef();
1791   }
1792 
1793   // Returns
1794   for (const auto& r : ret_def) {
1795     fdef.mutable_ret()->insert({r.first, r.second});
1796   }
1797 
1798   // Control returns
1799   for (const auto& cr : control_ret_def) {
1800     fdef.mutable_control_ret()->insert({cr.first, cr.second});
1801   }
1802 
1803   auto* op_def_registry = OpRegistry::Global();
1804   // Check if any op is stateful.
1805   for (const auto& n : node_def) {
1806     const OpDef* op_def = nullptr;
1807     auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1808     // Lookup can fail if e.g. we are calling a function that was not yet
1809     // defined.  If it happens, conservatively assume the op is stateful.
1810     if (!status.ok() || op_def->is_stateful()) {
1811       fdef.mutable_signature()->set_is_stateful(true);
1812     }
1813   }
1814 
1815   return fdef;
1816 }
1817 
1818 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1819 FunctionDef FunctionDefHelper::Create(
1820     const string& function_name, gtl::ArraySlice<string> in_def,
1821     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1822     gtl::ArraySlice<Node> node_def,
1823     gtl::ArraySlice<std::pair<string, string>> ret_def) {
1824   return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1825                 /*control_ret_def=*/{});
1826 }
1827 
1828 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1829 FunctionDef FunctionDefHelper::Define(const string& name,
1830                                       gtl::ArraySlice<string> arg_def,
1831                                       gtl::ArraySlice<string> ret_def,
1832                                       gtl::ArraySlice<string> attr_def,
1833                                       gtl::ArraySlice<Node> node_def) {
1834   FunctionDef fdef;
1835   OpDefBuilder b(name);
1836   for (const auto& a : arg_def) b.Input(a);
1837   for (const auto& r : ret_def) b.Output(r);
1838   for (const auto& a : attr_def) b.Attr(a);
1839 
1840   OpRegistrationData op_reg_data;
1841   TF_CHECK_OK(b.Finalize(&op_reg_data));
1842   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1843 
1844   // Mapping from legacy output names to NodeDef outputs.
1845   std::unordered_map<string, string> ret_index;
1846   for (const auto& a : fdef.signature().input_arg()) {
1847     ret_index[a.name()] = a.name();
1848   }
1849 
1850   // For looking up OpDefs
1851   auto* op_def_registry = OpRegistry::Global();
1852 
1853   // Function body
1854   for (const auto& src : node_def) {
1855     NodeDef* n = fdef.add_node_def();
1856     n->set_op(src.op);
1857     n->set_name(src.ret[0]);
1858     for (const auto& a : src.attr) {
1859       n->mutable_attr()->insert({a.first, a.second.proto});
1860     }
1861     for (const string& a : src.arg) {
1862       const auto iter = ret_index.find(a);
1863       CHECK(iter != ret_index.end())
1864           << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1865       n->add_input(iter->second);
1866     }
1867     for (const string& d : src.dep) {
1868       n->add_input(strings::StrCat("^", d));
1869     }
1870 
1871     // Add the outputs of this node to ret_index.
1872     const OpDef* op_def = nullptr;
1873     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1874     CHECK(op_def != nullptr) << n->op();
1875     NameRangeMap output_names;
1876     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1877     for (const auto& o : output_names) {
1878       CHECK_LE(o.second.second, src.ret.size())
1879           << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1880           << "' of " << name;
1881       for (int i = o.second.first; i < o.second.second; ++i) {
1882         ret_index[src.ret[i]] =
1883             strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1884       }
1885     }
1886     if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1887   }
1888 
1889   // Returns
1890   for (const auto& r : fdef.signature().output_arg()) {
1891     const auto iter = ret_index.find(r.name());
1892     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1893     fdef.mutable_ret()->insert({r.name(), iter->second});
1894   }
1895   return fdef;
1896 }
1897 
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1898 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1899                                       gtl::ArraySlice<string> ret_def,
1900                                       gtl::ArraySlice<string> attr_def,
1901                                       gtl::ArraySlice<Node> node_def) {
1902   return Define("_", arg_def, ret_def, attr_def, node_def);
1903 }
1904 
1905 namespace gradient {
1906 
1907 typedef std::unordered_map<string, Creator> OpGradFactory;
1908 
GetOpGradFactory()1909 OpGradFactory* GetOpGradFactory() {
1910   static OpGradFactory* factory = new OpGradFactory;
1911   return factory;
1912 }
1913 
RegisterOp(const string & op,Creator func)1914 bool RegisterOp(const string& op, Creator func) {
1915   CHECK(GetOpGradFactory()->insert({op, func}).second)
1916       << "Duplicated gradient for " << op;
1917   return true;
1918 }
1919 
GetOpGradientCreator(const string & op,Creator * creator)1920 Status GetOpGradientCreator(const string& op, Creator* creator) {
1921   auto fac = GetOpGradFactory();
1922   auto iter = fac->find(op);
1923   if (iter == fac->end()) {
1924     return errors::NotFound("No gradient defined for op: ", op);
1925   }
1926   *creator = iter->second;
1927   return Status::OK();
1928 }
1929 
1930 }  // end namespace gradient
1931 
1932 }  // namespace tensorflow
1933