1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17
18 #include <ctype.h>
19
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/util/device_name_utils.h"
41 #include "tensorflow/core/util/equal_graph_def.h"
42
43 namespace tensorflow {
44
45 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
46 /* static */ constexpr const char* const
47 FunctionLibraryDefinition::kDeviceArgOp;
48 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
49 /* static */ constexpr const char* const
50 FunctionLibraryDefinition::kDeviceRetOp;
51 /* static */ constexpr const char* const
52 FunctionLibraryDefinition::kIntsOnDeviceAttr;
53 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
54 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
55
56 // Extracts the actual type from "attr_values" based on its definition
57 // "arg_def".
58 //
59 // If "arg_def" is a N*T type, *is_type_list is set to false, and
60 // *dtypes is set to be a vector of size N and each element is T.
61 //
62 // If "arg_def" is a list(type), *is_type_list is set to true, and
63 // *dtypes is set to be a vector of types specified in attrs for
64 // arg_def.
65 //
66 // Otherwise (arg_def is a simple type T), *is_type_list is set to
67 // false, and *dtypes is set to a single element vector, whose only
68 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)69 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
70 bool* is_type_list, DataTypeVector* dtypes) {
71 dtypes->clear();
72 if (!arg_def.type_list_attr().empty()) {
73 const AttrValue* v = attrs.Find(arg_def.type_list_attr());
74 if (v == nullptr) {
75 return errors::NotFound("type attr not found: ",
76 arg_def.type_list_attr());
77 }
78 *is_type_list = true;
79 for (int i = 0; i < v->list().type_size(); ++i) {
80 dtypes->push_back(v->list().type(i));
81 }
82 return Status::OK();
83 }
84
85 *is_type_list = false;
86 int num = 1;
87 if (!arg_def.number_attr().empty()) {
88 const AttrValue* v = attrs.Find(arg_def.number_attr());
89 if (v == nullptr) {
90 return errors::NotFound("type attr not found: ", arg_def.type_attr());
91 }
92 num = v->i();
93 }
94
95 DataType dtype;
96 if (arg_def.type() != DT_INVALID) {
97 dtype = arg_def.type();
98 } else if (arg_def.type_attr().empty()) {
99 dtype = DT_INVALID;
100 } else {
101 const AttrValue* v = attrs.Find(arg_def.type_attr());
102 if (v == nullptr) {
103 return errors::NotFound("type attr not found: ", arg_def.type_attr());
104 }
105 dtype = v->type();
106 }
107 dtypes->resize(num, dtype);
108 return Status::OK();
109 }
110
111 namespace {
112
113 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)114 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
115 SetAttrValue(val, &((*ndef->mutable_attr())[name]));
116 }
117
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)118 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
119 // attr_values should specify all attrs defined in fdef.
120 for (const auto& a : sig.attr()) {
121 const AttrValue* v = attr_values.Find(a.name());
122 if (!v) {
123 return errors::NotFound("Attr ", a.name(), " is not found from ",
124 SummarizeOpDef(sig));
125 }
126 Status status = AttrValueHasType(*v, a.type());
127 if (!status.ok()) {
128 errors::AppendToMessage(&status, "for attr '", a.name(), "'");
129 return status;
130 }
131 }
132
133 // TODO(josh11b): Enable this code once it works with function gradients.
134 // Right now the C++ function gradient code assumes it can pass
135 // all the attrs of the function to the gradient, and any attrs that
136 // the gradient doesn't care about will be ignored.
137 #if 0
138 if (attr_values.size() != sig.attr_size()) {
139 for (const auto& a : attr_values) {
140 // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
141 bool found = false;
142 for (const auto& s : sig.attr()) {
143 if (a.first == s.name()) {
144 found = true;
145 break;
146 }
147 }
148 if (!found) {
149 return errors::NotFound("Attr ", a.first, " is not found in ",
150 SummarizeOpDef(sig));
151 }
152 }
153 }
154 #endif
155
156 return Status::OK();
157 }
158
159 // A helper class for instantiating functions. This contains shared information
160 // like the resulting graph and node name index.
161 class FunctionInstantiationHelper {
162 public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)163 FunctionInstantiationHelper(GetFunctionSignature get_function,
164 InstantiationResult* result)
165 : get_function_(std ::move(get_function)), result_(*result) {
166 result_.nodes.clear();
167 }
168
169 // Builds index for nodes that can be used as node's input arguments.
170 // `resource_arg_unique_id`: if non-negative, will be populated to the
171 // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64 resource_arg_unique_id)172 Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
173 const FunctionDef::ArgAttrs* arg_attrs,
174 bool ints_on_device, int64 resource_arg_unique_id) {
175 bool is_type_list;
176 DataTypeVector dtypes;
177 TF_RETURN_IF_ERROR(
178 ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
179 CHECK_GE(dtypes.size(), size_t{1});
180 int arg_index = result_.nodes.size();
181 TF_RETURN_IF_ERROR(
182 AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
183 // Creates dtypes.size() nodes in the graph.
184 for (size_t i = 0; i < dtypes.size(); ++i) {
185 TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
186 {true, arg_index, 0, false, {dtypes[i]}}));
187 DCHECK_EQ(arg_index, result_.nodes.size());
188 string name = arg_def.name();
189 if (dtypes.size() > 1) {
190 strings::StrAppend(&name, "_", i);
191 }
192 NodeDef* gnode = AddNode(name);
193 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
194 gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
195 } else {
196 gnode->set_op(FunctionLibraryDefinition::kArgOp);
197 }
198 DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
199 AddAttr("T", dtype, gnode);
200 AddAttr("index", arg_index, gnode);
201 if (resource_arg_unique_id >= 0) {
202 AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
203 }
204 if (arg_attrs) {
205 for (const auto& arg_attr : arg_attrs->attr()) {
206 AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
207 }
208 }
209 result_.arg_types.push_back(dtypes[i]);
210 ++arg_index;
211 }
212 return Status::OK();
213 }
214
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)215 Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
216 const int arg_index) {
217 const OpDef* node_sig = nullptr;
218 TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
219 if (node_sig->output_arg_size() == 0) {
220 return AddItem(node.name(), {false, arg_index, 0, false, {}});
221 }
222 const int num_retval = node_sig->output_arg_size();
223 int start = 0;
224 bool is_type_list;
225 DataTypeVector dtypes;
226 for (int i = 0; i < num_retval; ++i) {
227 TF_RETURN_IF_ERROR(
228 ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
229 // Note that we rely on the backwards-compatibility test enforcing
230 // that output_arg(*).name() doesn't change here.
231 const string base_name =
232 strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
233 TF_RETURN_IF_ERROR(
234 AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
235 for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
236 TF_RETURN_IF_ERROR(
237 AddItem(strings::StrCat(base_name, ":", j),
238 {false, arg_index, start + j, false, {dtypes[j]}}));
239 }
240 start += dtypes.size();
241 }
242 return Status::OK();
243 }
244
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)245 Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
246 const OpDef* fnode_sig = nullptr;
247 TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
248 NodeDef* gnode = AddNode(fnode.name());
249 gnode->set_op(fnode.op());
250 gnode->set_device(fnode.device());
251 int gnode_idx = nodes_.size() - 1;
252
253 // Input
254 const int num_args = fnode_sig->input_arg_size();
255 bool is_type_list; // ignored
256 DataTypeVector dtypes;
257 int fnode_arg_index = 0;
258 for (int i = 0; i < num_args; ++i) {
259 TF_RETURN_IF_ERROR(
260 ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
261 // Consume inputs (indexed by fnode_arg_index) until we have
262 // matched each element of dtypes (indexed by j).
263 for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
264 if (fnode_arg_index >= fnode.input_size()) {
265 // Should never happen if we computed dtypes correctly.
266 return errors::InvalidArgument(
267 "Attempt to access beyond input size: ", fnode_arg_index,
268 " >= ", fnode.input_size());
269 }
270 // Look up the next input.
271 const string& input_name = fnode.input(fnode_arg_index);
272 const auto* item = GetItemOrNull(input_name);
273 if (item == nullptr) {
274 return errors::InvalidArgument(
275 "input ", input_name,
276 " is not found: ", FormatNodeDefForError(fnode));
277 }
278 if (item->dtypes.size() > dtypes.size() - j) {
279 return errors::InvalidArgument("Input ", input_name, " too long for ",
280 fnode_sig->input_arg(i).name());
281 }
282 // Match up all the elements of this input (indexed by k) with
283 // elements of dtypes (advancing j).
284 for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
285 if (item->dtypes[k] != dtypes[j]) {
286 return errors::InvalidArgument(
287 "input ", fnode_sig->input_arg(i).name(), "[", j,
288 "] expected type ", DataTypeString(dtypes[j]),
289 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
290 input_name, "[", k, "]");
291 }
292 if (item->is_func_arg) {
293 AddInput(gnode_idx, item->nid + k, 0);
294 } else {
295 AddInput(gnode_idx, item->nid, item->idx + k);
296 }
297 }
298 }
299 }
300
301 // Control deps.
302 for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
303 const string& input = fnode.input(i);
304 if (input.empty() || input[0] != '^') {
305 return errors::InvalidArgument("Expected input[", i, "] == '", input,
306 "' to be a control input.");
307 }
308 int nid = -1;
309 const string node_name = input.substr(1);
310 const string node_colon = node_name + ":";
311 const string node_colon_bound = node_name + ";";
312 // index_ is a map sorted lexicographically, so the key we are looking for
313 // must lie in the range [node_name, node_colon_bound).
314 auto it = index_.lower_bound(node_name);
315 while (it != index_.end() && it->first <= node_colon_bound) {
316 if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
317 nid = it->second.nid;
318 break;
319 }
320 ++it;
321 }
322 if (nid == -1) {
323 return errors::InvalidArgument("input[", i, "] == '", input,
324 "', is not found.");
325 }
326 AddDep(gnode_idx, nid);
327 }
328
329 // Attrs.
330 for (const auto& p : attrs) {
331 (*gnode->mutable_attr())[p.first] = p.second;
332 }
333
334 return Status::OK();
335 }
336
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)337 Status AddReturnNode(
338 const OpDef::ArgDef& ret_def, AttrSlice attrs,
339 const ::tensorflow::protobuf::Map<string, string>& ret_map,
340 bool ints_on_device, int* ret_index) {
341 auto ret_iter = ret_map.find(ret_def.name());
342 if (ret_iter == ret_map.end()) {
343 return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
344 }
345 bool is_type_list;
346 DataTypeVector dtypes;
347 TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
348 CHECK_GE(dtypes.size(), size_t{1});
349 const auto* item = GetItemOrNull(ret_iter->second);
350 if (item == nullptr) {
351 return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
352 ret_iter->second, " is not found.");
353 }
354 if (dtypes != item->dtypes) {
355 return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
356 " : ", DataTypeVectorString(dtypes),
357 " vs. ",
358 DataTypeVectorString(item->dtypes));
359 }
360 for (size_t i = 0; i < dtypes.size(); ++i) {
361 string name = strings::StrCat(ret_def.name(), "_RetVal");
362 if (dtypes.size() > 1) {
363 strings::StrAppend(&name, "_", i);
364 }
365 NodeDef* gnode = AddNode(name);
366 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
367 gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
368 } else {
369 gnode->set_op(FunctionLibraryDefinition::kRetOp);
370 }
371 AddInput(nodes_.size() - 1, item->nid, item->idx + i);
372 DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
373 AddAttr("T", dtype, gnode);
374 AddAttr("index", (*ret_index)++, gnode);
375 result_.ret_types.push_back(dtypes[i]);
376 }
377 return Status::OK();
378 }
379
380 // Adds the actual node inputs to the result graph by converting indexes to
381 // the node names.
AddNodeInputs()382 void AddNodeInputs() {
383 for (int i = 0; i < result_.nodes.size(); i++) {
384 NodeInfo& node_info = nodes_[i];
385 for (const auto& p : node_info.data_inputs) {
386 result_.nodes[i].add_input(Name(p.first, p.second));
387 }
388 for (int index : node_info.control_inputs) {
389 result_.nodes[i].add_input(Dep(index));
390 }
391 }
392 }
393
394 private:
395 // This is used to build a small index for all names that can be used as a
396 // node's input arguments.
397 //
398 // If is_func_arg is true, the name is a function's argument. In
399 // this case, the produced graph def has node[nid:nid + dtype.size()].
400 //
401 // Otherwise, the name is a function body's node return value. In
402 // this case, the produced graph def has one node node[nid] and
403 // the node's output index [idx ... idx + num) corresponds to the
404 // named outputs.
405 //
406 // In all cases, "dtype" specifies the data type.
407 struct NameInfoItem {
408 bool is_func_arg;
409 int nid;
410 int idx;
411 bool is_type_list;
412 DataTypeVector dtypes;
413 };
414
415 // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)416 Status AddItem(const string& name, const NameInfoItem& item) {
417 if (!index_.insert({name, item}).second) {
418 return errors::InvalidArgument(
419 strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
420 " name: "),
421 name);
422 }
423 return Status::OK();
424 }
425
GetItemOrNull(const string & name) const426 const NameInfoItem* GetItemOrNull(const string& name) const {
427 return gtl::FindOrNull(index_, name);
428 }
429
Dep(int node_index) const430 string Dep(int node_index) const {
431 return strings::StrCat("^", Name(node_index));
432 }
433
Name(int node_index) const434 string Name(int node_index) const {
435 CHECK_LT(node_index, nodes_.size());
436 return nodes_[node_index].name;
437 }
438
Name(int node_index,int output_index) const439 string Name(int node_index, int output_index) const {
440 if (output_index == 0) {
441 return Name(node_index);
442 } else {
443 return strings::StrCat(Name(node_index), ":", output_index);
444 }
445 }
446
AddNode(const string & name)447 NodeDef* AddNode(const string& name) {
448 result_.nodes.emplace_back();
449 NodeDef* gnode = &result_.nodes.back();
450 gnode->set_name(name);
451 nodes_.push_back({name, {}, {}});
452 CHECK_EQ(result_.nodes.size(), nodes_.size());
453 return gnode;
454 }
455
AddInput(int node_index,int output_node,int output_index)456 void AddInput(int node_index, int output_node, int output_index) {
457 CHECK_LT(node_index, nodes_.size());
458 nodes_[node_index].data_inputs.push_back(
459 std::make_pair(output_node, output_index));
460 }
461
AddDep(int node_index,int dep_index)462 void AddDep(int node_index, int dep_index) {
463 CHECK_LT(node_index, nodes_.size());
464 nodes_[node_index].control_inputs.push_back(dep_index);
465 }
466
467 GetFunctionSignature get_function_;
468 InstantiationResult& result_;
469 // A small index for all names that can be used as a node's input arguments.
470 std::map<string, NameInfoItem> index_;
471 // This contains information about a node in the new graph including the node
472 // names and input nodes' indexes.
473 struct NodeInfo {
474 string name;
475 // Data inputs where <n, k> means arg k of node n.
476 std::vector<std::pair<int, int>> data_inputs;
477 // Control inputs (dependencies).
478 std::vector<int> control_inputs;
479 };
480 // nodes_[i] is the information about result_.nodes[i].
481 std::vector<NodeInfo> nodes_;
482 };
483
484 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)485 string Print(const OpDef::ArgDef& arg) {
486 string out;
487 strings::StrAppend(&out, arg.name(), ":");
488 if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
489 if (!arg.number_attr().empty()) {
490 strings::StrAppend(&out, arg.number_attr(), "*");
491 }
492 if (arg.type() != DT_INVALID) {
493 strings::StrAppend(&out, DataTypeString(arg.type()));
494 } else {
495 strings::StrAppend(&out, arg.type_attr());
496 }
497 if (arg.is_ref()) strings::StrAppend(&out, ")");
498 return out;
499 }
500
501 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)502 string Print(const AttrValue& attr_value) {
503 if (attr_value.value_case() == AttrValue::kType) {
504 return DataTypeString(attr_value.type());
505 } else if ((attr_value.value_case() == AttrValue::kList) &&
506 (attr_value.list().type_size() > 0)) {
507 string ret = "{";
508 for (int i = 0; i < attr_value.list().type_size(); ++i) {
509 if (i > 0) strings::StrAppend(&ret, ", ");
510 strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
511 }
512 strings::StrAppend(&ret, "}");
513 return ret;
514 } else if (attr_value.value_case() == AttrValue::kFunc) {
515 if (attr_value.func().attr_size() == 0) {
516 return attr_value.func().name();
517 }
518 std::vector<string> entries;
519 for (const auto& p : attr_value.func().attr()) {
520 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
521 }
522 std::sort(entries.begin(), entries.end());
523 return strings::StrCat(attr_value.func().name(), "[",
524 absl::StrJoin(entries, ", "), "]");
525 }
526 return SummarizeAttrValue(attr_value);
527 }
528
529 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)530 string Print(const NodeDef& n) {
531 string out;
532 strings::StrAppend(&out, n.name(), " = ", n.op());
533 if (n.attr_size() > 0) {
534 std::vector<string> entries;
535 for (auto& a : n.attr()) {
536 entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
537 }
538 std::sort(entries.begin(), entries.end());
539 // Add a short device string at the end of all attributes.
540 if (!n.device().empty()) {
541 DeviceNameUtils::ParsedName parsed;
542 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
543 entries.push_back(
544 strings::StrCat("device=", parsed.type, ":", parsed.id));
545 } else {
546 entries.push_back("device=<FAILED_TO_PARSE>");
547 }
548 }
549 strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
550 }
551 strings::StrAppend(&out, "(");
552 std::vector<StringPiece> dat;
553 std::vector<string> dep;
554 for (StringPiece s : n.input()) {
555 if (absl::ConsumePrefix(&s, "^")) {
556 dep.emplace_back(s);
557 } else {
558 dat.push_back(s);
559 }
560 }
561 strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
562 if (!dep.empty()) {
563 strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
564 }
565 return out;
566 }
567
Print(const FunctionDef & fdef)568 string Print(const FunctionDef& fdef) {
569 string out;
570 const OpDef& sig = fdef.signature();
571 strings::StrAppend(&out, "\n", sig.name());
572 if (sig.attr_size() > 0) {
573 strings::StrAppend(&out, "[");
574 for (int i = 0; i < sig.attr_size(); ++i) {
575 const auto& a = sig.attr(i);
576 if (i > 0) strings::StrAppend(&out, ", ");
577 if (a.type() == "type") {
578 strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
579 } else {
580 strings::StrAppend(&out, a.name(), ":", a.type());
581 }
582 }
583 strings::StrAppend(&out, "]");
584 }
585 strings::StrAppend(&out, "(");
586 for (int i = 0; i < sig.input_arg_size(); ++i) {
587 if (i > 0) strings::StrAppend(&out, ", ");
588 strings::StrAppend(&out, Print(sig.input_arg(i)));
589 }
590 strings::StrAppend(&out, ") -> (");
591 for (int i = 0; i < sig.output_arg_size(); ++i) {
592 if (i > 0) strings::StrAppend(&out, ", ");
593 strings::StrAppend(&out, Print(sig.output_arg(i)));
594 }
595 strings::StrAppend(&out, ") {\n");
596 for (const auto& n : fdef.node_def()) {
597 strings::StrAppend(&out, " ", Print(n), "\n");
598 }
599 for (const auto& cr : fdef.control_ret()) {
600 strings::StrAppend(&out, " @return ", cr.first, " = ", cr.second, "\n");
601 }
602 for (const auto& r : fdef.ret()) {
603 strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
604 }
605 strings::StrAppend(&out, "}\n");
606 return out;
607 }
608
Print(gtl::ArraySlice<const NodeDef * > nodes)609 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
610 std::vector<const NodeDef*> arg;
611 std::vector<const NodeDef*> ret;
612 std::vector<const NodeDef*> body;
613 for (const NodeDef* n : nodes) {
614 if (n->op() == FunctionLibraryDefinition::kArgOp ||
615 n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
616 arg.push_back(n);
617 } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
618 n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
619 ret.push_back(n);
620 } else {
621 body.push_back(n);
622 }
623 }
624 auto comp = [](const NodeDef* x, const NodeDef* y) {
625 int xi;
626 TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
627 int yi;
628 TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
629 return xi < yi;
630 };
631 std::sort(arg.begin(), arg.end(), comp);
632 std::sort(ret.begin(), ret.end(), comp);
633 string out;
634 strings::StrAppend(&out, "\n(");
635 auto get_type_and_device = [](const NodeDef& n) {
636 DataType dt;
637 if (!TryGetNodeAttr(n, "T", &dt)) {
638 dt = DT_INVALID;
639 }
640 if (!n.device().empty()) {
641 DeviceNameUtils::ParsedName parsed;
642 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
643 return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
644 parsed.id);
645 } else {
646 LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
647 << n.op() << ":" << n.name();
648 return strings::StrCat(DataTypeString(dt), "@",
649 "<FAILED_TO_PARSE_DEVICE>");
650 }
651 }
652 return DataTypeString(dt);
653 };
654 for (size_t i = 0; i < arg.size(); ++i) {
655 const NodeDef* n = arg[i];
656 if (i > 0) strings::StrAppend(&out, ", ");
657 CHECK_GE(n->attr_size(), 2);
658 strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
659 }
660 strings::StrAppend(&out, ") -> (");
661 for (size_t i = 0; i < ret.size(); ++i) {
662 const NodeDef* n = ret[i];
663 if (i > 0) strings::StrAppend(&out, ", ");
664 CHECK_LE(2, n->attr_size());
665
666 // The _RetVal op should have a unique non-control input. We assert that
667 // here and add it to the output.
668 bool found_non_control_input = false;
669 for (const string& input : n->input()) {
670 if (!input.empty() && input[0] != '^') {
671 DCHECK_EQ(found_non_control_input, false)
672 << "RetVal node has more than one non-control input: "
673 << absl::StrJoin(n->input(), ", ");
674 strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
675 found_non_control_input = true;
676 }
677 }
678 DCHECK_EQ(found_non_control_input, true)
679 << "RetVal did not have any non-control inputs: "
680 << absl::StrJoin(n->input(), ", ");
681 }
682 strings::StrAppend(&out, ") {\n");
683 for (size_t i = 0; i < body.size(); ++i) {
684 strings::StrAppend(&out, " ", Print(*body[i]), "\n");
685 }
686 strings::StrAppend(&out, "}\n");
687 return out;
688 }
689
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)690 Status AddDefaultAttrs(const string& op,
691 const GetFunctionSignature& get_function,
692 AttrValueMap* attrs) {
693 const OpDef* op_def = nullptr;
694 TF_RETURN_IF_ERROR(get_function(op, &op_def));
695 AttrSlice attr_slice(attrs);
696 for (const auto& attr_def : op_def->attr()) {
697 if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
698 if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
699 return errors::Internal("Somehow duplicated: ", attr_def.name());
700 }
701 }
702 }
703 return Status::OK();
704 }
705
706 } // end namespace
707
708 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)709 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
710 GetFunctionSignature get_function,
711 InstantiationResult* result) {
712 if (VLOG_IS_ON(5)) {
713 const auto& signature = fdef.signature();
714 VLOG(5) << "Instantiate function definition: name=" << signature.name()
715 << " #input_args=" << signature.input_arg_size()
716 << " #output_args=" << signature.output_arg_size()
717 << " #control_output=" << signature.control_output_size();
718 for (const auto& line : str_util::Split(Print(fdef), '\n')) {
719 VLOG(5) << "|| " << line;
720 }
721 }
722
723 const OpDef& sig = fdef.signature();
724 TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
725
726 bool ints_on_device =
727 fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
728 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
729
730 FunctionInstantiationHelper helper(get_function, result);
731 Status s;
732 for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
733 const OpDef::ArgDef& arg_def = sig.input_arg(i);
734 auto it = fdef.arg_attr().find(i);
735 const FunctionDef::ArgAttrs* arg_attrs =
736 it != fdef.arg_attr().end() ? &it->second : nullptr;
737 auto resource_id_it = fdef.resource_arg_unique_id().find(i);
738 int64 resource_arg_unique_id =
739 resource_id_it != fdef.resource_arg_unique_id().end()
740 ? resource_id_it->second
741 : -1LL;
742 s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
743 ints_on_device, resource_arg_unique_id);
744
745 if (!s.ok()) {
746 errors::AppendToMessage(&s, "In ", Print(arg_def));
747 return s;
748 }
749 }
750
751 auto substitute = [attr_values](StringPiece name, AttrValue* val) {
752 if (const AttrValue* v = attr_values.Find(name)) {
753 *val = *v;
754 return true;
755 }
756 return false;
757 };
758
759 // Makes a copy of all attrs in fdef and substitutes placeholders.
760 // After this step, every attr is bound to a concrete value.
761 std::vector<AttrValueMap> node_attrs;
762 node_attrs.resize(fdef.node_def_size());
763 for (int i = 0; i < fdef.node_def_size(); ++i) {
764 for (auto attr : fdef.node_def(i).attr()) {
765 if (!SubstitutePlaceholders(substitute, &attr.second)) {
766 return errors::InvalidArgument("Failed to bind all placeholders in ",
767 SummarizeAttrValue(attr.second));
768 }
769 if (!node_attrs[i].insert(attr).second) {
770 return errors::Internal("Somehow duplicated: ", attr.first);
771 }
772 }
773 TF_RETURN_IF_ERROR(
774 AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
775 }
776
777 for (int i = 0; i < fdef.node_def_size(); ++i) {
778 s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
779 result->nodes.size() + i);
780 if (!s.ok()) {
781 errors::AppendToMessage(&s, "In ",
782 FormatNodeDefForError(fdef.node_def(i)));
783 return s;
784 }
785 }
786 // Emits one node for each fdef.node_def.
787 for (int i = 0; i < fdef.node_def_size(); ++i) {
788 s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
789 if (!s.ok()) {
790 errors::AppendToMessage(&s, "In ",
791 FormatNodeDefForError(fdef.node_def(i)));
792 return s;
793 }
794 }
795
796 // Emits nodes for the function's return values.
797 int ret_index = 0;
798 for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
799 s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
800 &ret_index);
801 if (!s.ok()) {
802 errors::AppendToMessage(&s, "In function output ", Print(ret_def));
803 return s;
804 }
805 }
806
807 // Adds the actual node inputs using the input indexes.
808 helper.AddNodeInputs();
809
810 return Status::OK();
811 }
812
DebugString(const FunctionDef & func_def)813 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
814
DebugString(const GraphDef & instantiated_func_def)815 string DebugString(const GraphDef& instantiated_func_def) {
816 std::vector<const NodeDef*> ptrs;
817 for (const NodeDef& n : instantiated_func_def.node()) {
818 ptrs.push_back(&n);
819 }
820 return Print(ptrs);
821 }
822
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)823 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
824 std::vector<const NodeDef*> ptrs;
825 for (const NodeDef& n : instantiated_func_nodes) {
826 ptrs.push_back(&n);
827 }
828 return Print(ptrs);
829 }
830
DebugStringWhole(const GraphDef & gdef)831 string DebugStringWhole(const GraphDef& gdef) {
832 string ret;
833 for (const auto& fdef : gdef.library().function()) {
834 strings::StrAppend(&ret, Print(fdef));
835 }
836 strings::StrAppend(&ret, "\n");
837 for (const auto& ndef : gdef.node()) {
838 strings::StrAppend(&ret, Print(ndef), "\n");
839 }
840 return ret;
841 }
842
843 namespace {
844
845 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
846 // Python, it's possible to access unset attrs, which returns a default value
847 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)848 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
849 std::map<string, AttrValue> set_attrs;
850 for (const auto& pair : fdef.attr()) {
851 if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
852 set_attrs[pair.first] = pair.second;
853 }
854 }
855 return set_attrs;
856 }
857
858 } // end namespace
859
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)860 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
861 if (!OpDefEqual(f1.signature(), f2.signature())) return false;
862
863 std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
864 std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
865 if (f1_attrs.size() != f2_attrs.size()) return false;
866 for (const auto& iter1 : f1_attrs) {
867 auto iter2 = f2_attrs.find(iter1.first);
868 if (iter2 == f2_attrs.end()) return false;
869 if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
870 }
871
872 if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
873 return false;
874 }
875
876 std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
877 std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
878 if (ret1 != ret2) return false;
879
880 std::map<string, string> control_ret1(f1.control_ret().begin(),
881 f1.control_ret().end());
882 std::map<string, string> control_ret2(f2.control_ret().begin(),
883 f2.control_ret().end());
884 if (control_ret1 != control_ret2) return false;
885
886 return true;
887 }
888
FunctionDefHash(const FunctionDef & fdef)889 uint64 FunctionDefHash(const FunctionDef& fdef) {
890 // signature
891 uint64 h = OpDefHash(fdef.signature());
892
893 // attrs
894 std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
895 for (const auto& p : attrs) {
896 h = Hash64(p.first.data(), p.first.size(), h);
897 h = Hash64Combine(AttrValueHash(p.second), h);
898 }
899
900 // node defs
901 h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
902
903 // output names
904 std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
905 for (const auto& p : ret) {
906 h = Hash64(p.first.data(), p.first.size(), h);
907 h = Hash64(p.second.data(), p.second.size(), h);
908 }
909
910 // control output names
911 std::map<string, string> control_ret(fdef.control_ret().begin(),
912 fdef.control_ret().end());
913 for (const auto& p : control_ret) {
914 h = Hash64(p.first.data(), p.first.size(), h);
915 h = Hash64(p.second.data(), p.second.size(), h);
916 }
917
918 return h;
919 }
920
921 static constexpr const char* const kExecutorAttr = "_executor";
922
923 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)924 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
925 AttrSlice attrs) {
926 if (!options.executor_type.empty()) {
927 return options.executor_type;
928 } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
929 return executor_attr->s();
930 } else {
931 return string();
932 }
933 }
934
935 namespace {
936 class AttrKeyAndValue {
937 public:
938 enum ValueRepresentationOp {
939 kRaw,
940 kCEscape,
941 };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)942 AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
943 ValueRepresentationOp value_op = kRaw)
944 : key_name_(key_name),
945 key_suffix_(key_suffix),
946 value_op_(value_op),
947 value_(std::move(value)) {}
948
operator <(const AttrKeyAndValue & b) const949 bool operator<(const AttrKeyAndValue& b) const {
950 if (key_name_ != b.key_name_) {
951 return key_name_ < b.key_name_;
952 } else if (key_suffix_ != b.key_suffix_) {
953 return key_suffix_ < b.key_suffix_;
954 } else {
955 return value_ < b.value_;
956 }
957 }
958
AppendTo(bool first,string * s) const959 void AppendTo(bool first, string* s) const {
960 absl::string_view v;
961 bool add_escaped = false;
962 if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
963 // Use CEscape call below
964 add_escaped = true;
965 } else {
966 // Add raw value contents directly
967 v = value_;
968 }
969 if (key_suffix_ >= 0) {
970 strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
971 } else {
972 strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
973 }
974 if (add_escaped) {
975 strings::StrAppend(s, absl::CEscape(value_));
976 }
977 }
978
979 private:
NeedsEscaping(const string & s)980 static bool NeedsEscaping(const string& s) {
981 for (auto c : s) {
982 if (!isalnum(c) && (c != ' ')) {
983 return true;
984 }
985 }
986 return false;
987 }
988
989 absl::string_view key_name_;
990 int key_suffix_; // -1 if missing
991 ValueRepresentationOp value_op_;
992 string value_;
993 };
994 } // namespace
995
GetFunctionResourceInputDevice(const Tensor & input,const int arg_index,const FunctionDef & function_def,absl::flat_hash_map<string,std::vector<string>> * composite_devices)996 string GetFunctionResourceInputDevice(
997 const Tensor& input, const int arg_index, const FunctionDef& function_def,
998 absl::flat_hash_map<string, std::vector<string>>* composite_devices) {
999 const auto& handles = input.flat<ResourceHandle>();
1000 const ResourceHandle& handle0 = handles(0);
1001 string composite_device;
1002 auto iter = function_def.arg_attr().find(arg_index);
1003 if (iter != function_def.arg_attr().end()) {
1004 auto arg_attr = iter->second.attr().find("_composite_device");
1005 if (arg_attr != iter->second.attr().end()) {
1006 composite_device = arg_attr->second.s();
1007 }
1008 }
1009 if (!composite_device.empty()) {
1010 if (composite_devices->find(composite_device) == composite_devices->end()) {
1011 for (int i = 0; i < handles.size(); ++i) {
1012 (*composite_devices)[composite_device].push_back(handles(i).device());
1013 }
1014 }
1015 return composite_device;
1016 } else {
1017 return handle0.device();
1018 }
1019 }
1020
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)1021 string Canonicalize(const string& funcname, AttrSlice attrs,
1022 const FunctionLibraryRuntime::InstantiateOptions& options) {
1023 absl::InlinedVector<AttrKeyAndValue, 8> entries;
1024 entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1025 options.input_devices.size());
1026 for (const auto& p : attrs) {
1027 if (p.first != kExecutorAttr) {
1028 entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
1029 }
1030 }
1031 if (!options.target.empty()) {
1032 entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1033 AttrKeyAndValue::kCEscape));
1034 }
1035 for (int i = 0; i < options.input_devices.size(); ++i) {
1036 entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1037 AttrKeyAndValue::kCEscape));
1038 }
1039 for (int i = 0; i < options.output_devices.size(); ++i) {
1040 entries.push_back(AttrKeyAndValue("_output_dev", i,
1041 options.output_devices[i],
1042 AttrKeyAndValue::kCEscape));
1043 }
1044 for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1045 entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1046 DataTypeString(iter.second.dtype)));
1047 entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1048 iter.second.shape.DebugString(),
1049 AttrKeyAndValue::kCEscape));
1050 }
1051 if (options.lib_def) {
1052 entries.push_back(AttrKeyAndValue(
1053 "_lib_def", -1,
1054 absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1055 }
1056 if (!options.state_handle.empty()) {
1057 entries.push_back(
1058 AttrKeyAndValue("_state_handle", -1, options.state_handle));
1059 }
1060 string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1061 if (!executor_type.empty()) {
1062 entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1063 }
1064 if (options.config_proto.ByteSize() > 0) {
1065 string config_proto_serialized;
1066 options.config_proto.SerializeToString(&config_proto_serialized);
1067 entries.push_back(AttrKeyAndValue("_config_proto", -1,
1068 config_proto_serialized,
1069 AttrKeyAndValue::kCEscape));
1070 }
1071 std::sort(entries.begin(), entries.end());
1072 string result = strings::StrCat(funcname, "[");
1073 bool first = true;
1074 for (const auto& entry : entries) {
1075 entry.AppendTo(first, &result);
1076 first = false;
1077 }
1078 result += "]";
1079 return result;
1080 }
1081
Canonicalize(const string & funcname,AttrSlice attrs)1082 string Canonicalize(const string& funcname, AttrSlice attrs) {
1083 static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1084 new FunctionLibraryRuntime::InstantiateOptions;
1085 return Canonicalize(funcname, attrs, *kEmptyOptions);
1086 }
1087
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1088 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1089 DataTypeSlice ret_types)
1090 : arg_types_(arg_types.begin(), arg_types.end()),
1091 ret_types_(ret_types.begin(), ret_types.end()) {
1092 args_.resize(arg_types_.size());
1093 rets_.resize(ret_types_.size());
1094 }
1095
~FunctionCallFrame()1096 FunctionCallFrame::~FunctionCallFrame() {}
1097
SetArgs(gtl::ArraySlice<Tensor> args)1098 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1099 // Input type checks.
1100 if (args.size() != arg_types_.size()) {
1101 return errors::InvalidArgument("Expects ", arg_types_.size(),
1102 " arguments, but ", args.size(),
1103 " is provided");
1104 }
1105 for (size_t i = 0; i < args.size(); ++i) {
1106 if (arg_types_[i] != args[i].dtype()) {
1107 return errors::InvalidArgument(
1108 "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1109 DataTypeString(args[i].dtype()), " is provided");
1110 }
1111 args_[i] = args[i];
1112 }
1113 return Status::OK();
1114 }
1115
GetRetvals(std::vector<Tensor> * rets) const1116 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1117 rets->clear();
1118 rets->reserve(rets_.size());
1119 for (size_t i = 0; i < rets_.size(); ++i) {
1120 const auto& item = rets_[i];
1121 if (item.has_val) {
1122 rets->push_back(item.val);
1123 } else {
1124 return errors::Internal("Retval[", i, "] does not have value");
1125 }
1126 }
1127 return Status::OK();
1128 }
1129
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1130 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1131 bool allow_dead_tensors) {
1132 rets->clear();
1133 rets->reserve(rets_.size());
1134 for (size_t i = 0; i < rets_.size(); ++i) {
1135 if (rets_[i].has_val) {
1136 rets->emplace_back(std::move(rets_[i].val));
1137 } else if (allow_dead_tensors) {
1138 rets->emplace_back();
1139 } else {
1140 return errors::Internal("Retval[", i, "] does not have value");
1141 }
1142 }
1143 return Status::OK();
1144 }
1145
GetArg(int index,const Tensor ** val)1146 Status FunctionCallFrame::GetArg(int index, const Tensor** val) {
1147 if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1148 return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1149 args_.size(), ")");
1150 }
1151 *val = &args_[index];
1152 return Status::OK();
1153 }
1154
SetRetval(int index,const Tensor & val)1155 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1156 if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1157 return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1158 rets_.size(), ")");
1159 }
1160 if (val.dtype() != ret_types_[index]) {
1161 return errors::InvalidArgument(
1162 "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1163 ", but ", DataTypeString(val.dtype()), " is provided.");
1164 }
1165 Retval* item = &rets_[index];
1166 if (!item->has_val) {
1167 item->has_val = true;
1168 item->val = val;
1169 } else {
1170 return errors::Internal("Retval[", index, "] has already been set.");
1171 }
1172 return Status::OK();
1173 }
1174
1175 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in,const StackTracesMap & stack_traces)1176 FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
1177 const StackTracesMap& stack_traces)
1178 : fdef(fdef_in),
1179 // Exact shape inference for functions is handled by ShapeRefiner.
1180 // Here we pass a dummy shape inference function for legacy code paths.
1181 op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1182 true /* is_function */),
1183 stack_traces(stack_traces) {}
1184
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1185 FunctionLibraryDefinition::FunctionLibraryDefinition(
1186 const FunctionLibraryDefinition& other)
1187 : default_registry_(other.default_registry_) {
1188 tf_shared_lock l(other.mu_);
1189 function_defs_ = other.function_defs_;
1190 func_grad_ = other.func_grad_;
1191 }
1192
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1193 FunctionLibraryDefinition::FunctionLibraryDefinition(
1194 const OpRegistryInterface* default_registry,
1195 const FunctionDefLibrary& def_lib)
1196 : default_registry_(default_registry),
1197 function_defs_(def_lib.function_size()) {
1198 for (const auto& fdef : def_lib.function()) {
1199 // The latter function definition wins.
1200 auto& ptr = function_defs_[fdef.signature().name()];
1201 ptr.reset(new FunctionDefAndOpRegistration(fdef));
1202 }
1203 for (const auto& grad : def_lib.gradient()) {
1204 func_grad_[grad.function_name()] = grad.gradient_func();
1205 }
1206 }
1207
~FunctionLibraryDefinition()1208 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1209
Contains(const string & func) const1210 bool FunctionLibraryDefinition::Contains(const string& func) const {
1211 tf_shared_lock l(mu_);
1212 return function_defs_.find(func) != function_defs_.end();
1213 }
1214
Find(const string & func) const1215 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1216 tf_shared_lock l(mu_);
1217 auto result = FindHelper(func);
1218 if (result) {
1219 return &result->fdef;
1220 } else {
1221 return nullptr;
1222 }
1223 }
1224
1225 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1226 FunctionLibraryDefinition::FindHelper(const string& func) const {
1227 auto iter = function_defs_.find(func);
1228 if (iter == function_defs_.end()) {
1229 return nullptr;
1230 } else {
1231 return iter->second;
1232 }
1233 }
1234
AddFunctionDef(const FunctionDef & fdef,const StackTracesMap & stack_traces)1235 Status FunctionLibraryDefinition::AddFunctionDef(
1236 const FunctionDef& fdef, const StackTracesMap& stack_traces) {
1237 mutex_lock l(mu_);
1238 bool added;
1239 return AddFunctionDefHelper(fdef, stack_traces, &added);
1240 }
1241
AddFunctionDefHelper(const FunctionDef & fdef,const StackTracesMap & stack_traces,bool * added)1242 Status FunctionLibraryDefinition::AddFunctionDefHelper(
1243 const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) {
1244 *added = false;
1245 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1246 function_defs_[fdef.signature().name()];
1247 if (entry) {
1248 if (!FunctionDefsEqual(entry->fdef, fdef)) {
1249 return errors::InvalidArgument(
1250 "Cannot add function '", fdef.signature().name(),
1251 "' because a different function with the same name already "
1252 "exists.");
1253 }
1254 // Ignore duplicate FunctionDefs.
1255 return Status::OK();
1256 }
1257 const OpDef* op_def;
1258 if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1259 return errors::InvalidArgument(
1260 "Cannot add function '", fdef.signature().name(),
1261 "' because an op with the same name already exists.");
1262 }
1263 entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces);
1264 *added = true;
1265 return Status::OK();
1266 }
1267
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1268 Status FunctionLibraryDefinition::AddHelper(
1269 std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1270 *added = false;
1271 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1272 function_defs_[registration->fdef.signature().name()];
1273 if (entry) {
1274 if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1275 return errors::InvalidArgument(
1276 "Cannot add function '", registration->fdef.signature().name(),
1277 "' because a different function with the same name already "
1278 "exists.");
1279 }
1280 // Ignore duplicate FunctionDefs.
1281 return Status::OK();
1282 }
1283 const OpDef* op_def;
1284 if (default_registry_
1285 ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1286 .ok()) {
1287 return errors::InvalidArgument(
1288 "Cannot add function '", registration->fdef.signature().name(),
1289 "' because an op with the same name already exists.");
1290 }
1291 entry = std::move(registration);
1292 *added = true;
1293 return Status::OK();
1294 }
1295
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1296 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1297 const string& func, const FunctionLibraryDefinition& other) {
1298 if (default_registry_ != other.default_registry_) {
1299 return errors::InvalidArgument(
1300 "Cannot copy function '", func,
1301 "' because CopyFunctionDefFrom() requires that both libraries have the "
1302 "same default registry.");
1303 }
1304 std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1305 {
1306 tf_shared_lock l(other.mu_);
1307 function_def = other.FindHelper(func);
1308 }
1309 if (!function_def) {
1310 return errors::InvalidArgument(
1311 "Cannot copy function '", func,
1312 "' because no function with that name exists in the other library.");
1313 }
1314 {
1315 mutex_lock l(mu_);
1316 std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1317 if (entry) {
1318 if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1319 return errors::InvalidArgument(
1320 "Cannot copy function '", func,
1321 "' because a different function with the same name already "
1322 "exists.");
1323 }
1324 } else {
1325 entry = std::move(function_def);
1326 }
1327 }
1328 return Status::OK();
1329 }
1330
AddGradientDef(const GradientDef & grad)1331 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1332 mutex_lock l(mu_);
1333 bool added;
1334 return AddGradientDefHelper(grad, &added);
1335 }
1336
AddGradientDefHelper(const GradientDef & grad,bool * added)1337 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1338 bool* added) {
1339 *added = false;
1340 string* entry = &func_grad_[grad.function_name()];
1341 if (!entry->empty()) {
1342 if (*entry != grad.gradient_func()) {
1343 return errors::InvalidArgument(
1344 "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1345 grad.function_name(), "' because it already has gradient function ",
1346 "'", *entry, "'");
1347 }
1348 // Ignore duplicate GradientDefs
1349 return Status::OK();
1350 }
1351 *entry = grad.gradient_func();
1352 *added = true;
1353 return Status::OK();
1354 }
1355
AddLibrary(const FunctionLibraryDefinition & other)1356 Status FunctionLibraryDefinition::AddLibrary(
1357 const FunctionLibraryDefinition& other) {
1358 // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1359 // the duration of the function could lead to deadlock).
1360 FunctionLibraryDefinition clone(other);
1361 mutex_lock l(mu_);
1362 mutex_lock l2(clone.mu_);
1363 // Remember the funcs and grads that we added successfully so that
1364 // we can roll them back on error.
1365 std::vector<string> funcs;
1366 std::vector<string> funcs_with_grads;
1367 Status s;
1368 bool added;
1369 for (auto iter : clone.function_defs_) {
1370 s = AddHelper(iter.second, &added);
1371 if (!s.ok()) {
1372 Remove(funcs, funcs_with_grads);
1373 return s;
1374 }
1375 if (added) {
1376 funcs.push_back(iter.second->fdef.signature().name());
1377 }
1378 }
1379 for (auto iter : clone.func_grad_) {
1380 GradientDef grad;
1381 grad.set_function_name(iter.first);
1382 grad.set_gradient_func(iter.second);
1383 s = AddGradientDefHelper(grad, &added);
1384 if (!s.ok()) {
1385 Remove(funcs, funcs_with_grads);
1386 return s;
1387 }
1388 if (added) {
1389 funcs_with_grads.push_back(grad.function_name());
1390 }
1391 }
1392 return Status::OK();
1393 }
1394
AddLibrary(const FunctionDefLibrary & lib_def)1395 Status FunctionLibraryDefinition::AddLibrary(
1396 const FunctionDefLibrary& lib_def) {
1397 // Remember the funcs and grads that we added successfully so that
1398 // we can roll them back on error.
1399 mutex_lock l(mu_);
1400 std::vector<string> funcs;
1401 std::vector<string> funcs_with_grads;
1402 Status s;
1403 bool added;
1404 for (const FunctionDef& fdef : lib_def.function()) {
1405 s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added);
1406 if (!s.ok()) {
1407 Remove(funcs, funcs_with_grads);
1408 return s;
1409 }
1410 if (added) {
1411 funcs.push_back(fdef.signature().name());
1412 }
1413 }
1414 for (const GradientDef& grad : lib_def.gradient()) {
1415 s = AddGradientDefHelper(grad, &added);
1416 if (!s.ok()) {
1417 Remove(funcs, funcs_with_grads);
1418 return s;
1419 }
1420 if (added) {
1421 funcs_with_grads.push_back(grad.function_name());
1422 }
1423 }
1424 return Status::OK();
1425 }
1426
ReplaceFunction(const string & func,const FunctionDef & fdef)1427 Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
1428 const FunctionDef& fdef) {
1429 mutex_lock l(mu_);
1430 bool added;
1431 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1432 TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added));
1433 return Status::OK();
1434 }
1435
ReplaceGradient(const GradientDef & grad)1436 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1437 mutex_lock l(mu_);
1438 bool added;
1439 TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1440 TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1441 return Status::OK();
1442 }
1443
RemoveFunction(const string & func)1444 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1445 mutex_lock l(mu_);
1446 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1447 return Status::OK();
1448 }
1449
RemoveFunctionHelper(const string & func)1450 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1451 const auto& i = function_defs_.find(func);
1452 if (i == function_defs_.end()) {
1453 return errors::InvalidArgument("Tried to remove non-existent function '",
1454 func, "'.");
1455 }
1456 function_defs_.erase(i);
1457 return Status::OK();
1458 }
1459
Clear()1460 void FunctionLibraryDefinition::Clear() {
1461 mutex_lock l(mu_);
1462 function_defs_.clear();
1463 func_grad_.clear();
1464 }
1465
RemoveGradient(const string & func)1466 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1467 const auto& i = func_grad_.find(func);
1468 if (i == func_grad_.end()) {
1469 return errors::InvalidArgument("Tried to remove non-existent gradient '",
1470 func, "'.");
1471 }
1472 func_grad_.erase(i);
1473 return Status::OK();
1474 }
1475
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1476 void FunctionLibraryDefinition::Remove(
1477 const std::vector<string>& funcs,
1478 const std::vector<string>& funcs_with_grads) {
1479 for (const string& f : funcs) {
1480 Status s = RemoveFunctionHelper(f);
1481 DCHECK(s.ok());
1482 }
1483 for (const string& f : funcs_with_grads) {
1484 Status s = RemoveGradient(f);
1485 DCHECK(s.ok());
1486 }
1487 }
1488
FindGradient(const string & func) const1489 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1490 tf_shared_lock l(mu_);
1491 return gtl::FindWithDefault(func_grad_, func, "");
1492 }
1493
FindGradientHelper(const string & func) const1494 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1495 return gtl::FindWithDefault(func_grad_, func, "");
1496 }
1497
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1498 Status FunctionLibraryDefinition::LookUp(
1499 const string& op, const OpRegistrationData** op_reg_data) const {
1500 tf_shared_lock l(mu_);
1501 auto iter = function_defs_.find(op);
1502 if (iter != function_defs_.end()) {
1503 *op_reg_data = &iter->second->op_registration_data;
1504 return Status::OK();
1505 }
1506 return default_registry_->LookUp(op, op_reg_data);
1507 }
1508
UniqueFunctionName(StringPiece prefix) const1509 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1510 tf_shared_lock l(mu_);
1511 int index = 0;
1512 string name = strings::StrCat(prefix, index);
1513 while (function_defs_.find(name) != function_defs_.end()) {
1514 ++index;
1515 name = strings::StrCat(prefix, index);
1516 }
1517 return name;
1518 }
1519
GetAttrImpl(const NodeDef & ndef) const1520 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1521 const NodeDef& ndef) const {
1522 if (ndef.op() != kGradientOp) {
1523 // If 'ndef' calls a function and the function's def has the attr,
1524 // returns it.
1525 return Find(ndef.op());
1526 }
1527
1528 // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1529 // Foo's attributes.
1530 const NameAttrList* forward_func_attrs;
1531 if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1532 return nullptr;
1533 }
1534 const string& func_name = forward_func_attrs->name();
1535 {
1536 tf_shared_lock l(mu_);
1537 const string& grad_name = FindGradientHelper(func_name);
1538 // If 'func' has a user-defined gradient function, uses the grad
1539 // function's attrs to see if noinline is specified. Otherwise,
1540 // uses func's attrs.
1541 if (!grad_name.empty()) {
1542 if (const auto helper = FindHelper(grad_name)) {
1543 return &(helper->fdef);
1544 } else {
1545 return nullptr;
1546 }
1547 }
1548 if (const auto helper = FindHelper(func_name)) {
1549 return &(helper->fdef);
1550 } else {
1551 return nullptr;
1552 }
1553 }
1554 }
1555
ListFunctionNames() const1556 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1557 std::vector<string> function_names;
1558 tf_shared_lock l(mu_);
1559 function_names.reserve(function_defs_.size());
1560 for (const auto& it : function_defs_) {
1561 function_names.emplace_back(it.first);
1562 }
1563 return function_names;
1564 }
1565
ToProto() const1566 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1567 FunctionDefLibrary lib;
1568 tf_shared_lock l(mu_);
1569 for (const auto& f : function_defs_) {
1570 *lib.add_function() = f.second->fdef;
1571 }
1572 for (const auto& g : func_grad_) {
1573 GradientDef* gd = lib.add_gradient();
1574 gd->set_function_name(g.first);
1575 gd->set_gradient_func(g.second);
1576 }
1577 return lib;
1578 }
1579
1580 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1581 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1582 const string& attr, T* value) const {
1583 const FunctionDef* fdef = GetAttrImpl(ndef);
1584 if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1585 return Status::OK();
1586 }
1587 return errors::InvalidArgument("Attr ", attr, " is not defined.");
1588 }
1589
1590 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1591 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1592 T* value) const {
1593 return GetAttr(node.def(), attr, value);
1594 }
1595
1596 #define GET_ATTR(T) \
1597 template Status FunctionLibraryDefinition::GetAttr(const Node&, \
1598 const string&, T*) const; \
1599 template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
1600 const string&, T*) const;
1601 GET_ATTR(string)
1602 GET_ATTR(bool)
1603 #undef GET_ATTR
1604
1605 namespace {
1606
1607 constexpr char kApiImplements[] = "api_implements";
1608
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1609 std::set<string> ReachableFunctions(
1610 const FunctionLibraryDefinition& flib,
1611 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1612 // Functions that are reachable from the graph.
1613 std::set<string> reachable_funcs;
1614
1615 // For any functions, if it has attribute "api_implements" =
1616 // "some_interface" and it is reachable, then it means any other
1617 // function with same attribute name and value could also be potentially
1618 // reachable, eg via implementation_selector swapping the nodedef.
1619 absl::flat_hash_set<string> reachable_api_interface;
1620
1621 // Functions might be reachable from the nested function calls, so we keep a
1622 // queue of functions that we have to check.
1623 gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1624
1625 // Add reachable and not already processed functions to the functions queue.
1626 const auto add_to_func_queue = [&](const string& func_name) {
1627 const FunctionDef* func = flib.Find(func_name);
1628 if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1629 func_queue.push_back(func);
1630 }
1631 };
1632
1633 // If any function with certain API name is reachable, all the other functions
1634 // with same API name should also be checked.
1635 const auto add_function_with_api_interface = [&](const string& api_name) {
1636 if (!reachable_api_interface.contains(api_name)) {
1637 reachable_api_interface.insert(api_name);
1638 for (const auto& func_name : flib.ListFunctionNames()) {
1639 const auto& func_def = flib.Find(func_name);
1640 const auto attr_it = func_def->attr().find(kApiImplements);
1641 if (attr_it != func_def->attr().end() &&
1642 attr_it->second.s() == api_name) {
1643 add_to_func_queue(func_name);
1644 }
1645 }
1646 }
1647 };
1648
1649 // Add all the functions that are reachable from the given node to the queue.
1650 const auto process_node = [&](const NodeDef& node) {
1651 // Node itself can be a call to the function.
1652 add_to_func_queue(node.op());
1653
1654 // Or node can have an attribute referencing a function.
1655 for (const auto& attr : node.attr()) {
1656 const auto& attr_value = attr.second;
1657
1658 // 1. AttrValue.func
1659 if (attr_value.has_func()) {
1660 add_to_func_queue(attr_value.func().name());
1661 }
1662
1663 // 2. AttrValue.ListValue.func
1664 if (attr_value.has_list()) {
1665 for (const auto& func : attr_value.list().func()) {
1666 add_to_func_queue(func.name());
1667 }
1668 }
1669 }
1670 };
1671
1672 // Add all functions that are directly called from the optimized graph.
1673 std::for_each(nodes.begin(), nodes.end(), process_node);
1674
1675 // Process all reachable functions.
1676 while (!func_queue.empty()) {
1677 const FunctionDef* func = func_queue.back();
1678 func_queue.pop_back();
1679
1680 const string& func_name = func->signature().name();
1681 reachable_funcs.insert(func_name);
1682
1683 const auto attr_it = func->attr().find(kApiImplements);
1684 if (attr_it != func->attr().end()) {
1685 add_function_with_api_interface(attr_it->second.s());
1686 }
1687
1688 // Find all the functions called from the function body.
1689 const auto& func_body = func->node_def();
1690 std::for_each(func_body.begin(), func_body.end(), process_node);
1691
1692 // Check if the function has a registered gradient.
1693 const string grad_func_name = flib.FindGradient(func_name);
1694 if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1695 }
1696
1697 return reachable_funcs;
1698 }
1699
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1700 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1701 const FunctionLibraryDefinition& flib,
1702 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1703 std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1704
1705 FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1706 FunctionDefLibrary());
1707
1708 for (const string& func_name : reachable_funcs) {
1709 // This should never fail, because we copy functions from a valid flib and
1710 // use the same default registry.
1711 Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1712 TF_DCHECK_OK(added);
1713
1714 const string grad_func_name = flib.FindGradient(func_name);
1715 if (!grad_func_name.empty()) {
1716 GradientDef grad;
1717 grad.set_function_name(func_name);
1718 grad.set_gradient_func(grad_func_name);
1719 // It can only fail if function already has a gradient function.
1720 const Status added_grad = reachable_flib.AddGradientDef(grad);
1721 TF_DCHECK_OK(added_grad);
1722 }
1723 }
1724
1725 return reachable_flib;
1726 }
1727
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1728 string AllocatorAttributesToString(
1729 const std::vector<AllocatorAttributes>& attrs) {
1730 string result("[");
1731 // AllocatorAttribute::DebugString produces around 85 bytes now.
1732 result.reserve(100 * attrs.size());
1733 for (const AllocatorAttributes& attr : attrs) {
1734 result.append(attr.DebugString());
1735 result.append(", ");
1736 }
1737 if (!attrs.empty()) {
1738 result.resize(result.size() - 2);
1739 }
1740 result.append("]");
1741 return result;
1742 }
1743
IsSet(void * ptr)1744 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1745
1746 } // namespace
1747
ReachableDefinitions(const GraphDef & graph) const1748 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1749 const GraphDef& graph) const {
1750 return ReachableFunctionLibraryDefinition(*this, graph.node());
1751 }
1752
ReachableDefinitions(const FunctionDef & func) const1753 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1754 const FunctionDef& func) const {
1755 return ReachableFunctionLibraryDefinition(*this, func.node_def());
1756 }
1757
DebugString() const1758 string FunctionLibraryRuntime::Options::DebugString() const {
1759 return absl::StrCat(
1760 "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1761 " cancellation_manager=", IsSet(cancellation_manager),
1762 " collective_executor=", IsSet(collective_executor),
1763 " step_container=", IsSet(step_container),
1764 " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1765 " remote_execution=", remote_execution, " source_device=", source_device,
1766 " create_rendezvous=", create_rendezvous,
1767 " allow_dead_tensors=", allow_dead_tensors,
1768 " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1769 " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1770 }
1771
InitFromString(StringPiece val)1772 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1773 if (val.size() >= 2 && val[0] == '$') {
1774 proto.set_placeholder(val.data() + 1, val.size() - 1);
1775 } else {
1776 SetAttrValue(val, &proto);
1777 }
1778 }
1779
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1780 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1781 const string& name,
1782 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1783 AttrValueWrapper ret;
1784 ret.proto.mutable_func()->set_name(name);
1785 for (const auto& a : attrs) {
1786 ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1787 }
1788 return ret;
1789 }
1790
ToNodeDef() const1791 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1792 NodeDef n;
1793 n.set_op(this->op);
1794 n.set_name(this->ret[0]);
1795 for (const auto& a : this->attr) {
1796 n.mutable_attr()->insert({a.first, a.second.proto});
1797 }
1798 for (const string& a : this->arg) {
1799 n.add_input(a);
1800 }
1801 for (const string& d : this->dep) {
1802 n.add_input(strings::StrCat("^", d));
1803 }
1804 if (!this->device.empty()) {
1805 n.set_device(this->device);
1806 }
1807 return n;
1808 }
1809
1810 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1811 FunctionDef FunctionDefHelper::Create(
1812 const string& function_name, gtl::ArraySlice<string> in_def,
1813 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1814 gtl::ArraySlice<Node> node_def,
1815 gtl::ArraySlice<std::pair<string, string>> ret_def,
1816 gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1817 FunctionDef fdef;
1818
1819 // Signature
1820 OpDefBuilder b(function_name);
1821 for (const auto& i : in_def) b.Input(i);
1822 for (const auto& o : out_def) b.Output(o);
1823 for (const auto& a : attr_def) b.Attr(a);
1824 for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1825
1826 OpRegistrationData op_reg_data;
1827 TF_CHECK_OK(b.Finalize(&op_reg_data));
1828 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1829
1830 // Function body
1831 for (const auto& n : node_def) {
1832 *(fdef.add_node_def()) = n.ToNodeDef();
1833 }
1834
1835 // Returns
1836 for (const auto& r : ret_def) {
1837 fdef.mutable_ret()->insert({r.first, r.second});
1838 }
1839
1840 // Control returns
1841 for (const auto& cr : control_ret_def) {
1842 fdef.mutable_control_ret()->insert({cr.first, cr.second});
1843 }
1844
1845 auto* op_def_registry = OpRegistry::Global();
1846 // Check if any op is stateful.
1847 for (const auto& n : node_def) {
1848 const OpDef* op_def = nullptr;
1849 auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1850 // Lookup can fail if e.g. we are calling a function that was not yet
1851 // defined. If it happens, conservatively assume the op is stateful.
1852 if (!status.ok() || op_def->is_stateful()) {
1853 fdef.mutable_signature()->set_is_stateful(true);
1854 }
1855 }
1856
1857 return fdef;
1858 }
1859
1860 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1861 FunctionDef FunctionDefHelper::Create(
1862 const string& function_name, gtl::ArraySlice<string> in_def,
1863 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1864 gtl::ArraySlice<Node> node_def,
1865 gtl::ArraySlice<std::pair<string, string>> ret_def) {
1866 return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1867 /*control_ret_def=*/{});
1868 }
1869
1870 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1871 FunctionDef FunctionDefHelper::Define(const string& name,
1872 gtl::ArraySlice<string> arg_def,
1873 gtl::ArraySlice<string> ret_def,
1874 gtl::ArraySlice<string> attr_def,
1875 gtl::ArraySlice<Node> node_def) {
1876 FunctionDef fdef;
1877 OpDefBuilder b(name);
1878 for (const auto& a : arg_def) b.Input(a);
1879 for (const auto& r : ret_def) b.Output(r);
1880 for (const auto& a : attr_def) b.Attr(a);
1881
1882 OpRegistrationData op_reg_data;
1883 TF_CHECK_OK(b.Finalize(&op_reg_data));
1884 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1885
1886 // Mapping from legacy output names to NodeDef outputs.
1887 std::unordered_map<string, string> ret_index;
1888 for (const auto& a : fdef.signature().input_arg()) {
1889 ret_index[a.name()] = a.name();
1890 }
1891
1892 // For looking up OpDefs
1893 auto* op_def_registry = OpRegistry::Global();
1894
1895 // Function body
1896 for (const auto& src : node_def) {
1897 NodeDef* n = fdef.add_node_def();
1898 n->set_op(src.op);
1899 n->set_name(src.ret[0]);
1900 for (const auto& a : src.attr) {
1901 n->mutable_attr()->insert({a.first, a.second.proto});
1902 }
1903 for (const string& a : src.arg) {
1904 const auto iter = ret_index.find(a);
1905 CHECK(iter != ret_index.end())
1906 << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1907 n->add_input(iter->second);
1908 }
1909 for (const string& d : src.dep) {
1910 n->add_input(strings::StrCat("^", d));
1911 }
1912
1913 // Add the outputs of this node to ret_index.
1914 const OpDef* op_def = nullptr;
1915 TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1916 CHECK(op_def != nullptr) << n->op();
1917 NameRangeMap output_names;
1918 TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1919 for (const auto& o : output_names) {
1920 CHECK_LE(o.second.second, src.ret.size())
1921 << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1922 << "' of " << name;
1923 for (int i = o.second.first; i < o.second.second; ++i) {
1924 ret_index[src.ret[i]] =
1925 strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1926 }
1927 }
1928 if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1929 }
1930
1931 // Returns
1932 for (const auto& r : fdef.signature().output_arg()) {
1933 const auto iter = ret_index.find(r.name());
1934 CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1935 fdef.mutable_ret()->insert({r.name(), iter->second});
1936 }
1937 return fdef;
1938 }
1939
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1940 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1941 gtl::ArraySlice<string> ret_def,
1942 gtl::ArraySlice<string> attr_def,
1943 gtl::ArraySlice<Node> node_def) {
1944 return Define("_", arg_def, ret_def, attr_def, node_def);
1945 }
1946
1947 namespace gradient {
1948
1949 typedef std::unordered_map<string, Creator> OpGradFactory;
1950
GetOpGradFactory()1951 OpGradFactory* GetOpGradFactory() {
1952 static OpGradFactory* factory = new OpGradFactory;
1953 return factory;
1954 }
1955
RegisterOp(const string & op,Creator func)1956 bool RegisterOp(const string& op, Creator func) {
1957 CHECK(GetOpGradFactory()->insert({op, func}).second)
1958 << "Duplicated gradient for " << op;
1959 return true;
1960 }
1961
GetOpGradientCreator(const string & op,Creator * creator)1962 Status GetOpGradientCreator(const string& op, Creator* creator) {
1963 auto fac = GetOpGradFactory();
1964 auto iter = fac->find(op);
1965 if (iter == fac->end()) {
1966 return errors::NotFound("No gradient defined for op: ", op);
1967 }
1968 *creator = iter->second;
1969 return Status::OK();
1970 }
1971
1972 } // end namespace gradient
1973
1974 } // namespace tensorflow
1975