1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17
18 #include <ctype.h>
19
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/lib/strings/proto_serialization.h"
41 #include "tensorflow/core/platform/fingerprint.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tensorflow/core/util/equal_graph_def.h"
44
45 namespace tensorflow {
46
47 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
48 /* static */ constexpr const char* const
49 FunctionLibraryDefinition::kDeviceArgOp;
50 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
51 /* static */ constexpr const char* const
52 FunctionLibraryDefinition::kDeviceRetOp;
53 /* static */ constexpr const char* const
54 FunctionLibraryDefinition::kIntsOnDeviceAttr;
55 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
56 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
57
58 // Extracts the actual type from "attr_values" based on its definition
59 // "arg_def".
60 //
61 // If "arg_def" is a N*T type, *is_type_list is set to false, and
62 // *dtypes is set to be a vector of size N and each element is T.
63 //
64 // If "arg_def" is a list(type), *is_type_list is set to true, and
65 // *dtypes is set to be a vector of types specified in attrs for
66 // arg_def.
67 //
68 // Otherwise (arg_def is a simple type T), *is_type_list is set to
69 // false, and *dtypes is set to a single element vector, whose only
70 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)71 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
72 bool* is_type_list, DataTypeVector* dtypes) {
73 dtypes->clear();
74 if (!arg_def.type_list_attr().empty()) {
75 const AttrValue* v = attrs.FindByString(arg_def.type_list_attr());
76 if (v == nullptr) {
77 return errors::NotFound("type list attr not found: ",
78 arg_def.type_list_attr());
79 }
80 *is_type_list = true;
81 for (int i = 0; i < v->list().type_size(); ++i) {
82 dtypes->push_back(v->list().type(i));
83 }
84 return OkStatus();
85 }
86
87 *is_type_list = false;
88 int num = 1;
89 if (!arg_def.number_attr().empty()) {
90 const AttrValue* v = attrs.FindByString(arg_def.number_attr());
91 if (v == nullptr) {
92 return errors::NotFound("number attr not found: ", arg_def.number_attr());
93 }
94 num = v->i();
95 }
96
97 DataType dtype;
98 if (arg_def.type() != DT_INVALID) {
99 dtype = arg_def.type();
100 } else if (arg_def.type_attr().empty()) {
101 dtype = DT_INVALID;
102 } else {
103 const AttrValue* v = attrs.FindByString(arg_def.type_attr());
104 if (v == nullptr) {
105 return errors::NotFound("type attr not found: ", arg_def.type_attr());
106 }
107 dtype = v->type();
108 }
109 dtypes->resize(num, dtype);
110 return OkStatus();
111 }
112
113 namespace {
114
115 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)116 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
117 SetAttrValue(val, &((*ndef->mutable_attr())[name]));
118 }
119
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)120 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
121 // attr_values should specify all attrs defined in fdef, except for those
122 // which have a default value
123 for (const auto& attr : sig.attr()) {
124 const AttrValue* attr_value = attr_values.FindByString(attr.name());
125 if (attr_value) {
126 Status status = AttrValueHasType(*attr_value, attr.type());
127 if (!status.ok()) {
128 errors::AppendToMessage(&status, "for attr '", attr.name(), "'");
129 return status;
130 }
131 } else if (!attr.has_default_value()) {
132 return errors::NotFound("Attr ", attr.name(), " is not found from ",
133 SummarizeOpDef(sig));
134 }
135 }
136
137 // TODO(josh11b): Enable this code once it works with function gradients.
138 // Right now the C++ function gradient code assumes it can pass
139 // all the attrs of the function to the gradient, and any attrs that
140 // the gradient doesn't care about will be ignored.
141 #if 0
142 if (attr_values.size() != sig.attr_size()) {
143 for (const auto& a : attr_values) {
144 // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
145 bool found = false;
146 for (const auto& s : sig.attr()) {
147 if (a.first == s.name()) {
148 found = true;
149 break;
150 }
151 }
152 if (!found) {
153 return errors::NotFound("Attr ", a.first, " is not found in ",
154 SummarizeOpDef(sig));
155 }
156 }
157 }
158 #endif
159
160 return OkStatus();
161 }
162
163 // A helper class for instantiating functions. This contains shared information
164 // like the resulting graph and node name index.
165 class FunctionInstantiationHelper {
166 public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)167 FunctionInstantiationHelper(GetFunctionSignature get_function,
168 InstantiationResult* result)
169 : get_function_(std ::move(get_function)), result_(*result) {
170 result_.nodes.clear();
171 }
172
173 // Builds index for nodes that can be used as node's input arguments.
174 // `resource_arg_unique_id`: if non-negative, will be populated to the
175 // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64_t resource_arg_unique_id)176 Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
177 const FunctionDef::ArgAttrs* arg_attrs,
178 bool ints_on_device,
179 int64_t resource_arg_unique_id) {
180 bool is_type_list;
181 DataTypeVector dtypes;
182 TF_RETURN_IF_ERROR(
183 ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
184 if (dtypes.size() < size_t{1}) {
185 return errors::Internal("Expected a list of at least one dtype");
186 }
187 int arg_index = result_.nodes.size();
188 TF_RETURN_IF_ERROR(
189 AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
190 // Creates dtypes.size() nodes in the graph.
191 for (size_t i = 0; i < dtypes.size(); ++i) {
192 TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
193 {true, arg_index, 0, false, {dtypes[i]}}));
194 if (arg_index != result_.nodes.size()) {
195 return errors::Internal(
196 "Expected arg_index to be equal to the number of nodes in result.",
197 " Got ", arg_index, " and ", result_.nodes.size());
198 }
199 string name = arg_def.name();
200 if (dtypes.size() > 1) {
201 strings::StrAppend(&name, "_", i);
202 }
203 NodeDef* gnode = AddNode(name);
204 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
205 gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
206 } else {
207 gnode->set_op(FunctionLibraryDefinition::kArgOp);
208 }
209 DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
210 AddAttr("T", dtype, gnode);
211 AddAttr("index", arg_index, gnode);
212 if (resource_arg_unique_id >= 0) {
213 AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
214 }
215 if (arg_attrs) {
216 for (const auto& arg_attr : arg_attrs->attr()) {
217 AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
218 }
219 }
220 result_.arg_types.push_back(dtypes[i]);
221 ++arg_index;
222 }
223 return OkStatus();
224 }
225
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)226 Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
227 const int arg_index) {
228 const OpDef* node_sig = nullptr;
229 TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
230 if (node_sig->output_arg_size() == 0) {
231 return AddItem(node.name(), {false, arg_index, 0, false, {}});
232 }
233 const int num_retval = node_sig->output_arg_size();
234 int start = 0;
235 bool is_type_list;
236 DataTypeVector dtypes;
237 for (int i = 0; i < num_retval; ++i) {
238 TF_RETURN_IF_ERROR(
239 ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
240 // Note that we rely on the backwards-compatibility test enforcing
241 // that output_arg(*).name() doesn't change here.
242 const string base_name =
243 strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
244 TF_RETURN_IF_ERROR(
245 AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
246 for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
247 TF_RETURN_IF_ERROR(
248 AddItem(strings::StrCat(base_name, ":", j),
249 {false, arg_index, start + j, false, {dtypes[j]}}));
250 }
251 start += dtypes.size();
252 }
253 return OkStatus();
254 }
255
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)256 Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
257 const OpDef* fnode_sig = nullptr;
258 TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
259 NodeDef* gnode = AddNode(fnode.name());
260 gnode->set_op(fnode.op());
261 gnode->set_device(fnode.device());
262 int gnode_idx = nodes_.size() - 1;
263
264 // Input
265 const int num_args = fnode_sig->input_arg_size();
266 bool is_type_list; // ignored
267 DataTypeVector dtypes;
268 int fnode_arg_index = 0;
269 for (int i = 0; i < num_args; ++i) {
270 TF_RETURN_IF_ERROR(
271 ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
272 // Consume inputs (indexed by fnode_arg_index) until we have
273 // matched each element of dtypes (indexed by j).
274 for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
275 if (fnode_arg_index >= fnode.input_size()) {
276 // Should never happen if we computed dtypes correctly.
277 return errors::InvalidArgument(
278 "Attempt to access beyond input size: ", fnode_arg_index,
279 " >= ", fnode.input_size());
280 }
281 // Look up the next input.
282 const string& input_name = fnode.input(fnode_arg_index);
283 const auto* item = GetItemOrNull(input_name);
284 if (item == nullptr) {
285 return errors::InvalidArgument(
286 "input ", input_name,
287 " is not found: ", FormatNodeDefForError(fnode));
288 }
289 if (item->dtypes.size() > dtypes.size() - j) {
290 return errors::InvalidArgument("Input ", input_name, " too long for ",
291 fnode_sig->input_arg(i).name());
292 }
293 // Match up all the elements of this input (indexed by k) with
294 // elements of dtypes (advancing j).
295 for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
296 if (item->dtypes[k] != dtypes[j]) {
297 return errors::InvalidArgument(
298 "input ", fnode_sig->input_arg(i).name(), "[", j,
299 "] expected type ", DataTypeString(dtypes[j]),
300 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
301 input_name, "[", k, "]");
302 }
303 if (item->is_func_arg) {
304 AddInput(gnode_idx, item->nid + k, 0);
305 } else {
306 AddInput(gnode_idx, item->nid, item->idx + k);
307 }
308 }
309 }
310 }
311
312 // Control deps.
313 for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
314 const string& input = fnode.input(i);
315 if (input.empty() || input[0] != '^') {
316 return errors::InvalidArgument("Expected input[", i, "] == '", input,
317 "' to be a control input.");
318 }
319 int nid = -1;
320 const string node_name = input.substr(1);
321 const string node_colon = node_name + ":";
322 const string node_colon_bound = node_name + ";";
323 // index_ is a map sorted lexicographically, so the key we are looking for
324 // must lie in the range [node_name, node_colon_bound).
325 auto it = index_.lower_bound(node_name);
326 while (it != index_.end() && it->first <= node_colon_bound) {
327 if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
328 nid = it->second.nid;
329 break;
330 }
331 ++it;
332 }
333 if (nid == -1) {
334 return errors::InvalidArgument("input[", i, "] == '", input,
335 "', is not found.");
336 }
337 AddDep(gnode_idx, nid);
338 }
339
340 // Attrs.
341 for (const auto& p : attrs) {
342 (*gnode->mutable_attr())[p.first] = p.second;
343 }
344
345 // Experimental_debug_info.
346 if (fnode.has_experimental_debug_info()) {
347 gnode->mutable_experimental_debug_info()->MergeFrom(
348 fnode.experimental_debug_info());
349 }
350
351 // Tye info.
352 // TODO(mdan): Might this need adjustment at instantiation?
353 if (fnode.has_experimental_type()) {
354 *gnode->mutable_experimental_type() = fnode.experimental_type();
355 }
356
357 return OkStatus();
358 }
359
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)360 Status AddReturnNode(
361 const OpDef::ArgDef& ret_def, AttrSlice attrs,
362 const ::tensorflow::protobuf::Map<string, string>& ret_map,
363 bool ints_on_device, int* ret_index) {
364 auto ret_iter = ret_map.find(ret_def.name());
365 if (ret_iter == ret_map.end()) {
366 return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
367 }
368 bool is_type_list;
369 DataTypeVector dtypes;
370 TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
371 CHECK_GE(dtypes.size(), size_t{1});
372 const auto* item = GetItemOrNull(ret_iter->second);
373 if (item == nullptr) {
374 return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
375 ret_iter->second, " is not found.");
376 }
377 if (dtypes != item->dtypes) {
378 return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
379 " : ", DataTypeVectorString(dtypes),
380 " vs. ",
381 DataTypeVectorString(item->dtypes));
382 }
383 for (size_t i = 0; i < dtypes.size(); ++i) {
384 string name = strings::StrCat(ret_def.name(), "_RetVal");
385 if (dtypes.size() > 1) {
386 strings::StrAppend(&name, "_", i);
387 }
388 NodeDef* gnode = AddNode(name);
389 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
390 gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
391 } else {
392 gnode->set_op(FunctionLibraryDefinition::kRetOp);
393 }
394 AddInput(nodes_.size() - 1, item->nid, item->idx + i);
395 DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
396 AddAttr("T", dtype, gnode);
397 AddAttr("index", (*ret_index)++, gnode);
398 result_.ret_types.push_back(dtypes[i]);
399 }
400 return OkStatus();
401 }
402
403 // Adds the actual node inputs to the result graph by converting indexes to
404 // the node names.
AddNodeInputs()405 void AddNodeInputs() {
406 for (int i = 0; i < result_.nodes.size(); i++) {
407 NodeInfo& node_info = nodes_[i];
408 for (const auto& p : node_info.data_inputs) {
409 result_.nodes[i].add_input(Name(p.first, p.second));
410 }
411 for (int index : node_info.control_inputs) {
412 result_.nodes[i].add_input(Dep(index));
413 }
414 }
415 }
416
417 private:
418 // This is used to build a small index for all names that can be used as a
419 // node's input arguments.
420 //
421 // If is_func_arg is true, the name is a function's argument. In
422 // this case, the produced graph def has node[nid:nid + dtype.size()].
423 //
424 // Otherwise, the name is a function body's node return value. In
425 // this case, the produced graph def has one node node[nid] and
426 // the node's output index [idx ... idx + num) corresponds to the
427 // named outputs.
428 //
429 // In all cases, "dtype" specifies the data type.
430 struct NameInfoItem {
431 bool is_func_arg;
432 int nid;
433 int idx;
434 bool is_type_list;
435 DataTypeVector dtypes;
436 };
437
438 // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)439 Status AddItem(const string& name, const NameInfoItem& item) {
440 if (!index_.insert({name, item}).second) {
441 return errors::InvalidArgument(
442 strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
443 " name: "),
444 name);
445 }
446 return OkStatus();
447 }
448
GetItemOrNull(const string & name) const449 const NameInfoItem* GetItemOrNull(const string& name) const {
450 return gtl::FindOrNull(index_, name);
451 }
452
Dep(int node_index) const453 string Dep(int node_index) const {
454 return strings::StrCat("^", Name(node_index));
455 }
456
Name(int node_index) const457 string Name(int node_index) const {
458 CHECK_LT(node_index, nodes_.size());
459 return nodes_[node_index].name;
460 }
461
Name(int node_index,int output_index) const462 string Name(int node_index, int output_index) const {
463 if (output_index == 0) {
464 return Name(node_index);
465 } else {
466 return strings::StrCat(Name(node_index), ":", output_index);
467 }
468 }
469
AddNode(const string & name)470 NodeDef* AddNode(const string& name) {
471 result_.nodes.emplace_back();
472 NodeDef* gnode = &result_.nodes.back();
473 gnode->set_name(name);
474 nodes_.push_back({name, {}, {}});
475 CHECK_EQ(result_.nodes.size(), nodes_.size());
476 return gnode;
477 }
478
AddInput(int node_index,int output_node,int output_index)479 void AddInput(int node_index, int output_node, int output_index) {
480 CHECK_LT(node_index, nodes_.size());
481 nodes_[node_index].data_inputs.push_back(
482 std::make_pair(output_node, output_index));
483 }
484
AddDep(int node_index,int dep_index)485 void AddDep(int node_index, int dep_index) {
486 CHECK_LT(node_index, nodes_.size());
487 nodes_[node_index].control_inputs.push_back(dep_index);
488 }
489
490 GetFunctionSignature get_function_;
491 InstantiationResult& result_;
492 // A small index for all names that can be used as a node's input arguments.
493 std::map<string, NameInfoItem> index_;
494 // This contains information about a node in the new graph including the node
495 // names and input nodes' indexes.
496 struct NodeInfo {
497 string name;
498 // Data inputs where <n, k> means arg k of node n.
499 std::vector<std::pair<int, int>> data_inputs;
500 // Control inputs (dependencies).
501 std::vector<int> control_inputs;
502 };
503 // nodes_[i] is the information about result_.nodes[i].
504 std::vector<NodeInfo> nodes_;
505 };
506
507 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)508 string Print(const OpDef::ArgDef& arg) {
509 string out;
510 strings::StrAppend(&out, arg.name(), ":");
511 if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
512 if (!arg.number_attr().empty()) {
513 strings::StrAppend(&out, arg.number_attr(), "*");
514 }
515 if (arg.type() != DT_INVALID) {
516 strings::StrAppend(&out, DataTypeString(arg.type()));
517 } else {
518 strings::StrAppend(&out, arg.type_attr());
519 }
520 if (arg.is_ref()) strings::StrAppend(&out, ")");
521 return out;
522 }
523
524 // TODO(josh11b): Merge this with SummarizeAttrValue().
525 // When hash_string_attrs = true, string attributes are hashed instead of being
526 // truncated with ellipses. This is done to reduce the chance of collisions when
527 // looking up functions using the canonical representation.
Print(const AttrValue & attr_value,const bool hash_string_attrs=false)528 string Print(const AttrValue& attr_value,
529 const bool hash_string_attrs = false) {
530 if (attr_value.value_case() == AttrValue::kType) {
531 return DataTypeString(attr_value.type());
532 } else if ((attr_value.value_case() == AttrValue::kList) &&
533 (attr_value.list().type_size() > 0)) {
534 string ret = "{";
535 for (int i = 0; i < attr_value.list().type_size(); ++i) {
536 if (i > 0) strings::StrAppend(&ret, ", ");
537 strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
538 }
539 strings::StrAppend(&ret, "}");
540 return ret;
541 } else if (attr_value.value_case() == AttrValue::kFunc) {
542 if (attr_value.func().attr_size() == 0) {
543 return attr_value.func().name();
544 }
545 std::vector<string> entries;
546 for (const auto& p : attr_value.func().attr()) {
547 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
548 }
549 std::sort(entries.begin(), entries.end());
550 return strings::StrCat(attr_value.func().name(), "[",
551 absl::StrJoin(entries, ", "), "]");
552 } else if (attr_value.value_case() == AttrValue::kS && hash_string_attrs) {
553 return strings::StrCat(Fingerprint64(attr_value.s()));
554 }
555 return SummarizeAttrValue(attr_value);
556 }
557
558 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)559 string Print(const NodeDef& n) {
560 string out;
561 strings::StrAppend(&out, n.name(), " = ", n.op());
562 if (n.attr_size() > 0) {
563 std::vector<string> entries;
564 for (auto& a : n.attr()) {
565 entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
566 }
567 std::sort(entries.begin(), entries.end());
568 // Add a short device string at the end of all attributes.
569 if (!n.device().empty()) {
570 DeviceNameUtils::ParsedName parsed;
571 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
572 entries.push_back(
573 strings::StrCat("device=", parsed.type, ":", parsed.id));
574 } else {
575 entries.push_back("device=<FAILED_TO_PARSE>");
576 }
577 }
578 strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
579 }
580 strings::StrAppend(&out, "(");
581 std::vector<StringPiece> dat;
582 std::vector<string> dep;
583 for (StringPiece s : n.input()) {
584 if (absl::ConsumePrefix(&s, "^")) {
585 dep.emplace_back(s);
586 } else {
587 dat.push_back(s);
588 }
589 }
590 strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
591 if (!dep.empty()) {
592 strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
593 }
594 return out;
595 }
596
Print(const FunctionDef & fdef)597 string Print(const FunctionDef& fdef) {
598 string out;
599 const OpDef& sig = fdef.signature();
600 strings::StrAppend(&out, "\n", sig.name());
601 if (sig.attr_size() > 0) {
602 strings::StrAppend(&out, "[");
603 for (int i = 0; i < sig.attr_size(); ++i) {
604 const auto& a = sig.attr(i);
605 if (i > 0) strings::StrAppend(&out, ", ");
606 if (a.type() == "type") {
607 strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
608 } else {
609 strings::StrAppend(&out, a.name(), ":", a.type());
610 }
611 }
612 strings::StrAppend(&out, "]");
613 }
614 strings::StrAppend(&out, "(");
615 for (int i = 0; i < sig.input_arg_size(); ++i) {
616 if (i > 0) strings::StrAppend(&out, ", ");
617 strings::StrAppend(&out, Print(sig.input_arg(i)));
618 }
619 strings::StrAppend(&out, ") -> (");
620 for (int i = 0; i < sig.output_arg_size(); ++i) {
621 if (i > 0) strings::StrAppend(&out, ", ");
622 strings::StrAppend(&out, Print(sig.output_arg(i)));
623 }
624 strings::StrAppend(&out, ") {\n");
625 for (const auto& n : fdef.node_def()) {
626 strings::StrAppend(&out, " ", Print(n), "\n");
627 }
628 for (const auto& cr : fdef.control_ret()) {
629 strings::StrAppend(&out, " @return ", cr.first, " = ", cr.second, "\n");
630 }
631 for (const auto& r : fdef.ret()) {
632 strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
633 }
634 strings::StrAppend(&out, "}\n");
635 return out;
636 }
637
Print(gtl::ArraySlice<const NodeDef * > nodes)638 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
639 std::vector<const NodeDef*> arg;
640 std::vector<const NodeDef*> ret;
641 std::vector<const NodeDef*> body;
642 for (const NodeDef* n : nodes) {
643 if (n->op() == FunctionLibraryDefinition::kArgOp ||
644 n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
645 arg.push_back(n);
646 } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
647 n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
648 ret.push_back(n);
649 } else {
650 body.push_back(n);
651 }
652 }
653 auto comp = [](const NodeDef* x, const NodeDef* y) {
654 int xi;
655 TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
656 int yi;
657 TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
658 return xi < yi;
659 };
660 std::sort(arg.begin(), arg.end(), comp);
661 std::sort(ret.begin(), ret.end(), comp);
662 string out;
663 strings::StrAppend(&out, "\n(");
664 auto get_type_and_device = [](const NodeDef& n) {
665 DataType dt;
666 if (!TryGetNodeAttr(n, "T", &dt)) {
667 dt = DT_INVALID;
668 }
669 if (!n.device().empty()) {
670 DeviceNameUtils::ParsedName parsed;
671 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
672 return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
673 parsed.id);
674 } else {
675 LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
676 << n.op() << ":" << n.name();
677 return strings::StrCat(DataTypeString(dt), "@",
678 "<FAILED_TO_PARSE_DEVICE>");
679 }
680 }
681 return DataTypeString(dt);
682 };
683 for (size_t i = 0; i < arg.size(); ++i) {
684 const NodeDef* n = arg[i];
685 if (i > 0) strings::StrAppend(&out, ", ");
686 CHECK_GE(n->attr_size(), 2);
687 strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
688 }
689 strings::StrAppend(&out, ") -> (");
690 for (size_t i = 0; i < ret.size(); ++i) {
691 const NodeDef* n = ret[i];
692 if (i > 0) strings::StrAppend(&out, ", ");
693 CHECK_LE(2, n->attr_size());
694
695 // The _RetVal op should have a unique non-control input. We assert that
696 // here and add it to the output.
697 bool found_non_control_input = false;
698 for (const string& input : n->input()) {
699 if (!input.empty() && input[0] != '^') {
700 DCHECK_EQ(found_non_control_input, false)
701 << "RetVal node has more than one non-control input: "
702 << absl::StrJoin(n->input(), ", ");
703 strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
704 found_non_control_input = true;
705 }
706 }
707 DCHECK_EQ(found_non_control_input, true)
708 << "RetVal did not have any non-control inputs: "
709 << absl::StrJoin(n->input(), ", ");
710 }
711 strings::StrAppend(&out, ") {\n");
712 for (size_t i = 0; i < body.size(); ++i) {
713 strings::StrAppend(&out, " ", Print(*body[i]), "\n");
714 }
715 strings::StrAppend(&out, "}\n");
716 return out;
717 }
718
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)719 Status AddDefaultAttrs(const string& op,
720 const GetFunctionSignature& get_function,
721 AttrValueMap* attrs) {
722 const OpDef* op_def = nullptr;
723 TF_RETURN_IF_ERROR(get_function(op, &op_def));
724 AttrSlice attr_slice(attrs);
725 for (const auto& attr_def : op_def->attr()) {
726 if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
727 if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
728 return errors::Internal("Somehow duplicated: ", attr_def.name());
729 }
730 }
731 }
732 return OkStatus();
733 }
734
735 } // end namespace
736
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)737 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
738 GetFunctionSignature get_function,
739 InstantiationResult* result) {
740 if (VLOG_IS_ON(5)) {
741 const auto& signature = fdef.signature();
742 VLOG(5) << "Instantiate function definition: name=" << signature.name()
743 << " #input_args=" << signature.input_arg_size()
744 << " #output_args=" << signature.output_arg_size()
745 << " #control_output=" << signature.control_output_size();
746 for (const auto& line : str_util::Split(Print(fdef), '\n')) {
747 VLOG(5) << "|| " << line;
748 }
749 }
750
751 const OpDef& sig = fdef.signature();
752 TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
753
754 const AttrValue* attr_values_ints_on_device =
755 attr_values.Find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
756 bool ints_on_device =
757 (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
758 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) ||
759 (attr_values_ints_on_device != nullptr &&
760 attr_values_ints_on_device->b());
761
762 FunctionInstantiationHelper helper(get_function, result);
763 Status s;
764 for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
765 const OpDef::ArgDef& arg_def = sig.input_arg(i);
766 auto it = fdef.arg_attr().find(i);
767 const FunctionDef::ArgAttrs* arg_attrs =
768 it != fdef.arg_attr().end() ? &it->second : nullptr;
769 auto resource_id_it = fdef.resource_arg_unique_id().find(i);
770 int64_t resource_arg_unique_id =
771 resource_id_it != fdef.resource_arg_unique_id().end()
772 ? resource_id_it->second
773 : -1LL;
774 s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
775 ints_on_device, resource_arg_unique_id);
776
777 if (!s.ok()) {
778 errors::AppendToMessage(&s, "In ", Print(arg_def));
779 return s;
780 }
781 }
782
783 auto substitute = [attr_values, &sig](const string& name, AttrValue* val) {
784 // Look for a specified value...
785 if (const AttrValue* v = attr_values.FindByString(name)) {
786 *val = *v;
787 return true;
788 }
789 // .. and if not, then check for a default value.
790 if (const OpDef::AttrDef* attr = FindAttr(name, sig)) {
791 if (attr->has_default_value()) {
792 *val = attr->default_value();
793 return true;
794 }
795 }
796 // No luck finding a substitution.
797 return false;
798 };
799
800 // Makes a copy of all attrs in fdef and substitutes placeholders.
801 // After this step, every attr is bound to a concrete value.
802 std::vector<AttrValueMap> node_attrs;
803 node_attrs.resize(fdef.node_def_size());
804 for (int i = 0; i < fdef.node_def_size(); ++i) {
805 for (auto attr : fdef.node_def(i).attr()) {
806 if (!SubstitutePlaceholders(substitute, &attr.second)) {
807 return errors::InvalidArgument("Failed to bind all placeholders in ",
808 SummarizeAttrValue(attr.second));
809 }
810 if (!node_attrs[i].insert(attr).second) {
811 return errors::Internal("Somehow duplicated: ", attr.first);
812 }
813 }
814 TF_RETURN_IF_ERROR(
815 AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
816 }
817
818 for (int i = 0; i < fdef.node_def_size(); ++i) {
819 s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
820 result->nodes.size() + i);
821 if (!s.ok()) {
822 errors::AppendToMessage(&s, "In ",
823 FormatNodeDefForError(fdef.node_def(i)));
824 return s;
825 }
826 }
827 // Emits one node for each fdef.node_def.
828 for (int i = 0; i < fdef.node_def_size(); ++i) {
829 s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
830 if (!s.ok()) {
831 errors::AppendToMessage(&s, "In ",
832 FormatNodeDefForError(fdef.node_def(i)));
833 return s;
834 }
835 }
836
837 // Emits nodes for the function's return values.
838 int ret_index = 0;
839 for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
840 s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
841 &ret_index);
842 if (!s.ok()) {
843 errors::AppendToMessage(&s, "In function output ", Print(ret_def));
844 return s;
845 }
846 }
847
848 // Adds the actual node inputs using the input indexes.
849 helper.AddNodeInputs();
850
851 return OkStatus();
852 }
853
DebugString(const FunctionDef & func_def)854 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
855
DebugString(const GraphDef & instantiated_func_def)856 string DebugString(const GraphDef& instantiated_func_def) {
857 std::vector<const NodeDef*> ptrs;
858 for (const NodeDef& n : instantiated_func_def.node()) {
859 ptrs.push_back(&n);
860 }
861 return Print(ptrs);
862 }
863
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)864 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
865 std::vector<const NodeDef*> ptrs;
866 for (const NodeDef& n : instantiated_func_nodes) {
867 ptrs.push_back(&n);
868 }
869 return Print(ptrs);
870 }
871
DebugStringWhole(const GraphDef & gdef)872 string DebugStringWhole(const GraphDef& gdef) {
873 string ret;
874 for (const auto& fdef : gdef.library().function()) {
875 strings::StrAppend(&ret, Print(fdef));
876 }
877 strings::StrAppend(&ret, "\n");
878 for (const auto& ndef : gdef.node()) {
879 strings::StrAppend(&ret, Print(ndef), "\n");
880 }
881 return ret;
882 }
883
884 namespace {
885
886 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
887 // Python, it's possible to access unset attrs, which returns a default value
888 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)889 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
890 std::map<string, AttrValue> set_attrs;
891 for (const auto& pair : fdef.attr()) {
892 if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
893 set_attrs[pair.first] = pair.second;
894 }
895 }
896 return set_attrs;
897 }
898
899 } // end namespace
900
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)901 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
902 if (!OpDefEqual(f1.signature(), f2.signature())) return false;
903
904 std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
905 std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
906 if (f1_attrs.size() != f2_attrs.size()) return false;
907 for (const auto& iter1 : f1_attrs) {
908 auto iter2 = f2_attrs.find(iter1.first);
909 if (iter2 == f2_attrs.end()) return false;
910 if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
911 }
912
913 if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
914 return false;
915 }
916
917 std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
918 std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
919 if (ret1 != ret2) return false;
920
921 std::map<string, string> control_ret1(f1.control_ret().begin(),
922 f1.control_ret().end());
923 std::map<string, string> control_ret2(f2.control_ret().begin(),
924 f2.control_ret().end());
925 if (control_ret1 != control_ret2) return false;
926
927 return true;
928 }
929
FunctionDefHash(const FunctionDef & fdef)930 uint64 FunctionDefHash(const FunctionDef& fdef) {
931 // signature
932 uint64 h = OpDefHash(fdef.signature());
933
934 // attrs
935 std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
936 for (const auto& p : attrs) {
937 h = Hash64(p.first.data(), p.first.size(), h);
938 h = Hash64Combine(AttrValueHash(p.second), h);
939 }
940
941 // node defs
942 h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
943
944 // output names
945 std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
946 for (const auto& p : ret) {
947 h = Hash64(p.first.data(), p.first.size(), h);
948 h = Hash64(p.second.data(), p.second.size(), h);
949 }
950
951 // control output names
952 std::map<string, string> control_ret(fdef.control_ret().begin(),
953 fdef.control_ret().end());
954 for (const auto& p : control_ret) {
955 h = Hash64(p.first.data(), p.first.size(), h);
956 h = Hash64(p.second.data(), p.second.size(), h);
957 }
958
959 return h;
960 }
961
962 static constexpr const char* const kExecutorAttr = "_executor";
963
964 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)965 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
966 AttrSlice attrs) {
967 if (!options.executor_type.empty()) {
968 return options.executor_type;
969 } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
970 return executor_attr->s();
971 } else {
972 return string();
973 }
974 }
975
976 namespace {
977 class AttrKeyAndValue {
978 public:
979 enum ValueRepresentationOp {
980 kRaw,
981 kCEscape,
982 };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)983 AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
984 ValueRepresentationOp value_op = kRaw)
985 : key_name_(key_name),
986 key_suffix_(key_suffix),
987 value_op_(value_op),
988 value_(std::move(value)) {}
989
operator <(const AttrKeyAndValue & b) const990 bool operator<(const AttrKeyAndValue& b) const {
991 if (key_name_ != b.key_name_) {
992 return key_name_ < b.key_name_;
993 } else if (key_suffix_ != b.key_suffix_) {
994 return key_suffix_ < b.key_suffix_;
995 } else {
996 return value_ < b.value_;
997 }
998 }
999
AppendTo(bool first,string * s) const1000 void AppendTo(bool first, string* s) const {
1001 absl::string_view v;
1002 bool add_escaped = false;
1003 if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
1004 // Use CEscape call below
1005 add_escaped = true;
1006 } else {
1007 // Add raw value contents directly
1008 v = value_;
1009 }
1010 if (key_suffix_ >= 0) {
1011 strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
1012 } else {
1013 strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
1014 }
1015 if (add_escaped) {
1016 strings::StrAppend(s, absl::CEscape(value_));
1017 }
1018 }
1019
1020 private:
NeedsEscaping(const string & s)1021 static bool NeedsEscaping(const string& s) {
1022 for (auto c : s) {
1023 if (!isalnum(c) && (c != ' ')) {
1024 return true;
1025 }
1026 }
1027 return false;
1028 }
1029
1030 absl::string_view key_name_;
1031 int key_suffix_; // -1 if missing
1032 ValueRepresentationOp value_op_;
1033 string value_;
1034 };
1035 } // namespace
1036
GetFunctionResourceInputDevice(const Tensor & input,const int arg_index,const FunctionDef & function_def,absl::flat_hash_map<string,std::vector<string>> * composite_devices)1037 string GetFunctionResourceInputDevice(
1038 const Tensor& input, const int arg_index, const FunctionDef& function_def,
1039 absl::flat_hash_map<string, std::vector<string>>* composite_devices) {
1040 const auto& handles = input.flat<ResourceHandle>();
1041 const ResourceHandle& handle0 = handles(0);
1042 string composite_device;
1043 auto iter = function_def.arg_attr().find(arg_index);
1044 if (iter != function_def.arg_attr().end()) {
1045 auto arg_attr = iter->second.attr().find("_composite_device");
1046 if (arg_attr != iter->second.attr().end()) {
1047 composite_device = arg_attr->second.s();
1048 }
1049 }
1050 if (!composite_device.empty()) {
1051 if (composite_devices->find(composite_device) == composite_devices->end()) {
1052 for (int i = 0; i < handles.size(); ++i) {
1053 (*composite_devices)[composite_device].push_back(handles(i).device());
1054 }
1055 }
1056 return composite_device;
1057 } else {
1058 return handle0.device();
1059 }
1060 }
1061
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)1062 string Canonicalize(const string& funcname, AttrSlice attrs,
1063 const FunctionLibraryRuntime::InstantiateOptions& options) {
1064 absl::InlinedVector<AttrKeyAndValue, 8> entries;
1065 entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1066 options.input_devices.size());
1067 for (const auto& p : attrs) {
1068 if (p.first != kExecutorAttr) {
1069 entries.push_back(AttrKeyAndValue(
1070 p.first, -1, Print(p.second, /*hash_string_attrs=*/true)));
1071 }
1072 }
1073 if (!options.target.empty()) {
1074 entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1075 AttrKeyAndValue::kCEscape));
1076 }
1077 for (int i = 0; i < options.input_devices.size(); ++i) {
1078 entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1079 AttrKeyAndValue::kCEscape));
1080 }
1081 for (int i = 0; i < options.output_devices.size(); ++i) {
1082 entries.push_back(AttrKeyAndValue("_output_dev", i,
1083 options.output_devices[i],
1084 AttrKeyAndValue::kCEscape));
1085 }
1086 for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1087 entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1088 DataTypeString(iter.second.dtype)));
1089 entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1090 iter.second.shape.DebugString(),
1091 AttrKeyAndValue::kCEscape));
1092 }
1093 if (options.lib_def) {
1094 entries.push_back(AttrKeyAndValue(
1095 "_lib_def", -1,
1096 absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1097 }
1098 if (!options.state_handle.empty()) {
1099 entries.push_back(
1100 AttrKeyAndValue("_state_handle", -1, options.state_handle));
1101 }
1102 string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1103 if (!executor_type.empty()) {
1104 entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1105 }
1106 if (options.config_proto.ByteSize() > 0) {
1107 string config_proto_serialized;
1108 SerializeToStringDeterministic(options.config_proto,
1109 &config_proto_serialized);
1110 entries.push_back(AttrKeyAndValue("_config_proto", -1,
1111 config_proto_serialized,
1112 AttrKeyAndValue::kCEscape));
1113 }
1114 std::sort(entries.begin(), entries.end());
1115 string result = strings::StrCat(funcname, "[");
1116 bool first = true;
1117 for (const auto& entry : entries) {
1118 entry.AppendTo(first, &result);
1119 first = false;
1120 }
1121 result += "]";
1122 return result;
1123 }
1124
Canonicalize(const string & funcname,AttrSlice attrs)1125 string Canonicalize(const string& funcname, AttrSlice attrs) {
1126 static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1127 new FunctionLibraryRuntime::InstantiateOptions;
1128 return Canonicalize(funcname, attrs, *kEmptyOptions);
1129 }
1130
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1131 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1132 DataTypeSlice ret_types)
1133 : arg_types_(arg_types.begin(), arg_types.end()),
1134 ret_types_(ret_types.begin(), ret_types.end()) {
1135 args_.resize(arg_types_.size());
1136 rets_.resize(ret_types_.size());
1137 }
1138
~FunctionCallFrame()1139 FunctionCallFrame::~FunctionCallFrame() {}
1140
SetArgs(gtl::ArraySlice<Tensor> args)1141 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1142 // Input type checks.
1143 if (args.size() != arg_types_.size()) {
1144 return errors::InvalidArgument("Expects ", arg_types_.size(),
1145 " arguments, but ", args.size(),
1146 " is provided");
1147 }
1148 for (size_t i = 0; i < args.size(); ++i) {
1149 if (arg_types_[i] != args[i].dtype()) {
1150 return errors::InvalidArgument(
1151 "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1152 DataTypeString(args[i].dtype()), " is provided");
1153 }
1154 args_[i] = args[i];
1155 }
1156 return OkStatus();
1157 }
1158
GetRetvals(std::vector<Tensor> * rets) const1159 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1160 rets->clear();
1161 rets->reserve(rets_.size());
1162 for (size_t i = 0; i < rets_.size(); ++i) {
1163 const auto& item = rets_[i];
1164 if (item.has_val) {
1165 rets->push_back(item.val);
1166 } else {
1167 return errors::Internal("Retval[", i, "] does not have value");
1168 }
1169 }
1170 return OkStatus();
1171 }
1172
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1173 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1174 bool allow_dead_tensors) {
1175 rets->clear();
1176 rets->reserve(rets_.size());
1177 for (size_t i = 0; i < rets_.size(); ++i) {
1178 if (rets_[i].has_val) {
1179 rets->emplace_back(std::move(rets_[i].val));
1180 } else if (allow_dead_tensors) {
1181 rets->emplace_back();
1182 } else {
1183 return errors::Internal("Retval[", i, "] does not have value");
1184 }
1185 }
1186 return OkStatus();
1187 }
1188
GetArg(int index,const Tensor ** val)1189 Status FunctionCallFrame::GetArg(int index, const Tensor** val) {
1190 if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1191 return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1192 args_.size(), ")");
1193 }
1194 *val = &args_[index];
1195 return OkStatus();
1196 }
1197
SetRetval(int index,const Tensor & val)1198 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1199 if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1200 return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1201 rets_.size(), ")");
1202 }
1203 if (val.dtype() != ret_types_[index]) {
1204 return errors::InvalidArgument(
1205 "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1206 ", but ", DataTypeString(val.dtype()), " is provided.");
1207 }
1208 Retval* item = &rets_[index];
1209 if (!item->has_val) {
1210 item->has_val = true;
1211 item->val = val;
1212 } else {
1213 return errors::Internal("Retval[", index, "] has already been set.");
1214 }
1215 return OkStatus();
1216 }
1217
1218 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in,const StackTracesMap & stack_traces)1219 FunctionDefAndOpRegistration(const FunctionDef& fdef_in,
1220 const StackTracesMap& stack_traces)
1221 : fdef(fdef_in),
1222 // Exact shape inference for functions is handled by ShapeRefiner.
1223 // Here we pass a dummy shape inference function for legacy code paths.
1224 op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1225 true /* is_function */),
1226 stack_traces(stack_traces) {}
1227
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1228 FunctionLibraryDefinition::FunctionLibraryDefinition(
1229 const FunctionLibraryDefinition& other)
1230 : default_registry_(other.default_registry_) {
1231 tf_shared_lock l(other.mu_);
1232 function_defs_ = other.function_defs_;
1233 func_grad_ = other.func_grad_;
1234 }
1235
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1236 FunctionLibraryDefinition::FunctionLibraryDefinition(
1237 const OpRegistryInterface* default_registry,
1238 const FunctionDefLibrary& def_lib)
1239 : default_registry_(default_registry),
1240 function_defs_(def_lib.function_size()) {
1241 for (const auto& fdef : def_lib.function()) {
1242 // The latter function definition wins.
1243 auto& ptr = function_defs_[fdef.signature().name()];
1244 ptr.reset(new FunctionDefAndOpRegistration(fdef));
1245 }
1246 for (const auto& grad : def_lib.gradient()) {
1247 func_grad_[grad.function_name()] = grad.gradient_func();
1248 }
1249 }
1250
~FunctionLibraryDefinition()1251 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1252
Contains(const string & func) const1253 bool FunctionLibraryDefinition::Contains(const string& func) const {
1254 tf_shared_lock l(mu_);
1255 return function_defs_.find(func) != function_defs_.end();
1256 }
1257
Find(const string & func) const1258 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1259 tf_shared_lock l(mu_);
1260 auto result = FindHelper(func);
1261 if (result) {
1262 return &result->fdef;
1263 } else {
1264 return nullptr;
1265 }
1266 }
1267
1268 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1269 FunctionLibraryDefinition::FindHelper(const string& func) const {
1270 auto iter = function_defs_.find(func);
1271 if (iter == function_defs_.end()) {
1272 return nullptr;
1273 } else {
1274 return iter->second;
1275 }
1276 }
1277
AddFunctionDef(const FunctionDef & fdef,const StackTracesMap & stack_traces)1278 Status FunctionLibraryDefinition::AddFunctionDef(
1279 const FunctionDef& fdef, const StackTracesMap& stack_traces) {
1280 mutex_lock l(mu_);
1281 bool added;
1282 return AddFunctionDefHelper(fdef, stack_traces, &added);
1283 }
1284
AddFunctionDefHelper(const FunctionDef & fdef,const StackTracesMap & stack_traces,bool * added)1285 Status FunctionLibraryDefinition::AddFunctionDefHelper(
1286 const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) {
1287 *added = false;
1288 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1289 function_defs_[fdef.signature().name()];
1290 if (entry) {
1291 if (!FunctionDefsEqual(entry->fdef, fdef)) {
1292 return errors::InvalidArgument(
1293 "Cannot add function '", fdef.signature().name(),
1294 "' because a different function with the same name already "
1295 "exists.");
1296 }
1297 // Ignore duplicate FunctionDefs.
1298 return OkStatus();
1299 }
1300 const OpDef* op_def;
1301 if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1302 return errors::InvalidArgument(
1303 "Cannot add function '", fdef.signature().name(),
1304 "' because an op with the same name already exists.");
1305 }
1306 entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces);
1307 *added = true;
1308 return OkStatus();
1309 }
1310
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1311 Status FunctionLibraryDefinition::AddHelper(
1312 std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1313 *added = false;
1314 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1315 function_defs_[registration->fdef.signature().name()];
1316 if (entry) {
1317 if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1318 return errors::InvalidArgument(
1319 "Cannot add function '", registration->fdef.signature().name(),
1320 "' because a different function with the same name already "
1321 "exists.");
1322 }
1323 // Ignore duplicate FunctionDefs.
1324 return OkStatus();
1325 }
1326 const OpDef* op_def;
1327 if (default_registry_
1328 ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1329 .ok()) {
1330 return errors::InvalidArgument(
1331 "Cannot add function '", registration->fdef.signature().name(),
1332 "' because an op with the same name already exists.");
1333 }
1334 entry = std::move(registration);
1335 *added = true;
1336 return OkStatus();
1337 }
1338
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1339 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1340 const string& func, const FunctionLibraryDefinition& other) {
1341 if (default_registry_ != other.default_registry_) {
1342 return errors::InvalidArgument(
1343 "Cannot copy function '", func,
1344 "' because CopyFunctionDefFrom() requires that both libraries have the "
1345 "same default registry.");
1346 }
1347 std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1348 {
1349 tf_shared_lock l(other.mu_);
1350 function_def = other.FindHelper(func);
1351 }
1352 if (!function_def) {
1353 return errors::InvalidArgument(
1354 "Cannot copy function '", func,
1355 "' because no function with that name exists in the other library.");
1356 }
1357 {
1358 mutex_lock l(mu_);
1359 std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1360 if (entry) {
1361 if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1362 return errors::InvalidArgument(
1363 "Cannot copy function '", func,
1364 "' because a different function with the same name already "
1365 "exists.");
1366 }
1367 } else {
1368 entry = std::move(function_def);
1369 }
1370 }
1371 return OkStatus();
1372 }
1373
AddGradientDef(const GradientDef & grad)1374 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1375 mutex_lock l(mu_);
1376 bool added;
1377 return AddGradientDefHelper(grad, &added);
1378 }
1379
AddGradientDefHelper(const GradientDef & grad,bool * added)1380 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1381 bool* added) {
1382 *added = false;
1383 string* entry = &func_grad_[grad.function_name()];
1384 if (!entry->empty()) {
1385 if (*entry != grad.gradient_func()) {
1386 return errors::InvalidArgument(
1387 "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1388 grad.function_name(), "' because it already has gradient function ",
1389 "'", *entry, "'");
1390 }
1391 // Ignore duplicate GradientDefs
1392 return OkStatus();
1393 }
1394 *entry = grad.gradient_func();
1395 *added = true;
1396 return OkStatus();
1397 }
1398
AddLibrary(const FunctionLibraryDefinition & other)1399 Status FunctionLibraryDefinition::AddLibrary(
1400 const FunctionLibraryDefinition& other) {
1401 // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1402 // the duration of the function could lead to deadlock).
1403 FunctionLibraryDefinition clone(other);
1404 mutex_lock l(mu_);
1405 mutex_lock l2(clone.mu_);
1406 // Remember the funcs and grads that we added successfully so that
1407 // we can roll them back on error.
1408 std::vector<string> funcs;
1409 std::vector<string> funcs_with_grads;
1410 Status s;
1411 bool added;
1412 for (auto iter : clone.function_defs_) {
1413 s = AddHelper(iter.second, &added);
1414 if (!s.ok()) {
1415 Status remove_status = Remove(funcs, funcs_with_grads);
1416 if (!remove_status.ok()) {
1417 return remove_status;
1418 }
1419 return s;
1420 }
1421 if (added) {
1422 funcs.push_back(iter.second->fdef.signature().name());
1423 }
1424 }
1425 for (auto iter : clone.func_grad_) {
1426 GradientDef grad;
1427 grad.set_function_name(iter.first);
1428 grad.set_gradient_func(iter.second);
1429 s = AddGradientDefHelper(grad, &added);
1430 if (!s.ok()) {
1431 Status remove_status = Remove(funcs, funcs_with_grads);
1432 if (!remove_status.ok()) {
1433 return remove_status;
1434 }
1435 return s;
1436 }
1437 if (added) {
1438 funcs_with_grads.push_back(grad.function_name());
1439 }
1440 }
1441 return OkStatus();
1442 }
1443
AddLibrary(const FunctionDefLibrary & lib_def)1444 Status FunctionLibraryDefinition::AddLibrary(
1445 const FunctionDefLibrary& lib_def) {
1446 // Remember the funcs and grads that we added successfully so that
1447 // we can roll them back on error.
1448 mutex_lock l(mu_);
1449 std::vector<string> funcs;
1450 std::vector<string> funcs_with_grads;
1451 Status s;
1452 bool added;
1453 for (const FunctionDef& fdef : lib_def.function()) {
1454 s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added);
1455 if (!s.ok()) {
1456 Status remove_status = Remove(funcs, funcs_with_grads);
1457 if (!remove_status.ok()) {
1458 return remove_status;
1459 }
1460 return s;
1461 }
1462 if (added) {
1463 funcs.push_back(fdef.signature().name());
1464 }
1465 }
1466 for (const GradientDef& grad : lib_def.gradient()) {
1467 s = AddGradientDefHelper(grad, &added);
1468 if (!s.ok()) {
1469 Status remove_status = Remove(funcs, funcs_with_grads);
1470 if (!remove_status.ok()) {
1471 return remove_status;
1472 }
1473 return s;
1474 }
1475 if (added) {
1476 funcs_with_grads.push_back(grad.function_name());
1477 }
1478 }
1479 return OkStatus();
1480 }
1481
ReplaceFunction(const string & func,const FunctionDef & fdef,const StackTracesMap & stack_traces)1482 Status FunctionLibraryDefinition::ReplaceFunction(
1483 const string& func, const FunctionDef& fdef,
1484 const StackTracesMap& stack_traces) {
1485 mutex_lock l(mu_);
1486 bool added;
1487 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1488 TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, stack_traces, &added));
1489 return OkStatus();
1490 }
1491
ReplaceGradient(const GradientDef & grad)1492 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1493 mutex_lock l(mu_);
1494 bool added;
1495 TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1496 TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1497 return OkStatus();
1498 }
1499
RemoveFunction(const string & func)1500 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1501 mutex_lock l(mu_);
1502 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1503 return OkStatus();
1504 }
1505
RemoveFunctionHelper(const string & func)1506 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1507 const auto& i = function_defs_.find(func);
1508 if (i == function_defs_.end()) {
1509 return errors::InvalidArgument("Tried to remove non-existent function '",
1510 func, "'.");
1511 }
1512 function_defs_.erase(i);
1513 return OkStatus();
1514 }
1515
Clear()1516 void FunctionLibraryDefinition::Clear() {
1517 mutex_lock l(mu_);
1518 function_defs_.clear();
1519 func_grad_.clear();
1520 }
1521
RemoveGradient(const string & func)1522 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1523 const auto& i = func_grad_.find(func);
1524 if (i == func_grad_.end()) {
1525 return errors::InvalidArgument("Tried to remove non-existent gradient '",
1526 func, "'.");
1527 }
1528 func_grad_.erase(i);
1529 return OkStatus();
1530 }
1531
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1532 Status FunctionLibraryDefinition::Remove(
1533 const std::vector<string>& funcs,
1534 const std::vector<string>& funcs_with_grads) {
1535 Status s;
1536 for (const string& f : funcs) {
1537 s = RemoveFunctionHelper(f);
1538 if (!s.ok()) {
1539 return s;
1540 }
1541 }
1542 for (const string& f : funcs_with_grads) {
1543 s = RemoveGradient(f);
1544 if (!s.ok()) {
1545 return s;
1546 }
1547 }
1548 return OkStatus();
1549 }
1550
FindGradient(const string & func) const1551 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1552 tf_shared_lock l(mu_);
1553 return gtl::FindWithDefault(func_grad_, func, "");
1554 }
1555
FindGradientHelper(const string & func) const1556 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1557 return gtl::FindWithDefault(func_grad_, func, "");
1558 }
1559
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1560 Status FunctionLibraryDefinition::LookUp(
1561 const string& op, const OpRegistrationData** op_reg_data) const {
1562 tf_shared_lock l(mu_);
1563 auto iter = function_defs_.find(op);
1564 if (iter != function_defs_.end()) {
1565 *op_reg_data = &iter->second->op_registration_data;
1566 return OkStatus();
1567 }
1568 return default_registry_->LookUp(op, op_reg_data);
1569 }
1570
UniqueFunctionName(StringPiece prefix) const1571 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1572 tf_shared_lock l(mu_);
1573 int index = 0;
1574 string name = strings::StrCat(prefix, index);
1575 while (function_defs_.find(name) != function_defs_.end()) {
1576 ++index;
1577 name = strings::StrCat(prefix, index);
1578 }
1579 return name;
1580 }
1581
GetAttrImpl(const NodeDef & ndef) const1582 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1583 const NodeDef& ndef) const {
1584 if (ndef.op() != kGradientOp) {
1585 // If 'ndef' calls a function and the function's def has the attr,
1586 // returns it.
1587 return Find(ndef.op());
1588 }
1589
1590 // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1591 // Foo's attributes.
1592 const NameAttrList* forward_func_attrs;
1593 if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1594 return nullptr;
1595 }
1596 const string& func_name = forward_func_attrs->name();
1597 {
1598 tf_shared_lock l(mu_);
1599 const string& grad_name = FindGradientHelper(func_name);
1600 // If 'func' has a user-defined gradient function, uses the grad
1601 // function's attrs to see if noinline is specified. Otherwise,
1602 // uses func's attrs.
1603 if (!grad_name.empty()) {
1604 if (const auto helper = FindHelper(grad_name)) {
1605 return &(helper->fdef);
1606 } else {
1607 return nullptr;
1608 }
1609 }
1610 if (const auto helper = FindHelper(func_name)) {
1611 return &(helper->fdef);
1612 } else {
1613 return nullptr;
1614 }
1615 }
1616 }
1617
ListFunctionNames() const1618 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1619 std::vector<string> function_names;
1620 tf_shared_lock l(mu_);
1621 function_names.reserve(function_defs_.size());
1622 for (const auto& it : function_defs_) {
1623 function_names.emplace_back(it.first);
1624 }
1625 return function_names;
1626 }
1627
ToProto() const1628 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1629 FunctionDefLibrary lib;
1630 tf_shared_lock l(mu_);
1631 for (const auto& f : function_defs_) {
1632 *lib.add_function() = f.second->fdef;
1633 }
1634 for (const auto& g : func_grad_) {
1635 GradientDef* gd = lib.add_gradient();
1636 gd->set_function_name(g.first);
1637 gd->set_gradient_func(g.second);
1638 }
1639 return lib;
1640 }
1641
1642 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1643 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1644 const string& attr, T* value) const {
1645 const FunctionDef* fdef = GetAttrImpl(ndef);
1646 if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1647 return OkStatus();
1648 }
1649 return errors::InvalidArgument("Attr ", attr, " is not defined.");
1650 }
1651
1652 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1653 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1654 T* value) const {
1655 return GetAttr(node.def(), attr, value);
1656 }
1657
1658 #define GET_ATTR(T) \
1659 template Status FunctionLibraryDefinition::GetAttr(const Node&, \
1660 const string&, T*) const; \
1661 template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
1662 const string&, T*) const;
1663 GET_ATTR(string)
1664 GET_ATTR(bool)
1665 #undef GET_ATTR
1666
1667 namespace {
1668
1669 constexpr char kApiImplements[] = "api_implements";
1670
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1671 std::set<string> ReachableFunctions(
1672 const FunctionLibraryDefinition& flib,
1673 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1674 // Functions that are reachable from the graph.
1675 std::set<string> reachable_funcs;
1676
1677 // For any functions, if it has attribute "api_implements" =
1678 // "some_interface" and it is reachable, then it means any other
1679 // function with same attribute name and value could also be potentially
1680 // reachable, eg via implementation_selector swapping the nodedef.
1681 absl::flat_hash_set<string> reachable_api_interface;
1682
1683 // Functions might be reachable from the nested function calls, so we keep a
1684 // queue of functions that we have to check.
1685 gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1686
1687 // Add reachable and not already processed functions to the functions queue.
1688 const auto add_to_func_queue = [&](const string& func_name) {
1689 const FunctionDef* func = flib.Find(func_name);
1690 if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1691 func_queue.push_back(func);
1692 }
1693 };
1694
1695 // If any function with certain API name is reachable, all the other functions
1696 // with same API name should also be checked.
1697 const auto add_function_with_api_interface = [&](const string& api_name) {
1698 if (!reachable_api_interface.contains(api_name)) {
1699 reachable_api_interface.insert(api_name);
1700 for (const auto& func_name : flib.ListFunctionNames()) {
1701 const auto& func_def = flib.Find(func_name);
1702 const auto attr_it = func_def->attr().find(kApiImplements);
1703 if (attr_it != func_def->attr().end() &&
1704 attr_it->second.s() == api_name) {
1705 add_to_func_queue(func_name);
1706 }
1707 }
1708 }
1709 };
1710
1711 // Add all the functions that are reachable from the given node to the queue.
1712 const auto process_node = [&](const NodeDef& node) {
1713 // Node itself can be a call to the function.
1714 add_to_func_queue(node.op());
1715
1716 // Or node can have an attribute referencing a function.
1717 for (const auto& attr : node.attr()) {
1718 const auto& attr_value = attr.second;
1719
1720 // 1. AttrValue.func
1721 if (attr_value.has_func()) {
1722 add_to_func_queue(attr_value.func().name());
1723 }
1724
1725 // 2. AttrValue.ListValue.func
1726 if (attr_value.has_list()) {
1727 for (const auto& func : attr_value.list().func()) {
1728 add_to_func_queue(func.name());
1729 }
1730 }
1731 }
1732 };
1733
1734 // Add all functions that are directly called from the optimized graph.
1735 std::for_each(nodes.begin(), nodes.end(), process_node);
1736
1737 // Process all reachable functions.
1738 while (!func_queue.empty()) {
1739 const FunctionDef* func = func_queue.back();
1740 func_queue.pop_back();
1741
1742 const string& func_name = func->signature().name();
1743 reachable_funcs.insert(func_name);
1744
1745 const auto attr_it = func->attr().find(kApiImplements);
1746 if (attr_it != func->attr().end()) {
1747 add_function_with_api_interface(attr_it->second.s());
1748 }
1749
1750 // Find all the functions called from the function body.
1751 const auto& func_body = func->node_def();
1752 std::for_each(func_body.begin(), func_body.end(), process_node);
1753
1754 // Check if the function has a registered gradient.
1755 const string grad_func_name = flib.FindGradient(func_name);
1756 if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1757 }
1758
1759 return reachable_funcs;
1760 }
1761
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1762 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1763 const FunctionLibraryDefinition& flib,
1764 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1765 std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1766
1767 FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1768 FunctionDefLibrary());
1769
1770 for (const string& func_name : reachable_funcs) {
1771 // This should never fail, because we copy functions from a valid flib and
1772 // use the same default registry.
1773 Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1774 TF_DCHECK_OK(added);
1775
1776 const string grad_func_name = flib.FindGradient(func_name);
1777 if (!grad_func_name.empty()) {
1778 GradientDef grad;
1779 grad.set_function_name(func_name);
1780 grad.set_gradient_func(grad_func_name);
1781 // It can only fail if function already has a gradient function.
1782 const Status added_grad = reachable_flib.AddGradientDef(grad);
1783 TF_DCHECK_OK(added_grad);
1784 }
1785 }
1786
1787 return reachable_flib;
1788 }
1789
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1790 string AllocatorAttributesToString(
1791 const std::vector<AllocatorAttributes>& attrs) {
1792 string result("[");
1793 // AllocatorAttribute::DebugString produces around 85 bytes now.
1794 result.reserve(100 * attrs.size());
1795 for (const AllocatorAttributes& attr : attrs) {
1796 result.append(attr.DebugString());
1797 result.append(", ");
1798 }
1799 if (!attrs.empty()) {
1800 result.resize(result.size() - 2);
1801 }
1802 result.append("]");
1803 return result;
1804 }
1805
IsSet(void * ptr)1806 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1807
1808 } // namespace
1809
ReachableDefinitions(const GraphDef & graph) const1810 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1811 const GraphDef& graph) const {
1812 return ReachableFunctionLibraryDefinition(*this, graph.node());
1813 }
1814
ReachableDefinitions(const FunctionDef & func) const1815 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1816 const FunctionDef& func) const {
1817 return ReachableFunctionLibraryDefinition(*this, func.node_def());
1818 }
1819
DebugString() const1820 string FunctionLibraryRuntime::Options::DebugString() const {
1821 return absl::StrCat(
1822 "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1823 " cancellation_manager=", IsSet(cancellation_manager),
1824 " collective_executor=", IsSet(collective_executor),
1825 " step_container=", IsSet(step_container),
1826 " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1827 " remote_execution=", remote_execution, " source_device=", source_device,
1828 " create_rendezvous=", create_rendezvous,
1829 " allow_dead_tensors=", allow_dead_tensors,
1830 " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1831 " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1832 }
1833
InitFromString(StringPiece val)1834 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1835 if (val.size() >= 2 && val[0] == '$') {
1836 proto.set_placeholder(val.data() + 1, val.size() - 1);
1837 } else {
1838 SetAttrValue(val, &proto);
1839 }
1840 }
1841
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1842 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1843 const string& name,
1844 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1845 AttrValueWrapper ret;
1846 ret.proto.mutable_func()->set_name(name);
1847 for (const auto& a : attrs) {
1848 ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1849 }
1850 return ret;
1851 }
1852
ToNodeDef() const1853 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1854 NodeDef n;
1855 n.set_op(this->op);
1856 n.set_name(GetName());
1857 for (const auto& a : this->attr) {
1858 n.mutable_attr()->insert({a.first, a.second.proto});
1859 }
1860 for (const string& a : this->arg) {
1861 n.add_input(a);
1862 }
1863 for (const string& d : this->dep) {
1864 n.add_input(strings::StrCat("^", d));
1865 }
1866 if (!this->device.empty()) {
1867 n.set_device(this->device);
1868 }
1869 if (!this->original_node_names.empty()) {
1870 *n.mutable_experimental_debug_info()->mutable_original_node_names() = {
1871 this->original_node_names.begin(), this->original_node_names.end()};
1872 }
1873 if (!this->original_func_names.empty()) {
1874 *n.mutable_experimental_debug_info()->mutable_original_func_names() = {
1875 this->original_func_names.begin(), this->original_func_names.end()};
1876 }
1877 return n;
1878 }
1879
1880 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1881 FunctionDef FunctionDefHelper::Create(
1882 const string& function_name, gtl::ArraySlice<string> in_def,
1883 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1884 gtl::ArraySlice<Node> node_def,
1885 gtl::ArraySlice<std::pair<string, string>> ret_def,
1886 gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1887 FunctionDef fdef;
1888
1889 // Signature
1890 OpDefBuilder b(function_name);
1891 for (const auto& i : in_def) b.Input(i);
1892 for (const auto& o : out_def) b.Output(o);
1893 for (const auto& a : attr_def) b.Attr(a);
1894 for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1895
1896 OpRegistrationData op_reg_data;
1897 TF_CHECK_OK(b.Finalize(&op_reg_data));
1898 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1899
1900 // Function body
1901 for (const auto& n : node_def) {
1902 *(fdef.add_node_def()) = n.ToNodeDef();
1903 }
1904
1905 // Returns
1906 for (const auto& r : ret_def) {
1907 fdef.mutable_ret()->insert({r.first, r.second});
1908 }
1909
1910 // Control returns
1911 for (const auto& cr : control_ret_def) {
1912 fdef.mutable_control_ret()->insert({cr.first, cr.second});
1913 }
1914
1915 auto* op_def_registry = OpRegistry::Global();
1916 // Check if any op is stateful.
1917 for (const auto& n : node_def) {
1918 const OpDef* op_def = nullptr;
1919 auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1920 // Lookup can fail if e.g. we are calling a function that was not yet
1921 // defined. If it happens, conservatively assume the op is stateful.
1922 if (!status.ok() || op_def->is_stateful()) {
1923 fdef.mutable_signature()->set_is_stateful(true);
1924 }
1925 }
1926
1927 return fdef;
1928 }
1929
1930 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1931 FunctionDef FunctionDefHelper::Create(
1932 const string& function_name, gtl::ArraySlice<string> in_def,
1933 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1934 gtl::ArraySlice<Node> node_def,
1935 gtl::ArraySlice<std::pair<string, string>> ret_def) {
1936 return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1937 /*control_ret_def=*/{});
1938 }
1939
1940 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1941 FunctionDef FunctionDefHelper::Define(const string& name,
1942 gtl::ArraySlice<string> arg_def,
1943 gtl::ArraySlice<string> ret_def,
1944 gtl::ArraySlice<string> attr_def,
1945 gtl::ArraySlice<Node> node_def) {
1946 FunctionDef fdef;
1947 OpDefBuilder b(name);
1948 for (const auto& a : arg_def) b.Input(a);
1949 for (const auto& r : ret_def) b.Output(r);
1950 for (const auto& a : attr_def) b.Attr(a);
1951
1952 OpRegistrationData op_reg_data;
1953 TF_CHECK_OK(b.Finalize(&op_reg_data));
1954 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1955
1956 // Mapping from legacy output names to NodeDef outputs.
1957 std::unordered_map<string, string> ret_index;
1958 for (const auto& a : fdef.signature().input_arg()) {
1959 ret_index[a.name()] = a.name();
1960 }
1961
1962 // For looking up OpDefs
1963 auto* op_def_registry = OpRegistry::Global();
1964
1965 // Function body
1966 for (const auto& src : node_def) {
1967 NodeDef* n = fdef.add_node_def();
1968 n->set_op(src.op);
1969 n->set_name(src.GetName());
1970 for (const auto& a : src.attr) {
1971 n->mutable_attr()->insert({a.first, a.second.proto});
1972 }
1973 for (const string& a : src.arg) {
1974 const auto iter = ret_index.find(a);
1975 CHECK(iter != ret_index.end())
1976 << "Node input '" << a << "' in '" << n->name() << "' of " << name;
1977 n->add_input(iter->second);
1978 }
1979 for (const string& d : src.dep) {
1980 n->add_input(strings::StrCat("^", d));
1981 }
1982
1983 // Add the outputs of this node to ret_index.
1984 const OpDef* op_def = nullptr;
1985 TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1986 CHECK(op_def != nullptr) << n->op();
1987 NameRangeMap output_names;
1988 TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1989 for (const auto& o : output_names) {
1990 CHECK_LE(o.second.second, src.ret.size())
1991 << "Missing ret for output '" << o.first << "' in '" << n->name()
1992 << "' of " << name;
1993 for (int i = o.second.first; i < o.second.second; ++i) {
1994 ret_index[src.ret[i]] =
1995 strings::StrCat(n->name(), ":", o.first, ":", i - o.second.first);
1996 }
1997 }
1998 if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1999 }
2000
2001 // Returns
2002 for (const auto& r : fdef.signature().output_arg()) {
2003 const auto iter = ret_index.find(r.name());
2004 CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
2005 fdef.mutable_ret()->insert({r.name(), iter->second});
2006 }
2007 return fdef;
2008 }
2009
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)2010 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
2011 gtl::ArraySlice<string> ret_def,
2012 gtl::ArraySlice<string> attr_def,
2013 gtl::ArraySlice<Node> node_def) {
2014 return Define("_", arg_def, ret_def, attr_def, node_def);
2015 }
2016
2017 namespace gradient {
2018
2019 typedef std::unordered_map<string, Creator> OpGradFactory;
2020
GetOpGradFactory()2021 OpGradFactory* GetOpGradFactory() {
2022 static OpGradFactory* factory = new OpGradFactory;
2023 return factory;
2024 }
2025
RegisterOp(const string & op,Creator func)2026 bool RegisterOp(const string& op, Creator func) {
2027 CHECK(GetOpGradFactory()->insert({op, func}).second)
2028 << "Duplicated gradient for " << op;
2029 return true;
2030 }
2031
GetOpGradientCreator(const string & op,Creator * creator)2032 Status GetOpGradientCreator(const string& op, Creator* creator) {
2033 auto fac = GetOpGradFactory();
2034 auto iter = fac->find(op);
2035 if (iter == fac->end()) {
2036 return errors::NotFound("No gradient defined for op: ", op);
2037 }
2038 *creator = iter->second;
2039 return OkStatus();
2040 }
2041
2042 } // end namespace gradient
2043
2044 } // namespace tensorflow
2045