1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/utils/functions.h"
16
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_replace.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph_def_util.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/utils.h"
34 #include "tensorflow/core/lib/strings/scanner.h"
35
36 namespace tensorflow {
37 namespace grappler {
38
GrapplerFunctionItem(string func_name,string description,AttrSlice func_attr,std::vector<const FunctionDef::ArgAttrs * > arg_attr,std::vector<InputArgInstantiation> input_args,std::vector<OutputArgInstantiation> output_args,std::vector<ControlOutput> control_outputs,const int graph_def_version,const bool is_stateful,GraphDef && function_body)39 GrapplerFunctionItem::GrapplerFunctionItem(
40 string func_name, string description, AttrSlice func_attr,
41 std::vector<const FunctionDef::ArgAttrs*> arg_attr,
42 std::vector<InputArgInstantiation> input_args,
43 std::vector<OutputArgInstantiation> output_args,
44 std::vector<ControlOutput> control_outputs, const int graph_def_version,
45 const bool is_stateful, GraphDef&& function_body)
46 : description_(std::move(description)),
47 func_attr_(func_attr),
48 arg_attr_(std::move(arg_attr)),
49 input_args_(std::move(input_args)),
50 output_args_(std::move(output_args)),
51 control_outputs_(std::move(control_outputs)),
52 is_stateful_(is_stateful) {
53 id = std::move(func_name);
54 graph = std::move(function_body);
55 graph.mutable_versions()->set_producer(graph_def_version);
56
57 // Fill the feed nodes with function input arguments.
58 for (const InputArgInstantiation& input_arg : input_args_) {
59 feed.push_back({input_arg.node_name, Tensor()});
60 }
61 // Fill the fetch nodes with outputs.
62 for (const OutputArgInstantiation& output_arg : output_args_) {
63 fetch.push_back(output_arg.node_name);
64 }
65 // We must keep all control output nodes.
66 for (const ControlOutput& control_output : control_outputs_) {
67 keep_ops.push_back(control_output.node_name);
68 }
69
70 // Tensorflow functions execution semantics is different from the main graph,
71 // and we need to preserve it when we do graph optimizations.
72 optimization_options().allow_pruning_stateful_and_dataset_ops = false;
73 }
74
description() const75 const string& GrapplerFunctionItem::description() const { return description_; }
76
inputs() const77 const std::vector<InputArgInstantiation>& GrapplerFunctionItem::inputs() const {
78 return input_args_;
79 }
80
input(int i) const81 const InputArgInstantiation& GrapplerFunctionItem::input(int i) const {
82 return input_args_[i];
83 }
84
input_size() const85 const std::size_t GrapplerFunctionItem::input_size() const {
86 return input_args_.size();
87 }
88
outputs() const89 const std::vector<OutputArgInstantiation>& GrapplerFunctionItem::outputs()
90 const {
91 return output_args_;
92 }
93
output(int i) const94 const OutputArgInstantiation& GrapplerFunctionItem::output(int i) const {
95 return output_args_[i];
96 }
97
output_size() const98 const std::size_t GrapplerFunctionItem::output_size() const {
99 return output_args_.size();
100 }
101
control_outputs() const102 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
103 const {
104 return control_outputs_;
105 }
106
control_output_size() const107 const std::size_t GrapplerFunctionItem::control_output_size() const {
108 return control_outputs_.size();
109 }
110
func_attr() const111 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
112
113 const std::vector<const FunctionDef::ArgAttrs*>&
arg_attr() const114 GrapplerFunctionItem::arg_attr() const {
115 return arg_attr_;
116 }
117
function_body() const118 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
119
mutable_function_body()120 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
121
is_stateful() const122 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; }
123
SwapFunctionBody(GraphDef && other)124 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
125 graph.Swap(&other);
126 return *this;
127 }
128
HasParametrizedType(const FunctionDef & func)129 bool HasParametrizedType(const FunctionDef& func) {
130 const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
131 return !arg.type_attr().empty() || !arg.number_attr().empty() ||
132 !arg.type_list_attr().empty();
133 };
134
135 const auto& input = func.signature().input_arg();
136 const auto& output = func.signature().output_arg();
137 return std::any_of(input.begin(), input.end(), is_type_parametrized) ||
138 std::any_of(output.begin(), output.end(), is_type_parametrized);
139 }
140
HasParametrizedBody(const FunctionDef & func)141 bool HasParametrizedBody(const FunctionDef& func) {
142 const auto is_parametrized = [&](const NodeDef& node) {
143 for (const auto& attr : node.attr()) {
144 if (!attr.second.placeholder().empty()) return true;
145 }
146 return false;
147 };
148 return std::any_of(func.node_def().begin(), func.node_def().end(),
149 is_parametrized);
150 }
151
IsParametrized(const FunctionDef & func)152 bool IsParametrized(const FunctionDef& func) {
153 return HasParametrizedType(func) || HasParametrizedBody(func);
154 }
155
InstantiationTypeParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,DataType> * type_parameters)156 Status InstantiationTypeParameters(
157 const FunctionDef& func, const AttrSlice& func_instantiation_attr,
158 absl::flat_hash_map<string, DataType>* type_parameters) {
159 if (!type_parameters->empty()) {
160 return errors::InvalidArgument("Type parameters output map must be empty");
161 }
162
163 const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) -> Status {
164 if (!arg.type_attr().empty()) {
165 DataType dtype;
166 TF_RETURN_IF_ERROR(
167 GetNodeAttr(func_instantiation_attr, arg.type_attr(), &dtype));
168 type_parameters->emplace(arg.type_attr(), dtype);
169
170 } else if (!arg.type_list_attr().empty()) {
171 std::vector<DataType> dtypes;
172 TF_RETURN_IF_ERROR(
173 GetNodeAttr(func_instantiation_attr, arg.type_list_attr(), &dtypes));
174 int index = 0;
175 for (const DataType& dtype : dtypes) {
176 type_parameters->emplace(absl::StrCat(arg.type_list_attr(), ":", index),
177 dtype);
178 ++index;
179 }
180 }
181 return Status::OK();
182 };
183
184 for (const auto& input : func.signature().input_arg())
185 TF_RETURN_IF_ERROR(resolve_type_attr(input));
186 for (const auto& output : func.signature().output_arg())
187 TF_RETURN_IF_ERROR(resolve_type_attr(output));
188
189 return Status::OK();
190 }
191
InstantiationBodyParameters(const FunctionDef & func,const AttrSlice & func_instantiation_attr,absl::flat_hash_map<string,AttrValue> * body_parameters)192 Status InstantiationBodyParameters(
193 const FunctionDef& func, const AttrSlice& func_instantiation_attr,
194 absl::flat_hash_map<string, AttrValue>* body_parameters) {
195 if (!body_parameters->empty()) {
196 return errors::InvalidArgument("Body parameters output map must be empty");
197 }
198
199 for (const NodeDef& func_body_node : func.node_def()) {
200 for (auto& attr : func_body_node.attr()) {
201 const string& placeholder = attr.second.placeholder();
202
203 if (placeholder.empty() || body_parameters->contains(placeholder)) {
204 continue;
205 }
206
207 const AttrValue* placeholder_value =
208 func_instantiation_attr.Find(placeholder);
209 if (placeholder_value) {
210 body_parameters->insert({placeholder, *placeholder_value});
211 } else {
212 return errors::InvalidArgument("Can't resolve placeholder: ",
213 placeholder);
214 }
215 }
216 }
217
218 return Status::OK();
219 }
220
MakeGrapplerFunctionItem(const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)221 Status MakeGrapplerFunctionItem(const FunctionDef& func,
222 const AttrSlice& func_instantiation_attr,
223 const FunctionLibraryDefinition& flib,
224 const int graph_def_version,
225 GrapplerFunctionItem* item) {
226 const OpDef& signature = func.signature();
227
228 if (signature.name().empty()) {
229 return errors::InvalidArgument("Function name must be specified");
230 }
231
232 // Function types will be resolved from function instantiation attributes. All
233 // other attributes will be lost during conversion to FunctionDef.
234 for (const OpDef::AttrDef& attr : signature.attr()) {
235 if (attr.type() != "type") {
236 return errors::InvalidArgument(
237 "Function signature must have only type attributes");
238 }
239 }
240
241 // Instantiate function into a statically defined FunctionBody Graph.
242 std::unique_ptr<FunctionBody> fbody;
243 TF_RETURN_IF_ERROR(
244 FunctionDefToBodyHelper(func, func_instantiation_attr, &flib, &fbody));
245
246 GraphDef function_body;
247 fbody->graph->ToGraphDef(&function_body);
248
249 // Function body shares the library with the graph that instantiated it. We do
250 // not need a full copy of the function library, just the reachable subset.
251 *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
252
253 VLOG(3) << absl::Substitute(
254 "Deleted $0 unreachable functions from the Grappler function item "
255 "instantiation of $1 (library size = $2)",
256 flib.num_functions() - function_body.library().function_size(),
257 signature.name(), function_body.library().function_size());
258
259 const int num_instantiated_inputs = fbody->arg_types.size();
260 const int num_instantiated_outputs = fbody->ret_types.size();
261
262 std::vector<InputArgInstantiation> inputs;
263 inputs.reserve(num_instantiated_inputs);
264
265 for (int in_id = 0; in_id < num_instantiated_inputs; ++in_id) {
266 const Node* node = fbody->arg_nodes[in_id];
267 const DataType& dtype = fbody->arg_types[in_id];
268 inputs.emplace_back(node->name(), dtype);
269 }
270
271 std::vector<OutputArgInstantiation> outputs;
272 outputs.reserve(num_instantiated_outputs);
273
274 for (int out_id = 0; out_id < num_instantiated_outputs; ++out_id) {
275 const Node* node = fbody->ret_nodes[out_id];
276 const DataType& dtype = fbody->ret_types[out_id];
277 outputs.emplace_back(node->name(), dtype);
278 }
279
280 // Control outputs ensure that all side-effectful nodes in the function body
281 // will execute, even if they are not required to compute regular output args.
282 std::vector<ControlOutput> control_outputs;
283 control_outputs.reserve(func.control_ret_size());
284 for (const auto& control_ret : func.control_ret()) {
285 control_outputs.push_back({control_ret.first, control_ret.second});
286 }
287
288 std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr);
289 for (const auto& attr : func.arg_attr()) {
290 arg_attr.at(attr.first) = &attr.second;
291 }
292
293 *item = GrapplerFunctionItem(
294 /*func_name=*/signature.name(),
295 /*description=*/signature.description(),
296 /*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr),
297 std::move(inputs), std::move(outputs), std::move(control_outputs),
298 graph_def_version, signature.is_stateful(), std::move(function_body));
299 return Status::OK();
300 }
301
MakeGrapplerFunctionItem(const FunctionDef & func,const FunctionLibraryDefinition & flib,const int graph_def_version,GrapplerFunctionItem * item)302 Status MakeGrapplerFunctionItem(const FunctionDef& func,
303 const FunctionLibraryDefinition& flib,
304 const int graph_def_version,
305 GrapplerFunctionItem* item) {
306 return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version,
307 item);
308 }
309
ReplaceInputWithConst(const NodeDef & input_const,int input_index,GrapplerFunctionItem * item)310 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
311 GrapplerFunctionItem* item) {
312 if (!IsConstant(input_const)) {
313 return errors::InvalidArgument("Input node is not a constant: ",
314 SummarizeNodeDef(input_const));
315 }
316 const int item_input_size = item->input_size();
317 if (input_index < 0 || input_index >= item_input_size) {
318 return errors::InvalidArgument(
319 "Function input index is out of bound: index=", input_index,
320 " input_size=", item->input_size());
321 }
322
323 const InputArgInstantiation& input_arg = item->input(input_index);
324
325 for (NodeDef& node : *item->graph.mutable_node()) {
326 // Replace '_Arg' node in the function body with a 'Const' node.
327 if (node.name() == input_arg.node_name) {
328 node = input_const;
329 node.set_name(input_arg.node_name);
330 node.clear_input();
331 node.clear_device(); // device placement is defined by instantiating node
332 }
333
334 // Update index in all inputs after the removed const input.
335 if (IsArg(node)) {
336 auto attrs = AttrSlice(node);
337 int index;
338 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
339 if (index >= input_index) {
340 (*node.mutable_attr())["index"].set_i(index - 1);
341 }
342 }
343 }
344
345 item->input_args_.erase(item->input_args_.begin() + input_index);
346 item->arg_attr_.erase(item->arg_attr_.begin() + input_index);
347
348 return Status::OK();
349 }
350
RemoveFunctionOutputs(const absl::flat_hash_set<int> & remove_outputs,GrapplerFunctionItem * item,std::vector<std::pair<int,int>> * output_mapping)351 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
352 GrapplerFunctionItem* item,
353 std::vector<std::pair<int, int>>* output_mapping) {
354 DCHECK(output_mapping->empty());
355
356 // Do some sanity checking of the removed outputs positions.
357 for (int remove_output : remove_outputs) {
358 const int item_output_size = item->output_size();
359 if (remove_output < 0 || remove_output >= item_output_size) {
360 return errors::InvalidArgument(
361 "Function output index is out of bound: index=", remove_output,
362 " output_size=", item->output_size());
363 }
364 }
365
366 absl::flat_hash_set<const OutputArgInstantiation*> remove_output_args;
367 const auto is_remove_output_arg = [&](const OutputArgInstantiation& output) {
368 return remove_output_args.find(&output) != remove_output_args.end();
369 };
370
371 for (int i = 0, end = item->output_size(); i < end; ++i) {
372 const OutputArgInstantiation& output = item->output(i);
373 if (remove_outputs.contains(i)) {
374 VLOG(3) << "Remove functions output: name=" << output.node_name
375 << "(index = " << i << ")";
376 remove_output_args.insert(&output);
377 } else if (!remove_output_args.empty()) {
378 // Add output mapping only if output position changed.
379 output_mapping->push_back({i, i - remove_output_args.size()});
380 }
381 }
382
383 // Update 'index' attribute in all '_Retval' nodes that are in output mapping.
384 for (NodeDef& node : *item->graph.mutable_node()) {
385 if (IsRetval(node)) {
386 auto attrs = AttrSlice(node);
387 int index;
388 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "index", &index));
389
390 for (const auto& mapping : *output_mapping) {
391 const int from = mapping.first;
392 const int to = mapping.second;
393 if (index == from) {
394 (*node.mutable_attr())["index"].set_i(to);
395 }
396 }
397 }
398 }
399
400 auto& o = item->output_args_;
401 o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
402
403 return Status::OK();
404 }
405
406 namespace {
407
408 // FunctionDef uses different connectivity encoding for the function body nodes,
409 // than a GraphDef (see function.proto for details). This is a helper class that
410 // converts inputs in GraphDef format (node[:position]) to the FunctionDef
411 // format (node:output[:position]).
412 class MakeFunctionDefHelper {
413 public:
414 MakeFunctionDefHelper() = default;
415
416 Status Initialize(const GrapplerFunctionItem& item,
417 const FunctionLibraryDefinition& flib);
418
419 // Converts input name from GraphDef format (name[:position]) to the
420 // FunctionDef input format (name[:output][:position]) using registered input
421 // arg instantiations and function body outputs.
422 Status AsFunctionDefInput(const string& graph_def_input,
423 string* func_def_input) const;
424
425 // Updates Node inputs from GraphDef to FunctionDef format.
426 Status AsFunctionDefNode(NodeDef* function_body_node) const;
427
IsInputNode(const NodeDef & node) const428 bool IsInputNode(const NodeDef& node) const {
429 return input_nodes_.contains(node.name());
430 }
431
IsOutputNode(const NodeDef & node) const432 bool IsOutputNode(const NodeDef& node) const {
433 return output_nodes_.contains(node.name());
434 }
435
436 private:
437 absl::flat_hash_set<absl::string_view> input_nodes_;
438 absl::flat_hash_set<absl::string_view> output_nodes_;
439 // Mapping from function body node name to output names range map.
440 absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
441 };
442
Initialize(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib)443 Status MakeFunctionDefHelper::Initialize(
444 const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib) {
445 for (const InputArgInstantiation& input_arg : item.inputs()) {
446 input_nodes_.insert(input_arg.node_name);
447 }
448 for (const OutputArgInstantiation& output_arg : item.outputs()) {
449 output_nodes_.insert(output_arg.node_name);
450 }
451
452 for (const NodeDef& node : item.function_body().node()) {
453 const OpRegistrationData* registration;
454 TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration));
455
456 tensorflow::NameRangeMap outputs_range_map;
457 TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
458 node, registration->op_def, nullptr, &outputs_range_map));
459
460 function_body_outputs_.emplace(node.name(), std::move(outputs_range_map));
461 }
462
463 return Status::OK();
464 }
465
AsFunctionDefInput(const string & graph_def_input,string * func_def_input) const466 Status MakeFunctionDefHelper::AsFunctionDefInput(const string& graph_def_input,
467 string* func_def_input) const {
468 if (IsControlInput(graph_def_input)) {
469 *func_def_input = graph_def_input;
470 return Status::OK();
471 }
472
473 const SafeTensorId tensor = ParseTensorName(graph_def_input);
474 DCHECK_GE(tensor.index(), 0);
475
476 // Graph def input corresponds to one of the function inputs.
477 const auto is_input = input_nodes_.find(tensor.node());
478 if (is_input != input_nodes_.end()) {
479 DCHECK_EQ(tensor.index(), 0);
480 *func_def_input = tensor.node();
481 return Status::OK();
482 }
483
484 // Or it must be output from one of the function body nodes
485 const auto is_body_output = function_body_outputs_.find(tensor.node());
486 if (is_body_output != function_body_outputs_.end()) {
487 const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
488
489 for (const auto& el : outputs_range_map) {
490 const auto& output_name = el.first;
491 const auto& output_range = el.second;
492 if (tensor.index() >= output_range.first &&
493 tensor.index() < output_range.second) {
494 *func_def_input = absl::StrCat(tensor.node(), ":", output_name, ":",
495 tensor.index() - output_range.first);
496 return Status::OK();
497 }
498 }
499 }
500
501 return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
502 }
503
AsFunctionDefNode(NodeDef * function_body_node) const504 Status MakeFunctionDefHelper::AsFunctionDefNode(
505 NodeDef* function_body_node) const {
506 string func_def_input;
507
508 for (int i = 0; i < function_body_node->input_size(); ++i) {
509 TF_RETURN_IF_ERROR(
510 AsFunctionDefInput(function_body_node->input(i), &func_def_input));
511 function_body_node->set_input(i, func_def_input);
512 }
513
514 return Status::OK();
515 }
516
517 } // namespace
518
MakeFunctionDef(const GrapplerFunctionItem & item,const FunctionLibraryDefinition & flib,FunctionDef * func)519 Status MakeFunctionDef(const GrapplerFunctionItem& item,
520 const FunctionLibraryDefinition& flib,
521 FunctionDef* func) {
522 func->mutable_signature()->set_name(item.id);
523 func->mutable_signature()->set_description(item.description());
524 func->mutable_signature()->set_is_stateful(item.is_stateful());
525
526 MakeFunctionDefHelper helper;
527 TF_RETURN_IF_ERROR(helper.Initialize(item, flib));
528
529 // Mapping from the '_Retval' node name to the output tensor.
530 absl::flat_hash_map<absl::string_view, string> output_tensors;
531 for (const NodeDef& func_body_node : item.function_body().node()) {
532 if (!helper.IsOutputNode(func_body_node)) continue;
533 if (func_body_node.input_size() != 1) {
534 return errors::Internal("_Retval node must have single input: ",
535 SummarizeNodeDef(func_body_node));
536 }
537 output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
538 }
539
540 for (const InputArgInstantiation& input_arg : item.inputs()) {
541 OpDef::ArgDef arg_def;
542 arg_def.set_name(input_arg.node_name);
543 arg_def.set_type(input_arg.data_type);
544 arg_def.set_is_ref(IsRefType(input_arg.data_type));
545 *func->mutable_signature()->add_input_arg() = arg_def;
546 }
547
548 // Add function output arguments.
549 for (const OutputArgInstantiation& output_arg : item.outputs()) {
550 const string output_name =
551 absl::StrReplaceAll(output_arg.node_name, {{"_RetVal", ""}});
552
553 OpDef::ArgDef arg_def;
554 arg_def.set_name(output_name);
555 arg_def.set_type(output_arg.data_type);
556 arg_def.set_is_ref(IsRefType(output_arg.data_type));
557 *func->mutable_signature()->add_output_arg() = arg_def;
558
559 auto it = output_tensors.find(output_arg.node_name);
560 if (it == output_tensors.end()) {
561 return errors::Internal(
562 "Can't find an output tensor for the output node: ",
563 output_arg.node_name);
564 }
565
566 TF_RETURN_IF_ERROR(helper.AsFunctionDefInput(
567 it->second, &(*func->mutable_ret())[output_name]));
568 }
569
570 // Add function control outputs.
571 for (const ControlOutput& control_out : item.control_outputs()) {
572 func->mutable_control_ret()->insert(
573 {control_out.output_name, control_out.node_name});
574 *func->mutable_signature()->add_control_output() = control_out.output_name;
575 }
576
577 // Copy function definition specific attributes.
578 for (const auto& attr : item.func_attr()) {
579 const auto& attr_name = attr.first;
580 const auto& attr_value = attr.second;
581 (*func->mutable_attr())[attr_name] = attr_value;
582 }
583
584 // Copy function arg attributes.
585 for (int i = 0, end = item.arg_attr().size(); i < end; ++i) {
586 const auto* attr = item.arg_attr().at(i);
587 if (attr != nullptr) {
588 (*func->mutable_arg_attr())[i] = *attr;
589 }
590 }
591
592 // Copy function body nodes to the FunctionDef and update input format
593 for (const NodeDef& func_node : item.function_body().node()) {
594 // Skip original `_Arg` and `_Retval` nodes. If node was converted to some
595 // other type (e.g. inputs converted to placeholders), we need to check that
596 // it's not registered as function input or output node.
597 if (IsArg(func_node) || IsRetval(func_node) ||
598 helper.IsInputNode(func_node) || helper.IsOutputNode(func_node))
599 continue;
600
601 NodeDef* func_def_node = func->add_node_def();
602 *func_def_node = func_node;
603 TF_RETURN_IF_ERROR(helper.AsFunctionDefNode(func_def_node));
604 }
605
606 return Status::OK();
607 }
608
609 } // end namespace grappler
610 } // end namespace tensorflow
611