• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 
18 #include <ctype.h>
19 
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/lib/strings/proto_serialization.h"
41 #include "tensorflow/core/platform/fingerprint.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tensorflow/core/util/equal_graph_def.h"
44 
45 namespace tensorflow {
46 
47 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
48 /* static */ constexpr const char* const
49     FunctionLibraryDefinition::kDeviceArgOp;
50 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
51 /* static */ constexpr const char* const
52     FunctionLibraryDefinition::kDeviceRetOp;
53 /* static */ constexpr const char* const
54     FunctionLibraryDefinition::kIntsOnDeviceAttr;
55 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
56 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
57 
58 // Extracts the actual type from "attr_values" based on its definition
59 // "arg_def".
60 //
61 // If "arg_def" is a N*T type, *is_type_list is set to false, and
62 // *dtypes is set to be a vector of size N and each element is T.
63 //
64 // If "arg_def" is a list(type), *is_type_list is set to true, and
65 // *dtypes is set to be a vector of types specified in attrs for
66 // arg_def.
67 //
68 // Otherwise (arg_def is a simple type T), *is_type_list is set to
69 // false, and *dtypes is set to a single element vector, whose only
70 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)71 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
72                   bool* is_type_list, DataTypeVector* dtypes) {
73   dtypes->clear();
74   if (!arg_def.type_list_attr().empty()) {
75     const AttrValue* v = attrs.FindByString(arg_def.type_list_attr());
76     if (v == nullptr) {
77       return errors::NotFound("type list attr not found: ",
78                               arg_def.type_list_attr());
79     }
80     *is_type_list = true;
81     for (int i = 0; i < v->list().type_size(); ++i) {
82       dtypes->push_back(v->list().type(i));
83     }
84     return OkStatus();
85   }
86 
87   *is_type_list = false;
88   int num = 1;
89   if (!arg_def.number_attr().empty()) {
90     const AttrValue* v = attrs.FindByString(arg_def.number_attr());
91     if (v == nullptr) {
92       return errors::NotFound("number attr not found: ", arg_def.number_attr());
93     }
94     num = v->i();
95   }
96 
97   DataType dtype;
98   if (arg_def.type() != DT_INVALID) {
99     dtype = arg_def.type();
100   } else if (arg_def.type_attr().empty()) {
101     dtype = DT_INVALID;
102   } else {
103     const AttrValue* v = attrs.FindByString(arg_def.type_attr());
104     if (v == nullptr) {
105       return errors::NotFound("type attr not found: ", arg_def.type_attr());
106     }
107     dtype = v->type();
108   }
109   dtypes->resize(num, dtype);
110   return OkStatus();
111 }
112 
113 namespace {
114 
115 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)116 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
117   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
118 }
119 
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)120 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
121   // attr_values should specify all attrs defined in fdef, except for those
122   // which have a default value
123   for (const auto& attr : sig.attr()) {
124     const AttrValue* attr_value = attr_values.FindByString(attr.name());
125     if (attr_value) {
126       Status status = AttrValueHasType(*attr_value, attr.type());
127       if (!status.ok()) {
128         errors::AppendToMessage(&status, "for attr '", attr.name(), "'");
129         return status;
130       }
131     } else if (!attr.has_default_value()) {
132       return errors::NotFound("Attr ", attr.name(), " is not found from ",
133                               SummarizeOpDef(sig));
134     }
135   }
136 
137 // TODO(josh11b): Enable this code once it works with function gradients.
138 // Right now the C++ function gradient code assumes it can pass
139 // all the attrs of the function to the gradient, and any attrs that
140 // the gradient doesn't care about will be ignored.
141 #if 0
142   if (attr_values.size() != sig.attr_size()) {
143     for (const auto& a : attr_values) {
144       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
145       bool found = false;
146       for (const auto& s : sig.attr()) {
147         if (a.first == s.name()) {
148           found = true;
149           break;
150         }
151       }
152       if (!found) {
153         return errors::NotFound("Attr ", a.first, " is not found in ",
154                                 SummarizeOpDef(sig));
155       }
156     }
157   }
158 #endif
159 
160   return OkStatus();
161 }
162 
163 // A helper class for instantiating functions. This contains shared information
164 // like the resulting graph and node name index.
165 class FunctionInstantiationHelper {
166  public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)167   FunctionInstantiationHelper(GetFunctionSignature get_function,
168                               InstantiationResult* result)
169       : get_function_(std ::move(get_function)), result_(*result) {
170     result_.nodes.clear();
171   }
172 
173   // Builds index for nodes that can be used as node's input arguments.
174   // `resource_arg_unique_id`: if non-negative, will be populated to the
175   // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64_t resource_arg_unique_id)176   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
177                             const FunctionDef::ArgAttrs* arg_attrs,
178                             bool ints_on_device,
179                             int64_t resource_arg_unique_id) {
180     bool is_type_list;
181     DataTypeVector dtypes;
182     TF_RETURN_IF_ERROR(
183         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
184     if (dtypes.size() < size_t{1}) {
185       return errors::Internal("Expected a list of at least one dtype");
186     }
187     int arg_index = result_.nodes.size();
188     TF_RETURN_IF_ERROR(
189         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
190     // Creates dtypes.size() nodes in the graph.
191     for (size_t i = 0; i < dtypes.size(); ++i) {
192       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
193                                  {true, arg_index, 0, false, {dtypes[i]}}));
194       if (arg_index != result_.nodes.size()) {
195         return errors::Internal(
196             "Expected arg_index to be equal to the number of nodes in result.",
197             " Got ", arg_index, " and ", result_.nodes.size());
198       }
199       string name = arg_def.name();
200       if (dtypes.size() > 1) {
201         strings::StrAppend(&name, "_", i);
202       }
203       NodeDef* gnode = AddNode(name);
204       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
205         gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
206       } else {
207         gnode->set_op(FunctionLibraryDefinition::kArgOp);
208       }
209       DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
210       AddAttr("T", dtype, gnode);
211       AddAttr("index", arg_index, gnode);
212       if (resource_arg_unique_id >= 0) {
213         AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
214       }
215       if (arg_attrs) {
216         for (const auto& arg_attr : arg_attrs->attr()) {
217           AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
218         }
219       }
220       result_.arg_types.push_back(dtypes[i]);
221       ++arg_index;
222     }
223     return OkStatus();
224   }
225 
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)226   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
227                               const int arg_index) {
228     const OpDef* node_sig = nullptr;
229     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
230     if (node_sig->output_arg_size() == 0) {
231       return AddItem(node.name(), {false, arg_index, 0, false, {}});
232     }
233     const int num_retval = node_sig->output_arg_size();
234     int start = 0;
235     bool is_type_list;
236     DataTypeVector dtypes;
237     for (int i = 0; i < num_retval; ++i) {
238       TF_RETURN_IF_ERROR(
239           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
240       // Note that we rely on the backwards-compatibility test enforcing
241       // that output_arg(*).name() doesn't change here.
242       const string base_name =
243           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
244       TF_RETURN_IF_ERROR(
245           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
246       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
247         TF_RETURN_IF_ERROR(
248             AddItem(strings::StrCat(base_name, ":", j),
249                     {false, arg_index, start + j, false, {dtypes[j]}}));
250       }
251       start += dtypes.size();
252     }
253     return OkStatus();
254   }
255 
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)256   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
257     const OpDef* fnode_sig = nullptr;
258     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
259     NodeDef* gnode = AddNode(fnode.name());
260     gnode->set_op(fnode.op());
261     gnode->set_device(fnode.device());
262     int gnode_idx = nodes_.size() - 1;
263 
264     // Input
265     const int num_args = fnode_sig->input_arg_size();
266     bool is_type_list;  // ignored
267     DataTypeVector dtypes;
268     int fnode_arg_index = 0;
269     for (int i = 0; i < num_args; ++i) {
270       TF_RETURN_IF_ERROR(
271           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
272       // Consume inputs (indexed by fnode_arg_index) until we have
273       // matched each element of dtypes (indexed by j).
274       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
275         if (fnode_arg_index >= fnode.input_size()) {
276           // Should never happen if we computed dtypes correctly.
277           return errors::InvalidArgument(
278               "Attempt to access beyond input size: ", fnode_arg_index,
279               " >= ", fnode.input_size());
280         }
281         // Look up the next input.
282         const string& input_name = fnode.input(fnode_arg_index);
283         const auto* item = GetItemOrNull(input_name);
284         if (item == nullptr) {
285           return errors::InvalidArgument(
286               "input ", input_name,
287               " is not found: ", FormatNodeDefForError(fnode));
288         }
289         if (item->dtypes.size() > dtypes.size() - j) {
290           return errors::InvalidArgument("Input ", input_name, " too long for ",
291                                          fnode_sig->input_arg(i).name());
292         }
293         // Match up all the elements of this input (indexed by k) with
294         // elements of dtypes (advancing j).
295         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
296           if (item->dtypes[k] != dtypes[j]) {
297             return errors::InvalidArgument(
298                 "input ", fnode_sig->input_arg(i).name(), "[", j,
299                 "] expected type ", DataTypeString(dtypes[j]),
300                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
301                 input_name, "[", k, "]");
302           }
303           if (item->is_func_arg) {
304             AddInput(gnode_idx, item->nid + k, 0);
305           } else {
306             AddInput(gnode_idx, item->nid, item->idx + k);
307           }
308         }
309       }
310     }
311 
312     // Control deps.
313     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
314       const string& input = fnode.input(i);
315       if (input.empty() || input[0] != '^') {
316         return errors::InvalidArgument("Expected input[", i, "] == '", input,
317                                        "' to be a control input.");
318       }
319       int nid = -1;
320       const string node_name = input.substr(1);
321       const string node_colon = node_name + ":";
322       const string node_colon_bound = node_name + ";";
323       // index_ is a map sorted lexicographically, so the key we are looking for
324       // must lie in the range [node_name, node_colon_bound).
325       auto it = index_.lower_bound(node_name);
326       while (it != index_.end() && it->first <= node_colon_bound) {
327         if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
328           nid = it->second.nid;
329           break;
330         }
331         ++it;
332       }
333       if (nid == -1) {
334         return errors::InvalidArgument("input[", i, "] == '", input,
335                                        "', is not found.");
336       }
337       AddDep(gnode_idx, nid);
338     }
339 
340     // Attrs.
341     for (const auto& p : attrs) {
342       (*gnode->mutable_attr())[p.first] = p.second;
343     }
344 
345     // Experimental_debug_info.
346     if (fnode.has_experimental_debug_info()) {
347       gnode->mutable_experimental_debug_info()->MergeFrom(
348           fnode.experimental_debug_info());
349     }
350 
351     // Tye info.
352     // TODO(mdan): Might this need adjustment at instantiation?
353     if (fnode.has_experimental_type()) {
354       *gnode->mutable_experimental_type() = fnode.experimental_type();
355     }
356 
357     return OkStatus();
358   }
359 
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)360   Status AddReturnNode(
361       const OpDef::ArgDef& ret_def, AttrSlice attrs,
362       const ::tensorflow::protobuf::Map<string, string>& ret_map,
363       bool ints_on_device, int* ret_index) {
364     auto ret_iter = ret_map.find(ret_def.name());
365     if (ret_iter == ret_map.end()) {
366       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
367     }
368     bool is_type_list;
369     DataTypeVector dtypes;
370     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
371     CHECK_GE(dtypes.size(), size_t{1});
372     const auto* item = GetItemOrNull(ret_iter->second);
373     if (item == nullptr) {
374       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
375                                      ret_iter->second, " is not found.");
376     }
377     if (dtypes != item->dtypes) {
378       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
379                                      " : ", DataTypeVectorString(dtypes),
380                                      " vs. ",
381                                      DataTypeVectorString(item->dtypes));
382     }
383     for (size_t i = 0; i < dtypes.size(); ++i) {
384       string name = strings::StrCat(ret_def.name(), "_RetVal");
385       if (dtypes.size() > 1) {
386         strings::StrAppend(&name, "_", i);
387       }
388       NodeDef* gnode = AddNode(name);
389       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
390         gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
391       } else {
392         gnode->set_op(FunctionLibraryDefinition::kRetOp);
393       }
394       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
395       DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
396       AddAttr("T", dtype, gnode);
397       AddAttr("index", (*ret_index)++, gnode);
398       result_.ret_types.push_back(dtypes[i]);
399     }
400     return OkStatus();
401   }
402 
403   // Adds the actual node inputs to the result graph by converting indexes to
404   // the node names.
AddNodeInputs()405   void AddNodeInputs() {
406     for (int i = 0; i < result_.nodes.size(); i++) {
407       NodeInfo& node_info = nodes_[i];
408       for (const auto& p : node_info.data_inputs) {
409         result_.nodes[i].add_input(Name(p.first, p.second));
410       }
411       for (int index : node_info.control_inputs) {
412         result_.nodes[i].add_input(Dep(index));
413       }
414     }
415   }
416 
417  private:
418   // This is used to build a small index for all names that can be used as a
419   // node's input arguments.
420   //
421   // If is_func_arg is true, the name is a function's argument.  In
422   // this case, the produced graph def has node[nid:nid + dtype.size()].
423   //
424   // Otherwise, the name is a function body's node return value.  In
425   // this case, the produced graph def has one node node[nid] and
426   // the node's output index [idx ... idx + num) corresponds to the
427   // named outputs.
428   //
429   // In all cases, "dtype" specifies the data type.
430   struct NameInfoItem {
431     bool is_func_arg;
432     int nid;
433     int idx;
434     bool is_type_list;
435     DataTypeVector dtypes;
436   };
437 
438   // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)439   Status AddItem(const string& name, const NameInfoItem& item) {
440     if (!index_.insert({name, item}).second) {
441       return errors::InvalidArgument(
442           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
443                           " name: "),
444           name);
445     }
446     return OkStatus();
447   }
448 
GetItemOrNull(const string & name) const449   const NameInfoItem* GetItemOrNull(const string& name) const {
450     return gtl::FindOrNull(index_, name);
451   }
452 
Dep(int node_index) const453   string Dep(int node_index) const {
454     return strings::StrCat("^", Name(node_index));
455   }
456 
Name(int node_index) const457   string Name(int node_index) const {
458     CHECK_LT(node_index, nodes_.size());
459     return nodes_[node_index].name;
460   }
461 
Name(int node_index,int output_index) const462   string Name(int node_index, int output_index) const {
463     if (output_index == 0) {
464       return Name(node_index);
465     } else {
466       return strings::StrCat(Name(node_index), ":", output_index);
467     }
468   }
469 
AddNode(const string & name)470   NodeDef* AddNode(const string& name) {
471     result_.nodes.emplace_back();
472     NodeDef* gnode = &result_.nodes.back();
473     gnode->set_name(name);
474     nodes_.push_back({name, {}, {}});
475     CHECK_EQ(result_.nodes.size(), nodes_.size());
476     return gnode;
477   }
478 
AddInput(int node_index,int output_node,int output_index)479   void AddInput(int node_index, int output_node, int output_index) {
480     CHECK_LT(node_index, nodes_.size());
481     nodes_[node_index].data_inputs.push_back(
482         std::make_pair(output_node, output_index));
483   }
484 
AddDep(int node_index,int dep_index)485   void AddDep(int node_index, int dep_index) {
486     CHECK_LT(node_index, nodes_.size());
487     nodes_[node_index].control_inputs.push_back(dep_index);
488   }
489 
490   GetFunctionSignature get_function_;
491   InstantiationResult& result_;
492   // A small index for all names that can be used as a node's input arguments.
493   std::map<string, NameInfoItem> index_;
494   // This contains information about a node in the new graph including the node
495   // names and input nodes' indexes.
496   struct NodeInfo {
497     string name;
498     // Data inputs where <n, k> means arg k of node n.
499     std::vector<std::pair<int, int>> data_inputs;
500     // Control inputs (dependencies).
501     std::vector<int> control_inputs;
502   };
503   // nodes_[i] is the information about result_.nodes[i].
504   std::vector<NodeInfo> nodes_;
505 };
506 
507 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)508 string Print(const OpDef::ArgDef& arg) {
509   string out;
510   strings::StrAppend(&out, arg.name(), ":");
511   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
512   if (!arg.number_attr().empty()) {
513     strings::StrAppend(&out, arg.number_attr(), "*");
514   }
515   if (arg.type() != DT_INVALID) {
516     strings::StrAppend(&out, DataTypeString(arg.type()));
517   } else {
518     strings::StrAppend(&out, arg.type_attr());
519   }
520   if (arg.is_ref()) strings::StrAppend(&out, ")");
521   return out;
522 }
523 
524 // TODO(josh11b): Merge this with SummarizeAttrValue().
525 // When hash_string_attrs = true, string attributes are hashed instead of being
526 // truncated with ellipses. This is done to reduce the chance of collisions when
527 // looking up functions using the canonical representation.
Print(const AttrValue & attr_value,const bool hash_string_attrs=false)528 string Print(const AttrValue& attr_value,
529              const bool hash_string_attrs = false) {
530   if (attr_value.value_case() == AttrValue::kType) {
531     return DataTypeString(attr_value.type());
532   } else if ((attr_value.value_case() == AttrValue::kList) &&
533              (attr_value.list().type_size() > 0)) {
534     string ret = "{";
535     for (int i = 0; i < attr_value.list().type_size(); ++i) {
536       if (i > 0) strings::StrAppend(&ret, ", ");
537       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
538     }
539     strings::StrAppend(&ret, "}");
540     return ret;
541   } else if (attr_value.value_case() == AttrValue::kFunc) {
542     if (attr_value.func().attr_size() == 0) {
543       return attr_value.func().name();
544     }
545     std::vector<string> entries;
546     for (const auto& p : attr_value.func().attr()) {
547       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
548     }
549     std::sort(entries.begin(), entries.end());
550     return strings::StrCat(attr_value.func().name(), "[",
551                            absl::StrJoin(entries, ", "), "]");
552   } else if (attr_value.value_case() == AttrValue::kS && hash_string_attrs) {
553     return strings::StrCat(Fingerprint64(attr_value.s()));
554   }
555   return SummarizeAttrValue(attr_value);
556 }
557 
558 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)559 string Print(const NodeDef& n) {
560   string out;
561   strings::StrAppend(&out, n.name(), " = ", n.op());
562   if (n.attr_size() > 0) {
563     std::vector<string> entries;
564     for (auto& a : n.attr()) {
565       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
566     }
567     std::sort(entries.begin(), entries.end());
568     // Add a short device string at the end of all attributes.
569     if (!n.device().empty()) {
570       DeviceNameUtils::ParsedName parsed;
571       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
572         entries.push_back(
573             strings::StrCat("device=", parsed.type, ":", parsed.id));
574       } else {
575         entries.push_back("device=<FAILED_TO_PARSE>");
576       }
577     }
578     strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
579   }
580   strings::StrAppend(&out, "(");
581   std::vector<StringPiece> dat;
582   std::vector<string> dep;
583   for (StringPiece s : n.input()) {
584     if (absl::ConsumePrefix(&s, "^")) {
585       dep.emplace_back(s);
586     } else {
587       dat.push_back(s);
588     }
589   }
590   strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
591   if (!dep.empty()) {
592     strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
593   }
594   return out;
595 }
596 
Print(const FunctionDef & fdef)597 string Print(const FunctionDef& fdef) {
598   string out;
599   const OpDef& sig = fdef.signature();
600   strings::StrAppend(&out, "\n", sig.name());
601   if (sig.attr_size() > 0) {
602     strings::StrAppend(&out, "[");
603     for (int i = 0; i < sig.attr_size(); ++i) {
604       const auto& a = sig.attr(i);
605       if (i > 0) strings::StrAppend(&out, ", ");
606       if (a.type() == "type") {
607         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
608       } else {
609         strings::StrAppend(&out, a.name(), ":", a.type());
610       }
611     }
612     strings::StrAppend(&out, "]");
613   }
614   strings::StrAppend(&out, "(");
615   for (int i = 0; i < sig.input_arg_size(); ++i) {
616     if (i > 0) strings::StrAppend(&out, ", ");
617     strings::StrAppend(&out, Print(sig.input_arg(i)));
618   }
619   strings::StrAppend(&out, ") -> (");
620   for (int i = 0; i < sig.output_arg_size(); ++i) {
621     if (i > 0) strings::StrAppend(&out, ", ");
622     strings::StrAppend(&out, Print(sig.output_arg(i)));
623   }
624   strings::StrAppend(&out, ") {\n");
625   for (const auto& n : fdef.node_def()) {
626     strings::StrAppend(&out, "  ", Print(n), "\n");
627   }
628   for (const auto& cr : fdef.control_ret()) {
629     strings::StrAppend(&out, "  @return ", cr.first, " = ", cr.second, "\n");
630   }
631   for (const auto& r : fdef.ret()) {
632     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
633   }
634   strings::StrAppend(&out, "}\n");
635   return out;
636 }
637 
Print(gtl::ArraySlice<const NodeDef * > nodes)638 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
639   std::vector<const NodeDef*> arg;
640   std::vector<const NodeDef*> ret;
641   std::vector<const NodeDef*> body;
642   for (const NodeDef* n : nodes) {
643     if (n->op() == FunctionLibraryDefinition::kArgOp ||
644         n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
645       arg.push_back(n);
646     } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
647                n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
648       ret.push_back(n);
649     } else {
650       body.push_back(n);
651     }
652   }
653   auto comp = [](const NodeDef* x, const NodeDef* y) {
654     int xi;
655     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
656     int yi;
657     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
658     return xi < yi;
659   };
660   std::sort(arg.begin(), arg.end(), comp);
661   std::sort(ret.begin(), ret.end(), comp);
662   string out;
663   strings::StrAppend(&out, "\n(");
664   auto get_type_and_device = [](const NodeDef& n) {
665     DataType dt;
666     if (!TryGetNodeAttr(n, "T", &dt)) {
667       dt = DT_INVALID;
668     }
669     if (!n.device().empty()) {
670       DeviceNameUtils::ParsedName parsed;
671       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
672         return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
673                                parsed.id);
674       } else {
675         LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
676                      << n.op() << ":" << n.name();
677         return strings::StrCat(DataTypeString(dt), "@",
678                                "<FAILED_TO_PARSE_DEVICE>");
679       }
680     }
681     return DataTypeString(dt);
682   };
683   for (size_t i = 0; i < arg.size(); ++i) {
684     const NodeDef* n = arg[i];
685     if (i > 0) strings::StrAppend(&out, ", ");
686     CHECK_GE(n->attr_size(), 2);
687     strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
688   }
689   strings::StrAppend(&out, ") -> (");
690   for (size_t i = 0; i < ret.size(); ++i) {
691     const NodeDef* n = ret[i];
692     if (i > 0) strings::StrAppend(&out, ", ");
693     CHECK_LE(2, n->attr_size());
694 
695     // The _RetVal op should have a unique non-control input. We assert that
696     // here and add it to the output.
697     bool found_non_control_input = false;
698     for (const string& input : n->input()) {
699       if (!input.empty() && input[0] != '^') {
700         DCHECK_EQ(found_non_control_input, false)
701             << "RetVal node has more than one non-control input: "
702             << absl::StrJoin(n->input(), ", ");
703         strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
704         found_non_control_input = true;
705       }
706     }
707     DCHECK_EQ(found_non_control_input, true)
708         << "RetVal did not have any non-control inputs: "
709         << absl::StrJoin(n->input(), ", ");
710   }
711   strings::StrAppend(&out, ") {\n");
712   for (size_t i = 0; i < body.size(); ++i) {
713     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
714   }
715   strings::StrAppend(&out, "}\n");
716   return out;
717 }
718 
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)719 Status AddDefaultAttrs(const string& op,
720                        const GetFunctionSignature& get_function,
721                        AttrValueMap* attrs) {
722   const OpDef* op_def = nullptr;
723   TF_RETURN_IF_ERROR(get_function(op, &op_def));
724   AttrSlice attr_slice(attrs);
725   for (const auto& attr_def : op_def->attr()) {
726     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
727       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
728         return errors::Internal("Somehow duplicated: ", attr_def.name());
729       }
730     }
731   }
732   return OkStatus();
733 }
734 
735 }  // end namespace
736 
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)737 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
738                            GetFunctionSignature get_function,
739                            InstantiationResult* result) {
740   if (VLOG_IS_ON(5)) {
741     const auto& signature = fdef.signature();
742     VLOG(5) << "Instantiate function definition: name=" << signature.name()
743             << " #input_args=" << signature.input_arg_size()
744             << " #output_args=" << signature.output_arg_size()
745             << " #control_output=" << signature.control_output_size();
746     for (const auto& line : str_util::Split(Print(fdef), '\n')) {
747       VLOG(5) << "|| " << line;
748     }
749   }
750 
751   const OpDef& sig = fdef.signature();
752   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
753 
754   const AttrValue* attr_values_ints_on_device =
755       attr_values.Find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
756   bool ints_on_device =
757       (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
758        fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) ||
759       (attr_values_ints_on_device != nullptr &&
760        attr_values_ints_on_device->b());
761 
762   FunctionInstantiationHelper helper(get_function, result);
763   Status s;
764   for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
765     const OpDef::ArgDef& arg_def = sig.input_arg(i);
766     auto it = fdef.arg_attr().find(i);
767     const FunctionDef::ArgAttrs* arg_attrs =
768         it != fdef.arg_attr().end() ? &it->second : nullptr;
769     auto resource_id_it = fdef.resource_arg_unique_id().find(i);
770     int64_t resource_arg_unique_id =
771         resource_id_it != fdef.resource_arg_unique_id().end()
772             ? resource_id_it->second
773             : -1LL;
774     s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
775                                   ints_on_device, resource_arg_unique_id);
776 
777     if (!s.ok()) {
778       errors::AppendToMessage(&s, "In ", Print(arg_def));
779       return s;
780     }
781   }
782 
783   auto substitute = [attr_values, &sig](const string& name, AttrValue* val) {
784     // Look for a specified value...
785     if (const AttrValue* v = attr_values.FindByString(name)) {
786       *val = *v;
787       return true;
788     }
789     // .. and if not, then check for a default value.
790     if (const OpDef::AttrDef* attr = FindAttr(name, sig)) {
791       if (attr->has_default_value()) {
792         *val = attr->default_value();
793         return true;
794       }
795     }
796     // No luck finding a substitution.
797     return false;
798   };
799 
800   // Makes a copy of all attrs in fdef and substitutes placeholders.
801   // After this step, every attr is bound to a concrete value.
802   std::vector<AttrValueMap> node_attrs;
803   node_attrs.resize(fdef.node_def_size());
804   for (int i = 0; i < fdef.node_def_size(); ++i) {
805     for (auto attr : fdef.node_def(i).attr()) {
806       if (!SubstitutePlaceholders(substitute, &attr.second)) {
807         return errors::InvalidArgument("Failed to bind all placeholders in ",
808                                        SummarizeAttrValue(attr.second));
809       }
810       if (!node_attrs[i].insert(attr).second) {
811         return errors::Internal("Somehow duplicated: ", attr.first);
812       }
813     }
814     TF_RETURN_IF_ERROR(
815         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
816   }
817 
818   for (int i = 0; i < fdef.node_def_size(); ++i) {
819     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
820                                     result->nodes.size() + i);
821     if (!s.ok()) {
822       errors::AppendToMessage(&s, "In ",
823                               FormatNodeDefForError(fdef.node_def(i)));
824       return s;
825     }
826   }
827   // Emits one node for each fdef.node_def.
828   for (int i = 0; i < fdef.node_def_size(); ++i) {
829     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
830     if (!s.ok()) {
831       errors::AppendToMessage(&s, "In ",
832                               FormatNodeDefForError(fdef.node_def(i)));
833       return s;
834     }
835   }
836 
837   // Emits nodes for the function's return values.
838   int ret_index = 0;
839   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
840     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
841                              &ret_index);
842     if (!s.ok()) {
843       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
844       return s;
845     }
846   }
847 
848   // Adds the actual node inputs using the input indexes.
849   helper.AddNodeInputs();
850 
851   return OkStatus();
852 }
853 
DebugString(const FunctionDef & func_def)854 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
855 
DebugString(const GraphDef & instantiated_func_def)856 string DebugString(const GraphDef& instantiated_func_def) {
857   std::vector<const NodeDef*> ptrs;
858   for (const NodeDef& n : instantiated_func_def.node()) {
859     ptrs.push_back(&n);
860   }
861   return Print(ptrs);
862 }
863 
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)864 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
865   std::vector<const NodeDef*> ptrs;
866   for (const NodeDef& n : instantiated_func_nodes) {
867     ptrs.push_back(&n);
868   }
869   return Print(ptrs);
870 }
871 
DebugStringWhole(const GraphDef & gdef)872 string DebugStringWhole(const GraphDef& gdef) {
873   string ret;
874   for (const auto& fdef : gdef.library().function()) {
875     strings::StrAppend(&ret, Print(fdef));
876   }
877   strings::StrAppend(&ret, "\n");
878   for (const auto& ndef : gdef.node()) {
879     strings::StrAppend(&ret, Print(ndef), "\n");
880   }
881   return ret;
882 }
883 
884 namespace {
885 
886 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
887 // Python, it's possible to access unset attrs, which returns a default value
888 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)889 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
890   std::map<string, AttrValue> set_attrs;
891   for (const auto& pair : fdef.attr()) {
892     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
893       set_attrs[pair.first] = pair.second;
894     }
895   }
896   return set_attrs;
897 }
898 
899 }  // end namespace
900 
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)901 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
902   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
903 
904   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
905   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
906   if (f1_attrs.size() != f2_attrs.size()) return false;
907   for (const auto& iter1 : f1_attrs) {
908     auto iter2 = f2_attrs.find(iter1.first);
909     if (iter2 == f2_attrs.end()) return false;
910     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
911   }
912 
913   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
914     return false;
915   }
916 
917   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
918   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
919   if (ret1 != ret2) return false;
920 
921   std::map<string, string> control_ret1(f1.control_ret().begin(),
922                                         f1.control_ret().end());
923   std::map<string, string> control_ret2(f2.control_ret().begin(),
924                                         f2.control_ret().end());
925   if (control_ret1 != control_ret2) return false;
926 
927   return true;
928 }
929 
FunctionDefHash(const FunctionDef & fdef)930 uint64 FunctionDefHash(const FunctionDef& fdef) {
931   // signature
932   uint64 h = OpDefHash(fdef.signature());
933 
934   // attrs
935   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
936   for (const auto& p : attrs) {
937     h = Hash64(p.first.data(), p.first.size(), h);
938     h = Hash64Combine(AttrValueHash(p.second), h);
939   }
940 
941   // node defs
942   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
943 
944   // output names
945   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
946   for (const auto& p : ret) {
947     h = Hash64(p.first.data(), p.first.size(), h);
948     h = Hash64(p.second.data(), p.second.size(), h);
949   }
950 
951   // control output names
952   std::map<string, string> control_ret(fdef.control_ret().begin(),
953                                        fdef.control_ret().end());
954   for (const auto& p : control_ret) {
955     h = Hash64(p.first.data(), p.first.size(), h);
956     h = Hash64(p.second.data(), p.second.size(), h);
957   }
958 
959   return h;
960 }
961 
962 static constexpr const char* const kExecutorAttr = "_executor";
963 
964 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)965 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
966                                             AttrSlice attrs) {
967   if (!options.executor_type.empty()) {
968     return options.executor_type;
969   } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
970     return executor_attr->s();
971   } else {
972     return string();
973   }
974 }
975 
976 namespace {
977 class AttrKeyAndValue {
978  public:
979   enum ValueRepresentationOp {
980     kRaw,
981     kCEscape,
982   };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)983   AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
984                   ValueRepresentationOp value_op = kRaw)
985       : key_name_(key_name),
986         key_suffix_(key_suffix),
987         value_op_(value_op),
988         value_(std::move(value)) {}
989 
operator <(const AttrKeyAndValue & b) const990   bool operator<(const AttrKeyAndValue& b) const {
991     if (key_name_ != b.key_name_) {
992       return key_name_ < b.key_name_;
993     } else if (key_suffix_ != b.key_suffix_) {
994       return key_suffix_ < b.key_suffix_;
995     } else {
996       return value_ < b.value_;
997     }
998   }
999 
AppendTo(bool first,string * s) const1000   void AppendTo(bool first, string* s) const {
1001     absl::string_view v;
1002     bool add_escaped = false;
1003     if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
1004       // Use CEscape call below
1005       add_escaped = true;
1006     } else {
1007       // Add raw value contents directly
1008       v = value_;
1009     }
1010     if (key_suffix_ >= 0) {
1011       strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
1012     } else {
1013       strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
1014     }
1015     if (add_escaped) {
1016       strings::StrAppend(s, absl::CEscape(value_));
1017     }
1018   }
1019 
1020  private:
NeedsEscaping(const string & s)1021   static bool NeedsEscaping(const string& s) {
1022     for (auto c : s) {
1023       if (!isalnum(c) && (c != ' ')) {
1024         return true;
1025       }
1026     }
1027     return false;
1028   }
1029 
1030   absl::string_view key_name_;
1031   int key_suffix_;  // -1 if missing
1032   ValueRepresentationOp value_op_;
1033   string value_;
1034 };
1035 }  // namespace
1036 
GetFunctionResourceInputDevice(const Tensor & input,const int arg_index,const FunctionDef & function_def,absl::flat_hash_map<string,std::vector<string>> * composite_devices)1037 string GetFunctionResourceInputDevice(
1038     const Tensor& input, const int arg_index, const FunctionDef& function_def,
1039     absl::flat_hash_map<string, std::vector<string>>* composite_devices) {
1040   const auto& handles = input.flat<ResourceHandle>();
1041   const ResourceHandle& handle0 = handles(0);
1042   string composite_device;
1043   auto iter = function_def.arg_attr().find(arg_index);
1044   if (iter != function_def.arg_attr().end()) {
1045     auto arg_attr = iter->second.attr().find("_composite_device");
1046     if (arg_attr != iter->second.attr().end()) {
1047       composite_device = arg_attr->second.s();
1048     }
1049   }
1050   if (!composite_device.empty()) {
1051     if (composite_devices->find(composite_device) == composite_devices->end()) {
1052       for (int i = 0; i < handles.size(); ++i) {
1053         (*composite_devices)[composite_device].push_back(handles(i).device());
1054       }
1055     }
1056     return composite_device;
1057   } else {
1058     return handle0.device();
1059   }
1060 }
1061 
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)1062 string Canonicalize(const string& funcname, AttrSlice attrs,
1063                     const FunctionLibraryRuntime::InstantiateOptions& options) {
1064   absl::InlinedVector<AttrKeyAndValue, 8> entries;
1065   entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1066                   options.input_devices.size());
1067   for (const auto& p : attrs) {
1068     if (p.first != kExecutorAttr) {
1069       entries.push_back(AttrKeyAndValue(
1070           p.first, -1, Print(p.second, /*hash_string_attrs=*/true)));
1071     }
1072   }
1073   if (!options.target.empty()) {
1074     entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1075                                       AttrKeyAndValue::kCEscape));
1076   }
1077   for (int i = 0; i < options.input_devices.size(); ++i) {
1078     entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1079                                       AttrKeyAndValue::kCEscape));
1080   }
1081   for (int i = 0; i < options.output_devices.size(); ++i) {
1082     entries.push_back(AttrKeyAndValue("_output_dev", i,
1083                                       options.output_devices[i],
1084                                       AttrKeyAndValue::kCEscape));
1085   }
1086   for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1087     entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1088                                       DataTypeString(iter.second.dtype)));
1089     entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1090                                       iter.second.shape.DebugString(),
1091                                       AttrKeyAndValue::kCEscape));
1092   }
1093   if (options.lib_def) {
1094     entries.push_back(AttrKeyAndValue(
1095         "_lib_def", -1,
1096         absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1097   }
1098   if (!options.state_handle.empty()) {
1099     entries.push_back(
1100         AttrKeyAndValue("_state_handle", -1, options.state_handle));
1101   }
1102   string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1103   if (!executor_type.empty()) {
1104     entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1105   }
1106   if (options.config_proto.ByteSize() > 0) {
1107     string config_proto_serialized;
1108     SerializeToStringDeterministic(options.config_proto,
1109                                    &config_proto_serialized);
1110     entries.push_back(AttrKeyAndValue("_config_proto", -1,
1111                                       config_proto_serialized,
1112                                       AttrKeyAndValue::kCEscape));
1113   }
1114   std::sort(entries.begin(), entries.end());
1115   string result = strings::StrCat(funcname, "[");
1116   bool first = true;
1117   for (const auto& entry : entries) {
1118     entry.AppendTo(first, &result);
1119     first = false;
1120   }
1121   result += "]";
1122   return result;
1123 }
1124 
Canonicalize(const string & funcname,AttrSlice attrs)1125 string Canonicalize(const string& funcname, AttrSlice attrs) {
1126   static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1127       new FunctionLibraryRuntime::InstantiateOptions;
1128   return Canonicalize(funcname, attrs, *kEmptyOptions);
1129 }
1130 
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1131 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1132                                      DataTypeSlice ret_types)
1133     : arg_types_(arg_types.begin(), arg_types.end()),
1134       ret_types_(ret_types.begin(), ret_types.end()) {
1135   args_.resize(arg_types_.size());
1136   rets_.resize(ret_types_.size());
1137 }
1138 
~FunctionCallFrame()1139 FunctionCallFrame::~FunctionCallFrame() {}
1140 
SetArgs(gtl::ArraySlice<Tensor> args)1141 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1142   // Input type checks.
1143   if (args.size() != arg_types_.size()) {
1144     return errors::InvalidArgument("Expects ", arg_types_.size(),
1145                                    " arguments, but ", args.size(),
1146                                    " is provided");
1147   }
1148   for (size_t i = 0; i < args.size(); ++i) {
1149     if (arg_types_[i] != args[i].dtype()) {
1150       return errors::InvalidArgument(
1151           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1152           DataTypeString(args[i].dtype()), " is provided");
1153     }
1154     args_[i] = args[i];
1155   }
1156   return OkStatus();
1157 }
1158 
GetRetvals(std::vector<Tensor> * rets) const1159 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1160   rets->clear();
1161   rets->reserve(rets_.size());
1162   for (size_t i = 0; i < rets_.size(); ++i) {
1163     const auto& item = rets_[i];
1164     if (item.has_val) {
1165       rets->push_back(item.val);
1166     } else {
1167       return errors::Internal("Retval[", i, "] does not have value");
1168     }
1169   }
1170   return OkStatus();
1171 }
1172 
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1173 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1174                                          bool allow_dead_tensors) {
1175   rets->clear();
1176   rets->reserve(rets_.size());
1177   for (size_t i = 0; i < rets_.size(); ++i) {
1178     if (rets_[i].has_val) {
1179       rets->emplace_back(std::move(rets_[i].val));
1180     } else if (allow_dead_tensors) {
1181       rets->emplace_back();
1182     } else {
1183       return errors::Internal("Retval[", i, "] does not have value");
1184     }
1185   }
1186   return OkStatus();
1187 }
1188 
GetArg(int index,const Tensor ** val)1189 Status FunctionCallFrame::GetArg(int index, const Tensor** val) {
1190   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1191     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1192                                    args_.size(), ")");
1193   }
1194   *val = &args_[index];
1195   return OkStatus();
1196 }
1197 
SetRetval(int index,const Tensor & val)1198 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1199   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1200     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1201                                    rets_.size(), ")");
1202   }
1203   if (val.dtype() != ret_types_[index]) {
1204     return errors::InvalidArgument(
1205         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1206         ", but ", DataTypeString(val.dtype()), " is provided.");
1207   }
1208   Retval* item = &rets_[index];
1209   if (!item->has_val) {
1210     item->has_val = true;
1211     item->val = val;
1212   } else {
1213     return errors::Internal("Retval[", index, "] has already been set.");
1214   }
1215   return OkStatus();
1216 }
1217 
1218 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in,const StackTracesMap & stack_traces)1219     FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
1220                                  const StackTracesMap& stack_traces)
1221     : fdef(fdef_in),
1222       // Exact shape inference for functions is handled by ShapeRefiner.
1223       // Here we pass a dummy shape inference function for legacy code paths.
1224       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1225                            true /* is_function */),
1226       stack_traces(stack_traces) {}
1227 
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1228 FunctionLibraryDefinition::FunctionLibraryDefinition(
1229     const FunctionLibraryDefinition& other)
1230     : default_registry_(other.default_registry_) {
1231   tf_shared_lock l(other.mu_);
1232   function_defs_ = other.function_defs_;
1233   func_grad_ = other.func_grad_;
1234 }
1235 
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1236 FunctionLibraryDefinition::FunctionLibraryDefinition(
1237     const OpRegistryInterface* default_registry,
1238     const FunctionDefLibrary& def_lib)
1239     : default_registry_(default_registry),
1240       function_defs_(def_lib.function_size()) {
1241   for (const auto& fdef : def_lib.function()) {
1242     // The latter function definition wins.
1243     auto& ptr = function_defs_[fdef.signature().name()];
1244     ptr.reset(new FunctionDefAndOpRegistration(fdef));
1245   }
1246   for (const auto& grad : def_lib.gradient()) {
1247     func_grad_[grad.function_name()] = grad.gradient_func();
1248   }
1249 }
1250 
~FunctionLibraryDefinition()1251 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1252 
Contains(const string & func) const1253 bool FunctionLibraryDefinition::Contains(const string& func) const {
1254   tf_shared_lock l(mu_);
1255   return function_defs_.find(func) != function_defs_.end();
1256 }
1257 
Find(const string & func) const1258 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1259   tf_shared_lock l(mu_);
1260   auto result = FindHelper(func);
1261   if (result) {
1262     return &result->fdef;
1263   } else {
1264     return nullptr;
1265   }
1266 }
1267 
1268 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1269 FunctionLibraryDefinition::FindHelper(const string& func) const {
1270   auto iter = function_defs_.find(func);
1271   if (iter == function_defs_.end()) {
1272     return nullptr;
1273   } else {
1274     return iter->second;
1275   }
1276 }
1277 
AddFunctionDef(const FunctionDef & fdef,const StackTracesMap & stack_traces)1278 Status FunctionLibraryDefinition::AddFunctionDef(
1279     const FunctionDef& fdef, const StackTracesMap& stack_traces) {
1280   mutex_lock l(mu_);
1281   bool added;
1282   return AddFunctionDefHelper(fdef, stack_traces, &added);
1283 }
1284 
AddFunctionDefHelper(const FunctionDef & fdef,const StackTracesMap & stack_traces,bool * added)1285 Status FunctionLibraryDefinition::AddFunctionDefHelper(
1286     const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) {
1287   *added = false;
1288   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1289       function_defs_[fdef.signature().name()];
1290   if (entry) {
1291     if (!FunctionDefsEqual(entry->fdef, fdef)) {
1292       return errors::InvalidArgument(
1293           "Cannot add function '", fdef.signature().name(),
1294           "' because a different function with the same name already "
1295           "exists.");
1296     }
1297     // Ignore duplicate FunctionDefs.
1298     return OkStatus();
1299   }
1300   const OpDef* op_def;
1301   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1302     return errors::InvalidArgument(
1303         "Cannot add function '", fdef.signature().name(),
1304         "' because an op with the same name already exists.");
1305   }
1306   entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces);
1307   *added = true;
1308   return OkStatus();
1309 }
1310 
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1311 Status FunctionLibraryDefinition::AddHelper(
1312     std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1313   *added = false;
1314   std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1315       function_defs_[registration->fdef.signature().name()];
1316   if (entry) {
1317     if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1318       return errors::InvalidArgument(
1319           "Cannot add function '", registration->fdef.signature().name(),
1320           "' because a different function with the same name already "
1321           "exists.");
1322     }
1323     // Ignore duplicate FunctionDefs.
1324     return OkStatus();
1325   }
1326   const OpDef* op_def;
1327   if (default_registry_
1328           ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1329           .ok()) {
1330     return errors::InvalidArgument(
1331         "Cannot add function '", registration->fdef.signature().name(),
1332         "' because an op with the same name already exists.");
1333   }
1334   entry = std::move(registration);
1335   *added = true;
1336   return OkStatus();
1337 }
1338 
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1339 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1340     const string& func, const FunctionLibraryDefinition& other) {
1341   if (default_registry_ != other.default_registry_) {
1342     return errors::InvalidArgument(
1343         "Cannot copy function '", func,
1344         "' because CopyFunctionDefFrom() requires that both libraries have the "
1345         "same default registry.");
1346   }
1347   std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1348   {
1349     tf_shared_lock l(other.mu_);
1350     function_def = other.FindHelper(func);
1351   }
1352   if (!function_def) {
1353     return errors::InvalidArgument(
1354         "Cannot copy function '", func,
1355         "' because no function with that name exists in the other library.");
1356   }
1357   {
1358     mutex_lock l(mu_);
1359     std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1360     if (entry) {
1361       if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1362         return errors::InvalidArgument(
1363             "Cannot copy function '", func,
1364             "' because a different function with the same name already "
1365             "exists.");
1366       }
1367     } else {
1368       entry = std::move(function_def);
1369     }
1370   }
1371   return OkStatus();
1372 }
1373 
AddGradientDef(const GradientDef & grad)1374 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1375   mutex_lock l(mu_);
1376   bool added;
1377   return AddGradientDefHelper(grad, &added);
1378 }
1379 
AddGradientDefHelper(const GradientDef & grad,bool * added)1380 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1381                                                        bool* added) {
1382   *added = false;
1383   string* entry = &func_grad_[grad.function_name()];
1384   if (!entry->empty()) {
1385     if (*entry != grad.gradient_func()) {
1386       return errors::InvalidArgument(
1387           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1388           grad.function_name(), "' because it already has gradient function ",
1389           "'", *entry, "'");
1390     }
1391     // Ignore duplicate GradientDefs
1392     return OkStatus();
1393   }
1394   *entry = grad.gradient_func();
1395   *added = true;
1396   return OkStatus();
1397 }
1398 
AddLibrary(const FunctionLibraryDefinition & other)1399 Status FunctionLibraryDefinition::AddLibrary(
1400     const FunctionLibraryDefinition& other) {
1401   // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1402   // the duration of the function could lead to deadlock).
1403   FunctionLibraryDefinition clone(other);
1404   mutex_lock l(mu_);
1405   mutex_lock l2(clone.mu_);
1406   // Remember the funcs and grads that we added successfully so that
1407   // we can roll them back on error.
1408   std::vector<string> funcs;
1409   std::vector<string> funcs_with_grads;
1410   Status s;
1411   bool added;
1412   for (auto iter : clone.function_defs_) {
1413     s = AddHelper(iter.second, &added);
1414     if (!s.ok()) {
1415       Status remove_status = Remove(funcs, funcs_with_grads);
1416       if (!remove_status.ok()) {
1417         return remove_status;
1418       }
1419       return s;
1420     }
1421     if (added) {
1422       funcs.push_back(iter.second->fdef.signature().name());
1423     }
1424   }
1425   for (auto iter : clone.func_grad_) {
1426     GradientDef grad;
1427     grad.set_function_name(iter.first);
1428     grad.set_gradient_func(iter.second);
1429     s = AddGradientDefHelper(grad, &added);
1430     if (!s.ok()) {
1431       Status remove_status = Remove(funcs, funcs_with_grads);
1432       if (!remove_status.ok()) {
1433         return remove_status;
1434       }
1435       return s;
1436     }
1437     if (added) {
1438       funcs_with_grads.push_back(grad.function_name());
1439     }
1440   }
1441   return OkStatus();
1442 }
1443 
AddLibrary(const FunctionDefLibrary & lib_def)1444 Status FunctionLibraryDefinition::AddLibrary(
1445     const FunctionDefLibrary& lib_def) {
1446   // Remember the funcs and grads that we added successfully so that
1447   // we can roll them back on error.
1448   mutex_lock l(mu_);
1449   std::vector<string> funcs;
1450   std::vector<string> funcs_with_grads;
1451   Status s;
1452   bool added;
1453   for (const FunctionDef& fdef : lib_def.function()) {
1454     s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added);
1455     if (!s.ok()) {
1456       Status remove_status = Remove(funcs, funcs_with_grads);
1457       if (!remove_status.ok()) {
1458         return remove_status;
1459       }
1460       return s;
1461     }
1462     if (added) {
1463       funcs.push_back(fdef.signature().name());
1464     }
1465   }
1466   for (const GradientDef& grad : lib_def.gradient()) {
1467     s = AddGradientDefHelper(grad, &added);
1468     if (!s.ok()) {
1469       Status remove_status = Remove(funcs, funcs_with_grads);
1470       if (!remove_status.ok()) {
1471         return remove_status;
1472       }
1473       return s;
1474     }
1475     if (added) {
1476       funcs_with_grads.push_back(grad.function_name());
1477     }
1478   }
1479   return OkStatus();
1480 }
1481 
ReplaceFunction(const string & func,const FunctionDef & fdef,const StackTracesMap & stack_traces)1482 Status FunctionLibraryDefinition::ReplaceFunction(
1483     const string& func, const FunctionDef& fdef,
1484     const StackTracesMap& stack_traces) {
1485   mutex_lock l(mu_);
1486   bool added;
1487   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1488   TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, stack_traces, &added));
1489   return OkStatus();
1490 }
1491 
ReplaceGradient(const GradientDef & grad)1492 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1493   mutex_lock l(mu_);
1494   bool added;
1495   TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1496   TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1497   return OkStatus();
1498 }
1499 
RemoveFunction(const string & func)1500 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1501   mutex_lock l(mu_);
1502   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1503   return OkStatus();
1504 }
1505 
RemoveFunctionHelper(const string & func)1506 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1507   const auto& i = function_defs_.find(func);
1508   if (i == function_defs_.end()) {
1509     return errors::InvalidArgument("Tried to remove non-existent function '",
1510                                    func, "'.");
1511   }
1512   function_defs_.erase(i);
1513   return OkStatus();
1514 }
1515 
Clear()1516 void FunctionLibraryDefinition::Clear() {
1517   mutex_lock l(mu_);
1518   function_defs_.clear();
1519   func_grad_.clear();
1520 }
1521 
RemoveGradient(const string & func)1522 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1523   const auto& i = func_grad_.find(func);
1524   if (i == func_grad_.end()) {
1525     return errors::InvalidArgument("Tried to remove non-existent gradient '",
1526                                    func, "'.");
1527   }
1528   func_grad_.erase(i);
1529   return OkStatus();
1530 }
1531 
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1532 Status FunctionLibraryDefinition::Remove(
1533     const std::vector<string>& funcs,
1534     const std::vector<string>& funcs_with_grads) {
1535   Status s;
1536   for (const string& f : funcs) {
1537     s = RemoveFunctionHelper(f);
1538     if (!s.ok()) {
1539       return s;
1540     }
1541   }
1542   for (const string& f : funcs_with_grads) {
1543     s = RemoveGradient(f);
1544     if (!s.ok()) {
1545       return s;
1546     }
1547   }
1548   return OkStatus();
1549 }
1550 
FindGradient(const string & func) const1551 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1552   tf_shared_lock l(mu_);
1553   return gtl::FindWithDefault(func_grad_, func, "");
1554 }
1555 
FindGradientHelper(const string & func) const1556 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1557   return gtl::FindWithDefault(func_grad_, func, "");
1558 }
1559 
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1560 Status FunctionLibraryDefinition::LookUp(
1561     const string& op, const OpRegistrationData** op_reg_data) const {
1562   tf_shared_lock l(mu_);
1563   auto iter = function_defs_.find(op);
1564   if (iter != function_defs_.end()) {
1565     *op_reg_data = &iter->second->op_registration_data;
1566     return OkStatus();
1567   }
1568   return default_registry_->LookUp(op, op_reg_data);
1569 }
1570 
UniqueFunctionName(StringPiece prefix) const1571 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1572   tf_shared_lock l(mu_);
1573   int index = 0;
1574   string name = strings::StrCat(prefix, index);
1575   while (function_defs_.find(name) != function_defs_.end()) {
1576     ++index;
1577     name = strings::StrCat(prefix, index);
1578   }
1579   return name;
1580 }
1581 
GetAttrImpl(const NodeDef & ndef) const1582 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1583     const NodeDef& ndef) const {
1584   if (ndef.op() != kGradientOp) {
1585     // If 'ndef' calls a function and the function's def has the attr,
1586     // returns it.
1587     return Find(ndef.op());
1588   }
1589 
1590   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1591   // Foo's attributes.
1592   const NameAttrList* forward_func_attrs;
1593   if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1594     return nullptr;
1595   }
1596   const string& func_name = forward_func_attrs->name();
1597   {
1598     tf_shared_lock l(mu_);
1599     const string& grad_name = FindGradientHelper(func_name);
1600     // If 'func' has a user-defined gradient function, uses the grad
1601     // function's attrs to see if noinline is specified. Otherwise,
1602     // uses func's attrs.
1603     if (!grad_name.empty()) {
1604       if (const auto helper = FindHelper(grad_name)) {
1605         return &(helper->fdef);
1606       } else {
1607         return nullptr;
1608       }
1609     }
1610     if (const auto helper = FindHelper(func_name)) {
1611       return &(helper->fdef);
1612     } else {
1613       return nullptr;
1614     }
1615   }
1616 }
1617 
ListFunctionNames() const1618 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1619   std::vector<string> function_names;
1620   tf_shared_lock l(mu_);
1621   function_names.reserve(function_defs_.size());
1622   for (const auto& it : function_defs_) {
1623     function_names.emplace_back(it.first);
1624   }
1625   return function_names;
1626 }
1627 
ToProto() const1628 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1629   FunctionDefLibrary lib;
1630   tf_shared_lock l(mu_);
1631   for (const auto& f : function_defs_) {
1632     *lib.add_function() = f.second->fdef;
1633   }
1634   for (const auto& g : func_grad_) {
1635     GradientDef* gd = lib.add_gradient();
1636     gd->set_function_name(g.first);
1637     gd->set_gradient_func(g.second);
1638   }
1639   return lib;
1640 }
1641 
1642 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1643 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1644                                           const string& attr, T* value) const {
1645   const FunctionDef* fdef = GetAttrImpl(ndef);
1646   if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1647     return OkStatus();
1648   }
1649   return errors::InvalidArgument("Attr ", attr, " is not defined.");
1650 }
1651 
1652 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1653 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1654                                           T* value) const {
1655   return GetAttr(node.def(), attr, value);
1656 }
1657 
1658 #define GET_ATTR(T)                                                            \
1659   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
1660                                                      const string&, T*) const; \
1661   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
1662                                                      const string&, T*) const;
1663 GET_ATTR(string)
1664 GET_ATTR(bool)
1665 #undef GET_ATTR
1666 
1667 namespace {
1668 
1669 constexpr char kApiImplements[] = "api_implements";
1670 
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1671 std::set<string> ReachableFunctions(
1672     const FunctionLibraryDefinition& flib,
1673     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1674   // Functions that are reachable from the graph.
1675   std::set<string> reachable_funcs;
1676 
1677   // For any functions, if it has attribute "api_implements" =
1678   // "some_interface" and it is reachable, then it means any other
1679   // function with same attribute name and value could also be potentially
1680   // reachable, eg via implementation_selector swapping the nodedef.
1681   absl::flat_hash_set<string> reachable_api_interface;
1682 
1683   // Functions might be reachable from the nested function calls, so we keep a
1684   // queue of functions that we have to check.
1685   gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1686 
1687   // Add reachable and not already processed functions to the functions queue.
1688   const auto add_to_func_queue = [&](const string& func_name) {
1689     const FunctionDef* func = flib.Find(func_name);
1690     if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1691       func_queue.push_back(func);
1692     }
1693   };
1694 
1695   // If any function with certain API name is reachable, all the other functions
1696   // with same API name should also be checked.
1697   const auto add_function_with_api_interface = [&](const string& api_name) {
1698     if (!reachable_api_interface.contains(api_name)) {
1699       reachable_api_interface.insert(api_name);
1700       for (const auto& func_name : flib.ListFunctionNames()) {
1701         const auto& func_def = flib.Find(func_name);
1702         const auto attr_it = func_def->attr().find(kApiImplements);
1703         if (attr_it != func_def->attr().end() &&
1704             attr_it->second.s() == api_name) {
1705           add_to_func_queue(func_name);
1706         }
1707       }
1708     }
1709   };
1710 
1711   // Add all the functions that are reachable from the given node to the queue.
1712   const auto process_node = [&](const NodeDef& node) {
1713     // Node itself can be a call to the function.
1714     add_to_func_queue(node.op());
1715 
1716     // Or node can have an attribute referencing a function.
1717     for (const auto& attr : node.attr()) {
1718       const auto& attr_value = attr.second;
1719 
1720       // 1. AttrValue.func
1721       if (attr_value.has_func()) {
1722         add_to_func_queue(attr_value.func().name());
1723       }
1724 
1725       // 2. AttrValue.ListValue.func
1726       if (attr_value.has_list()) {
1727         for (const auto& func : attr_value.list().func()) {
1728           add_to_func_queue(func.name());
1729         }
1730       }
1731     }
1732   };
1733 
1734   // Add all functions that are directly called from the optimized graph.
1735   std::for_each(nodes.begin(), nodes.end(), process_node);
1736 
1737   // Process all reachable functions.
1738   while (!func_queue.empty()) {
1739     const FunctionDef* func = func_queue.back();
1740     func_queue.pop_back();
1741 
1742     const string& func_name = func->signature().name();
1743     reachable_funcs.insert(func_name);
1744 
1745     const auto attr_it = func->attr().find(kApiImplements);
1746     if (attr_it != func->attr().end()) {
1747       add_function_with_api_interface(attr_it->second.s());
1748     }
1749 
1750     // Find all the functions called from the function body.
1751     const auto& func_body = func->node_def();
1752     std::for_each(func_body.begin(), func_body.end(), process_node);
1753 
1754     // Check if the function has a registered gradient.
1755     const string grad_func_name = flib.FindGradient(func_name);
1756     if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1757   }
1758 
1759   return reachable_funcs;
1760 }
1761 
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1762 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1763     const FunctionLibraryDefinition& flib,
1764     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1765   std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1766 
1767   FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1768                                            FunctionDefLibrary());
1769 
1770   for (const string& func_name : reachable_funcs) {
1771     // This should never fail, because we copy functions from a valid flib and
1772     // use the same default registry.
1773     Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1774     TF_DCHECK_OK(added);
1775 
1776     const string grad_func_name = flib.FindGradient(func_name);
1777     if (!grad_func_name.empty()) {
1778       GradientDef grad;
1779       grad.set_function_name(func_name);
1780       grad.set_gradient_func(grad_func_name);
1781       // It can only fail if function already has a gradient function.
1782       const Status added_grad = reachable_flib.AddGradientDef(grad);
1783       TF_DCHECK_OK(added_grad);
1784     }
1785   }
1786 
1787   return reachable_flib;
1788 }
1789 
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1790 string AllocatorAttributesToString(
1791     const std::vector<AllocatorAttributes>& attrs) {
1792   string result("[");
1793   // AllocatorAttribute::DebugString produces around 85 bytes now.
1794   result.reserve(100 * attrs.size());
1795   for (const AllocatorAttributes& attr : attrs) {
1796     result.append(attr.DebugString());
1797     result.append(", ");
1798   }
1799   if (!attrs.empty()) {
1800     result.resize(result.size() - 2);
1801   }
1802   result.append("]");
1803   return result;
1804 }
1805 
IsSet(void * ptr)1806 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1807 
1808 }  // namespace
1809 
ReachableDefinitions(const GraphDef & graph) const1810 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1811     const GraphDef& graph) const {
1812   return ReachableFunctionLibraryDefinition(*this, graph.node());
1813 }
1814 
ReachableDefinitions(const FunctionDef & func) const1815 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1816     const FunctionDef& func) const {
1817   return ReachableFunctionLibraryDefinition(*this, func.node_def());
1818 }
1819 
DebugString() const1820 string FunctionLibraryRuntime::Options::DebugString() const {
1821   return absl::StrCat(
1822       "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1823       " cancellation_manager=", IsSet(cancellation_manager),
1824       " collective_executor=", IsSet(collective_executor),
1825       " step_container=", IsSet(step_container),
1826       " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1827       " remote_execution=", remote_execution, " source_device=", source_device,
1828       " create_rendezvous=", create_rendezvous,
1829       " allow_dead_tensors=", allow_dead_tensors,
1830       " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1831       " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1832 }
1833 
InitFromString(StringPiece val)1834 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1835   if (val.size() >= 2 && val[0] == '$') {
1836     proto.set_placeholder(val.data() + 1, val.size() - 1);
1837   } else {
1838     SetAttrValue(val, &proto);
1839   }
1840 }
1841 
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1842 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1843     const string& name,
1844     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1845   AttrValueWrapper ret;
1846   ret.proto.mutable_func()->set_name(name);
1847   for (const auto& a : attrs) {
1848     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1849   }
1850   return ret;
1851 }
1852 
ToNodeDef() const1853 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1854   NodeDef n;
1855   n.set_op(this->op);
1856   n.set_name(GetName());
1857   for (const auto& a : this->attr) {
1858     n.mutable_attr()->insert({a.first, a.second.proto});
1859   }
1860   for (const string& a : this->arg) {
1861     n.add_input(a);
1862   }
1863   for (const string& d : this->dep) {
1864     n.add_input(strings::StrCat("^", d));
1865   }
1866   if (!this->device.empty()) {
1867     n.set_device(this->device);
1868   }
1869   if (!this->original_node_names.empty()) {
1870     *n.mutable_experimental_debug_info()->mutable_original_node_names() = {
1871         this->original_node_names.begin(), this->original_node_names.end()};
1872   }
1873   if (!this->original_func_names.empty()) {
1874     *n.mutable_experimental_debug_info()->mutable_original_func_names() = {
1875         this->original_func_names.begin(), this->original_func_names.end()};
1876   }
1877   return n;
1878 }
1879 
1880 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1881 FunctionDef FunctionDefHelper::Create(
1882     const string& function_name, gtl::ArraySlice<string> in_def,
1883     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1884     gtl::ArraySlice<Node> node_def,
1885     gtl::ArraySlice<std::pair<string, string>> ret_def,
1886     gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1887   FunctionDef fdef;
1888 
1889   // Signature
1890   OpDefBuilder b(function_name);
1891   for (const auto& i : in_def) b.Input(i);
1892   for (const auto& o : out_def) b.Output(o);
1893   for (const auto& a : attr_def) b.Attr(a);
1894   for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1895 
1896   OpRegistrationData op_reg_data;
1897   TF_CHECK_OK(b.Finalize(&op_reg_data));
1898   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1899 
1900   // Function body
1901   for (const auto& n : node_def) {
1902     *(fdef.add_node_def()) = n.ToNodeDef();
1903   }
1904 
1905   // Returns
1906   for (const auto& r : ret_def) {
1907     fdef.mutable_ret()->insert({r.first, r.second});
1908   }
1909 
1910   // Control returns
1911   for (const auto& cr : control_ret_def) {
1912     fdef.mutable_control_ret()->insert({cr.first, cr.second});
1913   }
1914 
1915   auto* op_def_registry = OpRegistry::Global();
1916   // Check if any op is stateful.
1917   for (const auto& n : node_def) {
1918     const OpDef* op_def = nullptr;
1919     auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1920     // Lookup can fail if e.g. we are calling a function that was not yet
1921     // defined.  If it happens, conservatively assume the op is stateful.
1922     if (!status.ok() || op_def->is_stateful()) {
1923       fdef.mutable_signature()->set_is_stateful(true);
1924     }
1925   }
1926 
1927   return fdef;
1928 }
1929 
1930 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1931 FunctionDef FunctionDefHelper::Create(
1932     const string& function_name, gtl::ArraySlice<string> in_def,
1933     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1934     gtl::ArraySlice<Node> node_def,
1935     gtl::ArraySlice<std::pair<string, string>> ret_def) {
1936   return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1937                 /*control_ret_def=*/{});
1938 }
1939 
1940 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1941 FunctionDef FunctionDefHelper::Define(const string& name,
1942                                       gtl::ArraySlice<string> arg_def,
1943                                       gtl::ArraySlice<string> ret_def,
1944                                       gtl::ArraySlice<string> attr_def,
1945                                       gtl::ArraySlice<Node> node_def) {
1946   FunctionDef fdef;
1947   OpDefBuilder b(name);
1948   for (const auto& a : arg_def) b.Input(a);
1949   for (const auto& r : ret_def) b.Output(r);
1950   for (const auto& a : attr_def) b.Attr(a);
1951 
1952   OpRegistrationData op_reg_data;
1953   TF_CHECK_OK(b.Finalize(&op_reg_data));
1954   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1955 
1956   // Mapping from legacy output names to NodeDef outputs.
1957   std::unordered_map<string, string> ret_index;
1958   for (const auto& a : fdef.signature().input_arg()) {
1959     ret_index[a.name()] = a.name();
1960   }
1961 
1962   // For looking up OpDefs
1963   auto* op_def_registry = OpRegistry::Global();
1964 
1965   // Function body
1966   for (const auto& src : node_def) {
1967     NodeDef* n = fdef.add_node_def();
1968     n->set_op(src.op);
1969     n->set_name(src.GetName());
1970     for (const auto& a : src.attr) {
1971       n->mutable_attr()->insert({a.first, a.second.proto});
1972     }
1973     for (const string& a : src.arg) {
1974       const auto iter = ret_index.find(a);
1975       CHECK(iter != ret_index.end())
1976           << "Node input '" << a << "' in '" << n->name() << "' of " << name;
1977       n->add_input(iter->second);
1978     }
1979     for (const string& d : src.dep) {
1980       n->add_input(strings::StrCat("^", d));
1981     }
1982 
1983     // Add the outputs of this node to ret_index.
1984     const OpDef* op_def = nullptr;
1985     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1986     CHECK(op_def != nullptr) << n->op();
1987     NameRangeMap output_names;
1988     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1989     for (const auto& o : output_names) {
1990       CHECK_LE(o.second.second, src.ret.size())
1991           << "Missing ret for output '" << o.first << "' in '" << n->name()
1992           << "' of " << name;
1993       for (int i = o.second.first; i < o.second.second; ++i) {
1994         ret_index[src.ret[i]] =
1995             strings::StrCat(n->name(), ":", o.first, ":", i - o.second.first);
1996       }
1997     }
1998     if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1999   }
2000 
2001   // Returns
2002   for (const auto& r : fdef.signature().output_arg()) {
2003     const auto iter = ret_index.find(r.name());
2004     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
2005     fdef.mutable_ret()->insert({r.name(), iter->second});
2006   }
2007   return fdef;
2008 }
2009 
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)2010 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
2011                                       gtl::ArraySlice<string> ret_def,
2012                                       gtl::ArraySlice<string> attr_def,
2013                                       gtl::ArraySlice<Node> node_def) {
2014   return Define("_", arg_def, ret_def, attr_def, node_def);
2015 }
2016 
2017 namespace gradient {
2018 
2019 typedef std::unordered_map<string, Creator> OpGradFactory;
2020 
GetOpGradFactory()2021 OpGradFactory* GetOpGradFactory() {
2022   static OpGradFactory* factory = new OpGradFactory;
2023   return factory;
2024 }
2025 
RegisterOp(const string & op,Creator func)2026 bool RegisterOp(const string& op, Creator func) {
2027   CHECK(GetOpGradFactory()->insert({op, func}).second)
2028       << "Duplicated gradient for " << op;
2029   return true;
2030 }
2031 
GetOpGradientCreator(const string & op,Creator * creator)2032 Status GetOpGradientCreator(const string& op, Creator* creator) {
2033   auto fac = GetOpGradFactory();
2034   auto iter = fac->find(op);
2035   if (iter == fac->end()) {
2036     return errors::NotFound("No gradient defined for op: ", op);
2037   }
2038   *creator = iter->second;
2039   return OkStatus();
2040 }
2041 
2042 }  // end namespace gradient
2043 
2044 }  // namespace tensorflow
2045