1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17
18 #include <ctype.h>
19
20 #include <map>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/util/device_name_utils.h"
41 #include "tensorflow/core/util/equal_graph_def.h"
42
43 namespace tensorflow {
44
45 /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
46 /* static */ constexpr const char* const
47 FunctionLibraryDefinition::kDeviceArgOp;
48 /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
49 /* static */ constexpr const char* const
50 FunctionLibraryDefinition::kDeviceRetOp;
51 /* static */ constexpr const char* const
52 FunctionLibraryDefinition::kIntsOnDeviceAttr;
53 /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
54 /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
55
56 // Extracts the actual type from "attr_values" based on its definition
57 // "arg_def".
58 //
59 // If "arg_def" is a N*T type, *is_type_list is set to false, and
60 // *dtypes is set to be a vector of size N and each element is T.
61 //
62 // If "arg_def" is a list(type), *is_type_list is set to true, and
63 // *dtypes is set to be a vector of types specified in attrs for
64 // arg_def.
65 //
66 // Otherwise (arg_def is a simple type T), *is_type_list is set to
67 // false, and *dtypes is set to a single element vector, whose only
68 // element is T.
ArgNumType(AttrSlice attrs,const OpDef::ArgDef & arg_def,bool * is_type_list,DataTypeVector * dtypes)69 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
70 bool* is_type_list, DataTypeVector* dtypes) {
71 dtypes->clear();
72 if (!arg_def.type_list_attr().empty()) {
73 const AttrValue* v = attrs.Find(arg_def.type_list_attr());
74 if (v == nullptr) {
75 return errors::NotFound("type attr not found: ",
76 arg_def.type_list_attr());
77 }
78 *is_type_list = true;
79 for (int i = 0; i < v->list().type_size(); ++i) {
80 dtypes->push_back(v->list().type(i));
81 }
82 return Status::OK();
83 }
84
85 *is_type_list = false;
86 int num = 1;
87 if (!arg_def.number_attr().empty()) {
88 const AttrValue* v = attrs.Find(arg_def.number_attr());
89 if (v == nullptr) {
90 return errors::NotFound("type attr not found: ", arg_def.type_attr());
91 }
92 num = v->i();
93 }
94
95 DataType dtype;
96 if (arg_def.type() != DT_INVALID) {
97 dtype = arg_def.type();
98 } else if (arg_def.type_attr().empty()) {
99 dtype = DT_INVALID;
100 } else {
101 const AttrValue* v = attrs.Find(arg_def.type_attr());
102 if (v == nullptr) {
103 return errors::NotFound("type attr not found: ", arg_def.type_attr());
104 }
105 dtype = v->type();
106 }
107 dtypes->resize(num, dtype);
108 return Status::OK();
109 }
110
111 namespace {
112
113 template <typename T>
AddAttr(const string & name,const T & val,NodeDef * ndef)114 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
115 SetAttrValue(val, &((*ndef->mutable_attr())[name]));
116 }
117
ValidateSignatureWithAttrs(const OpDef & sig,AttrSlice attr_values)118 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
119 // attr_values should specify all attrs defined in fdef.
120 for (const auto& a : sig.attr()) {
121 const AttrValue* v = attr_values.Find(a.name());
122 if (!v) {
123 return errors::NotFound("Attr ", a.name(), " is not found from ",
124 SummarizeOpDef(sig));
125 }
126 Status status = AttrValueHasType(*v, a.type());
127 if (!status.ok()) {
128 errors::AppendToMessage(&status, "for attr '", a.name(), "'");
129 return status;
130 }
131 }
132
133 // TODO(josh11b): Enable this code once it works with function gradients.
134 // Right now the C++ function gradient code assumes it can pass
135 // all the attrs of the function to the gradient, and any attrs that
136 // the gradient doesn't care about will be ignored.
137 #if 0
138 if (attr_values.size() != sig.attr_size()) {
139 for (const auto& a : attr_values) {
140 // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
141 bool found = false;
142 for (const auto& s : sig.attr()) {
143 if (a.first == s.name()) {
144 found = true;
145 break;
146 }
147 }
148 if (!found) {
149 return errors::NotFound("Attr ", a.first, " is not found in ",
150 SummarizeOpDef(sig));
151 }
152 }
153 }
154 #endif
155
156 return Status::OK();
157 }
158
159 // A helper class for instantiating functions. This contains shared information
160 // like the resulting graph and node name index.
161 class FunctionInstantiationHelper {
162 public:
FunctionInstantiationHelper(GetFunctionSignature get_function,InstantiationResult * result)163 FunctionInstantiationHelper(GetFunctionSignature get_function,
164 InstantiationResult* result)
165 : get_function_(std ::move(get_function)), result_(*result) {
166 result_.nodes.clear();
167 }
168
169 // Builds index for nodes that can be used as node's input arguments.
170 // `resource_arg_unique_id`: if non-negative, will be populated to the
171 // "_resource_arg_unique_id" attribute of the arg node.
BuildInputArgIndex(const OpDef::ArgDef & arg_def,AttrSlice attr_values,const FunctionDef::ArgAttrs * arg_attrs,bool ints_on_device,int64 resource_arg_unique_id)172 Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
173 const FunctionDef::ArgAttrs* arg_attrs,
174 bool ints_on_device, int64 resource_arg_unique_id) {
175 bool is_type_list;
176 DataTypeVector dtypes;
177 TF_RETURN_IF_ERROR(
178 ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
179 CHECK_GE(dtypes.size(), size_t{1});
180 int arg_index = result_.nodes.size();
181 TF_RETURN_IF_ERROR(
182 AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
183 // Creates dtypes.size() nodes in the graph.
184 for (size_t i = 0; i < dtypes.size(); ++i) {
185 TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
186 {true, arg_index, 0, false, {dtypes[i]}}));
187 DCHECK_EQ(arg_index, result_.nodes.size());
188 string name = arg_def.name();
189 if (dtypes.size() > 1) {
190 strings::StrAppend(&name, "_", i);
191 }
192 NodeDef* gnode = AddNode(name);
193 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
194 gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
195 } else {
196 gnode->set_op(FunctionLibraryDefinition::kArgOp);
197 }
198 DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
199 AddAttr("T", dtype, gnode);
200 AddAttr("index", arg_index, gnode);
201 if (resource_arg_unique_id >= 0) {
202 AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
203 }
204 if (arg_attrs) {
205 for (const auto& arg_attr : arg_attrs->attr()) {
206 AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
207 }
208 }
209 result_.arg_types.push_back(dtypes[i]);
210 ++arg_index;
211 }
212 return Status::OK();
213 }
214
BuildNodeOutputIndex(const NodeDef & node,AttrSlice attrs,const int arg_index)215 Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
216 const int arg_index) {
217 const OpDef* node_sig = nullptr;
218 TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
219 if (node_sig->output_arg_size() == 0) {
220 return AddItem(node.name(), {false, arg_index, 0, false, {}});
221 }
222 const int num_retval = node_sig->output_arg_size();
223 int start = 0;
224 bool is_type_list;
225 DataTypeVector dtypes;
226 for (int i = 0; i < num_retval; ++i) {
227 TF_RETURN_IF_ERROR(
228 ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
229 // Note that we rely on the backwards-compatibility test enforcing
230 // that output_arg(*).name() doesn't change here.
231 const string base_name =
232 strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
233 TF_RETURN_IF_ERROR(
234 AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
235 for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
236 TF_RETURN_IF_ERROR(
237 AddItem(strings::StrCat(base_name, ":", j),
238 {false, arg_index, start + j, false, {dtypes[j]}}));
239 }
240 start += dtypes.size();
241 }
242 return Status::OK();
243 }
244
InstantiateNode(const NodeDef & fnode,AttrSlice attrs)245 Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
246 const OpDef* fnode_sig = nullptr;
247 TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
248 NodeDef* gnode = AddNode(fnode.name());
249 gnode->set_op(fnode.op());
250 gnode->set_device(fnode.device());
251 int gnode_idx = nodes_.size() - 1;
252
253 // Input
254 const int num_args = fnode_sig->input_arg_size();
255 bool is_type_list; // ignored
256 DataTypeVector dtypes;
257 int fnode_arg_index = 0;
258 for (int i = 0; i < num_args; ++i) {
259 TF_RETURN_IF_ERROR(
260 ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
261 // Consume inputs (indexed by fnode_arg_index) until we have
262 // matched each element of dtypes (indexed by j).
263 for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
264 if (fnode_arg_index >= fnode.input_size()) {
265 // Should never happen if we computed dtypes correctly.
266 return errors::InvalidArgument(
267 "Attempt to access beyond input size: ", fnode_arg_index,
268 " >= ", fnode.input_size());
269 }
270 // Look up the next input.
271 const string& input_name = fnode.input(fnode_arg_index);
272 const auto* item = GetItemOrNull(input_name);
273 if (item == nullptr) {
274 return errors::InvalidArgument(
275 "input ", input_name,
276 " is not found: ", FormatNodeDefForError(fnode));
277 }
278 if (item->dtypes.size() > dtypes.size() - j) {
279 return errors::InvalidArgument("Input ", input_name, " too long for ",
280 fnode_sig->input_arg(i).name());
281 }
282 // Match up all the elements of this input (indexed by k) with
283 // elements of dtypes (advancing j).
284 for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
285 if (item->dtypes[k] != dtypes[j]) {
286 return errors::InvalidArgument(
287 "input ", fnode_sig->input_arg(i).name(), "[", j,
288 "] expected type ", DataTypeString(dtypes[j]),
289 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
290 input_name, "[", k, "]");
291 }
292 if (item->is_func_arg) {
293 AddInput(gnode_idx, item->nid + k, 0);
294 } else {
295 AddInput(gnode_idx, item->nid, item->idx + k);
296 }
297 }
298 }
299 }
300
301 // Control deps.
302 for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
303 const string& input = fnode.input(i);
304 if (input.empty() || input[0] != '^') {
305 return errors::InvalidArgument("Expected input[", i, "] == '", input,
306 "' to be a control input.");
307 }
308 int nid = -1;
309 const string node_name = input.substr(1);
310 const string node_colon = node_name + ":";
311 const string node_colon_bound = node_name + ";";
312 // index_ is a map sorted lexicographically, so the key we are looking for
313 // must lie in the range [node_name, node_colon_bound).
314 auto it = index_.lower_bound(node_name);
315 while (it != index_.end() && it->first <= node_colon_bound) {
316 if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
317 nid = it->second.nid;
318 break;
319 }
320 ++it;
321 }
322 if (nid == -1) {
323 return errors::InvalidArgument("input[", i, "] == '", input,
324 "', is not found.");
325 }
326 AddDep(gnode_idx, nid);
327 }
328
329 // Attrs.
330 for (const auto& p : attrs) {
331 (*gnode->mutable_attr())[p.first] = p.second;
332 }
333
334 return Status::OK();
335 }
336
AddReturnNode(const OpDef::ArgDef & ret_def,AttrSlice attrs,const::tensorflow::protobuf::Map<string,string> & ret_map,bool ints_on_device,int * ret_index)337 Status AddReturnNode(
338 const OpDef::ArgDef& ret_def, AttrSlice attrs,
339 const ::tensorflow::protobuf::Map<string, string>& ret_map,
340 bool ints_on_device, int* ret_index) {
341 auto ret_iter = ret_map.find(ret_def.name());
342 if (ret_iter == ret_map.end()) {
343 return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
344 }
345 bool is_type_list;
346 DataTypeVector dtypes;
347 TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
348 CHECK_GE(dtypes.size(), size_t{1});
349 const auto* item = GetItemOrNull(ret_iter->second);
350 if (item == nullptr) {
351 return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
352 ret_iter->second, " is not found.");
353 }
354 if (dtypes != item->dtypes) {
355 return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
356 " : ", DataTypeVectorString(dtypes),
357 " vs. ",
358 DataTypeVectorString(item->dtypes));
359 }
360 for (size_t i = 0; i < dtypes.size(); ++i) {
361 string name = strings::StrCat(ret_def.name(), "_RetVal");
362 if (dtypes.size() > 1) {
363 strings::StrAppend(&name, "_", i);
364 }
365 NodeDef* gnode = AddNode(name);
366 if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
367 gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
368 } else {
369 gnode->set_op(FunctionLibraryDefinition::kRetOp);
370 }
371 AddInput(nodes_.size() - 1, item->nid, item->idx + i);
372 DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
373 AddAttr("T", dtype, gnode);
374 AddAttr("index", (*ret_index)++, gnode);
375 result_.ret_types.push_back(dtypes[i]);
376 }
377 return Status::OK();
378 }
379
380 // Adds the actual node inputs to the result graph by converting indexes to
381 // the node names.
AddNodeInputs()382 void AddNodeInputs() {
383 for (int i = 0; i < result_.nodes.size(); i++) {
384 NodeInfo& node_info = nodes_[i];
385 for (const auto& p : node_info.data_inputs) {
386 result_.nodes[i].add_input(Name(p.first, p.second));
387 }
388 for (int index : node_info.control_inputs) {
389 result_.nodes[i].add_input(Dep(index));
390 }
391 }
392 }
393
394 private:
395 // This is used to build a small index for all names that can be used as a
396 // node's input arguments.
397 //
398 // If is_func_arg is true, the name is a function's argument. In
399 // this case, the produced graph def has node[nid:nid + dtype.size()].
400 //
401 // Otherwise, the name is a function body's node return value. In
402 // this case, the produced graph def has one node node[nid] and
403 // the node's output index [idx ... idx + num) corresponds to the
404 // named outputs.
405 //
406 // In all cases, "dtype" specifies the data type.
407 struct NameInfoItem {
408 bool is_func_arg;
409 int nid;
410 int idx;
411 bool is_type_list;
412 DataTypeVector dtypes;
413 };
414
415 // Adds an item into the input name index.
AddItem(const string & name,const NameInfoItem & item)416 Status AddItem(const string& name, const NameInfoItem& item) {
417 if (!index_.insert({name, item}).second) {
418 return errors::InvalidArgument(
419 strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
420 " name: "),
421 name);
422 }
423 return Status::OK();
424 }
425
GetItemOrNull(const string & name) const426 const NameInfoItem* GetItemOrNull(const string& name) const {
427 return gtl::FindOrNull(index_, name);
428 }
429
Dep(int node_index) const430 string Dep(int node_index) const {
431 return strings::StrCat("^", Name(node_index));
432 }
433
Name(int node_index) const434 string Name(int node_index) const {
435 CHECK_LT(node_index, nodes_.size());
436 return nodes_[node_index].name;
437 }
438
Name(int node_index,int output_index) const439 string Name(int node_index, int output_index) const {
440 if (output_index == 0) {
441 return Name(node_index);
442 } else {
443 return strings::StrCat(Name(node_index), ":", output_index);
444 }
445 }
446
AddNode(const string & name)447 NodeDef* AddNode(const string& name) {
448 result_.nodes.emplace_back();
449 NodeDef* gnode = &result_.nodes.back();
450 gnode->set_name(name);
451 nodes_.push_back({name, {}, {}});
452 CHECK_EQ(result_.nodes.size(), nodes_.size());
453 return gnode;
454 }
455
AddInput(int node_index,int output_node,int output_index)456 void AddInput(int node_index, int output_node, int output_index) {
457 CHECK_LT(node_index, nodes_.size());
458 nodes_[node_index].data_inputs.push_back(
459 std::make_pair(output_node, output_index));
460 }
461
AddDep(int node_index,int dep_index)462 void AddDep(int node_index, int dep_index) {
463 CHECK_LT(node_index, nodes_.size());
464 nodes_[node_index].control_inputs.push_back(dep_index);
465 }
466
467 GetFunctionSignature get_function_;
468 InstantiationResult& result_;
469 // A small index for all names that can be used as a node's input arguments.
470 std::map<string, NameInfoItem> index_;
471 // This contains information about a node in the new graph including the node
472 // names and input nodes' indexes.
473 struct NodeInfo {
474 string name;
475 // Data inputs where <n, k> means arg k of node n.
476 std::vector<std::pair<int, int>> data_inputs;
477 // Control inputs (dependencies).
478 std::vector<int> control_inputs;
479 };
480 // nodes_[i] is the information about result_.nodes[i].
481 std::vector<NodeInfo> nodes_;
482 };
483
484 // Various helpers Print(proto) to print relevant protos to ascii.
Print(const OpDef::ArgDef & arg)485 string Print(const OpDef::ArgDef& arg) {
486 string out;
487 strings::StrAppend(&out, arg.name(), ":");
488 if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
489 if (!arg.number_attr().empty()) {
490 strings::StrAppend(&out, arg.number_attr(), "*");
491 }
492 if (arg.type() != DT_INVALID) {
493 strings::StrAppend(&out, DataTypeString(arg.type()));
494 } else {
495 strings::StrAppend(&out, arg.type_attr());
496 }
497 if (arg.is_ref()) strings::StrAppend(&out, ")");
498 return out;
499 }
500
501 // TODO(josh11b): Merge this with SummarizeAttrValue().
Print(const AttrValue & attr_value)502 string Print(const AttrValue& attr_value) {
503 if (attr_value.value_case() == AttrValue::kType) {
504 return DataTypeString(attr_value.type());
505 } else if ((attr_value.value_case() == AttrValue::kList) &&
506 (attr_value.list().type_size() > 0)) {
507 string ret = "{";
508 for (int i = 0; i < attr_value.list().type_size(); ++i) {
509 if (i > 0) strings::StrAppend(&ret, ", ");
510 strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
511 }
512 strings::StrAppend(&ret, "}");
513 return ret;
514 } else if (attr_value.value_case() == AttrValue::kFunc) {
515 if (attr_value.func().attr_size() == 0) {
516 return attr_value.func().name();
517 }
518 std::vector<string> entries;
519 for (const auto& p : attr_value.func().attr()) {
520 entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
521 }
522 std::sort(entries.begin(), entries.end());
523 return strings::StrCat(attr_value.func().name(), "[",
524 absl::StrJoin(entries, ", "), "]");
525 }
526 return SummarizeAttrValue(attr_value);
527 }
528
529 // TODO(josh11b): Merge this with SummarizeNodeDef().
Print(const NodeDef & n)530 string Print(const NodeDef& n) {
531 string out;
532 strings::StrAppend(&out, n.name(), " = ", n.op());
533 if (n.attr_size() > 0) {
534 std::vector<string> entries;
535 for (auto& a : n.attr()) {
536 entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
537 }
538 std::sort(entries.begin(), entries.end());
539 // Add a short device string at the end of all attributes.
540 if (!n.device().empty()) {
541 DeviceNameUtils::ParsedName parsed;
542 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
543 entries.push_back(
544 strings::StrCat("device=", parsed.type, ":", parsed.id));
545 } else {
546 entries.push_back("device=<FAILED_TO_PARSE>");
547 }
548 }
549 strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
550 }
551 strings::StrAppend(&out, "(");
552 std::vector<StringPiece> dat;
553 std::vector<string> dep;
554 for (StringPiece s : n.input()) {
555 if (absl::ConsumePrefix(&s, "^")) {
556 dep.emplace_back(s);
557 } else {
558 dat.push_back(s);
559 }
560 }
561 strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
562 if (!dep.empty()) {
563 strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
564 }
565 return out;
566 }
567
Print(const FunctionDef & fdef)568 string Print(const FunctionDef& fdef) {
569 string out;
570 const OpDef& sig = fdef.signature();
571 strings::StrAppend(&out, "\n", sig.name());
572 if (sig.attr_size() > 0) {
573 strings::StrAppend(&out, "[");
574 for (int i = 0; i < sig.attr_size(); ++i) {
575 const auto& a = sig.attr(i);
576 if (i > 0) strings::StrAppend(&out, ", ");
577 if (a.type() == "type") {
578 strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
579 } else {
580 strings::StrAppend(&out, a.name(), ":", a.type());
581 }
582 }
583 strings::StrAppend(&out, "]");
584 }
585 strings::StrAppend(&out, "(");
586 for (int i = 0; i < sig.input_arg_size(); ++i) {
587 if (i > 0) strings::StrAppend(&out, ", ");
588 strings::StrAppend(&out, Print(sig.input_arg(i)));
589 }
590 strings::StrAppend(&out, ") -> (");
591 for (int i = 0; i < sig.output_arg_size(); ++i) {
592 if (i > 0) strings::StrAppend(&out, ", ");
593 strings::StrAppend(&out, Print(sig.output_arg(i)));
594 }
595 strings::StrAppend(&out, ") {\n");
596 for (const auto& n : fdef.node_def()) {
597 strings::StrAppend(&out, " ", Print(n), "\n");
598 }
599 for (const auto& cr : fdef.control_ret()) {
600 strings::StrAppend(&out, " @return ", cr.first, " = ", cr.second, "\n");
601 }
602 for (const auto& r : fdef.ret()) {
603 strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
604 }
605 strings::StrAppend(&out, "}\n");
606 return out;
607 }
608
Print(gtl::ArraySlice<const NodeDef * > nodes)609 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
610 std::vector<const NodeDef*> arg;
611 std::vector<const NodeDef*> ret;
612 std::vector<const NodeDef*> body;
613 for (const NodeDef* n : nodes) {
614 if (n->op() == FunctionLibraryDefinition::kArgOp ||
615 n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
616 arg.push_back(n);
617 } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
618 n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
619 ret.push_back(n);
620 } else {
621 body.push_back(n);
622 }
623 }
624 auto comp = [](const NodeDef* x, const NodeDef* y) {
625 int xi;
626 TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
627 int yi;
628 TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
629 return xi < yi;
630 };
631 std::sort(arg.begin(), arg.end(), comp);
632 std::sort(ret.begin(), ret.end(), comp);
633 string out;
634 strings::StrAppend(&out, "\n(");
635 auto get_type_and_device = [](const NodeDef& n) {
636 DataType dt;
637 if (!TryGetNodeAttr(n, "T", &dt)) {
638 dt = DT_INVALID;
639 }
640 if (!n.device().empty()) {
641 DeviceNameUtils::ParsedName parsed;
642 if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
643 return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
644 parsed.id);
645 } else {
646 LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
647 << n.op() << ":" << n.name();
648 return strings::StrCat(DataTypeString(dt), "@",
649 "<FAILED_TO_PARSE_DEVICE>");
650 }
651 }
652 return DataTypeString(dt);
653 };
654 for (size_t i = 0; i < arg.size(); ++i) {
655 const NodeDef* n = arg[i];
656 if (i > 0) strings::StrAppend(&out, ", ");
657 CHECK_GE(n->attr_size(), 2);
658 strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
659 }
660 strings::StrAppend(&out, ") -> (");
661 for (size_t i = 0; i < ret.size(); ++i) {
662 const NodeDef* n = ret[i];
663 if (i > 0) strings::StrAppend(&out, ", ");
664 CHECK_LE(2, n->attr_size());
665
666 // The _RetVal op should have a unique non-control input. We assert that
667 // here and add it to the output.
668 bool found_non_control_input = false;
669 for (const string& input : n->input()) {
670 if (!input.empty() && input[0] != '^') {
671 DCHECK_EQ(found_non_control_input, false)
672 << "RetVal node has more than one non-control input: "
673 << absl::StrJoin(n->input(), ", ");
674 strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
675 found_non_control_input = true;
676 }
677 }
678 DCHECK_EQ(found_non_control_input, true)
679 << "RetVal did not have any non-control inputs: "
680 << absl::StrJoin(n->input(), ", ");
681 }
682 strings::StrAppend(&out, ") {\n");
683 for (size_t i = 0; i < body.size(); ++i) {
684 strings::StrAppend(&out, " ", Print(*body[i]), "\n");
685 }
686 strings::StrAppend(&out, "}\n");
687 return out;
688 }
689
AddDefaultAttrs(const string & op,const GetFunctionSignature & get_function,AttrValueMap * attrs)690 Status AddDefaultAttrs(const string& op,
691 const GetFunctionSignature& get_function,
692 AttrValueMap* attrs) {
693 const OpDef* op_def = nullptr;
694 TF_RETURN_IF_ERROR(get_function(op, &op_def));
695 AttrSlice attr_slice(attrs);
696 for (const auto& attr_def : op_def->attr()) {
697 if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
698 if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
699 return errors::Internal("Somehow duplicated: ", attr_def.name());
700 }
701 }
702 }
703 return Status::OK();
704 }
705
706 } // end namespace
707
708 // TODO(shikharagarwal): Transmit original node names correctly in file.
InstantiateFunction(const FunctionDef & fdef,AttrSlice attr_values,GetFunctionSignature get_function,InstantiationResult * result)709 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
710 GetFunctionSignature get_function,
711 InstantiationResult* result) {
712 if (VLOG_IS_ON(5)) {
713 const auto& signature = fdef.signature();
714 VLOG(5) << "Instantiate function definition: name=" << signature.name()
715 << " #input_args=" << signature.input_arg_size()
716 << " #output_args=" << signature.output_arg_size()
717 << " #control_output=" << signature.control_output_size();
718 for (const auto& line : str_util::Split(Print(fdef), '\n')) {
719 VLOG(5) << "|| " << line;
720 }
721 }
722
723 const OpDef& sig = fdef.signature();
724 TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
725
726 bool ints_on_device =
727 fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
728 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
729
730 FunctionInstantiationHelper helper(get_function, result);
731 Status s;
732 for (int i = 0, e = sig.input_arg_size(); i < e; ++i) {
733 const OpDef::ArgDef& arg_def = sig.input_arg(i);
734 auto it = fdef.arg_attr().find(i);
735 const FunctionDef::ArgAttrs* arg_attrs =
736 it != fdef.arg_attr().end() ? &it->second : nullptr;
737 auto resource_id_it = fdef.resource_arg_unique_id().find(i);
738 int64 resource_arg_unique_id =
739 resource_id_it != fdef.resource_arg_unique_id().end()
740 ? resource_id_it->second
741 : -1LL;
742 s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
743 ints_on_device, resource_arg_unique_id);
744
745 if (!s.ok()) {
746 errors::AppendToMessage(&s, "In ", Print(arg_def));
747 return s;
748 }
749 }
750
751 auto substitute = [attr_values](StringPiece name, AttrValue* val) {
752 if (const AttrValue* v = attr_values.Find(name)) {
753 *val = *v;
754 return true;
755 }
756 return false;
757 };
758
759 // Makes a copy of all attrs in fdef and substitutes placeholders.
760 // After this step, every attr is bound to a concrete value.
761 std::vector<AttrValueMap> node_attrs;
762 node_attrs.resize(fdef.node_def_size());
763 for (int i = 0; i < fdef.node_def_size(); ++i) {
764 for (auto attr : fdef.node_def(i).attr()) {
765 if (!SubstitutePlaceholders(substitute, &attr.second)) {
766 return errors::InvalidArgument("Failed to bind all placeholders in ",
767 SummarizeAttrValue(attr.second));
768 }
769 if (!node_attrs[i].insert(attr).second) {
770 return errors::Internal("Somehow duplicated: ", attr.first);
771 }
772 }
773 TF_RETURN_IF_ERROR(
774 AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
775 }
776
777 for (int i = 0; i < fdef.node_def_size(); ++i) {
778 s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
779 result->nodes.size() + i);
780 if (!s.ok()) {
781 errors::AppendToMessage(&s, "In ",
782 FormatNodeDefForError(fdef.node_def(i)));
783 return s;
784 }
785 }
786 // Emits one node for each fdef.node_def.
787 for (int i = 0; i < fdef.node_def_size(); ++i) {
788 s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
789 if (!s.ok()) {
790 errors::AppendToMessage(&s, "In ",
791 FormatNodeDefForError(fdef.node_def(i)));
792 return s;
793 }
794 }
795
796 // Emits nodes for the function's return values.
797 int ret_index = 0;
798 for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
799 s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
800 &ret_index);
801 if (!s.ok()) {
802 errors::AppendToMessage(&s, "In function output ", Print(ret_def));
803 return s;
804 }
805 }
806
807 // Adds the actual node inputs using the input indexes.
808 helper.AddNodeInputs();
809
810 return Status::OK();
811 }
812
DebugString(const FunctionDef & func_def)813 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
814
DebugString(const GraphDef & instantiated_func_def)815 string DebugString(const GraphDef& instantiated_func_def) {
816 std::vector<const NodeDef*> ptrs;
817 for (const NodeDef& n : instantiated_func_def.node()) {
818 ptrs.push_back(&n);
819 }
820 return Print(ptrs);
821 }
822
DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes)823 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
824 std::vector<const NodeDef*> ptrs;
825 for (const NodeDef& n : instantiated_func_nodes) {
826 ptrs.push_back(&n);
827 }
828 return Print(ptrs);
829 }
830
DebugStringWhole(const GraphDef & gdef)831 string DebugStringWhole(const GraphDef& gdef) {
832 string ret;
833 for (const auto& fdef : gdef.library().function()) {
834 strings::StrAppend(&ret, Print(fdef));
835 }
836 strings::StrAppend(&ret, "\n");
837 for (const auto& ndef : gdef.node()) {
838 strings::StrAppend(&ret, Print(ndef), "\n");
839 }
840 return ret;
841 }
842
843 namespace {
844
845 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
846 // Python, it's possible to access unset attrs, which returns a default value
847 // and adds an unset attr to the map.
GetSetAttrs(const FunctionDef & fdef)848 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
849 std::map<string, AttrValue> set_attrs;
850 for (const auto& pair : fdef.attr()) {
851 if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
852 set_attrs[pair.first] = pair.second;
853 }
854 }
855 return set_attrs;
856 }
857
858 } // end namespace
859
FunctionDefsEqual(const FunctionDef & f1,const FunctionDef & f2)860 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
861 if (!OpDefEqual(f1.signature(), f2.signature())) return false;
862
863 std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
864 std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
865 if (f1_attrs.size() != f2_attrs.size()) return false;
866 for (const auto& iter1 : f1_attrs) {
867 auto iter2 = f2_attrs.find(iter1.first);
868 if (iter2 == f2_attrs.end()) return false;
869 if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
870 }
871
872 if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
873 return false;
874 }
875
876 std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
877 std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
878 if (ret1 != ret2) return false;
879
880 std::map<string, string> control_ret1(f1.control_ret().begin(),
881 f1.control_ret().end());
882 std::map<string, string> control_ret2(f2.control_ret().begin(),
883 f2.control_ret().end());
884 if (control_ret1 != control_ret2) return false;
885
886 return true;
887 }
888
FunctionDefHash(const FunctionDef & fdef)889 uint64 FunctionDefHash(const FunctionDef& fdef) {
890 // signature
891 uint64 h = OpDefHash(fdef.signature());
892
893 // attrs
894 std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
895 for (const auto& p : attrs) {
896 h = Hash64(p.first.data(), p.first.size(), h);
897 h = Hash64Combine(AttrValueHash(p.second), h);
898 }
899
900 // node defs
901 h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
902
903 // output names
904 std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
905 for (const auto& p : ret) {
906 h = Hash64(p.first.data(), p.first.size(), h);
907 h = Hash64(p.second.data(), p.second.size(), h);
908 }
909
910 // control output names
911 std::map<string, string> control_ret(fdef.control_ret().begin(),
912 fdef.control_ret().end());
913 for (const auto& p : control_ret) {
914 h = Hash64(p.first.data(), p.first.size(), h);
915 h = Hash64(p.second.data(), p.second.size(), h);
916 }
917
918 return h;
919 }
920
921 static constexpr const char* const kExecutorAttr = "_executor";
922
923 /* static */
ExecutorType(const InstantiateOptions & options,AttrSlice attrs)924 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
925 AttrSlice attrs) {
926 if (!options.executor_type.empty()) {
927 return options.executor_type;
928 } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
929 return executor_attr->s();
930 } else {
931 return string();
932 }
933 }
934
935 namespace {
936 class AttrKeyAndValue {
937 public:
938 enum ValueRepresentationOp {
939 kRaw,
940 kCEscape,
941 };
AttrKeyAndValue(absl::string_view key_name,int key_suffix,string value,ValueRepresentationOp value_op=kRaw)942 AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value,
943 ValueRepresentationOp value_op = kRaw)
944 : key_name_(key_name),
945 key_suffix_(key_suffix),
946 value_op_(value_op),
947 value_(std::move(value)) {}
948
operator <(const AttrKeyAndValue & b) const949 bool operator<(const AttrKeyAndValue& b) const {
950 if (key_name_ != b.key_name_) {
951 return key_name_ < b.key_name_;
952 } else if (key_suffix_ != b.key_suffix_) {
953 return key_suffix_ < b.key_suffix_;
954 } else {
955 return value_ < b.value_;
956 }
957 }
958
AppendTo(bool first,string * s) const959 void AppendTo(bool first, string* s) const {
960 absl::string_view v;
961 bool add_escaped = false;
962 if ((value_op_ == kCEscape) && NeedsEscaping(value_)) {
963 // Use CEscape call below
964 add_escaped = true;
965 } else {
966 // Add raw value contents directly
967 v = value_;
968 }
969 if (key_suffix_ >= 0) {
970 strings::StrAppend(s, first ? "" : ",", key_name_, key_suffix_, "=", v);
971 } else {
972 strings::StrAppend(s, first ? "" : ",", key_name_, "=", v);
973 }
974 if (add_escaped) {
975 strings::StrAppend(s, absl::CEscape(value_));
976 }
977 }
978
979 private:
NeedsEscaping(const string & s)980 static bool NeedsEscaping(const string& s) {
981 for (auto c : s) {
982 if (!isalnum(c) && (c != ' ')) {
983 return true;
984 }
985 }
986 return false;
987 }
988
989 absl::string_view key_name_;
990 int key_suffix_; // -1 if missing
991 ValueRepresentationOp value_op_;
992 string value_;
993 };
994 } // namespace
995
Canonicalize(const string & funcname,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options)996 string Canonicalize(const string& funcname, AttrSlice attrs,
997 const FunctionLibraryRuntime::InstantiateOptions& options) {
998 absl::InlinedVector<AttrKeyAndValue, 8> entries;
999 entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) +
1000 options.input_devices.size());
1001 for (const auto& p : attrs) {
1002 if (p.first != kExecutorAttr) {
1003 entries.push_back(AttrKeyAndValue(p.first, -1, Print(p.second)));
1004 }
1005 }
1006 if (!options.target.empty()) {
1007 entries.push_back(AttrKeyAndValue("_target", -1, options.target,
1008 AttrKeyAndValue::kCEscape));
1009 }
1010 for (int i = 0; i < options.input_devices.size(); ++i) {
1011 entries.push_back(AttrKeyAndValue("_input_dev", i, options.input_devices[i],
1012 AttrKeyAndValue::kCEscape));
1013 }
1014 for (int i = 0; i < options.output_devices.size(); ++i) {
1015 entries.push_back(AttrKeyAndValue("_output_dev", i,
1016 options.output_devices[i],
1017 AttrKeyAndValue::kCEscape));
1018 }
1019 for (const auto& iter : options.input_resource_dtypes_and_shapes) {
1020 entries.push_back(AttrKeyAndValue("_input_resource_dtype", iter.first,
1021 DataTypeString(iter.second.dtype)));
1022 entries.push_back(AttrKeyAndValue("_input_resource_shape", iter.first,
1023 iter.second.shape.DebugString(),
1024 AttrKeyAndValue::kCEscape));
1025 }
1026 if (options.lib_def) {
1027 entries.push_back(AttrKeyAndValue(
1028 "_lib_def", -1,
1029 absl::StrCat("", reinterpret_cast<uintptr_t>(options.lib_def))));
1030 }
1031 if (!options.state_handle.empty()) {
1032 entries.push_back(
1033 AttrKeyAndValue("_state_handle", -1, options.state_handle));
1034 }
1035 string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
1036 if (!executor_type.empty()) {
1037 entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type));
1038 }
1039 if (options.config_proto.ByteSize() > 0) {
1040 string config_proto_serialized;
1041 options.config_proto.SerializeToString(&config_proto_serialized);
1042 entries.push_back(AttrKeyAndValue("_config_proto", -1,
1043 config_proto_serialized,
1044 AttrKeyAndValue::kCEscape));
1045 }
1046 std::sort(entries.begin(), entries.end());
1047 string result = strings::StrCat(funcname, "[");
1048 bool first = true;
1049 for (const auto& entry : entries) {
1050 entry.AppendTo(first, &result);
1051 first = false;
1052 }
1053 result += "]";
1054 return result;
1055 }
1056
Canonicalize(const string & funcname,AttrSlice attrs)1057 string Canonicalize(const string& funcname, AttrSlice attrs) {
1058 static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions =
1059 new FunctionLibraryRuntime::InstantiateOptions;
1060 return Canonicalize(funcname, attrs, *kEmptyOptions);
1061 }
1062
FunctionCallFrame(DataTypeSlice arg_types,DataTypeSlice ret_types)1063 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
1064 DataTypeSlice ret_types)
1065 : arg_types_(arg_types.begin(), arg_types.end()),
1066 ret_types_(ret_types.begin(), ret_types.end()) {
1067 args_.resize(arg_types_.size());
1068 rets_.resize(ret_types_.size());
1069 }
1070
~FunctionCallFrame()1071 FunctionCallFrame::~FunctionCallFrame() {}
1072
SetArgs(gtl::ArraySlice<Tensor> args)1073 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
1074 // Input type checks.
1075 if (args.size() != arg_types_.size()) {
1076 return errors::InvalidArgument("Expects ", arg_types_.size(),
1077 " arguments, but ", args.size(),
1078 " is provided");
1079 }
1080 for (size_t i = 0; i < args.size(); ++i) {
1081 if (arg_types_[i] != args[i].dtype()) {
1082 return errors::InvalidArgument(
1083 "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
1084 DataTypeString(args[i].dtype()), " is provided");
1085 }
1086 args_[i] = args[i];
1087 }
1088 return Status::OK();
1089 }
1090
GetRetvals(std::vector<Tensor> * rets) const1091 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
1092 rets->clear();
1093 rets->reserve(rets_.size());
1094 for (size_t i = 0; i < rets_.size(); ++i) {
1095 const auto& item = rets_[i];
1096 if (item.has_val) {
1097 rets->push_back(item.val);
1098 } else {
1099 return errors::Internal("Retval[", i, "] does not have value");
1100 }
1101 }
1102 return Status::OK();
1103 }
1104
ConsumeRetvals(std::vector<Tensor> * rets,bool allow_dead_tensors)1105 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
1106 bool allow_dead_tensors) {
1107 rets->clear();
1108 rets->reserve(rets_.size());
1109 for (size_t i = 0; i < rets_.size(); ++i) {
1110 if (rets_[i].has_val) {
1111 rets->emplace_back(std::move(rets_[i].val));
1112 } else if (allow_dead_tensors) {
1113 rets->emplace_back();
1114 } else {
1115 return errors::Internal("Retval[", i, "] does not have value");
1116 }
1117 }
1118 return Status::OK();
1119 }
1120
GetArg(int index,Tensor * val) const1121 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
1122 if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
1123 return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
1124 args_.size(), ")");
1125 }
1126 *val = args_[index];
1127 return Status::OK();
1128 }
1129
SetRetval(int index,const Tensor & val)1130 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
1131 if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
1132 return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
1133 rets_.size(), ")");
1134 }
1135 if (val.dtype() != ret_types_[index]) {
1136 return errors::InvalidArgument(
1137 "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
1138 ", but ", DataTypeString(val.dtype()), " is provided.");
1139 }
1140 Retval* item = &rets_[index];
1141 if (!item->has_val) {
1142 item->has_val = true;
1143 item->val = val;
1144 } else {
1145 return errors::Internal("Retval[", index, "] has already been set.");
1146 }
1147 return Status::OK();
1148 }
1149
1150 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef & fdef_in)1151 FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
1152 : fdef(fdef_in),
1153 // Exact shape inference for functions is handled by ShapeRefiner.
1154 // Here we pass a dummy shape inference function for legacy code paths.
1155 op_registration_data(fdef.signature(), shape_inference::UnknownShape,
1156 true /* is_function */) {}
1157
FunctionLibraryDefinition(const FunctionLibraryDefinition & other)1158 FunctionLibraryDefinition::FunctionLibraryDefinition(
1159 const FunctionLibraryDefinition& other)
1160 : default_registry_(other.default_registry_) {
1161 tf_shared_lock l(other.mu_);
1162 function_defs_ = other.function_defs_;
1163 func_grad_ = other.func_grad_;
1164 }
1165
FunctionLibraryDefinition(const OpRegistryInterface * default_registry,const FunctionDefLibrary & def_lib)1166 FunctionLibraryDefinition::FunctionLibraryDefinition(
1167 const OpRegistryInterface* default_registry,
1168 const FunctionDefLibrary& def_lib)
1169 : default_registry_(default_registry),
1170 function_defs_(def_lib.function_size()) {
1171 for (const auto& fdef : def_lib.function()) {
1172 // The latter function definition wins.
1173 auto& ptr = function_defs_[fdef.signature().name()];
1174 ptr.reset(new FunctionDefAndOpRegistration(fdef));
1175 }
1176 for (const auto& grad : def_lib.gradient()) {
1177 func_grad_[grad.function_name()] = grad.gradient_func();
1178 }
1179 }
1180
~FunctionLibraryDefinition()1181 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
1182
Contains(const string & func) const1183 bool FunctionLibraryDefinition::Contains(const string& func) const {
1184 tf_shared_lock l(mu_);
1185 return function_defs_.find(func) != function_defs_.end();
1186 }
1187
Find(const string & func) const1188 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
1189 tf_shared_lock l(mu_);
1190 auto result = FindHelper(func);
1191 if (result) {
1192 return &result->fdef;
1193 } else {
1194 return nullptr;
1195 }
1196 }
1197
1198 std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FindHelper(const string & func) const1199 FunctionLibraryDefinition::FindHelper(const string& func) const {
1200 auto iter = function_defs_.find(func);
1201 if (iter == function_defs_.end()) {
1202 return nullptr;
1203 } else {
1204 return iter->second;
1205 }
1206 }
1207
AddFunctionDef(const FunctionDef & fdef)1208 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
1209 mutex_lock l(mu_);
1210 bool added;
1211 return AddFunctionDefHelper(fdef, &added);
1212 }
1213
AddFunctionDefHelper(const FunctionDef & fdef,bool * added)1214 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
1215 bool* added) {
1216 *added = false;
1217 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1218 function_defs_[fdef.signature().name()];
1219 if (entry) {
1220 if (!FunctionDefsEqual(entry->fdef, fdef)) {
1221 return errors::InvalidArgument(
1222 "Cannot add function '", fdef.signature().name(),
1223 "' because a different function with the same name already "
1224 "exists.");
1225 }
1226 // Ignore duplicate FunctionDefs.
1227 return Status::OK();
1228 }
1229 const OpDef* op_def;
1230 if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
1231 return errors::InvalidArgument(
1232 "Cannot add function '", fdef.signature().name(),
1233 "' because an op with the same name already exists.");
1234 }
1235 entry = std::make_shared<FunctionDefAndOpRegistration>(fdef);
1236 *added = true;
1237 return Status::OK();
1238 }
1239
AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration,bool * added)1240 Status FunctionLibraryDefinition::AddHelper(
1241 std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
1242 *added = false;
1243 std::shared_ptr<FunctionDefAndOpRegistration>& entry =
1244 function_defs_[registration->fdef.signature().name()];
1245 if (entry) {
1246 if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
1247 return errors::InvalidArgument(
1248 "Cannot add function '", registration->fdef.signature().name(),
1249 "' because a different function with the same name already "
1250 "exists.");
1251 }
1252 // Ignore duplicate FunctionDefs.
1253 return Status::OK();
1254 }
1255 const OpDef* op_def;
1256 if (default_registry_
1257 ->LookUpOpDef(registration->fdef.signature().name(), &op_def)
1258 .ok()) {
1259 return errors::InvalidArgument(
1260 "Cannot add function '", registration->fdef.signature().name(),
1261 "' because an op with the same name already exists.");
1262 }
1263 entry = std::move(registration);
1264 *added = true;
1265 return Status::OK();
1266 }
1267
CopyFunctionDefFrom(const string & func,const FunctionLibraryDefinition & other)1268 Status FunctionLibraryDefinition::CopyFunctionDefFrom(
1269 const string& func, const FunctionLibraryDefinition& other) {
1270 if (default_registry_ != other.default_registry_) {
1271 return errors::InvalidArgument(
1272 "Cannot copy function '", func,
1273 "' because CopyFunctionDefFrom() requires that both libraries have the "
1274 "same default registry.");
1275 }
1276 std::shared_ptr<FunctionDefAndOpRegistration> function_def;
1277 {
1278 tf_shared_lock l(other.mu_);
1279 function_def = other.FindHelper(func);
1280 }
1281 if (!function_def) {
1282 return errors::InvalidArgument(
1283 "Cannot copy function '", func,
1284 "' because no function with that name exists in the other library.");
1285 }
1286 {
1287 mutex_lock l(mu_);
1288 std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
1289 if (entry) {
1290 if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
1291 return errors::InvalidArgument(
1292 "Cannot copy function '", func,
1293 "' because a different function with the same name already "
1294 "exists.");
1295 }
1296 } else {
1297 entry = std::move(function_def);
1298 }
1299 }
1300 return Status::OK();
1301 }
1302
AddGradientDef(const GradientDef & grad)1303 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
1304 mutex_lock l(mu_);
1305 bool added;
1306 return AddGradientDefHelper(grad, &added);
1307 }
1308
AddGradientDefHelper(const GradientDef & grad,bool * added)1309 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
1310 bool* added) {
1311 *added = false;
1312 string* entry = &func_grad_[grad.function_name()];
1313 if (!entry->empty()) {
1314 if (*entry != grad.gradient_func()) {
1315 return errors::InvalidArgument(
1316 "Cannot assign gradient function '", grad.gradient_func(), "' to '",
1317 grad.function_name(), "' because it already has gradient function ",
1318 "'", *entry, "'");
1319 }
1320 // Ignore duplicate GradientDefs
1321 return Status::OK();
1322 }
1323 *entry = grad.gradient_func();
1324 *added = true;
1325 return Status::OK();
1326 }
1327
AddLibrary(const FunctionLibraryDefinition & other)1328 Status FunctionLibraryDefinition::AddLibrary(
1329 const FunctionLibraryDefinition& other) {
1330 // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
1331 // the duration of the function could lead to deadlock).
1332 FunctionLibraryDefinition clone(other);
1333 mutex_lock l(mu_);
1334 mutex_lock l2(clone.mu_);
1335 // Remember the funcs and grads that we added successfully so that
1336 // we can roll them back on error.
1337 std::vector<string> funcs;
1338 std::vector<string> funcs_with_grads;
1339 Status s;
1340 bool added;
1341 for (auto iter : clone.function_defs_) {
1342 s = AddHelper(iter.second, &added);
1343 if (!s.ok()) {
1344 Remove(funcs, funcs_with_grads);
1345 return s;
1346 }
1347 if (added) {
1348 funcs.push_back(iter.second->fdef.signature().name());
1349 }
1350 }
1351 for (auto iter : clone.func_grad_) {
1352 GradientDef grad;
1353 grad.set_function_name(iter.first);
1354 grad.set_gradient_func(iter.second);
1355 s = AddGradientDefHelper(grad, &added);
1356 if (!s.ok()) {
1357 Remove(funcs, funcs_with_grads);
1358 return s;
1359 }
1360 if (added) {
1361 funcs_with_grads.push_back(grad.function_name());
1362 }
1363 }
1364 return Status::OK();
1365 }
1366
AddLibrary(const FunctionDefLibrary & lib_def)1367 Status FunctionLibraryDefinition::AddLibrary(
1368 const FunctionDefLibrary& lib_def) {
1369 // Remember the funcs and grads that we added successfully so that
1370 // we can roll them back on error.
1371 mutex_lock l(mu_);
1372 std::vector<string> funcs;
1373 std::vector<string> funcs_with_grads;
1374 Status s;
1375 bool added;
1376 for (const FunctionDef& fdef : lib_def.function()) {
1377 s = AddFunctionDefHelper(fdef, &added);
1378 if (!s.ok()) {
1379 Remove(funcs, funcs_with_grads);
1380 return s;
1381 }
1382 if (added) {
1383 funcs.push_back(fdef.signature().name());
1384 }
1385 }
1386 for (const GradientDef& grad : lib_def.gradient()) {
1387 s = AddGradientDefHelper(grad, &added);
1388 if (!s.ok()) {
1389 Remove(funcs, funcs_with_grads);
1390 return s;
1391 }
1392 if (added) {
1393 funcs_with_grads.push_back(grad.function_name());
1394 }
1395 }
1396 return Status::OK();
1397 }
1398
ReplaceFunction(const string & func,const FunctionDef & fdef)1399 Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
1400 const FunctionDef& fdef) {
1401 mutex_lock l(mu_);
1402 bool added;
1403 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1404 TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
1405 return Status::OK();
1406 }
1407
ReplaceGradient(const GradientDef & grad)1408 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
1409 mutex_lock l(mu_);
1410 bool added;
1411 TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
1412 TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
1413 return Status::OK();
1414 }
1415
RemoveFunction(const string & func)1416 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
1417 mutex_lock l(mu_);
1418 TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
1419 return Status::OK();
1420 }
1421
RemoveFunctionHelper(const string & func)1422 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
1423 const auto& i = function_defs_.find(func);
1424 if (i == function_defs_.end()) {
1425 return errors::InvalidArgument("Tried to remove non-existent function '",
1426 func, "'.");
1427 }
1428 function_defs_.erase(i);
1429 return Status::OK();
1430 }
1431
RemoveGradient(const string & func)1432 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
1433 const auto& i = func_grad_.find(func);
1434 if (i == func_grad_.end()) {
1435 return errors::InvalidArgument("Tried to remove non-existent gradient '",
1436 func, "'.");
1437 }
1438 func_grad_.erase(i);
1439 return Status::OK();
1440 }
1441
Remove(const std::vector<string> & funcs,const std::vector<string> & funcs_with_grads)1442 void FunctionLibraryDefinition::Remove(
1443 const std::vector<string>& funcs,
1444 const std::vector<string>& funcs_with_grads) {
1445 for (const string& f : funcs) {
1446 Status s = RemoveFunctionHelper(f);
1447 DCHECK(s.ok());
1448 }
1449 for (const string& f : funcs_with_grads) {
1450 Status s = RemoveGradient(f);
1451 DCHECK(s.ok());
1452 }
1453 }
1454
FindGradient(const string & func) const1455 string FunctionLibraryDefinition::FindGradient(const string& func) const {
1456 tf_shared_lock l(mu_);
1457 return gtl::FindWithDefault(func_grad_, func, "");
1458 }
1459
FindGradientHelper(const string & func) const1460 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
1461 return gtl::FindWithDefault(func_grad_, func, "");
1462 }
1463
LookUp(const string & op,const OpRegistrationData ** op_reg_data) const1464 Status FunctionLibraryDefinition::LookUp(
1465 const string& op, const OpRegistrationData** op_reg_data) const {
1466 tf_shared_lock l(mu_);
1467 auto iter = function_defs_.find(op);
1468 if (iter != function_defs_.end()) {
1469 *op_reg_data = &iter->second->op_registration_data;
1470 return Status::OK();
1471 }
1472 return default_registry_->LookUp(op, op_reg_data);
1473 }
1474
UniqueFunctionName(StringPiece prefix) const1475 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
1476 tf_shared_lock l(mu_);
1477 int index = 0;
1478 string name = strings::StrCat(prefix, index);
1479 while (function_defs_.find(name) != function_defs_.end()) {
1480 ++index;
1481 name = strings::StrCat(prefix, index);
1482 }
1483 return name;
1484 }
1485
GetAttrImpl(const NodeDef & ndef) const1486 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
1487 const NodeDef& ndef) const {
1488 if (ndef.op() != kGradientOp) {
1489 // If 'ndef' calls a function and the function's def has the attr,
1490 // returns it.
1491 return Find(ndef.op());
1492 }
1493
1494 // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
1495 // Foo's attributes.
1496 const NameAttrList* forward_func_attrs;
1497 if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
1498 return nullptr;
1499 }
1500 const string& func_name = forward_func_attrs->name();
1501 {
1502 tf_shared_lock l(mu_);
1503 const string& grad_name = FindGradientHelper(func_name);
1504 // If 'func' has a user-defined gradient function, uses the grad
1505 // function's attrs to see if noinline is specified. Otherwise,
1506 // uses func's attrs.
1507 if (!grad_name.empty()) {
1508 return &(FindHelper(grad_name)->fdef);
1509 }
1510 return &(FindHelper(func_name)->fdef);
1511 }
1512 }
1513
ListFunctionNames() const1514 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
1515 std::vector<string> function_names;
1516 tf_shared_lock l(mu_);
1517 function_names.reserve(function_defs_.size());
1518 for (const auto& it : function_defs_) {
1519 function_names.emplace_back(it.first);
1520 }
1521 return function_names;
1522 }
1523
ToProto() const1524 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
1525 FunctionDefLibrary lib;
1526 tf_shared_lock l(mu_);
1527 for (const auto& f : function_defs_) {
1528 *lib.add_function() = f.second->fdef;
1529 }
1530 for (const auto& g : func_grad_) {
1531 GradientDef* gd = lib.add_gradient();
1532 gd->set_function_name(g.first);
1533 gd->set_gradient_func(g.second);
1534 }
1535 return lib;
1536 }
1537
1538 template <typename T>
GetAttr(const NodeDef & ndef,const string & attr,T * value) const1539 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
1540 const string& attr, T* value) const {
1541 const FunctionDef* fdef = GetAttrImpl(ndef);
1542 if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
1543 return Status::OK();
1544 }
1545 return errors::InvalidArgument("Attr ", attr, " is not defined.");
1546 }
1547
1548 template <typename T>
GetAttr(const Node & node,const string & attr,T * value) const1549 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
1550 T* value) const {
1551 return GetAttr(node.def(), attr, value);
1552 }
1553
1554 #define GET_ATTR(T) \
1555 template Status FunctionLibraryDefinition::GetAttr(const Node&, \
1556 const string&, T*) const; \
1557 template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
1558 const string&, T*) const;
1559 GET_ATTR(string)
1560 GET_ATTR(bool)
1561 #undef GET_ATTR
1562
1563 namespace {
1564
1565 constexpr char kApiImplements[] = "api_implements";
1566
ReachableFunctions(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1567 std::set<string> ReachableFunctions(
1568 const FunctionLibraryDefinition& flib,
1569 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1570 // Functions that are reachable from the graph.
1571 std::set<string> reachable_funcs;
1572
1573 // For any functions, if it has attribute "api_implements" =
1574 // "some_interface" and it is reachable, then it means any other
1575 // function with same attribute name and value could also be potentially
1576 // reachable, eg via implementation_selector swapping the nodedef.
1577 absl::flat_hash_set<string> reachable_api_interface;
1578
1579 // Functions might be reachable from the nested function calls, so we keep a
1580 // queue of functions that we have to check.
1581 gtl::InlinedVector<const FunctionDef*, 4> func_queue;
1582
1583 // Add reachable and not already processed functions to the functions queue.
1584 const auto add_to_func_queue = [&](const string& func_name) {
1585 const FunctionDef* func = flib.Find(func_name);
1586 if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
1587 func_queue.push_back(func);
1588 }
1589 };
1590
1591 // If any function with certain API name is reachable, all the other functions
1592 // with same API name should also be checked.
1593 const auto add_function_with_api_interface = [&](const string& api_name) {
1594 if (!reachable_api_interface.contains(api_name)) {
1595 reachable_api_interface.insert(api_name);
1596 for (const auto& func_name : flib.ListFunctionNames()) {
1597 const auto& func_def = flib.Find(func_name);
1598 const auto attr_it = func_def->attr().find(kApiImplements);
1599 if (attr_it != func_def->attr().end() &&
1600 attr_it->second.s() == api_name) {
1601 add_to_func_queue(func_name);
1602 }
1603 }
1604 }
1605 };
1606
1607 // Add all the functions that are reachable from the given node to the queue.
1608 const auto process_node = [&](const NodeDef& node) {
1609 // Node itself can be a call to the function.
1610 add_to_func_queue(node.op());
1611
1612 // Or node can have an attribute referencing a function.
1613 for (const auto& attr : node.attr()) {
1614 const auto& attr_value = attr.second;
1615
1616 // 1. AttrValue.func
1617 if (attr_value.has_func()) {
1618 add_to_func_queue(attr_value.func().name());
1619 }
1620
1621 // 2. AttrValue.ListValue.func
1622 if (attr_value.has_list()) {
1623 for (const auto& func : attr_value.list().func()) {
1624 add_to_func_queue(func.name());
1625 }
1626 }
1627 }
1628 };
1629
1630 // Add all functions that are directly called from the optimized graph.
1631 std::for_each(nodes.begin(), nodes.end(), process_node);
1632
1633 // Process all reachable functions.
1634 while (!func_queue.empty()) {
1635 const FunctionDef* func = func_queue.back();
1636 func_queue.pop_back();
1637
1638 const string& func_name = func->signature().name();
1639 reachable_funcs.insert(func_name);
1640
1641 const auto attr_it = func->attr().find(kApiImplements);
1642 if (attr_it != func->attr().end()) {
1643 add_function_with_api_interface(attr_it->second.s());
1644 }
1645
1646 // Find all the functions called from the function body.
1647 const auto& func_body = func->node_def();
1648 std::for_each(func_body.begin(), func_body.end(), process_node);
1649
1650 // Check if the function has a registered gradient.
1651 const string grad_func_name = flib.FindGradient(func_name);
1652 if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
1653 }
1654
1655 return reachable_funcs;
1656 }
1657
ReachableFunctionLibraryDefinition(const FunctionLibraryDefinition & flib,const protobuf::RepeatedPtrField<NodeDef> & nodes)1658 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
1659 const FunctionLibraryDefinition& flib,
1660 const protobuf::RepeatedPtrField<NodeDef>& nodes) {
1661 std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
1662
1663 FunctionLibraryDefinition reachable_flib(flib.default_registry(),
1664 FunctionDefLibrary());
1665
1666 for (const string& func_name : reachable_funcs) {
1667 // This should never fail, because we copy functions from a valid flib and
1668 // use the same default registry.
1669 Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
1670 TF_DCHECK_OK(added);
1671
1672 const string grad_func_name = flib.FindGradient(func_name);
1673 if (!grad_func_name.empty()) {
1674 GradientDef grad;
1675 grad.set_function_name(func_name);
1676 grad.set_gradient_func(grad_func_name);
1677 // It can only fail if function already has a gradient function.
1678 const Status added_grad = reachable_flib.AddGradientDef(grad);
1679 TF_DCHECK_OK(added_grad);
1680 }
1681 }
1682
1683 return reachable_flib;
1684 }
1685
AllocatorAttributesToString(const std::vector<AllocatorAttributes> & attrs)1686 string AllocatorAttributesToString(
1687 const std::vector<AllocatorAttributes>& attrs) {
1688 string result("[");
1689 // AllocatorAttribute::DebugString produces around 85 bytes now.
1690 result.reserve(100 * attrs.size());
1691 for (const AllocatorAttributes& attr : attrs) {
1692 result.append(attr.DebugString());
1693 result.append(", ");
1694 }
1695 if (!attrs.empty()) {
1696 result.resize(result.size() - 2);
1697 }
1698 result.append("]");
1699 return result;
1700 }
1701
IsSet(void * ptr)1702 const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
1703
1704 } // namespace
1705
ReachableDefinitions(const GraphDef & graph) const1706 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1707 const GraphDef& graph) const {
1708 return ReachableFunctionLibraryDefinition(*this, graph.node());
1709 }
1710
ReachableDefinitions(const FunctionDef & func) const1711 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
1712 const FunctionDef& func) const {
1713 return ReachableFunctionLibraryDefinition(*this, func.node_def());
1714 }
1715
DebugString() const1716 string FunctionLibraryRuntime::Options::DebugString() const {
1717 return absl::StrCat(
1718 "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
1719 " cancellation_manager=", IsSet(cancellation_manager),
1720 " collective_executor=", IsSet(collective_executor),
1721 " step_container=", IsSet(step_container),
1722 " stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
1723 " remote_execution=", remote_execution, " source_device=", source_device,
1724 " create_rendezvous=", create_rendezvous,
1725 " allow_dead_tensors=", allow_dead_tensors,
1726 " args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
1727 " rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
1728 }
1729
InitFromString(StringPiece val)1730 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
1731 if (val.size() >= 2 && val[0] == '$') {
1732 proto.set_placeholder(val.data() + 1, val.size() - 1);
1733 } else {
1734 SetAttrValue(val, &proto);
1735 }
1736 }
1737
FunctionRef(const string & name,gtl::ArraySlice<std::pair<string,AttrValueWrapper>> attrs)1738 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
1739 const string& name,
1740 gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
1741 AttrValueWrapper ret;
1742 ret.proto.mutable_func()->set_name(name);
1743 for (const auto& a : attrs) {
1744 ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
1745 }
1746 return ret;
1747 }
1748
ToNodeDef() const1749 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
1750 NodeDef n;
1751 n.set_op(this->op);
1752 n.set_name(this->ret[0]);
1753 for (const auto& a : this->attr) {
1754 n.mutable_attr()->insert({a.first, a.second.proto});
1755 }
1756 for (const string& a : this->arg) {
1757 n.add_input(a);
1758 }
1759 for (const string& d : this->dep) {
1760 n.add_input(strings::StrCat("^", d));
1761 }
1762 if (!this->device.empty()) {
1763 n.set_device(this->device);
1764 }
1765 return n;
1766 }
1767
1768 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def,gtl::ArraySlice<std::pair<string,string>> control_ret_def)1769 FunctionDef FunctionDefHelper::Create(
1770 const string& function_name, gtl::ArraySlice<string> in_def,
1771 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1772 gtl::ArraySlice<Node> node_def,
1773 gtl::ArraySlice<std::pair<string, string>> ret_def,
1774 gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
1775 FunctionDef fdef;
1776
1777 // Signature
1778 OpDefBuilder b(function_name);
1779 for (const auto& i : in_def) b.Input(i);
1780 for (const auto& o : out_def) b.Output(o);
1781 for (const auto& a : attr_def) b.Attr(a);
1782 for (const auto& c : control_ret_def) b.ControlOutput(c.first);
1783
1784 OpRegistrationData op_reg_data;
1785 TF_CHECK_OK(b.Finalize(&op_reg_data));
1786 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1787
1788 // Function body
1789 for (const auto& n : node_def) {
1790 *(fdef.add_node_def()) = n.ToNodeDef();
1791 }
1792
1793 // Returns
1794 for (const auto& r : ret_def) {
1795 fdef.mutable_ret()->insert({r.first, r.second});
1796 }
1797
1798 // Control returns
1799 for (const auto& cr : control_ret_def) {
1800 fdef.mutable_control_ret()->insert({cr.first, cr.second});
1801 }
1802
1803 auto* op_def_registry = OpRegistry::Global();
1804 // Check if any op is stateful.
1805 for (const auto& n : node_def) {
1806 const OpDef* op_def = nullptr;
1807 auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
1808 // Lookup can fail if e.g. we are calling a function that was not yet
1809 // defined. If it happens, conservatively assume the op is stateful.
1810 if (!status.ok() || op_def->is_stateful()) {
1811 fdef.mutable_signature()->set_is_stateful(true);
1812 }
1813 }
1814
1815 return fdef;
1816 }
1817
1818 /* static */
Create(const string & function_name,gtl::ArraySlice<string> in_def,gtl::ArraySlice<string> out_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def,gtl::ArraySlice<std::pair<string,string>> ret_def)1819 FunctionDef FunctionDefHelper::Create(
1820 const string& function_name, gtl::ArraySlice<string> in_def,
1821 gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
1822 gtl::ArraySlice<Node> node_def,
1823 gtl::ArraySlice<std::pair<string, string>> ret_def) {
1824 return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
1825 /*control_ret_def=*/{});
1826 }
1827
1828 /* static */
Define(const string & name,gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1829 FunctionDef FunctionDefHelper::Define(const string& name,
1830 gtl::ArraySlice<string> arg_def,
1831 gtl::ArraySlice<string> ret_def,
1832 gtl::ArraySlice<string> attr_def,
1833 gtl::ArraySlice<Node> node_def) {
1834 FunctionDef fdef;
1835 OpDefBuilder b(name);
1836 for (const auto& a : arg_def) b.Input(a);
1837 for (const auto& r : ret_def) b.Output(r);
1838 for (const auto& a : attr_def) b.Attr(a);
1839
1840 OpRegistrationData op_reg_data;
1841 TF_CHECK_OK(b.Finalize(&op_reg_data));
1842 fdef.mutable_signature()->Swap(&op_reg_data.op_def);
1843
1844 // Mapping from legacy output names to NodeDef outputs.
1845 std::unordered_map<string, string> ret_index;
1846 for (const auto& a : fdef.signature().input_arg()) {
1847 ret_index[a.name()] = a.name();
1848 }
1849
1850 // For looking up OpDefs
1851 auto* op_def_registry = OpRegistry::Global();
1852
1853 // Function body
1854 for (const auto& src : node_def) {
1855 NodeDef* n = fdef.add_node_def();
1856 n->set_op(src.op);
1857 n->set_name(src.ret[0]);
1858 for (const auto& a : src.attr) {
1859 n->mutable_attr()->insert({a.first, a.second.proto});
1860 }
1861 for (const string& a : src.arg) {
1862 const auto iter = ret_index.find(a);
1863 CHECK(iter != ret_index.end())
1864 << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
1865 n->add_input(iter->second);
1866 }
1867 for (const string& d : src.dep) {
1868 n->add_input(strings::StrCat("^", d));
1869 }
1870
1871 // Add the outputs of this node to ret_index.
1872 const OpDef* op_def = nullptr;
1873 TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
1874 CHECK(op_def != nullptr) << n->op();
1875 NameRangeMap output_names;
1876 TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
1877 for (const auto& o : output_names) {
1878 CHECK_LE(o.second.second, src.ret.size())
1879 << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
1880 << "' of " << name;
1881 for (int i = o.second.first; i < o.second.second; ++i) {
1882 ret_index[src.ret[i]] =
1883 strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
1884 }
1885 }
1886 if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
1887 }
1888
1889 // Returns
1890 for (const auto& r : fdef.signature().output_arg()) {
1891 const auto iter = ret_index.find(r.name());
1892 CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
1893 fdef.mutable_ret()->insert({r.name(), iter->second});
1894 }
1895 return fdef;
1896 }
1897
Define(gtl::ArraySlice<string> arg_def,gtl::ArraySlice<string> ret_def,gtl::ArraySlice<string> attr_def,gtl::ArraySlice<Node> node_def)1898 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
1899 gtl::ArraySlice<string> ret_def,
1900 gtl::ArraySlice<string> attr_def,
1901 gtl::ArraySlice<Node> node_def) {
1902 return Define("_", arg_def, ret_def, attr_def, node_def);
1903 }
1904
1905 namespace gradient {
1906
1907 typedef std::unordered_map<string, Creator> OpGradFactory;
1908
GetOpGradFactory()1909 OpGradFactory* GetOpGradFactory() {
1910 static OpGradFactory* factory = new OpGradFactory;
1911 return factory;
1912 }
1913
RegisterOp(const string & op,Creator func)1914 bool RegisterOp(const string& op, Creator func) {
1915 CHECK(GetOpGradFactory()->insert({op, func}).second)
1916 << "Duplicated gradient for " << op;
1917 return true;
1918 }
1919
GetOpGradientCreator(const string & op,Creator * creator)1920 Status GetOpGradientCreator(const string& op, Creator* creator) {
1921 auto fac = GetOpGradFactory();
1922 auto iter = fac->find(op);
1923 if (iter == fac->end()) {
1924 return errors::NotFound("No gradient defined for op: ", op);
1925 }
1926 *creator = iter->second;
1927 return Status::OK();
1928 }
1929
1930 } // end namespace gradient
1931
1932 } // namespace tensorflow
1933