1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17
18 #include <ctype.h>
19
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/lib/strings/proto_serialization.h"
41 #include "tensorflow/core/util/device_name_utils.h"
42 #include "tensorflow/core/util/equal_graph_def.h"
43
44 namespace tensorflow {
45
46 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
47 /* static */ constexpr const char* const
48 FunctionLibraryDefinition::kDeviceArgOp;
49 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
50 /* static */ constexpr const char* const
51 FunctionLibraryDefinition::kDeviceRetOp;
52 /* static */ constexpr const char* const
53 FunctionLibraryDefinition::kIntsOnDeviceAttr;
54 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
55 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
56
57 // Extracts the actual type from "attr_values" based on its definition
58 // "arg_def".
59 //
60 // If "arg_def" is a N*T type, *is_type_list is set to false, and
61 // *dtypes is set to be a vector of size N and each element is T.
62 //
63 // If "arg_def" is a list(type), *is_type_list is set to true, and
64 // *dtypes is set to be a vector of types specified in attrs for
65 // arg_def.
66 //
67 // Otherwise (arg_def is a simple type T), *is_type_list is set to
68 // false, and *dtypes is set to a single element vector, whose only
69 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)70 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
71 bool* is_type_list, DataTypeVector* dtypes) {
72 dtypes->clear();
73 if (!arg_def.type_list_attr().empty()) {
74 const AttrValue* v = attrs.Find(arg_def.type_list_attr());
75 if (v == nullptr) {
76 return errors::NotFound("type attr not found: ",
77 arg_def.type_list_attr());
78 }
79 *is_type_list = true;
80 for (int i = 0; i < v->list().type_size(); ++i) {
81 dtypes->push_back(v->list().type(i));
82 }
83 return Status::OK();
84 }
85
86 *is_type_list = false;
87 int num = 1;
88 if (!arg_def.number_attr().empty()) {
89 const AttrValue* v = attrs.Find(arg_def.number_attr());
90 if (v == nullptr) {
91 return errors::NotFound("type attr not found: ", arg_def.type_attr());
92 }
93 num = v->i();
94 }
95
96 DataType dtype;
97 if (arg_def.type() != DT_INVALID) {
98 dtype = arg_def.type();
99 } else if (arg_def.type_attr().empty()) {
100 dtype = DT_INVALID;
101 } else {
102 const AttrValue* v = attrs.Find(arg_def.type_attr());
103 if (v == nullptr) {
104 return errors::NotFound("type attr not found: ", arg_def.type_attr());
105 }
106 dtype = v->type();
107 }
108 dtypes->resize(num, dtype);
109 return Status::OK();
110 }
111
112 namespace {
113
114 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)115 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
116 SetAttrValue(val, &((*ndef->mutable_attr())[name]));
117 }
118
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)119 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
120 // attr_values should specify all attrs defined in fdef.
121 for (const auto& a : sig.attr()) {
122 const AttrValue* v = attr_values.Find(a.name());
123 if (!v) {
124 return errors::NotFound("Attr ", a.name(), " is not found from ",
125 SummarizeOpDef(sig));
126 }
127 Status status = AttrValueHasType(*v, a.type());
128 if (!status.ok()) {
129 errors::AppendToMessage(&status, "for attr '", a.name(), "'");
130 return status;
131 }
132 }
133
134 // TODO(josh11b): Enable this code once it works with function gradients.
135 // Right now the C++ function gradient code assumes it can pass
136 // all the attrs of the function to the gradient, and any attrs that
137 // the gradient doesn't care about will be ignored.
138 #if 0
139 if (attr_values.size() != sig.attr_size()) {
140 for (const auto& a : attr_values) {
141 // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
142 bool found = false;
143 for (const auto& s : sig.attr()) {
144 if (a.first == s.name()) {
145 found = true;
146 break;
147 }
148 }
149 if (!found) {
150 return errors::NotFound("Attr ", a.first, " is not found in ",
151 SummarizeOpDef(sig));
152 }
153 }
154 }
155 #endif
156
157 return Status::OK();
158 }
159
160 // A helper class for instantiating functions. This contains shared information
161 // like the resulting graph and node name index.
162 class FunctionInstantiationHelper {
163 public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)164 FunctionInstantiationHelper(GetFunctionSignature get_function,
165 InstantiationResult* result)
166 : get_function_(std ::move(get_function)), result_(*result) {
167 result_.nodes.clear();
168 }
169
170 // Builds index for nodes that can be used as node's input arguments.
171 // `resource_arg_unique_id`: if non-negative, will be populated to the
172 // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64_t resource_arg_unique_id)173 Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
174 const FunctionDef::ArgAttrs* arg_attrs,
175 bool ints_on_device,
176 int64_t resource_arg_unique_id) {
177 bool is_type_list;
178 DataTypeVector dtypes;
179 TF_RETURN_IF_ERROR(
180 ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
181 CHECK_GE(dtypes.size(), size_t{1});
182 int arg_index = result_.nodes.size();
183 TF_RETURN_IF_ERROR(
184 AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
185 // Creates dtypes.size() nodes in the graph.
186 for (size_t i = 0; i < dtypes.size(); ++i) {
187 TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
188 {true, arg_index, 0, false, {dtypes[i]}}));
189 DCHECK_EQ(arg_index, result_.nodes.size());
190 string name = arg_def.name();
191 if (dtypes.size() > 1) {
192 strings::StrAppend(&name, "_", i);
193 }
194 NodeDef* gnode = AddNode(name);
195 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
196 gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
197 } else {
198 gnode->set_op(FunctionLibraryDefinition::kArgOp);
199 }
200 DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
201 AddAttr("T", dtype, gnode);
202 AddAttr("index", arg_index, gnode);
203 if (resource_arg_unique_id >= 0) {
204 AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
205 }
206 if (arg_attrs) {
207 for (const auto& arg_attr : arg_attrs->attr()) {
208 AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
209 }
210 }
211 result_.arg_types.push_back(dtypes[i]);
212 ++arg_index;
213 }
214 return Status::OK();
215 }
216
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)217 Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
218 const int arg_index) {
219 const OpDef* node_sig = nullptr;
220 TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
221 if (node_sig->output_arg_size() == 0) {
222 return AddItem(node.name(), {false, arg_index, 0, false, {}});
223 }
224 const int num_retval = node_sig->output_arg_size();
225 int start = 0;
226 bool is_type_list;
227 DataTypeVector dtypes;
228 for (int i = 0; i < num_retval; ++i) {
229 TF_RETURN_IF_ERROR(
230 ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
231 // Note that we rely on the backwards-compatibility test enforcing
232 // that output_arg(*).name() doesn't change here.
233 const string base_name =
234 strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
235 TF_RETURN_IF_ERROR(
236 AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
237 for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
238 TF_RETURN_IF_ERROR(
239 AddItem(strings::StrCat(base_name, ":", j),
240 {false, arg_index, start + j, false, {dtypes[j]}}));
241 }
242 start += dtypes.size();
243 }
244 return Status::OK();
245 }
246
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)247 Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
248 const OpDef* fnode_sig = nullptr;
249 TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
250 NodeDef* gnode = AddNode(fnode.name());
251 gnode->set_op(fnode.op());
252 gnode->set_device(fnode.device());
253 int gnode_idx = nodes_.size() - 1;
254
255 // Input
256 const int num_args = fnode_sig->input_arg_size();
257 bool is_type_list; // ignored
258 DataTypeVector dtypes;
259 int fnode_arg_index = 0;
260 for (int i = 0; i < num_args; ++i) {
261 TF_RETURN_IF_ERROR(
262 ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
263 // Consume inputs (indexed by fnode_arg_index) until we have
264 // matched each element of dtypes (indexed by j).
265 for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
266 if (fnode_arg_index >= fnode.input_size()) {
267 // Should never happen if we computed dtypes correctly.
268 return errors::InvalidArgument(
269 "Attempt to access beyond input size: ", fnode_arg_index,
270 " >= ", fnode.input_size());
271 }
272 // Look up the next input.
273 const string& input_name = fnode.input(fnode_arg_index);
274 const auto* item = GetItemOrNull(input_name);
275 if (item == nullptr) {
276 return errors::InvalidArgument(
277 "input ", input_name,
278 " is not found: ", FormatNodeDefForError(fnode));
279 }
280 if (item->dtypes.size() > dtypes.size() - j) {
281 return errors::InvalidArgument("Input ", input_name, " too long for ",
282 fnode_sig->input_arg(i).name());
283 }
284 // Match up all the elements of this input (indexed by k) with
285 // elements of dtypes (advancing j).
286 for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
287 if (item->dtypes[k] != dtypes[j]) {
288 return errors::InvalidArgument(
289 "input ", fnode_sig->input_arg(i).name(), "[", j,
290 "] expected type ", DataTypeString(dtypes[j]),
291 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
292 input_name, "[", k, "]");
293 }
294 if (item->is_func_arg) {
295 AddInput(gnode_idx, item->nid + k, 0);
296 } else {
297 AddInput(gnode_idx, item->nid, item->idx + k);
298 }
299 }
300 }
301 }
302
303 // Control deps.
304 for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
305 const string& input = fnode.input(i);
306 if (input.empty() || input[0] != '^') {
307 return errors::InvalidArgument("Expected input[", i, "] == '", input,
308 "' to be a control input.");
309 }
310 int nid = -1;
311 const string node_name = input.substr(1);
312 const string node_colon = node_name + ":";
313 const string node_colon_bound = node_name + ";";
314 // index_ is a map sorted lexicographically, so the key we are looking for
315 // must lie in the range [node_name, node_colon_bound).
316 auto it = index_.lower_bound(node_name);
317 while (it != index_.end() && it->first <= node_colon_bound) {
318 if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
319 nid = it->second.nid;
320 break;
321 }
322 ++it;
323 }
324 if (nid == -1) {
325 return errors::InvalidArgument("input[", i, "] == '", input,
326 "', is not found.");
327 }
328 AddDep(gnode_idx, nid);
329 }
330
331 // Attrs.
332 for (const auto& p : attrs) {
333 (*gnode->mutable_attr())[p.first] = p.second;
334 }
335
336 return Status::OK();
337 }
338
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)339 Status AddReturnNode(
340 const OpDef::ArgDef& ret_def, AttrSlice attrs,
341 const ::tensorflow::protobuf::Map<string, string>& ret_map,
342 bool ints_on_device, int* ret_index) {
343 auto ret_iter = ret_map.find(ret_def.name());
344 if (ret_iter == ret_map.end()) {
345 return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
346 }
347 bool is_type_list;
348 DataTypeVector dtypes;
349 TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
350 CHECK_GE(dtypes.size(), size_t{1});
351 const auto* item = GetItemOrNull(ret_iter->second);
352 if (item == nullptr) {
353 return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
354 ret_iter->second, " is not found.");
355 }
356 if (dtypes != item->dtypes) {
357 return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
358 " : ", DataTypeVectorString(dtypes),
359 " vs. ",
360 DataTypeVectorString(item->dtypes));
361 }
362 for (size_t i = 0; i < dtypes.size(); ++i) {
363 string name = strings::StrCat(ret_def.name(), "_RetVal");
364 if (dtypes.size() > 1) {
365 strings::StrAppend(&name, "_", i);
366 }
367 NodeDef* gnode = AddNode(name);
368 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
369 gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
370 } else {
371 gnode->set_op(FunctionLibraryDefinition::kRetOp);
372 }
373 AddInput(nodes_.size() - 1, item->nid, item->idx + i);
374 DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
375 AddAttr("T", dtype, gnode);
376 AddAttr("index", (*ret_index)++, gnode);
377 result_.ret_types.push_back(dtypes[i]);
378 }
379 return Status::OK();
380 }
381
382 // Adds the actual node inputs to the result graph by converting indexes to
383 // the node names.
AddNodeInputs()384 void AddNodeInputs() {
385 for (int i = 0; i < result_.nodes.size(); i++) {
386 NodeInfo& node_info = nodes_[i];
387 for (const auto& p : node_info.data_inputs) {
388 result_.nodes[i].add_input(Name(p.first, p.second));
389 }
390 for (int index : node_info.control_inputs) {
391 result_.nodes[i].add_input(Dep(index));
392 }
393 }
394 }
395
396 private:
397 // This is used to build a small index for all names that can be used as a
398 // node's input arguments.
399 //
400 // If is_func_arg is true, the name is a function's argument. In
401 // this case, the produced graph def has node[nid:nid + dtype.size()].
402 //
403 // Otherwise, the name is a function body's node return value. In
404 // this case, the produced graph def has one node node[nid] and
405 // the node's output index [idx ... idx + num) corresponds to the
406 // named outputs.
407 //
408 // In all cases, "dtype" specifies the data type.
409 struct NameInfoItem {
410 bool is_func_arg;
411 int nid;
412 int idx;
413 bool is_type_list;
414 DataTypeVector dtypes;
415 };
416
417 // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)418 Status AddItem(const string& name, const NameInfoItem& item) {
419 if (!index_.insert({name, item}).second) {
420 return errors::InvalidArgument(
421 strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
422 " name: "),
423 name);
424 }
425 return Status::OK();
426 }
427
GetItemOrNull(const string & name) const428 const NameInfoItem* GetItemOrNull(const string& name) const {
429 return gtl::FindOrNull(index_, name);
430 }
431
Dep(int node_index) const432 string Dep(int node_index) const {
433 return strings::StrCat("^", Name(node_index));
434 }
435
Name(int node_index) const436 string Name(int node_index) const {
437 CHECK_LT(node_index, nodes_.size());
438 return nodes_[node_index].name;
439 }
440
Name(int node_index,int output_index) const441 string Name(int node_index, int output_index) const {
442 if (output_index == 0) {
443 return Name(node_index);
444 } else {
445 return strings::StrCat(Name(node_index), ":", output_index);
446 }
447 }
448
AddNode(const string & name)449 NodeDef* AddNode(const string& name) {
450 result_.nodes.emplace_back();
451 NodeDef* gnode = &result_.nodes.back();
452 gnode->set_name(name);
453 nodes_.push_back({name, {}, {}});
454 CHECK_EQ(result_.nodes.size(), nodes_.size());
455 return gnode;
456 }
457
AddInput(int node_index,int output_node,int output_index)458 void AddInput(int node_index, int output_node, int output_index) {
459 CHECK_LT(node_index, nodes_.size());
460 nodes_[node_index].data_inputs.push_back(
461 std::make_pair(output_node, output_index));
462 }
463
AddDep(int node_index,int dep_index)464 void AddDep(int node_index, int dep_index) {
465 CHECK_LT(node_index, nodes_.size());
466 nodes_[node_index].control_inputs.push_back(dep_index);
467 }
468
469 GetFunctionSignature get_function_;
470 InstantiationResult& result_;
471 // A small index for all names that can be used as a node's input arguments.
472 std::map<string, NameInfoItem> index_;
473 // This contains information about a node in the new graph including the node
474 // names and input nodes' indexes.
475 struct NodeInfo {
476 string name;
477 // Data inputs where <n, k> means arg k of node n.
478 std::vector<std::pair<int, int>> data_inputs;
479 // Control inputs (dependencies).
480 std::vector<int> control_inputs;
481 };
482 // nodes_[i] is the information about result_.nodes[i].
483 std::vector<NodeInfo> nodes_;
484 };
485
486 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)487 string Print(const OpDef::ArgDef& arg) {
488 string out;
489 strings::StrAppend(&out, arg.name(), ":");
490 if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
491 if (!arg.number_attr().empty()) {
492 strings::StrAppend(&out, arg.number_attr(), "*");
493 }
494 if (arg.type() != DT_INVALID) {
495 strings::StrAppend(&out, DataTypeString(arg.type()));
496 } else {
497 strings::StrAppend(&out, arg.type_attr());
498 }
499 if (arg.is_ref()) strings::StrAppend(&out, ")");
500 return out;
501 }
502
503 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)504 string Print(const AttrValue& attr_value) {
505 if (attr_value.value_case() == AttrValue::kType) {
506 return DataTypeString(attr_value.type());
507 } else if ((attr_value.value_case() == AttrValue::kList) &&
508 (attr_value.list().type_size() > 0)) {
509 string ret = "{";
510 for (int i = 0; i < attr_value.list().type_size(); ++i) {
511 if (i > 0) strings::StrAppend(&ret, ", ");
512 strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
513 }
514 strings::StrAppend(&ret, "}");
515 return ret;
516 } else if (attr_value.value_case() == AttrValue::kFunc) {
517 if (attr_value.func().attr_size() == 0) {
518 return attr_value.func().name();
519 }
520 std::vector<string> entries;
521 for (const auto& p : attr_value.func().attr()) {
522 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
523 }
524 std::sort(entries.begin(), entries.end());
525 return strings::StrCat(attr_value.func().name(), "[",
526 absl::StrJoin(entries, ", "), "]");
527 }
528 return SummarizeAttrValue(attr_value);
529 }
530
531 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)532 string Print(const NodeDef& n) {
533 string out;
534 strings::StrAppend(&out, n.name(), " = ", n.op());
535 if (n.attr_size() > 0) {
536 std::vector<string> entries;
537 for (auto& a : n.attr()) {
538 entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
539 }
540 std::sort(entries.begin(), entries.end());
541 // Add a short device string at the end of all attributes.
542 if (!n.device().empty()) {
543 DeviceNameUtils::ParsedName parsed;
544 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
545 entries.push_back(
546 strings::StrCat("device=", parsed.type, ":", parsed.id));
547 } else {
548 entries.push_back("device=<FAILED_TO_PARSE>");
549 }
550 }
551 strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
552 }
553 strings::StrAppend(&out, "(");
554 std::vector<StringPiece> dat;
555 std::vector<string> dep;
556 for (StringPiece s : n.input()) {
557 if (absl::ConsumePrefix(&s, "^")) {
558 dep.emplace_back(s);
559 } else {
560 dat.push_back(s);
561 }
562 }
563 strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
564 if (!dep.empty()) {
565 strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
566 }
567 return out;
568 }
569
Print(const FunctionDef & fdef)570 string Print(const FunctionDef& fdef) {
571 string out;
572 const OpDef& sig = fdef.signature();
573 strings::StrAppend(&out, "\n", sig.name());
574 if (sig.attr_size() > 0) {
575 strings::StrAppend(&out, "[");
576 for (int i = 0; i < sig.attr_size(); ++i) {
577 const auto& a = sig.attr(i);
578 if (i > 0) strings::StrAppend(&out, ", ");
579 if (a.type() == "type") {
580 strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
581 } else {
582 strings::StrAppend(&out, a.name(), ":", a.type());
583 }
584 }
585 strings::StrAppend(&out, "]");
586 }
587 strings::StrAppend(&out, "(");
588 for (int i = 0; i < sig.input_arg_size(); ++i) {
589 if (i > 0) strings::StrAppend(&out, ", ");
590 strings::StrAppend(&out, Print(sig.input_arg(i)));
591 }
592 strings::StrAppend(&out, ") -> (");
593 for (int i = 0; i < sig.output_arg_size(); ++i) {
594 if (i > 0) strings::StrAppend(&out, ", ");
595 strings::StrAppend(&out, Print(sig.output_arg(i)));
596 }
597 strings::StrAppend(&out, ") {\n");
598 for (const auto& n : fdef.node_def()) {
599 strings::StrAppend(&out, " ", Print(n), "\n");
600 }
601 for (const auto& cr : fdef.control_ret()) {
602 strings::StrAppend(&out, " @return ", cr.first, " = ", cr.second, "\n");
603 }
604 for (const auto& r : fdef.ret()) {
605 strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
606 }
607 strings::StrAppend(&out, "}\n");
608 return out;
609 }
610
Print(gtl::ArraySlice<const NodeDef * > nodes)611 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
612 std::vector<const NodeDef*> arg;
613 std::vector<const NodeDef*> ret;
614 std::vector<const NodeDef*> body;
615 for (const NodeDef* n : nodes) {
616 if (n->op() == FunctionLibraryDefinition::kArgOp ||
617 n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
618 arg.push_back(n);
619 } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
620 n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
621 ret.push_back(n);
622 } else {
623 body.push_back(n);
624 }
625 }
626 auto comp = [](const NodeDef* x, const NodeDef* y) {
627 int xi;
628 TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
629 int yi;
630 TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
631 return xi < yi;
632 };
633 std::sort(arg.begin(), arg.end(), comp);
634 std::sort(ret.begin(), ret.end(), comp);
635 string out;
636 strings::StrAppend(&out, "\n(");
637 auto get_type_and_device = [](const NodeDef& n) {
638 DataType dt;
639 if (!TryGetNodeAttr(n, "T", &dt)) {
640 dt = DT_INVALID;
641 }
642 if (!n.device().empty()) {
643 DeviceNameUtils::ParsedName parsed;
644 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
645 return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
646 parsed.id);
647 } else {
648 LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
649 << n.op() << ":" << n.name();
650 return strings::StrCat(DataTypeString(dt), "@",
651 "<FAILED_TO_PARSE_DEVICE>");
652 }
653 }
654 return DataTypeString(dt);
655 };
656 for (size_t i = 0; i < arg.size(); ++i) {
657 const NodeDef* n = arg[i];
658 if (i > 0) strings::StrAppend(&out, ", ");
659 CHECK_GE(n->attr_size(), 2);
660 strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
661 }
662 strings::StrAppend(&out, ") -> (");
663 for (size_t i = 0; i < ret.size(); ++i) {
664 const NodeDef* n = ret[i];
665 if (i > 0) strings::StrAppend(&out, ", ");
666 CHECK_LE(2, n->attr_size());
667
668 // The _RetVal op should have a unique non-control input. We assert that
669 // here and add it to the output.
670 bool found_non_control_input = false;
671 for (const string& input : n->input()) {
672 if (!input.empty() && input[0] != '^') {
673 DCHECK_EQ(found_non_control_input, false)
674 << "RetVal node has more than one non-control input: "
675 << absl::StrJoin(n->input(), ", ");
676 strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
677 found_non_control_input = true;
678 }
679 }
680 DCHECK_EQ(found_non_control_input, true)
681 << "RetVal did not have any non-control inputs: "
682 << absl::StrJoin(n->input(), ", ");
683 }
684 strings::StrAppend(&out, ") {\n");
685 for (size_t i = 0; i < body.size(); ++i) {
686 strings::StrAppend(&out, " ", Print(*body[i]), "\n");
687 }
688 strings::StrAppend(&out, "}\n");
689 return out;
690 }
691
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)692 Status AddDefaultAttrs(const string& op,
693 const GetFunctionSignature& get_function,
694 AttrValueMap* attrs) {
695 const OpDef* op_def = nullptr;
696 TF_RETURN_IF_ERROR(get_function(op, &op_def));
697 AttrSlice attr_slice(attrs);
698 for (const auto& attr_def : op_def->attr()) {
699 if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
700 if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
701 return errors::Internal("Somehow duplicated: ", attr_def.name());
702 }
703 }
704 }
705 return Status::OK();
706 }
707
708 } // end namespace
709
710 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)711 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
712 GetFunctionSignature get_function,
713 InstantiationResult* result) {
714 if (VLOG_IS_ON(5)) {
715 const auto& signature = fdef.signature();
716 VLOG(5) << "Instantiate function definition: name=" << signature.name()
717 << " #input_args=" << signature.input_arg_size()
718 << " #output_args=" << signature.output_arg_size()
719 << " #control_output=" << signature.control_output_size();
720 for (const auto& line : str_util::Split(Print(fdef), '\n')) {
721 VLOG(5) << "|| " << line;
722 }
723 }
724
725 const OpDef& sig = fdef.signature();
726 TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
727
728 bool ints_on_device =
729 fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
730 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
731
732 FunctionInstantiationHelper helper(get_function, result);
733 Status s;
734 for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
735 const OpDef::ArgDef& arg_def = sig.input_arg(i);
736 auto it = fdef.arg_attr().find(i);
737 const FunctionDef::ArgAttrs* arg_attrs =
738 it != fdef.arg_attr().end() ? &it->second : nullptr;
739 auto resource_id_it = fdef.resource_arg_unique_id().find(i);
740 int64_t resource_arg_unique_id =
741 resource_id_it != fdef.resource_arg_unique_id().end()
742 ? resource_id_it->second
743 : -1LL;
744 s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
745 ints_on_device, resource_arg_unique_id);
746
747 if (!s.ok()) {
748 errors::AppendToMessage(&s, "In ", Print(arg_def));
749 return s;
750 }
751 }
752
753 auto substitute = [attr_values](StringPiece name, AttrValue* val) {
754 if (const AttrValue* v = attr_values.Find(name)) {
755 *val = *v;
756 return true;
757 }
758 return false;
759 };
760
761 // Makes a copy of all attrs in fdef and substitutes placeholders.
762 // After this step, every attr is bound to a concrete value.
763 std::vector<AttrValueMap> node_attrs;
764 node_attrs.resize(fdef.node_def_size());
765 for (int i = 0; i < fdef.node_def_size(); ++i) {
766 for (auto attr : fdef.node_def(i).attr()) {
767 if (!SubstitutePlaceholders(substitute, &attr.second)) {
768 return errors::InvalidArgument("Failed to bind all placeholders in ",
769 SummarizeAttrValue(attr.second));
770 }
771 if (!node_attrs[i].insert(attr).second) {
772 return errors::Internal("Somehow duplicated: ", attr.first);
773 }
774 }
775 TF_RETURN_IF_ERROR(
776 AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
777 }
778
779 for (int i = 0; i < fdef.node_def_size(); ++i) {
780 s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
781 result->nodes.size() + i);
782 if (!s.ok()) {
783 errors::AppendToMessage(&s, "In ",
784 FormatNodeDefForError(fdef.node_def(i)));
785 return s;
786 }
787 }
788 // Emits one node for each fdef.node_def.
789 for (int i = 0; i < fdef.node_def_size(); ++i) {
790 s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
791 if (!s.ok()) {
792 errors::AppendToMessage(&s, "In ",
793 FormatNodeDefForError(fdef.node_def(i)));
794 return s;
795 }
796 }
797
798 // Emits nodes for the function's return values.
799 int ret_index = 0;
800 for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
801 s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
802 &ret_index);
803 if (!s.ok()) {
804 errors::AppendToMessage(&s, "In function output ", Print(ret_def));
805 return s;
806 }
807 }
808
809 // Adds the actual node inputs using the input indexes.
810 helper.AddNodeInputs();
811
812 return Status::OK();
813 }
814
DebugString(const FunctionDef & func_def)815 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
816
DebugString(const GraphDef & instantiated_func_def)817 string DebugString(const GraphDef& instantiated_func_def) {
818 std::vector<const NodeDef*> ptrs;
819 for (const NodeDef& n : instantiated_func_def.node()) {
820 ptrs.push_back(&n);
821 }
822 return Print(ptrs);
823 }
824
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)825 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
826 std::vector<const NodeDef*> ptrs;
827 for (const NodeDef& n : instantiated_func_nodes) {
828 ptrs.push_back(&n);
829 }
830 return Print(ptrs);
831 }
832
DebugStringWhole(const GraphDef & gdef)833 string DebugStringWhole(const GraphDef& gdef) {
834 string ret;
835 for (const auto& fdef : gdef.library().function()) {
836 strings::StrAppend(&ret, Print(fdef));
837 }
838 strings::StrAppend(&ret, "\n");
839 for (const auto& ndef : gdef.node()) {
840 strings::StrAppend(&ret, Print(ndef), "\n");
841 }
842 return ret;
843 }
844
845 namespace {
846
847 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
848 // Python, it's possible to access unset attrs, which returns a default value
849 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)850 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
851 std::map<string, AttrValue> set_attrs;
852 for (const auto& pair : fdef.attr()) {
853 if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
854 set_attrs[pair.first] = pair.second;
855 }
856 }
857 return set_attrs;
858 }
859
860 } // end namespace
861
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)862 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
863 if (!OpDefEqual(f1.signature(), f2.signature())) return false;
864
865 std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
866 std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
867 if (f1_attrs.size() != f2_attrs.size()) return false;
868 for (const auto& iter1 : f1_attrs) {
869 auto iter2 = f2_attrs.find(iter1.first);
870 if (iter2 == f2_attrs.end()) return false;
871 if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
872 }
873
874 if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
875 return false;
876 }
877
878 std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
879 std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
880 if (ret1 != ret2) return false;
881
882 std::map<string, string> control_ret1(f1.control_ret().begin(),
883 f1.control_ret().end());
884 std::map<string, string> control_ret2(f2.control_ret().begin(),
885 f2.control_ret().end());
886 if (control_ret1 != control_ret2) return false;
887
888 return true;
889 }
890
FunctionDefHash(const FunctionDef & fdef)891 uint64 FunctionDefHash(const FunctionDef& fdef) {
892 // signature
893 uint64 h = OpDefHash(fdef.signature());
894
895 // attrs
896 std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
897 for (const auto& p : attrs) {
898 h = Hash64(p.first.data(), p.first.size(), h);
899 h = Hash64Combine(AttrValueHash(p.second), h);
900 }
901
902 // node defs
903 h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
904
905 // output names
906 std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
907 for (const auto& p : ret) {
908 h = Hash64(p.first.data(), p.first.size(), h);
909 h = Hash64(p.second.data(), p.second.size(), h);
910 }
911
912 // control output names
913 std::map<string, string> control_ret(fdef.control_ret().begin(),
914 fdef.control_ret().end());
915 for (const auto& p : control_ret) {
916 h = Hash64(p.first.data(), p.first.size(), h);
917 h = Hash64(p.second.data(), p.second.size(), h);
918 }
919
920 return h;
921 }
922
923 static constexpr const char* const kExecutorAttr = "_executor";
924
925 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)926 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
927 AttrSlice attrs) {
928 if (!options.executor_type.empty()) {
929 return options.executor_type;
930 } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
931 return executor_attr->s();
932 } else {
933 return string();
934 }
935 }
936
937 namespace {
938 class AttrKeyAndValue {
939 public:
940 enum ValueRepresentationOp {
941 kRaw,
942 kCEscape,
943 };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)944 AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
945 ValueRepresentationOp value_op = kRaw)
946 : key_name_(key_name),
947 key_suffix_(key_suffix),
948 value_op_(value_op),
949 value_(std::move(value)) {}
950
operator <(const AttrKeyAndValue & b) const951 bool operator<(const AttrKeyAndValue& b) const {
952 if (key_name_ != b.key_name_) {
953 return key_name_ < b.key_name_;
954 } else if (key_suffix_ != b.key_suffix_) {
955 return key_suffix_ < b.key_suffix_;
956 } else {
957 return value_ < b.value_;
958 }
959 }
960
AppendTo(bool first,string * s) const961 void AppendTo(bool first, string* s) const {
962 absl::string_view v;
963 bool add_escaped = false;
964 if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
965 // Use CEscape call below
966 add_escaped = true;
967 } else {
968 // Add raw value contents directly
969 v = value_;
970 }
971 if (key_suffix_ >= 0) {
972 strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
973 } else {
974 strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
975 }
976 if (add_escaped) {
977 strings::StrAppend(s, absl::CEscape(value_));
978 }
979 }
980
981 private:
NeedsEscaping(const string & s)982 static bool NeedsEscaping(const string& s) {
983 for (auto c : s) {
984 if (!isalnum(c) && (c != ' ')) {
985 return true;
986 }
987 }
988 return false;
989 }
990
991 absl::string_view key_name_;
992 int key_suffix_; // -1 if missing
993 ValueRepresentationOp value_op_;
994 string value_;
995 };
996 } // namespace
997
GetFunctionResourceInputDevice(const Tensor & input,const int arg_index,const FunctionDef & function_def,absl::flat_hash_map<string,std::vector<string>> * composite_devices)998 string GetFunctionResourceInputDevice(
999 const Tensor& input, const int arg_index, const FunctionDef& function_def,
1000 absl::flat_hash_map<string, std::vector<string>>* composite_devices) {
1001 const auto& handles = input.flat<ResourceHandle>();
1002 const ResourceHandle& handle0 = handles(0);
1003 string composite_device;
1004 auto iter = function_def.arg_attr().find(arg_index);
1005 if (iter != function_def.arg_attr().end()) {
1006 auto arg_attr = iter->second.attr().find("_composite_device");
1007 if (arg_attr != iter->second.attr().end()) {
1008 composite_device = arg_attr->second.s();
1009 }
1010 }
1011 if (!composite_device.empty()) {
1012 if (composite_devices->find(composite_device) == composite_devices->end()) {
1013 for (int i = 0; i < handles.size(); ++i) {
1014 (*composite_devices)[composite_device].push_back(handles(i).device());
1015 }
1016 }
1017 return composite_device;
1018 } else {
1019 return handle0.device();
1020 }
1021 }
1022
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)1023 string Canonicalize(const string& funcname, AttrSlice attrs,
1024 const FunctionLibraryRuntime::InstantiateOptions& options) {
1025 absl::InlinedVector<AttrKeyAndValue, 8> entries;
1026 entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1027 options.input_devices.size());
1028 for (const auto& p : attrs) {
1029 if (p.first != kExecutorAttr) {
1030 entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
1031 }
1032 }
1033 if (!options.target.empty()) {
1034 entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1035 AttrKeyAndValue::kCEscape));
1036 }
1037 for (int i = 0; i < options.input_devices.size(); ++i) {
1038 entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1039 AttrKeyAndValue::kCEscape));
1040 }
1041 for (int i = 0; i < options.output_devices.size(); ++i) {
1042 entries.push_back(AttrKeyAndValue("_output_dev", i,
1043 options.output_devices[i],
1044 AttrKeyAndValue::kCEscape));
1045 }
1046 for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1047 entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1048 DataTypeString(iter.second.dtype)));
1049 entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1050 iter.second.shape.DebugString(),
1051 AttrKeyAndValue::kCEscape));
1052 }
1053 if (options.lib_def) {
1054 entries.push_back(AttrKeyAndValue(
1055 "_lib_def", -1,
1056 absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1057 }
1058 if (!options.state_handle.empty()) {
1059 entries.push_back(
1060 AttrKeyAndValue("_state_handle", -1, options.state_handle));
1061 }
1062 string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1063 if (!executor_type.empty()) {
1064 entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1065 }
1066 if (options.config_proto.ByteSize() > 0) {
1067 string config_proto_serialized;
1068 SerializeToStringDeterministic(options.config_proto,
1069 &config_proto_serialized);
1070 entries.push_back(AttrKeyAndValue("_config_proto", -1,
1071 config_proto_serialized,
1072 AttrKeyAndValue::kCEscape));
1073 }
1074 std::sort(entries.begin(), entries.end());
1075 string result = strings::StrCat(funcname, "[");
1076 bool first = true;
1077 for (const auto& entry : entries) {
1078 entry.AppendTo(first, &result);
1079 first = false;
1080 }
1081 result += "]";
1082 return result;
1083 }
1084
Canonicalize(const string & funcname,AttrSlice attrs)1085 string Canonicalize(const string& funcname, AttrSlice attrs) {
1086 static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1087 new FunctionLibraryRuntime::InstantiateOptions;
1088 return Canonicalize(funcname, attrs, *kEmptyOptions);
1089 }
1090
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1091 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1092 DataTypeSlice ret_types)
1093 : arg_types_(arg_types.begin(), arg_types.end()),
1094 ret_types_(ret_types.begin(), ret_types.end()) {
1095 args_.resize(arg_types_.size());
1096 rets_.resize(ret_types_.size());
1097 }
1098
~FunctionCallFrame()1099 FunctionCallFrame::~FunctionCallFrame() {}
1100
SetArgs(gtl::ArraySlice<Tensor> args)1101 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1102 // Input type checks.
1103 if (args.size() != arg_types_.size()) {
1104 return errors::InvalidArgument("Expects ", arg_types_.size(),
1105 " arguments, but ", args.size(),
1106 " is provided");
1107 }
1108 for (size_t i = 0; i < args.size(); ++i) {
1109 if (arg_types_[i] != args[i].dtype()) {
1110 return errors::InvalidArgument(
1111 "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1112 DataTypeString(args[i].dtype()), " is provided");
1113 }
1114 args_[i] = args[i];
1115 }
1116 return Status::OK();
1117 }
1118
GetRetvals(std::vector<Tensor> * rets) const1119 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1120 rets->clear();
1121 rets->reserve(rets_.size());
1122 for (size_t i = 0; i < rets_.size(); ++i) {
1123 const auto& item = rets_[i];
1124 if (item.has_val) {
1125 rets->push_back(item.val);
1126 } else {
1127 return errors::Internal("Retval[", i, "] does not have value");
1128 }
1129 }
1130 return Status::OK();
1131 }
1132
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1133 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1134 bool allow_dead_tensors) {
1135 rets->clear();
1136 rets->reserve(rets_.size());
1137 for (size_t i = 0; i < rets_.size(); ++i) {
1138 if (rets_[i].has_val) {
1139 rets->emplace_back(std::move(rets_[i].val));
1140 } else if (allow_dead_tensors) {
1141 rets->emplace_back();
1142 } else {
1143 return errors::Internal("Retval[", i, "] does not have value");
1144 }
1145 }
1146 return Status::OK();
1147 }
1148
GetArg(int index,const Tensor ** val)1149 Status FunctionCallFrame::GetArg(int index, const Tensor** val) {
1150 if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1151 return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1152 args_.size(), ")");
1153 }
1154 *val = &args_[index];
1155 return Status::OK();
1156 }
1157
SetRetval(int index,const Tensor & val)1158 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1159 if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1160 return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1161 rets_.size(), ")");
1162 }
1163 if (val.dtype() != ret_types_[index]) {
1164 return errors::InvalidArgument(
1165 "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1166 ", but ", DataTypeString(val.dtype()), " is provided.");
1167 }
1168 Retval* item = &rets_[index];
1169 if (!item->has_val) {
1170 item->has_val = true;
1171 item->val = val;
1172 } else {
1173 return errors::Internal("Retval[", index, "] has already been set.");
1174 }
1175 return Status::OK();
1176 }
1177
1178 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in,const StackTracesMap & stack_traces)1179 FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
1180 const StackTracesMap& stack_traces)
1181 : fdef(fdef_in),
1182 // Exact shape inference for functions is handled by ShapeRefiner.
1183 // Here we pass a dummy shape inference function for legacy code paths.
1184 op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1185 true /* is_function */),
1186 stack_traces(stack_traces) {}
1187
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1188 FunctionLibraryDefinition::FunctionLibraryDefinition(
1189 const FunctionLibraryDefinition& other)
1190 : default_registry_(other.default_registry_) {
1191 tf_shared_lock l(other.mu_);
1192 function_defs_ = other.function_defs_;
1193 func_grad_ = other.func_grad_;
1194 }
1195
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1196 FunctionLibraryDefinition::FunctionLibraryDefinition(
1197 const OpRegistryInterface* default_registry,
1198 const FunctionDefLibrary& def_lib)
1199 : default_registry_(default_registry),
1200 function_defs_(def_lib.function_size()) {
1201 for (const auto& fdef : def_lib.function()) {
1202 // The latter function definition wins.
1203 auto& ptr = function_defs_[fdef.signature().name()];
1204 ptr.reset(new FunctionDefAndOpRegistration(fdef));
1205 }
1206 for (const auto& grad : def_lib.gradient()) {
1207 func_grad_[grad.function_name()] = grad.gradient_func();
1208 }
1209 }
1210
~FunctionLibraryDefinition()1211 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1212
Contains(const string & func) const1213 bool FunctionLibraryDefinition::Contains(const string& func) const {
1214 tf_shared_lock l(mu_);
1215 return function_defs_.find(func) != function_defs_.end();
1216 }
1217
Find(const string & func) const1218 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1219 tf_shared_lock l(mu_);
1220 auto result = FindHelper(func);
1221 if (result) {
1222 return &result->fdef;
1223 } else {
1224 return nullptr;
1225 }
1226 }
1227
1228 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1229 FunctionLibraryDefinition::FindHelper(const string& func) const {
1230 auto iter = function_defs_.find(func);
1231 if (iter == function_defs_.end()) {
1232 return nullptr;
1233 } else {
1234 return iter->second;
1235 }
1236 }
1237
AddFunctionDef(const FunctionDef & fdef,const StackTracesMap & stack_traces)1238 Status FunctionLibraryDefinition::AddFunctionDef(
1239 const FunctionDef& fdef, const StackTracesMap& stack_traces) {
1240 mutex_lock l(mu_);
1241 bool added;
1242 return AddFunctionDefHelper(fdef, stack_traces, &added);
1243 }
1244
AddFunctionDefHelper(const FunctionDef & fdef,const StackTracesMap & stack_traces,bool * added)1245 Status FunctionLibraryDefinition::AddFunctionDefHelper(
1246 const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) {
1247 *added = false;
1248 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1249 function_defs_[fdef.signature().name()];
1250 if (entry) {
1251 if (!FunctionDefsEqual(entry->fdef, fdef)) {
1252 return errors::InvalidArgument(
1253 "Cannot add function '", fdef.signature().name(),
1254 "' because a different function with the same name already "
1255 "exists.");
1256 }
1257 // Ignore duplicate FunctionDefs.
1258 return Status::OK();
1259 }
1260 const OpDef* op_def;
1261 if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1262 return errors::InvalidArgument(
1263 "Cannot add function '", fdef.signature().name(),
1264 "' because an op with the same name already exists.");
1265 }
1266 entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces);
1267 *added = true;
1268 return Status::OK();
1269 }
1270
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1271 Status FunctionLibraryDefinition::AddHelper(
1272 std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1273 *added = false;
1274 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1275 function_defs_[registration->fdef.signature().name()];
1276 if (entry) {
1277 if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1278 return errors::InvalidArgument(
1279 "Cannot add function '", registration->fdef.signature().name(),
1280 "' because a different function with the same name already "
1281 "exists.");
1282 }
1283 // Ignore duplicate FunctionDefs.
1284 return Status::OK();
1285 }
1286 const OpDef* op_def;
1287 if (default_registry_
1288 ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1289 .ok()) {
1290 return errors::InvalidArgument(
1291 "Cannot add function '", registration->fdef.signature().name(),
1292 "' because an op with the same name already exists.");
1293 }
1294 entry = std::move(registration);
1295 *added = true;
1296 return Status::OK();
1297 }
1298
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1299 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1300 const string& func, const FunctionLibraryDefinition& other) {
1301 if (default_registry_ != other.default_registry_) {
1302 return errors::InvalidArgument(
1303 "Cannot copy function '", func,
1304 "' because CopyFunctionDefFrom() requires that both libraries have the "
1305 "same default registry.");
1306 }
1307 std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1308 {
1309 tf_shared_lock l(other.mu_);
1310 function_def = other.FindHelper(func);
1311 }
1312 if (!function_def) {
1313 return errors::InvalidArgument(
1314 "Cannot copy function '", func,
1315 "' because no function with that name exists in the other library.");
1316 }
1317 {
1318 mutex_lock l(mu_);
1319 std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1320 if (entry) {
1321 if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1322 return errors::InvalidArgument(
1323 "Cannot copy function '", func,
1324 "' because a different function with the same name already "
1325 "exists.");
1326 }
1327 } else {
1328 entry = std::move(function_def);
1329 }
1330 }
1331 return Status::OK();
1332 }
1333
AddGradientDef(const GradientDef & grad)1334 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1335 mutex_lock l(mu_);
1336 bool added;
1337 return AddGradientDefHelper(grad, &added);
1338 }
1339
AddGradientDefHelper(const GradientDef & grad,bool * added)1340 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1341 bool* added) {
1342 *added = false;
1343 string* entry = &func_grad_[grad.function_name()];
1344 if (!entry->empty()) {
1345 if (*entry != grad.gradient_func()) {
1346 return errors::InvalidArgument(
1347 "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1348 grad.function_name(), "' because it already has gradient function ",
1349 "'", *entry, "'");
1350 }
1351 // Ignore duplicate GradientDefs
1352 return Status::OK();
1353 }
1354 *entry = grad.gradient_func();
1355 *added = true;
1356 return Status::OK();
1357 }
1358
AddLibrary(const FunctionLibraryDefinition & other)1359 Status FunctionLibraryDefinition::AddLibrary(
1360 const FunctionLibraryDefinition& other) {
1361 // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1362 // the duration of the function could lead to deadlock).
1363 FunctionLibraryDefinition clone(other);
1364 mutex_lock l(mu_);
1365 mutex_lock l2(clone.mu_);
1366 // Remember the funcs and grads that we added successfully so that
1367 // we can roll them back on error.
1368 std::vector<string> funcs;
1369 std::vector<string> funcs_with_grads;
1370 Status s;
1371 bool added;
1372 for (auto iter : clone.function_defs_) {
1373 s = AddHelper(iter.second, &added);
1374 if (!s.ok()) {
1375 Status remove_status = Remove(funcs, funcs_with_grads);
1376 if (!remove_status.ok()) {
1377 return remove_status;
1378 }
1379 return s;
1380 }
1381 if (added) {
1382 funcs.push_back(iter.second->fdef.signature().name());
1383 }
1384 }
1385 for (auto iter : clone.func_grad_) {
1386 GradientDef grad;
1387 grad.set_function_name(iter.first);
1388 grad.set_gradient_func(iter.second);
1389 s = AddGradientDefHelper(grad, &added);
1390 if (!s.ok()) {
1391 Status remove_status = Remove(funcs, funcs_with_grads);
1392 if (!remove_status.ok()) {
1393 return remove_status;
1394 }
1395 return s;
1396 }
1397 if (added) {
1398 funcs_with_grads.push_back(grad.function_name());
1399 }
1400 }
1401 return Status::OK();
1402 }
1403
AddLibrary(const FunctionDefLibrary & lib_def)1404 Status FunctionLibraryDefinition::AddLibrary(
1405 const FunctionDefLibrary& lib_def) {
1406 // Remember the funcs and grads that we added successfully so that
1407 // we can roll them back on error.
1408 mutex_lock l(mu_);
1409 std::vector<string> funcs;
1410 std::vector<string> funcs_with_grads;
1411 Status s;
1412 bool added;
1413 for (const FunctionDef& fdef : lib_def.function()) {
1414 s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added);
1415 if (!s.ok()) {
1416 Status remove_status = Remove(funcs, funcs_with_grads);
1417 if (!remove_status.ok()) {
1418 return remove_status;
1419 }
1420 return s;
1421 }
1422 if (added) {
1423 funcs.push_back(fdef.signature().name());
1424 }
1425 }
1426 for (const GradientDef& grad : lib_def.gradient()) {
1427 s = AddGradientDefHelper(grad, &added);
1428 if (!s.ok()) {
1429 Status remove_status = Remove(funcs, funcs_with_grads);
1430 if (!remove_status.ok()) {
1431 return remove_status;
1432 }
1433 return s;
1434 }
1435 if (added) {
1436 funcs_with_grads.push_back(grad.function_name());
1437 }
1438 }
1439 return Status::OK();
1440 }
1441
ReplaceFunction(const string & func,const FunctionDef & fdef,const StackTracesMap & stack_traces)1442 Status FunctionLibraryDefinition::ReplaceFunction(
1443 const string& func, const FunctionDef& fdef,
1444 const StackTracesMap& stack_traces) {
1445 mutex_lock l(mu_);
1446 bool added;
1447 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1448 TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, stack_traces, &added));
1449 return Status::OK();
1450 }
1451
ReplaceGradient(const GradientDef & grad)1452 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1453 mutex_lock l(mu_);
1454 bool added;
1455 TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1456 TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1457 return Status::OK();
1458 }
1459
RemoveFunction(const string & func)1460 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1461 mutex_lock l(mu_);
1462 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1463 return Status::OK();
1464 }
1465
RemoveFunctionHelper(const string & func)1466 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1467 const auto& i = function_defs_.find(func);
1468 if (i == function_defs_.end()) {
1469 return errors::InvalidArgument("Tried to remove non-existent function '",
1470 func, "'.");
1471 }
1472 function_defs_.erase(i);
1473 return Status::OK();
1474 }
1475
Clear()1476 void FunctionLibraryDefinition::Clear() {
1477 mutex_lock l(mu_);
1478 function_defs_.clear();
1479 func_grad_.clear();
1480 }
1481
RemoveGradient(const string & func)1482 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1483 const auto& i = func_grad_.find(func);
1484 if (i == func_grad_.end()) {
1485 return errors::InvalidArgument("Tried to remove non-existent gradient '",
1486 func, "'.");
1487 }
1488 func_grad_.erase(i);
1489 return Status::OK();
1490 }
1491
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1492 Status FunctionLibraryDefinition::Remove(
1493 const std::vector<string>& funcs,
1494 const std::vector<string>& funcs_with_grads) {
1495 Status s;
1496 for (const string& f : funcs) {
1497 s = RemoveFunctionHelper(f);
1498 if (!s.ok()) {
1499 return s;
1500 }
1501 }
1502 for (const string& f : funcs_with_grads) {
1503 s = RemoveGradient(f);
1504 if (!s.ok()) {
1505 return s;
1506 }
1507 }
1508 return Status::OK();
1509 }
1510
FindGradient(const string & func) const1511 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1512 tf_shared_lock l(mu_);
1513 return gtl::FindWithDefault(func_grad_, func, "");
1514 }
1515
FindGradientHelper(const string & func) const1516 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1517 return gtl::FindWithDefault(func_grad_, func, "");
1518 }
1519
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1520 Status FunctionLibraryDefinition::LookUp(
1521 const string& op, const OpRegistrationData** op_reg_data) const {
1522 tf_shared_lock l(mu_);
1523 auto iter = function_defs_.find(op);
1524 if (iter != function_defs_.end()) {
1525 *op_reg_data = &iter->second->op_registration_data;
1526 return Status::OK();
1527 }
1528 return default_registry_->LookUp(op, op_reg_data);
1529 }
1530
UniqueFunctionName(StringPiece prefix) const1531 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1532 tf_shared_lock l(mu_);
1533 int index = 0;
1534 string name = strings::StrCat(prefix, index);
1535 while (function_defs_.find(name) != function_defs_.end()) {
1536 ++index;
1537 name = strings::StrCat(prefix, index);
1538 }
1539 return name;
1540 }
1541
GetAttrImpl(const NodeDef & ndef) const1542 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1543 const NodeDef& ndef) const {
1544 if (ndef.op() != kGradientOp) {
1545 // If 'ndef' calls a function and the function's def has the attr,
1546 // returns it.
1547 return Find(ndef.op());
1548 }
1549
1550 // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1551 // Foo's attributes.
1552 const NameAttrList* forward_func_attrs;
1553 if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1554 return nullptr;
1555 }
1556 const string& func_name = forward_func_attrs->name();
1557 {
1558 tf_shared_lock l(mu_);
1559 const string& grad_name = FindGradientHelper(func_name);
1560 // If 'func' has a user-defined gradient function, uses the grad
1561 // function's attrs to see if noinline is specified. Otherwise,
1562 // uses func's attrs.
1563 if (!grad_name.empty()) {
1564 if (const auto helper = FindHelper(grad_name)) {
1565 return &(helper->fdef);
1566 } else {
1567 return nullptr;
1568 }
1569 }
1570 if (const auto helper = FindHelper(func_name)) {
1571 return &(helper->fdef);
1572 } else {
1573 return nullptr;
1574 }
1575 }
1576 }
1577
ListFunctionNames() const1578 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1579 std::vector<string> function_names;
1580 tf_shared_lock l(mu_);
1581 function_names.reserve(function_defs_.size());
1582 for (const auto& it : function_defs_) {
1583 function_names.emplace_back(it.first);
1584 }
1585 return function_names;
1586 }
1587
ToProto() const1588 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1589 FunctionDefLibrary lib;
1590 tf_shared_lock l(mu_);
1591 for (const auto& f : function_defs_) {
1592 *lib.add_function() = f.second->fdef;
1593 }
1594 for (const auto& g : func_grad_) {
1595 GradientDef* gd = lib.add_gradient();
1596 gd->set_function_name(g.first);
1597 gd->set_gradient_func(g.second);
1598 }
1599 return lib;
1600 }
1601
1602 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1603 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1604 const string& attr, T* value) const {
1605 const FunctionDef* fdef = GetAttrImpl(ndef);
1606 if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1607 return Status::OK();
1608 }
1609 return errors::InvalidArgument("Attr ", attr, " is not defined.");
1610 }
1611
1612 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1613 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1614 T* value) const {
1615 return GetAttr(node.def(), attr, value);
1616 }
1617
1618 #define GET_ATTR(T) \
1619 template Status FunctionLibraryDefinition::GetAttr(const Node&, \
1620 const string&, T*) const; \
1621 template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
1622 const string&, T*) const;
1623 GET_ATTR(string)
1624 GET_ATTR(bool)
1625 #undef GET_ATTR
1626
1627 namespace {
1628
1629 constexpr char kApiImplements[] = "api_implements";
1630
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1631 std::set<string> ReachableFunctions(
1632 const FunctionLibraryDefinition& flib,
1633 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1634 // Functions that are reachable from the graph.
1635 std::set<string> reachable_funcs;
1636
1637 // For any functions, if it has attribute "api_implements" =
1638 // "some_interface" and it is reachable, then it means any other
1639 // function with same attribute name and value could also be potentially
1640 // reachable, eg via implementation_selector swapping the nodedef.
1641 absl::flat_hash_set<string> reachable_api_interface;
1642
1643 // Functions might be reachable from the nested function calls, so we keep a
1644 // queue of functions that we have to check.
1645 gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1646
1647 // Add reachable and not already processed functions to the functions queue.
1648 const auto add_to_func_queue = [&](const string& func_name) {
1649 const FunctionDef* func = flib.Find(func_name);
1650 if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1651 func_queue.push_back(func);
1652 }
1653 };
1654
1655 // If any function with certain API name is reachable, all the other functions
1656 // with same API name should also be checked.
1657 const auto add_function_with_api_interface = [&](const string& api_name) {
1658 if (!reachable_api_interface.contains(api_name)) {
1659 reachable_api_interface.insert(api_name);
1660 for (const auto& func_name : flib.ListFunctionNames()) {
1661 const auto& func_def = flib.Find(func_name);
1662 const auto attr_it = func_def->attr().find(kApiImplements);
1663 if (attr_it != func_def->attr().end() &&
1664 attr_it->second.s() == api_name) {
1665 add_to_func_queue(func_name);
1666 }
1667 }
1668 }
1669 };
1670
1671 // Add all the functions that are reachable from the given node to the queue.
1672 const auto process_node = [&](const NodeDef& node) {
1673 // Node itself can be a call to the function.
1674 add_to_func_queue(node.op());
1675
1676 // Or node can have an attribute referencing a function.
1677 for (const auto& attr : node.attr()) {
1678 const auto& attr_value = attr.second;
1679
1680 // 1. AttrValue.func
1681 if (attr_value.has_func()) {
1682 add_to_func_queue(attr_value.func().name());
1683 }
1684
1685 // 2. AttrValue.ListValue.func
1686 if (attr_value.has_list()) {
1687 for (const auto& func : attr_value.list().func()) {
1688 add_to_func_queue(func.name());
1689 }
1690 }
1691 }
1692 };
1693
1694 // Add all functions that are directly called from the optimized graph.
1695 std::for_each(nodes.begin(), nodes.end(), process_node);
1696
1697 // Process all reachable functions.
1698 while (!func_queue.empty()) {
1699 const FunctionDef* func = func_queue.back();
1700 func_queue.pop_back();
1701
1702 const string& func_name = func->signature().name();
1703 reachable_funcs.insert(func_name);
1704
1705 const auto attr_it = func->attr().find(kApiImplements);
1706 if (attr_it != func->attr().end()) {
1707 add_function_with_api_interface(attr_it->second.s());
1708 }
1709
1710 // Find all the functions called from the function body.
1711 const auto& func_body = func->node_def();
1712 std::for_each(func_body.begin(), func_body.end(), process_node);
1713
1714 // Check if the function has a registered gradient.
1715 const string grad_func_name = flib.FindGradient(func_name);
1716 if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1717 }
1718
1719 return reachable_funcs;
1720 }
1721
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1722 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1723 const FunctionLibraryDefinition& flib,
1724 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1725 std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1726
1727 FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1728 FunctionDefLibrary());
1729
1730 for (const string& func_name : reachable_funcs) {
1731 // This should never fail, because we copy functions from a valid flib and
1732 // use the same default registry.
1733 Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1734 TF_DCHECK_OK(added);
1735
1736 const string grad_func_name = flib.FindGradient(func_name);
1737 if (!grad_func_name.empty()) {
1738 GradientDef grad;
1739 grad.set_function_name(func_name);
1740 grad.set_gradient_func(grad_func_name);
1741 // It can only fail if function already has a gradient function.
1742 const Status added_grad = reachable_flib.AddGradientDef(grad);
1743 TF_DCHECK_OK(added_grad);
1744 }
1745 }
1746
1747 return reachable_flib;
1748 }
1749
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1750 string AllocatorAttributesToString(
1751 const std::vector<AllocatorAttributes>& attrs) {
1752 string result("[");
1753 // AllocatorAttribute::DebugString produces around 85 bytes now.
1754 result.reserve(100 * attrs.size());
1755 for (const AllocatorAttributes& attr : attrs) {
1756 result.append(attr.DebugString());
1757 result.append(", ");
1758 }
1759 if (!attrs.empty()) {
1760 result.resize(result.size() - 2);
1761 }
1762 result.append("]");
1763 return result;
1764 }
1765
IsSet(void * ptr)1766 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1767
1768 } // namespace
1769
ReachableDefinitions(const GraphDef & graph) const1770 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1771 const GraphDef& graph) const {
1772 return ReachableFunctionLibraryDefinition(*this, graph.node());
1773 }
1774
ReachableDefinitions(const FunctionDef & func) const1775 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1776 const FunctionDef& func) const {
1777 return ReachableFunctionLibraryDefinition(*this, func.node_def());
1778 }
1779
DebugString() const1780 string FunctionLibraryRuntime::Options::DebugString() const {
1781 return absl::StrCat(
1782 "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1783 " cancellation_manager=", IsSet(cancellation_manager),
1784 " collective_executor=", IsSet(collective_executor),
1785 " step_container=", IsSet(step_container),
1786 " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1787 " remote_execution=", remote_execution, " source_device=", source_device,
1788 " create_rendezvous=", create_rendezvous,
1789 " allow_dead_tensors=", allow_dead_tensors,
1790 " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1791 " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1792 }
1793
InitFromString(StringPiece val)1794 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1795 if (val.size() >= 2 && val[0] == '$') {
1796 proto.set_placeholder(val.data() + 1, val.size() - 1);
1797 } else {
1798 SetAttrValue(val, &proto);
1799 }
1800 }
1801
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1802 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1803 const string& name,
1804 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1805 AttrValueWrapper ret;
1806 ret.proto.mutable_func()->set_name(name);
1807 for (const auto& a : attrs) {
1808 ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1809 }
1810 return ret;
1811 }
1812
ToNodeDef() const1813 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1814 NodeDef n;
1815 n.set_op(this->op);
1816 n.set_name(GetName());
1817 for (const auto& a : this->attr) {
1818 n.mutable_attr()->insert({a.first, a.second.proto});
1819 }
1820 for (const string& a : this->arg) {
1821 n.add_input(a);
1822 }
1823 for (const string& d : this->dep) {
1824 n.add_input(strings::StrCat("^", d));
1825 }
1826 if (!this->device.empty()) {
1827 n.set_device(this->device);
1828 }
1829 return n;
1830 }
1831
1832 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1833 FunctionDef FunctionDefHelper::Create(
1834 const string& function_name, gtl::ArraySlice<string> in_def,
1835 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1836 gtl::ArraySlice<Node> node_def,
1837 gtl::ArraySlice<std::pair<string, string>> ret_def,
1838 gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1839 FunctionDef fdef;
1840
1841 // Signature
1842 OpDefBuilder b(function_name);
1843 for (const auto& i : in_def) b.Input(i);
1844 for (const auto& o : out_def) b.Output(o);
1845 for (const auto& a : attr_def) b.Attr(a);
1846 for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1847
1848 OpRegistrationData op_reg_data;
1849 TF_CHECK_OK(b.Finalize(&op_reg_data));
1850 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1851
1852 // Function body
1853 for (const auto& n : node_def) {
1854 *(fdef.add_node_def()) = n.ToNodeDef();
1855 }
1856
1857 // Returns
1858 for (const auto& r : ret_def) {
1859 fdef.mutable_ret()->insert({r.first, r.second});
1860 }
1861
1862 // Control returns
1863 for (const auto& cr : control_ret_def) {
1864 fdef.mutable_control_ret()->insert({cr.first, cr.second});
1865 }
1866
1867 auto* op_def_registry = OpRegistry::Global();
1868 // Check if any op is stateful.
1869 for (const auto& n : node_def) {
1870 const OpDef* op_def = nullptr;
1871 auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1872 // Lookup can fail if e.g. we are calling a function that was not yet
1873 // defined. If it happens, conservatively assume the op is stateful.
1874 if (!status.ok() || op_def->is_stateful()) {
1875 fdef.mutable_signature()->set_is_stateful(true);
1876 }
1877 }
1878
1879 return fdef;
1880 }
1881
1882 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1883 FunctionDef FunctionDefHelper::Create(
1884 const string& function_name, gtl::ArraySlice<string> in_def,
1885 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1886 gtl::ArraySlice<Node> node_def,
1887 gtl::ArraySlice<std::pair<string, string>> ret_def) {
1888 return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1889 /*control_ret_def=*/{});
1890 }
1891
1892 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1893 FunctionDef FunctionDefHelper::Define(const string& name,
1894 gtl::ArraySlice<string> arg_def,
1895 gtl::ArraySlice<string> ret_def,
1896 gtl::ArraySlice<string> attr_def,
1897 gtl::ArraySlice<Node> node_def) {
1898 FunctionDef fdef;
1899 OpDefBuilder b(name);
1900 for (const auto& a : arg_def) b.Input(a);
1901 for (const auto& r : ret_def) b.Output(r);
1902 for (const auto& a : attr_def) b.Attr(a);
1903
1904 OpRegistrationData op_reg_data;
1905 TF_CHECK_OK(b.Finalize(&op_reg_data));
1906 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1907
1908 // Mapping from legacy output names to NodeDef outputs.
1909 std::unordered_map<string, string> ret_index;
1910 for (const auto& a : fdef.signature().input_arg()) {
1911 ret_index[a.name()] = a.name();
1912 }
1913
1914 // For looking up OpDefs
1915 auto* op_def_registry = OpRegistry::Global();
1916
1917 // Function body
1918 for (const auto& src : node_def) {
1919 NodeDef* n = fdef.add_node_def();
1920 n->set_op(src.op);
1921 n->set_name(src.GetName());
1922 for (const auto& a : src.attr) {
1923 n->mutable_attr()->insert({a.first, a.second.proto});
1924 }
1925 for (const string& a : src.arg) {
1926 const auto iter = ret_index.find(a);
1927 CHECK(iter != ret_index.end())
1928 << "Node input '" << a << "' in '" << n->name() << "' of " << name;
1929 n->add_input(iter->second);
1930 }
1931 for (const string& d : src.dep) {
1932 n->add_input(strings::StrCat("^", d));
1933 }
1934
1935 // Add the outputs of this node to ret_index.
1936 const OpDef* op_def = nullptr;
1937 TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1938 CHECK(op_def != nullptr) << n->op();
1939 NameRangeMap output_names;
1940 TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1941 for (const auto& o : output_names) {
1942 CHECK_LE(o.second.second, src.ret.size())
1943 << "Missing ret for output '" << o.first << "' in '" << n->name()
1944 << "' of " << name;
1945 for (int i = o.second.first; i < o.second.second; ++i) {
1946 ret_index[src.ret[i]] =
1947 strings::StrCat(n->name(), ":", o.first, ":", i - o.second.first);
1948 }
1949 }
1950 if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1951 }
1952
1953 // Returns
1954 for (const auto& r : fdef.signature().output_arg()) {
1955 const auto iter = ret_index.find(r.name());
1956 CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1957 fdef.mutable_ret()->insert({r.name(), iter->second});
1958 }
1959 return fdef;
1960 }
1961
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1962 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1963 gtl::ArraySlice<string> ret_def,
1964 gtl::ArraySlice<string> attr_def,
1965 gtl::ArraySlice<Node> node_def) {
1966 return Define("_", arg_def, ret_def, attr_def, node_def);
1967 }
1968
1969 namespace gradient {
1970
1971 typedef std::unordered_map<string, Creator> OpGradFactory;
1972
GetOpGradFactory()1973 OpGradFactory* GetOpGradFactory() {
1974 static OpGradFactory* factory = new OpGradFactory;
1975 return factory;
1976 }
1977
RegisterOp(const string & op,Creator func)1978 bool RegisterOp(const string& op, Creator func) {
1979 CHECK(GetOpGradFactory()->insert({op, func}).second)
1980 << "Duplicated gradient for " << op;
1981 return true;
1982 }
1983
GetOpGradientCreator(const string & op,Creator * creator)1984 Status GetOpGradientCreator(const string& op, Creator* creator) {
1985 auto fac = GetOpGradFactory();
1986 auto iter = fac->find(op);
1987 if (iter == fac->end()) {
1988 return errors::NotFound("No gradient defined for op: ", op);
1989 }
1990 *creator = iter->second;
1991 return Status::OK();
1992 }
1993
1994 } // end namespace gradient
1995
1996 } // namespace tensorflow
1997