1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/utils/functions.h"
16
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph_def_util.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/versions.pb.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/utils.h"
32 #include "tensorflow/core/lib/strings/scanner.h"
33
34 namespace tensorflow {
35 namespace grappler {
36
37 namespace {
38
RegisterFunctionBodyOutputs(const OpRegistrationData & registration,const NodeDef & node,GrapplerFunctionConnectivity * connectivity)39 Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
40 const NodeDef& node,
41 GrapplerFunctionConnectivity* connectivity) {
42 tensorflow::NameRangeMap outputs_range_map;
43 TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
44 node, registration.op_def, nullptr, &outputs_range_map));
45 connectivity->RegisterFunctionBodyOutputs(node.name(),
46 std::move(outputs_range_map));
47 return Status::OK();
48 }
49
RegisterFunctionBodyOutputs(const FunctionLibraryDefinition & flib,const NodeDef & node,GrapplerFunctionConnectivity * connectivity)50 Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib,
51 const NodeDef& node,
52 GrapplerFunctionConnectivity* connectivity) {
53 const OpRegistrationData* registration;
54 TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration));
55 return RegisterFunctionBodyOutputs(*registration, node, connectivity);
56 }
57
58 // Replace the placeholder attribute values with the values specified in
59 // instantiation attributes.
ResolveFunctionBodyNodeAttrPlaceholders(const AttrSlice & func_instantiation_attr,NodeDef * node)60 Status ResolveFunctionBodyNodeAttrPlaceholders(
61 const AttrSlice& func_instantiation_attr, NodeDef* node) {
62 for (auto& attr : *node->mutable_attr()) {
63 const string& placeholder = attr.second.placeholder();
64 if (placeholder.empty()) continue;
65
66 const AttrValue* attr_value = func_instantiation_attr.Find(placeholder);
67 if (attr_value) {
68 attr.second = *attr_value;
69 } else {
70 return errors::InvalidArgument("Can't resolve placeholder: ",
71 placeholder);
72 }
73 }
74 return Status::OK();
75 }
76
77 } // namespace
78
RegisterInputArgExpansion(InputArgExpansion input_arg_expansion)79 void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
80 InputArgExpansion input_arg_expansion) {
81 string input_name = input_arg_expansion.input_name;
82 const auto& placeholders = input_arg_expansion.placeholders;
83
84 for (int i = 0; i < placeholders.size(); ++i) {
85 const string& placeholder = input_arg_expansion.placeholders[i];
86 input_arg_placeholders_.insert(
87 {placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
88 }
89 input_arg_expansions_.insert(
90 {std::move(input_name), std::move(input_arg_expansion)});
91 }
92
RegisterFunctionBodyOutputs(const string & node_name,tensorflow::NameRangeMap && outputs)93 void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
94 const string& node_name, tensorflow::NameRangeMap&& outputs) {
95 function_body_outputs_[node_name] = std::move(outputs);
96 }
97
ExpandFunctionDefInput(const string & func_def_input,std::vector<string> * graph_def_inputs) const98 Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
99 const string& func_def_input, std::vector<string>* graph_def_inputs) const {
100 using ::tensorflow::strings::Scanner;
101
102 if (IsControlInput(func_def_input)) {
103 graph_def_inputs->push_back(func_def_input);
104 return Status::OK();
105 }
106
107 // Parse input format: "node_name[:node_output][:position]"
108 string node_name;
109 string node_output;
110 int position = -1;
111
112 StringPiece capture;
113 StringPiece remaining;
114
115 // Parse "node_name"
116 if (Scanner(func_def_input)
117 .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
118 .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
119 .GetResult(&remaining, &capture)) {
120 node_name = string(capture.data(), capture.size());
121 }
122
123 // Parse "node_output" if it exists
124 if (Scanner(remaining)
125 .OneLiteral(":")
126 .RestartCapture()
127 .One(strings::Scanner::LETTER)
128 .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
129 .GetResult(&remaining, &capture)) {
130 node_output = string(capture.data(), capture.size());
131 }
132
133 // Parse "position" if it exists
134 if (Scanner(remaining)
135 .OneLiteral(":")
136 .RestartCapture()
137 .Many(strings::Scanner::DIGIT)
138 .GetResult(nullptr, &capture)) {
139 CHECK(strings::safe_strto32(capture, &position));
140 }
141
142 // If "node_output" is not empty, it must be an output of a function body node
143 bool is_function_body_output = !node_output.empty();
144
145 // Function input argument: "node_name[:position]"
146 if (!is_function_body_output) {
147 auto input_arg = input_arg_expansions_.find(node_name);
148 if (input_arg != input_arg_expansions_.end()) {
149 const InputArgExpansion& input_arg_expansion = input_arg->second;
150 const auto& placeholders = input_arg_expansion.placeholders;
151
152 if (position == -1) {
153 // If position is not defined use all placeholders
154 graph_def_inputs->reserve(placeholders.size());
155 for (const string& placeholder : placeholders) {
156 graph_def_inputs->push_back(placeholder);
157 }
158 } else {
159 if (position > input_arg_expansion.placeholders.size() - 1) {
160 return errors::InvalidArgument("Invalid input ", node_name,
161 "position: ", position,
162 " (out of range)");
163 }
164 graph_def_inputs->push_back(input_arg_expansion.placeholders[position]);
165 }
166
167 return Status::OK();
168 }
169 }
170
171 // Function body output: "node_name:node_output[:position]"
172 if (is_function_body_output) {
173 auto function_body_outputs = function_body_outputs_.find(node_name);
174 if (function_body_outputs != function_body_outputs_.end()) {
175 const tensorflow::NameRangeMap& outputs = function_body_outputs->second;
176 auto output = outputs.find(node_output);
177 if (output != outputs.end()) {
178 const auto& output_range = output->second;
179
180 if (position == -1) {
181 graph_def_inputs->reserve(graph_def_inputs->size() +
182 output_range.second - output_range.first);
183 // If position is not defined expand node output range
184 for (int i = output_range.first; i < output_range.second; ++i) {
185 graph_def_inputs->push_back(
186 i == 0 ? node_name : absl::StrCat(node_name, ":", i));
187 }
188 } else {
189 if (position > (output_range.second - output_range.first)) {
190 return errors::InvalidArgument(
191 "Invalid node ", node_name, " output ", node_output,
192 " position: ", position, " (out of range)");
193 }
194 int pos = output_range.first + position;
195 graph_def_inputs->push_back(
196 pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
197 }
198
199 return Status::OK();
200 }
201 }
202 }
203
204 return errors::InvalidArgument("Failed to expand a function def input: ",
205 func_def_input);
206 }
207
ExpandNodeInputs(NodeDef * function_body_node) const208 Status GrapplerFunctionConnectivity::ExpandNodeInputs(
209 NodeDef* function_body_node) const {
210 std::vector<string> expanded_inputs;
211
212 for (const string& function_def_input : function_body_node->input()) {
213 TF_RETURN_IF_ERROR(
214 ExpandFunctionDefInput(function_def_input, &expanded_inputs));
215 }
216
217 function_body_node->clear_input();
218 for (string& expanded_input : expanded_inputs)
219 function_body_node->add_input(std::move(expanded_input));
220 return Status::OK();
221 }
222
AsFunctionDefInput(const string & graph_def_input,string * func_def_input) const223 Status GrapplerFunctionConnectivity::AsFunctionDefInput(
224 const string& graph_def_input, string* func_def_input) const {
225 if (IsControlInput(graph_def_input)) {
226 *func_def_input = graph_def_input;
227 return Status::OK();
228 }
229
230 const TensorId tensor = ParseTensorName(graph_def_input);
231 DCHECK_GE(tensor.index(), 0);
232
233 const absl::string_view node_name = tensor.node();
234 const int index = tensor.index();
235
236 // Check if it's an input arg placeholder
237 if (tensor.index() == 0) {
238 const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
239 if (is_input_placeholder != input_arg_placeholders_.end()) {
240 const InputArgPlaceholder& placeholder = is_input_placeholder->second;
241 *func_def_input =
242 absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
243 return Status::OK();
244 }
245 }
246
247 // It must be output from one of the function body nodes
248 const auto is_body_output = function_body_outputs_.find(tensor.node());
249 if (is_body_output != function_body_outputs_.end()) {
250 const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
251
252 for (const auto& el : outputs_range_map) {
253 const auto& output_name = el.first;
254 const auto& output_range = el.second;
255 if (index >= output_range.first && index < output_range.second) {
256 int pos = index - output_range.first;
257 *func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
258 return Status::OK();
259 }
260 }
261 }
262
263 return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
264 }
265
AsFunctionDefNode(NodeDef * function_body_node) const266 Status GrapplerFunctionConnectivity::AsFunctionDefNode(
267 NodeDef* function_body_node) const {
268 string func_def_input;
269
270 for (int i = 0; i < function_body_node->input_size(); ++i) {
271 TF_RETURN_IF_ERROR(
272 AsFunctionDefInput(function_body_node->input(i), &func_def_input));
273 function_body_node->set_input(i, func_def_input);
274 }
275
276 return Status::OK();
277 }
278
GetTypeAttr(const string & type_attr_name,DataType * data_type) const279 Status GrapplerFunctionItemInstantiation::GetTypeAttr(
280 const string& type_attr_name, DataType* data_type) const {
281 const AttrValue* type_attr = func_instantiation_attr_.Find(type_attr_name);
282 if (type_attr == nullptr) {
283 return errors::InvalidArgument("Type attribute ", type_attr_name,
284 " is not defined");
285 } else if (type_attr->type() == DT_INVALID) {
286 return errors::InvalidArgument("Type attribute ", type_attr_name,
287 " is not defined with a valid type");
288 } else {
289 *data_type = type_attr->type();
290 }
291 return Status::OK();
292 }
293
GetArgType(const OpDef::ArgDef & arg,DataType * data_type) const294 Status GrapplerFunctionItemInstantiation::GetArgType(
295 const OpDef::ArgDef& arg, DataType* data_type) const {
296 if (arg.type() != DT_INVALID) {
297 *data_type = arg.type();
298 } else {
299 if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) {
300 return errors::InvalidArgument(
301 "Arguments with sequence of tensors are not supported. Unsupported "
302 "argument name: ",
303 arg.name());
304 }
305 TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type));
306 }
307 return Status::OK();
308 }
309
GrapplerFunctionItem(string func_name,string description,AttrSlice func_attr,std::vector<InputArgExpansion> input_arg_expansions,std::vector<OutputArgExpansion> output_arg_expansions,std::vector<ControlOutput> control_outputs,const int graph_def_version,const bool is_stateful,GraphDef && function_body)310 GrapplerFunctionItem::GrapplerFunctionItem(
311 string func_name, string description, AttrSlice func_attr,
312 std::vector<InputArgExpansion> input_arg_expansions,
313 std::vector<OutputArgExpansion> output_arg_expansions,
314 std::vector<ControlOutput> control_outputs, const int graph_def_version,
315 const bool is_stateful, GraphDef&& function_body)
316 : description_(std::move(description)),
317 func_attr_(func_attr),
318 input_arg_expansions_(std::move(input_arg_expansions)),
319 output_arg_expansions_(std::move(output_arg_expansions)),
320 control_outputs_(std::move(control_outputs)),
321 is_stateful_(is_stateful) {
322 id = std::move(func_name);
323 graph = std::move(function_body);
324
325 graph.mutable_versions()->set_producer(graph_def_version);
326 // Fill the feed nodes with input placeholders.
327 for (const InputArgExpansion& input_arg : input_arg_expansions_) {
328 for (const string& placeholder : input_arg.placeholders) {
329 feed.push_back({placeholder, Tensor()});
330 }
331 }
332 // Fill the fetch nodes with outputs.
333 for (const OutputArgExpansion& output_arg : output_arg_expansions_) {
334 for (const string& output_node : output_arg.output_nodes) {
335 fetch.push_back(output_node);
336 }
337 }
338 // We must keep all control output nodes.
339 for (const ControlOutput& control_output : control_outputs_) {
340 keep_ops.push_back(control_output.node_name);
341 }
342
343 // Tensorflow functions execution semantics is different from the main graph,
344 // and we need to preserve it when we do graph optimizations.
345 optimization_options().allow_pruning_stateful_and_dataset_ops = false;
346 }
347
description() const348 const string& GrapplerFunctionItem::description() const { return description_; }
349
inputs() const350 const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const {
351 return input_arg_expansions_;
352 }
353
input(int i) const354 const InputArgExpansion& GrapplerFunctionItem::input(int i) const {
355 return input_arg_expansions_[i];
356 }
357
input_size() const358 const std::size_t GrapplerFunctionItem::input_size() const {
359 return input_arg_expansions_.size();
360 }
361
outputs() const362 const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const {
363 return output_arg_expansions_;
364 }
365
output(int i) const366 const OutputArgExpansion& GrapplerFunctionItem::output(int i) const {
367 return output_arg_expansions_[i];
368 }
369
output_size() const370 const std::size_t GrapplerFunctionItem::output_size() const {
371 return output_arg_expansions_.size();
372 }
373
control_outputs() const374 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
375 const {
376 return control_outputs_;
377 }
378
control_output_size() const379 const std::size_t GrapplerFunctionItem::control_output_size() const {
380 return control_outputs_.size();
381 }
382
func_attr() const383 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
384
function_body() const385 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
386
mutable_function_body()387 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
388
is_stateful() const389 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; }
390
SwapFunctionBody(GraphDef && other)391 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
392 graph.Swap(&other);
393 return *this;
394 }
395
HasParametrizedType(const FunctionDef & func)396 bool HasParametrizedType(const FunctionDef& func) {
397 const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
398 return !arg.type_attr().empty() || !arg.number_attr().empty() ||
399 !arg.type_list_attr().empty();
400 };
401
402 const auto& input = func.signature().input_arg();
403 const auto& output = func.signature().output_arg();
404 return std::any_of(input.begin(), input.end(), is_type_parametrized) ||
405 std::any_of(output.begin(), output.end(), is_type_parametrized);
406 }
407
HasParametrizedBody(const FunctionDef & func)408 bool HasParametrizedBody(const FunctionDef& func) {
409 const auto is_parametrized = [&](const NodeDef& node) {
410 for (const auto& attr : node.attr()) {
411 if (!attr.second.placeholder().empty()) return true;
412 }
413 return false;
414 };
415 return std::any_of(func.node_def().begin(), func.node_def().end(),
416 is_parametrized);
417 }
418
IsParametrized(const FunctionDef & func)419 bool IsParametrized(const FunctionDef& func) {
420 return HasParametrizedType(func) || HasParametrizedBody(func);
421 }
422
InstantiationTypeParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,DataType> * type_parameters)423 Status InstantiationTypeParameters(
424 const FunctionDef& func, const AttrSlice& func_instantiation_attr,
425 absl::flat_hash_map<string, DataType>* type_parameters) {
426 if (!type_parameters->empty()) {
427 return errors::InvalidArgument("Type parameters output map must be empty");
428 }
429
430 GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
431
432 const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
433 // Check if it's unknown and unresolved type.
434 if (arg.type() == DT_INVALID &&
435 type_parameters->find(arg.type_attr()) == type_parameters->end()) {
436 DataType data_type;
437 TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
438 type_parameters->insert({arg.type_attr(), data_type});
439 }
440 return Status::OK();
441 };
442
443 for (const auto& input : func.signature().input_arg())
444 TF_RETURN_IF_ERROR(resolve_type_attr(input));
445 for (const auto& output : func.signature().output_arg())
446 TF_RETURN_IF_ERROR(resolve_type_attr(output));
447
448 return Status::OK();
449 }
450
InstantiationBodyParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,AttrValue> * body_parameters)451 Status InstantiationBodyParameters(
452 const FunctionDef& func, const AttrSlice& func_instantiation_attr,
453 absl::flat_hash_map<string, AttrValue>* body_parameters) {
454 if (!body_parameters->empty()) {
455 return errors::InvalidArgument("Body parameters output map must be empty");
456 }
457
458 for (const NodeDef& func_body_node : func.node_def()) {
459 for (auto& attr : func_body_node.attr()) {
460 const string& placeholder = attr.second.placeholder();
461
462 if (placeholder.empty() ||
463 body_parameters->find(placeholder) != body_parameters->end()) {
464 continue;
465 }
466
467 const AttrValue* placeholder_value =
468 func_instantiation_attr.Find(placeholder);
469 if (placeholder_value) {
470 body_parameters->insert({placeholder, *placeholder_value});
471 } else {
472 return errors::InvalidArgument("Can't resolve placeholder: ",
473 placeholder);
474 }
475 }
476 }
477
478 return Status::OK();
479 }
480
MakeGrapplerFunctionItem(const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)481 Status MakeGrapplerFunctionItem(const FunctionDef& func,
482 const AttrSlice& func_instantiation_attr,
483 const FunctionLibraryDefinition& flib,
484 const int graph_def_version,
485 GrapplerFunctionItem* item) {
486 const OpDef& signature = func.signature();
487
488 if (signature.name().empty()) {
489 return errors::InvalidArgument("Function name must be specified");
490 }
491
492 // Function types will be resolved from function instantiation attributes. All
493 // other attributes will be lost during conversion to FunctionDef.
494 for (const OpDef::AttrDef& attr : signature.attr()) {
495 if (attr.type() != "type") {
496 return errors::InvalidArgument(
497 "Function signature must have only type attributes");
498 }
499 }
500
501 // Helper methods to lookup function instantiation attributes
502 GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
503
504 // Mapping from FunctionDef input format (name[:output][:position]) to
505 // GraphDef input format (name[:position])
506 GrapplerFunctionConnectivity connectivity;
507
508 // Instantiate function body into a statically defined graph def.
509 GraphDef function_body;
510
511 // Function body shares the library with the graph that instantiated it. We do
512 // not need a full copy of the function library, just the reachable subset.
513 *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
514
515 VLOG(3) << absl::Substitute(
516 "Deleted $0 unreachable functions from the Grappler function item "
517 "instantiation of $1 (library size = $2)",
518 flib.num_functions() - function_body.library().function_size(),
519 signature.name(), function_body.library().function_size());
520
521 // TODO(ezhulenev): support functions with tensor sequence inputs/outputs
522
523 // Make sure that there are no tensor lists in inputs or outputs.
524 for (const OpDef::ArgDef& input : signature.input_arg()) {
525 if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
526 return errors::InvalidArgument(
527 "Inputs with lists of tensors are not supported. Input: ",
528 input.name());
529 }
530 }
531 for (const OpDef::ArgDef& output : signature.output_arg()) {
532 if (!output.type_list_attr().empty() || !output.number_attr().empty()) {
533 return errors::InvalidArgument(
534 "Outputs with lists of tensors are not supported. Output: ",
535 output.name());
536 }
537 }
538
539 std::vector<InputArgExpansion> inputs;
540 inputs.reserve(signature.input_arg_size());
541
542 // For each input argument create a placeholder in function body.
543 for (const OpDef::ArgDef& input : signature.input_arg()) {
544 DataType input_data_type;
545 TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type));
546
547 NodeDef* placeholder = function_body.add_node();
548 placeholder->set_name(input.name());
549 placeholder->set_op("Placeholder");
550 (*placeholder->mutable_attr())["dtype"].set_type(input_data_type);
551 (*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank(
552 true);
553
554 InputArgExpansion input_expansion{/*input_name=*/input.name(),
555 /*data_type=*/input_data_type,
556 /*is_ref=*/input.is_ref(),
557 /*placeholders=*/{input.name()}};
558 connectivity.RegisterInputArgExpansion(input_expansion);
559 inputs.push_back(std::move(input_expansion));
560 }
561
562 // Keep names of all nodes in the function body to guarantee that we do not
563 // add an identity with a duplicate name.
564 absl::flat_hash_set<absl::string_view> func_body_nodes;
565
566 // Generate unique output node name: "${out_arg_name}_output_node_${index}".
567 const auto output_node_name = [&func_body_nodes](const OpDef::ArgDef& out,
568 int index) -> string {
569 string name = absl::StrCat(out.name(), "_output_node_", index);
570 int i = 1;
571 while (func_body_nodes.find(name) != func_body_nodes.end()) {
572 name = absl::StrCat(out.name(), "_output_node_", index, "_", i++);
573 }
574 return name;
575 };
576
577 // Add all function nodes to the function body.
578 for (const NodeDef& func_def_node : func.node_def()) {
579 func_body_nodes.insert(func_def_node.name());
580
581 NodeDef* new_node = function_body.add_node();
582 *new_node = func_def_node;
583
584 const OpRegistrationData* registration;
585 TF_RETURN_IF_ERROR(flib.LookUp(func_def_node.op(), ®istration));
586
587 // Resolve all placeholder values using function instantiation attributes.
588 TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders(
589 func_instantiation_attr, new_node));
590
591 // Register node output range in a function connectivity.
592 TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node,
593 &connectivity));
594 }
595
596 // Rewrite inputs to use GraphDef format
597 for (NodeDef& node : *function_body.mutable_node()) {
598 TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
599 }
600
601 std::vector<OutputArgExpansion> outputs;
602 outputs.reserve(signature.output_arg_size());
603
604 // For each function output argument we create an Identity node in the
605 // function body, that reads output tensor from the function body node.
606 for (const OpDef::ArgDef& out : signature.output_arg()) {
607 DataType output_data_type;
608 TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type));
609
610 std::vector<string> output_tensors;
611 auto ret = func.ret().find(out.name());
612 TF_RETURN_IF_ERROR(
613 ret != func.ret().end()
614 // Expand outputs using provided output mapping
615 ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)
616 // Otherwise output must be one of the function inputs
617 : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
618
619 absl::InlinedVector<string, 1> output_nodes;
620 for (int i = 0; i < output_tensors.size(); ++i) {
621 const string& output_tensor = output_tensors[i];
622
623 NodeDef* identity = function_body.add_node();
624 identity->set_name(output_node_name(out, i));
625 identity->set_op("Identity");
626 (*identity->mutable_attr())["T"].set_type(output_data_type);
627 identity->add_input(output_tensor);
628
629 output_nodes.push_back(identity->name());
630 }
631
632 OutputArgExpansion output{/*output_name=*/out.name(),
633 /*data_type=*/output_data_type,
634 /*is_ref=*/out.is_ref(),
635 /*output_nodes=*/std::move(output_nodes)};
636 outputs.push_back(std::move(output));
637 }
638
639 // Control outputs ensure that all side-effectful nodes in the function body
640 // will execute, even if they are not required to compute regular output args.
641 std::vector<ControlOutput> control_outputs;
642 control_outputs.reserve(func.control_ret_size());
643 for (const auto& control_ret : func.control_ret()) {
644 control_outputs.push_back({control_ret.first, control_ret.second});
645 }
646
647 *item = GrapplerFunctionItem(
648 /*func_name=*/signature.name(),
649 /*description=*/signature.description(),
650 /*func_attr=*/AttrSlice(&func.attr()), std::move(inputs),
651 std::move(outputs), std::move(control_outputs), graph_def_version,
652 signature.is_stateful(), std::move(function_body));
653 return Status::OK();
654 }
655
MakeGrapplerFunctionItem(const FunctionDef & func,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)656 Status MakeGrapplerFunctionItem(const FunctionDef& func,
657 const FunctionLibraryDefinition& flib,
658 const int graph_def_version,
659 GrapplerFunctionItem* item) {
660 return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version,
661 item);
662 }
663
664 // Register GrapplerFunctionItem input arg expansion and function body outputs
665 // in the GrapplerFunctionConnectivity.
RegisterGrapplerFunctionConnectivity(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,GrapplerFunctionConnectivity * connectivity)666 Status RegisterGrapplerFunctionConnectivity(
667 const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
668 GrapplerFunctionConnectivity* connectivity) {
669 for (const InputArgExpansion& input : item.inputs()) {
670 connectivity->RegisterInputArgExpansion(input);
671 }
672 for (const NodeDef& func_body_node : item.function_body().node()) {
673 TF_RETURN_IF_ERROR(
674 RegisterFunctionBodyOutputs(flib, func_body_node, connectivity));
675 }
676 return Status::OK();
677 }
678
ReplaceInputWithConst(const NodeDef & input_const,int input_index,GrapplerFunctionItem * item)679 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
680 GrapplerFunctionItem* item) {
681 if (!IsConstant(input_const)) {
682 return errors::InvalidArgument("Input node ", input_const.name(),
683 " is not a constant");
684 }
685
686 auto& inputs = item->input_arg_expansions_;
687
688 // Find input arg expansion and input placeholder position in it for the
689 // given function input position.
690 InputArgExpansion* input_arg_expansion = nullptr;
691 int placeholder_idx = input_index;
692
693 for (InputArgExpansion& input : inputs) {
694 if (placeholder_idx < input.placeholders.size()) {
695 input_arg_expansion = &input;
696 break;
697 }
698 placeholder_idx -= input.placeholders.size();
699 }
700
701 if (input_arg_expansion == nullptr) {
702 return errors::InvalidArgument("Input placeholder not found: input_index=",
703 input_index, " function=", item->id);
704 }
705
706 // Delete placeholder from input expansion.
707 string placeholder_name = input_arg_expansion->placeholders[placeholder_idx];
708 input_arg_expansion->placeholders.erase(
709 input_arg_expansion->placeholders.begin() + placeholder_idx);
710
711 // Delete empty input expansions.
712 inputs.erase(std::remove_if(inputs.begin(), inputs.end(),
713 [](const InputArgExpansion& input) {
714 return input.placeholders.empty();
715 }),
716 inputs.end());
717
718 // Replace placeholder node in the function body with a const node.
719 for (NodeDef& node : *item->graph.mutable_node()) {
720 if (node.name() == placeholder_name) {
721 node = input_const;
722 node.set_name(placeholder_name);
723 node.clear_input(); // remove potential control inputs
724 node.clear_device(); // device placement is defined by instantiating node
725 }
726 }
727
728 return Status::OK();
729 }
730
RemoveFunctionOutputs(const absl::flat_hash_set<int> & remove_outputs,GrapplerFunctionItem * item,std::vector<std::pair<int,int>> * output_mapping)731 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
732 GrapplerFunctionItem* item,
733 std::vector<std::pair<int, int>>* output_mapping) {
734 DCHECK(output_mapping->empty());
735
736 // Code below assumes that we do not support tensor list outputs and there is
737 // a 1-to-1 mapping between output tensor and output argument expansion.
738 for (const OutputArgExpansion& out_arg : item->outputs()) {
739 DCHECK(out_arg.output_nodes.size() == 1)
740 << "Output arg expansion must have single output";
741 }
742
743 // Do some sanity checking of the removed outputs positions.
744 for (int remove_output : remove_outputs) {
745 if (remove_output < 0 || remove_output >= item->output_size()) {
746 return errors::InvalidArgument(
747 "Function output index is out of bound: index=", remove_output,
748 " max_output_index=", item->output_size());
749 }
750 }
751
752 absl::flat_hash_set<const OutputArgExpansion*> remove_output_args;
753 const auto is_remove_output_arg = [&](const OutputArgExpansion& output) {
754 return remove_output_args.find(&output) != remove_output_args.end();
755 };
756
757 for (int i = 0; i < item->output_size(); ++i) {
758 const OutputArgExpansion& output = item->output(i);
759 if (remove_outputs.find(i) != remove_outputs.end()) {
760 VLOG(3) << "Remove functions output: output_name=" << output.output_name
761 << "(index = " << i << ")";
762 remove_output_args.insert(&output);
763 } else if (!remove_output_args.empty()) {
764 // Add output mapping only if output position changed.
765 output_mapping->push_back({i, i - remove_output_args.size()});
766 }
767 }
768
769 auto& o = item->output_arg_expansions_;
770 o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
771
772 return Status::OK();
773 }
774
MakeFunctionDef(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,FunctionDef * func)775 Status MakeFunctionDef(const GrapplerFunctionItem& item,
776 const FunctionLibraryDefinition& flib,
777 FunctionDef* func) {
778 func->mutable_signature()->set_name(item.id);
779 func->mutable_signature()->set_description(item.description());
780 func->mutable_signature()->set_is_stateful(item.is_stateful());
781
782 // Keep track of placeholders that were added to the graph in place of
783 // expanded function input arguments.
784 absl::flat_hash_set<absl::string_view> input_placeholders;
785 for (const InputArgExpansion& input_arg : item.inputs()) {
786 for (const string& placeholder : input_arg.placeholders) {
787 input_placeholders.insert(placeholder);
788 }
789 }
790
791 // Keep track of identity nodes that were added to the graph in place of
792 // expanded function output arguments.
793 absl::flat_hash_set<absl::string_view> output_nodes;
794 for (const OutputArgExpansion& output_arg : item.outputs()) {
795 for (const string& output_node : output_arg.output_nodes) {
796 output_nodes.insert(output_node);
797 }
798 }
799
800 // If the output identity node was not modified by any optimizer, we can
801 // bypass it and returns the function value from its input.
802 absl::flat_hash_map<absl::string_view, string> output_tensors;
803 for (const NodeDef& func_body_node : item.function_body().node()) {
804 if (!IsIdentity(func_body_node)) continue;
805
806 const string& node_name = func_body_node.name();
807 if (output_nodes.find(node_name) != output_nodes.end()) {
808 // Grappler optimizers might optimize nodes in the fanin of the output
809 // node, and forward their control dependencies. We can't express control
810 // dependencies in a function signature, so we have to keep the node.
811 if (func_body_node.input_size() == 1) {
812 VLOG(3) << "Bypass function output node: " << node_name << " -> "
813 << func_body_node.input(0);
814 output_tensors.emplace(node_name, func_body_node.input(0));
815 } else {
816 VLOG(3) << "Keep function output node: " << node_name;
817 }
818 }
819 }
820
821 // Return output tensor name (input of the output node) if it's safe to bypass
822 // output node, otherwise returns the output node name.
823 const auto output_tensor =
824 [&output_tensors](const OutputArgExpansion& output_arg) -> const string& {
825 const string& output_node = output_arg.output_nodes[0];
826 const auto is_output_tensor = output_tensors.find(output_node);
827 return is_output_tensor == output_tensors.end() ? output_node
828 : is_output_tensor->second;
829 };
830
831 // Build a GrapplerFunctionConnectivity from inputs and new function body.
832 GrapplerFunctionConnectivity connectivity;
833 TF_RETURN_IF_ERROR(
834 RegisterGrapplerFunctionConnectivity(item, flib, &connectivity));
835
836 // Add function input arguments.
837 for (const InputArgExpansion& input_arg : item.inputs()) {
838 DCHECK(input_arg.placeholders.size() == 1) // do some sanity checking
839 << "Inputs of tensor lists are not supported";
840
841 OpDef::ArgDef arg_def;
842 arg_def.set_name(input_arg.input_name);
843 arg_def.set_type(input_arg.data_type);
844 arg_def.set_is_ref(input_arg.is_ref);
845 *func->mutable_signature()->add_input_arg() = arg_def;
846 }
847
848 // Add function output arguments.
849 for (const OutputArgExpansion& output_arg : item.outputs()) {
850 DCHECK(output_arg.output_nodes.size() == 1) // do some sanity checking
851 << "Outputs of tensor lists are not supported";
852
853 OpDef::ArgDef arg_def;
854 arg_def.set_name(output_arg.output_name);
855 arg_def.set_type(output_arg.data_type);
856 arg_def.set_is_ref(output_arg.is_ref);
857 *func->mutable_signature()->add_output_arg() = arg_def;
858
859 TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(
860 output_tensor(output_arg),
861 &(*func->mutable_ret())[output_arg.output_name]));
862 }
863
864 // Add function control outputs.
865 for (const ControlOutput& control_out : item.control_outputs()) {
866 func->mutable_control_ret()->insert(
867 {control_out.output_name, control_out.node_name});
868 *func->mutable_signature()->add_control_output() = control_out.output_name;
869 }
870
871 // Copy function definition specific attributes.
872 for (const auto& attr : item.func_attr()) {
873 const auto& attr_name = attr.first;
874 const auto& attr_value = attr.second;
875 (*func->mutable_attr())[attr_name] = attr_value;
876 }
877
878 // Copy function body nodes to the FunctionDef and update input format
879 for (const NodeDef& func_node : item.function_body().node()) {
880 const string& name = func_node.name();
881
882 // Do not copy input placeholders.
883 if (IsPlaceholder(func_node) && input_placeholders.count(name)) continue;
884 // Do not copy output nodes that we bypassed.
885 if (IsIdentity(func_node) && output_tensors.count(name)) continue;
886
887 NodeDef* func_def_node = func->add_node_def();
888 *func_def_node = func_node;
889 TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node));
890 }
891
892 return Status::OK();
893 }
894
895 } // end namespace grappler
896 } // end namespace tensorflow
897