1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/function_body.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/graph_def_util.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op_def_builder.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/platform/errors.h"
45
46 namespace tensorflow {
47
48 namespace {
49
ValidateTensorId(const tf2xla::TensorId & id)50 Status ValidateTensorId(const tf2xla::TensorId& id) {
51 if (id.node_name().empty()) {
52 return errors::InvalidArgument("TensorId node_name must be non-empty");
53 }
54 if (id.output_index() < 0) {
55 return errors::InvalidArgument("TensorId output_index must be positive");
56 }
57 return Status::OK();
58 }
59
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)60 Status CheckNameDuplicates(const string& kind, const string& name,
61 std::set<string>* names) {
62 if (!name.empty()) {
63 if (!names->insert(name).second) {
64 return errors::InvalidArgument("duplicate ", kind, " name: ", name);
65 }
66 }
67 return Status::OK();
68 }
69
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)70 Status CheckFeedFetchNameConflicts(const string& kind,
71 const std::set<string>& names) {
72 // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
73 // since that will cause a collision in codegen symbols.
74 for (const string& name : names) {
75 const string name_data(name + "_data");
76 if (names.find(name_data) != names.end()) {
77 return errors::InvalidArgument("conflicting ", kind, " name: ", name,
78 " and ", name_data);
79 }
80 }
81 return Status::OK();
82 }
83
84 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
85 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)86 Status CopyAssociatedFunctions(Graph* g,
87 const FunctionLibraryDefinition* lookup_fld,
88 FunctionLibraryDefinition* fld) {
89 for (Node* n : g->op_nodes()) {
90 for (const auto& associated_function :
91 GetAssociatedFunctions(*n, lookup_fld)) {
92 switch (associated_function.type()) {
93 case AssociatedFunctionInfo::kFunctionCallNode: {
94 const FunctionDef* fdef =
95 lookup_fld->Find(associated_function.func_name());
96 if (!fdef) {
97 return errors::Internal(
98 "Cannot find function ", associated_function.func_name(),
99 " for function call node ", n->DebugString());
100 }
101 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
102 break;
103 }
104 case AssociatedFunctionInfo::kSymbolicGradient:
105 case AssociatedFunctionInfo::kFunctionAttr:
106 break;
107 }
108 }
109 }
110 return Status::OK();
111 }
112
113 // Replaces the single edge feeding into {dst,dst_input} with a new
114 // src/src_output specified by {with,with_output}.
ReplaceEdge(Graph * g,Node * dst,int dst_input,Node * with,int with_output)115 StatusOr<Node*> ReplaceEdge(Graph* g, Node* dst, int dst_input, Node* with,
116 int with_output) {
117 NodeDef replace_def = dst->def();
118 *replace_def.mutable_input(dst_input) = with->name();
119 TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, replace_def));
120 const Edge* usage_edge;
121 TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &usage_edge));
122 g->RemoveEdge(usage_edge);
123 g->AddEdge(with, with_output, replace_node, dst_input);
124 return replace_node;
125 }
126
127 // Replaces usages of the given `src_output` index of the given `src` node with
128 // the given `replacement` node (assumes the :0 output of `replacement`).
ReplaceSrcOutputUsageWithNode(Graph * g,Node * src,int src_output,Node * replacement)129 Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output,
130 Node* replacement) {
131 VLOG(1) << "Replace usages of output " << src_output << " of node "
132 << (VLOG_IS_ON(3) ? src->DebugString() : src->name()) << " with "
133 << (VLOG_IS_ON(3) ? replacement->DebugString() : replacement->name());
134 // Collect all usages of the specified src output (src->out_edges() iterator
135 // will not be stable under modifications).
136 struct OutEdgeInfo {
137 int dst_node_id, dst_input;
138 };
139 std::vector<OutEdgeInfo> usages;
140 for (const Edge* e : src->out_edges()) {
141 if (e->IsControlEdge() || e->src_output() != src_output) {
142 continue;
143 }
144 usages.push_back({e->dst()->id(), e->dst_input()});
145 }
146
147 // Now, replace each usage.
148 for (int i = 0, end = usages.size(); i < end; i++) {
149 // Make a copy of `usage_node`, and change its input to const node.
150 Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
151 VLOG(2) << " Replace usage by " << usage_node->DebugString();
152 // Note: Replacement output index is presumed to be 0.
153 TF_ASSIGN_OR_RETURN(
154 Node * replace_node,
155 ReplaceEdge(g, usage_node, usages[i].dst_input, replacement, 0));
156 // Later entries in `usages` might have `usage_node` as dst node, but
157 // `usage_node` is removed. Replace such entries with `replace_node`.
158 for (int j = i + 1, end = usages.size(); j < end; j++) {
159 if (usages[j].dst_node_id == usages[i].dst_node_id) {
160 usages[j].dst_node_id = replace_node->id();
161 }
162 }
163 }
164 return Status::OK();
165 }
166
167 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
168 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)169 Status ReplaceArgUsageWithConstNode(
170 Graph* g,
171 const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
172 // Collect all _Arg nodes.
173 absl::flat_hash_map<int, Node*> arg_nodes;
174 for (Node* n : g->op_nodes()) {
175 if (n->IsArg()) {
176 int index;
177 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
178 arg_nodes[index] = n;
179 }
180 }
181
182 for (const auto& iter : const_input_index_to_node) {
183 int arg_index = iter.first;
184 VLOG(2) << "Replace usages of _Arg " << arg_index;
185 NodeDef const_def = iter.second->def();
186 const_def.set_name(g->NewName(const_def.name()));
187 Status s;
188 Node* const_node = g->AddNode(const_def, &s);
189 TF_RETURN_IF_ERROR(s);
190 Node* arg_node = arg_nodes[arg_index];
191 TF_RETURN_IF_ERROR(
192 ReplaceSrcOutputUsageWithNode(g, arg_node, 0, const_node));
193 }
194 return Status::OK();
195 }
196
197 // Replaces the single input to _Retval nodes with an index in the keys of
198 // const_input_index_to_node with the single output of the corresponding _Arg
199 // node.
ReplaceRetvalInputWithArg(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)200 Status ReplaceRetvalInputWithArg(
201 Graph* g,
202 const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
203 absl::flat_hash_map<int, Node*> arg_nodes;
204 absl::flat_hash_map<int, Node*> ret_nodes;
205 for (Node* n : g->op_nodes()) {
206 if (n->IsRetval() || n->IsArg()) {
207 int index;
208 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
209 if (n->IsRetval()) {
210 ret_nodes[index] = n;
211 } else {
212 arg_nodes[index] = n;
213 }
214 }
215 }
216
217 for (const auto& iter : const_input_index_to_node) {
218 int arg_index = iter.first;
219 VLOG(2) << "Bind _Retval " << arg_index << " to _Arg " << arg_index;
220 TF_RETURN_IF_ERROR(
221 ReplaceEdge(g, ret_nodes[arg_index], 0, arg_nodes[arg_index], 0)
222 .status());
223 }
224 return Status::OK();
225 }
226
227 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
228 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
229 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld,bool passthrough_arg_to_retval=false)230 Status PropagateConstIntoFuncAttr(
231 Node* n, const string& attr_name,
232 const absl::flat_hash_map<int, const Node*>& const_input_index_to_node,
233 const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld,
234 bool passthrough_arg_to_retval = false) {
235 VLOG(1) << "Propagate const into " << attr_name << " of node " << n->name();
236 // Instantiate the function.
237 NameAttrList func_attr;
238 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
239 const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
240 if (!fdef) {
241 return errors::Internal("Cannot find function ", func_attr.name(),
242 " for node ", n->name());
243 }
244 std::unique_ptr<FunctionBody> fbody;
245 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
246 *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
247
248 // Rewrite _Arg usages with Const node.
249 Graph* func_graph = fbody->graph;
250 TF_RETURN_IF_ERROR(
251 ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
252 if (passthrough_arg_to_retval) {
253 TF_RETURN_IF_ERROR(
254 ReplaceRetvalInputWithArg(func_graph, const_input_index_to_node));
255 }
256
257 // Save rewritten function.
258 FunctionDef replace_fdef;
259 string new_func_name =
260 fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
261 TF_RETURN_IF_ERROR(
262 GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
263 TF_RETURN_IF_ERROR(fld->AddFunctionDef(
264 replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
265
266 VLOG(1) << "replace func " << func_attr.name() << " with " << new_func_name;
267 // Change the node to use rewritten function.
268 func_attr.set_name(new_func_name);
269 n->ClearAttr(attr_name);
270 n->AddAttr(attr_name, func_attr);
271
272 TF_RETURN_IF_ERROR(fld->AddFunctionDef(
273 replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
274
275 // Copy associated functions.
276 TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
277
278 return Status::OK();
279 }
280
281 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
282 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)283 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
284 const FunctionLibraryDefinition* lookup_fld,
285 FunctionLibraryDefinition* fld) {
286 // Notice that first input for If node is predicate; other inputs are function
287 // inputs.
288 absl::flat_hash_map<int, const Node*> const_input_index_to_node;
289 for (int i = 1; i < if_node->num_inputs(); i++) {
290 const Node* input_node;
291 TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
292 if (input_node->type_string() == "Const") {
293 const_input_index_to_node[i - 1] = input_node;
294 }
295 }
296 if (const_input_index_to_node.empty()) {
297 return Status::OK();
298 }
299
300 // Rewrite "then_branch" and "else_branch" function, replace usage of those
301 // _Arg nodes with corresponding const node.
302 for (const auto& attr_name :
303 std::vector<string>{"then_branch", "else_branch"}) {
304 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
305 if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
306 }
307
308 return Status::OK();
309 }
310
311 using GraphCache = absl::flat_hash_map<string, std::unique_ptr<FunctionBody>>;
312
FindOrInsert(GraphCache * cache,const NameAttrList & body_attr,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld)313 StatusOr<FunctionBody*> FindOrInsert(
314 GraphCache* cache, const NameAttrList& body_attr,
315 const FunctionLibraryDefinition* lookup_fld,
316 const FunctionLibraryDefinition* fallback_fld) {
317 const string name = body_attr.name();
318 std::unique_ptr<FunctionBody>& value = (*cache)[name];
319 if (!value) {
320 const FunctionDef* body_func = lookup_fld->Find(name);
321 if (!body_func && fallback_fld != nullptr) {
322 body_func = fallback_fld->Find(name);
323 }
324 if (!body_func) {
325 return errors::Internal("Traverse: Cannot find body function ", name);
326 }
327 std::unique_ptr<FunctionBody> fbody;
328 Status s = FunctionDefToBodyHelper(*body_func, AttrSlice(&body_attr.attr()),
329 lookup_fld, &fbody);
330 if (!s.ok() && fallback_fld != nullptr) {
331 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
332 *body_func, AttrSlice(&body_attr.attr()), fallback_fld, &fbody));
333 }
334 value = std::move(fbody);
335 }
336 return value.get();
337 }
338 // Determines whether a loop body is invariant for the given argument index.
339 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
340 const FunctionLibraryDefinition* lookup_fld,
341 const FunctionLibraryDefinition* fallback_fld,
342 GraphCache* cache);
343
344 // Traces backward through non-modifying ops such as Identity and loop-invariant
345 // While, to find a preceding source edge.
TraverseUnmodifiedPathBackward(const Edge * src,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)346 StatusOr<const Edge*> TraverseUnmodifiedPathBackward(
347 const Edge* src, const FunctionLibraryDefinition* lookup_fld,
348 const FunctionLibraryDefinition* fallback_fld, GraphCache* cache) {
349 const Edge* e = src;
350 VLOG(2) << "Traverse: Begin at " << e->DebugString();
351 // TODO(b/184727356): Also traverse If/Case nodes.
352 // Begin walking back from the output node.
353 while (IsConstTraversableOpType(e->src())) {
354 VLOG(3) << e->DebugString();
355
356 if (e->src()->IsWhileNode()) {
357 NameAttrList body_attr;
358 TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "body", &body_attr));
359 TF_ASSIGN_OR_RETURN(
360 FunctionBody * fbody,
361 FindOrInsert(cache, body_attr, lookup_fld, fallback_fld));
362 TF_ASSIGN_OR_RETURN(bool is_loop_invariant,
363 IsLoopInvariant(fbody, e->src_output(), lookup_fld,
364 fallback_fld, cache));
365 if (!is_loop_invariant) {
366 VLOG(2) << "Non-loop-invariant: index " << e->src_output() << " of "
367 << body_attr.name();
368 break;
369 }
370 } // if While|StatelessWhile
371 // Proceed backward to the src's input corresponding with the output index.
372 TF_RETURN_IF_ERROR(e->src()->input_edge(e->src_output(), &e));
373 }
374 VLOG(2) << "Traverse: Finish at " << e->DebugString();
375
376 return e;
377 }
378
379 // Determines whether a loop body is invariant for the given argument index.
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)380 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
381 const FunctionLibraryDefinition* lookup_fld,
382 const FunctionLibraryDefinition* fallback_fld,
383 GraphCache* cache) {
384 const Edge* e;
385 TF_RETURN_IF_ERROR(loop_body->ret_nodes[index]->input_edge(0, &e));
386 TF_ASSIGN_OR_RETURN(
387 const Edge* reachable,
388 TraverseUnmodifiedPathBackward(e, lookup_fld, fallback_fld, cache));
389 if (reachable->src()->id() == loop_body->arg_nodes[index]->id()) {
390 VLOG(2) << "Index " << index << " is loop invariant.";
391 return true;
392 }
393 VLOG(2) << "Index " << index << " not loop invariant: "
394 << "walk backward from " << e->src()->DebugString() << " to "
395 << reachable->src()->DebugString() << " did not reach "
396 << loop_body->arg_nodes[index]->DebugString();
397 return false;
398 }
399
400 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
401 // cond/body function to replace _Arg nodes with those Const inputs. Then,
402 // propagate these Const to consumers of the relevant outputs of the while loop.
PropagateConstIntoAndAroundWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)403 Status PropagateConstIntoAndAroundWhileNode(
404 Graph* g, Node* while_node, const FunctionLibraryDefinition* lookup_fld,
405 FunctionLibraryDefinition* fld) {
406 VLOG(1) << "Propagate const into " << while_node->name();
407
408 // For "While" node, we should only replace _Arg nodes which are loop
409 // invariants. For such _Arg nodes, the return value's input will come
410 // directly from the corresponding arg.
411 absl::flat_hash_map<int, const Node*> const_input_index_to_node;
412 absl::flat_hash_map<int, Node*> const_input_index_to_mutable_node;
413 NameAttrList body_attr;
414 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
415 const string fn_name = body_attr.name();
416 const FunctionDef* body_func = lookup_fld->Find(fn_name);
417 if (!body_func) {
418 return errors::Internal("Propagate: Cannot find body function ", fn_name,
419 " for While node ", while_node->name());
420 }
421 std::unique_ptr<FunctionBody> fbody;
422 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
423 *body_func, AttrSlice(&body_attr.attr()), lookup_fld, &fbody));
424 GraphCache cache;
425 for (int i = 0; i < while_node->num_inputs(); i++) {
426 // Check if i-th retval's input comes from i-th arg directly.
427 // For resource variable input of While nodes, TF2XLA convention is to place
428 // them at the end of all inputs (after all data inputs), and *not* return
429 // them. So number of While node inputs might be larger than number of its
430 // outputs.
431 if (i >= body_func->signature().output_arg_size()) {
432 break;
433 }
434
435 const Edge* input_edge;
436 TF_RETURN_IF_ERROR(while_node->input_edge(i, &input_edge));
437 TF_ASSIGN_OR_RETURN(input_edge, TraverseUnmodifiedPathBackward(
438 input_edge, lookup_fld, fld, &cache));
439 if (!input_edge->src()->IsConstant()) {
440 VLOG(2) << "Input " << i << " is not Const; is "
441 << input_edge->src()->type_string();
442 continue;
443 }
444
445 TF_ASSIGN_OR_RETURN(
446 bool is_loop_invariant,
447 IsLoopInvariant(fbody.get(), i, lookup_fld, fld, &cache));
448 if (!is_loop_invariant) {
449 VLOG(2) << "While state not loop-invariant; not propagating Const " << i;
450 continue;
451 }
452 VLOG(2) << "While state is loop-invariant; propagating Const " << i;
453
454 const_input_index_to_mutable_node[i] = input_edge->src();
455 const_input_index_to_node[i] = input_edge->src();
456 }
457 if (const_input_index_to_node.empty()) {
458 return Status::OK();
459 }
460
461 // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
462 // corresponding const node.
463 for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
464 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
465 while_node, attr_name, const_input_index_to_node, lookup_fld, fld,
466 /*passthrough_arg_to_retval=*/attr_name == "body"));
467 }
468
469 // Rewrite usages of the output edges corresponding to loop-invariant const
470 // inputs to refer instead to the Const node.
471 for (const auto& it : const_input_index_to_mutable_node) {
472 TF_RETURN_IF_ERROR(
473 ReplaceSrcOutputUsageWithNode(g, while_node, it.first, it.second));
474 }
475 return Status::OK();
476 }
477
478 } // namespace
479
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld)480 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
481 const FunctionLibraryDefinition* lookup_fld) {
482 GraphCache cache;
483 return IsLoopInvariant(loop_body, index, lookup_fld,
484 /*fallback_fld=*/nullptr, &cache);
485 }
486
487 const char kTpuReplicateAttrName[] = "_tpu_replicate";
488 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
489
ValidateConfig(const tf2xla::Config & config)490 Status ValidateConfig(const tf2xla::Config& config) {
491 std::set<string> names;
492 for (const tf2xla::Feed& feed : config.feed()) {
493 TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
494 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
495 TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
496 }
497 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
498 names.clear();
499 for (const tf2xla::Fetch& fetch : config.fetch()) {
500 TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
501 TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
502 }
503 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
504 if (config.fetch().empty()) {
505 return errors::InvalidArgument("fetches must be specified");
506 }
507 return Status::OK();
508 }
509
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)510 Status AddPlaceholdersForFeeds(
511 const tf2xla::Config& config, const OpRegistryInterface* op_registry,
512 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
513 struct PlaceholderInfo {
514 const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
515 string placeholder_name;
516 DataType data_type = DT_INVALID;
517 };
518
519 // Put each fed tensor into a map by name:port. A map is used for determinism
520 // when creating placeholders (genrules want deterministic output).
521 std::map<string, PlaceholderInfo> placeholder_info;
522 for (int i = 0; i < config.feed_size(); ++i) {
523 const tf2xla::Feed* feed = &config.feed(i);
524 const string name_port = TensorIdToString(feed->id());
525 PlaceholderInfo& info = placeholder_info[name_port];
526 info.feed = feed;
527 info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
528 "/", feed->id().node_name());
529 (*feed_remapping)[name_port] = info.placeholder_name;
530 }
531
532 // Verify node exists and determine data type.
533 std::unordered_map<string, const NodeDef*> name_to_node;
534 for (int i = 0; i < graph_def->node_size(); ++i) {
535 name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
536 }
537 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
538 PlaceholderInfo& info = it->second;
539 const tf2xla::TensorId& feed_id = info.feed->id();
540
541 // Find the existing node and determine data type.
542 auto node_it = name_to_node.find(feed_id.node_name());
543 if (node_it == name_to_node.end()) {
544 return errors::NotFound("Can't find feed node: ",
545 TensorIdToString(feed_id));
546 }
547 const NodeDef* existing = node_it->second;
548
549 if (info.feed->type() != DT_INVALID) {
550 info.data_type = info.feed->type();
551 } else {
552 // Build the node in order to infer its type.
553
554 // Must first add default attrs as well, so do this in a copied GraphDef.
555 GraphDef gd;
556 *gd.mutable_versions() = graph_def->versions();
557 *gd.add_node() = *existing;
558 MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
559 TF_RETURN_IF_ERROR(
560 AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
561
562 // Now build the node from the copied node def.
563 Graph g(op_registry);
564 g.set_versions(graph_def->versions());
565 Status status;
566 Node* feed_node = g.AddNode(gd.node(0), &status);
567 TF_RETURN_IF_ERROR(status);
568
569 if (info.feed->id().output_index() < feed_node->num_outputs()) {
570 info.data_type =
571 BaseType(feed_node->output_type(info.feed->id().output_index()));
572 } else {
573 return errors::InvalidArgument(
574 "Invalid output_index ", info.feed->id().output_index(),
575 " for feed node ", info.feed->id().node_name());
576 }
577 }
578 }
579
580 // Create placeholders. Note that we could avoid creating a placeholder for
581 // feeds which are already placeholders, but we omit that to avoid more cases
582 // in this code.
583 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
584 const PlaceholderInfo& info = it->second;
585 // TODO(shikharagarwal): Add original node information.
586 NodeDef* d = graph_def->add_node();
587 d->set_name(info.placeholder_name);
588 d->set_op("Placeholder");
589 auto& attr_map = *d->mutable_attr();
590 attr_map["dtype"].set_type(info.data_type);
591 *attr_map["shape"].mutable_shape() = info.feed->shape();
592 }
593
594 // Rewrite references to the fed tensors to refer to the placeholder.
595 for (int i = 0; i < graph_def->node_size(); ++i) {
596 NodeDef* node_def = graph_def->mutable_node(i);
597 for (int j = 0; j < node_def->input_size(); ++j) {
598 auto id = ParseTensorName(node_def->input(j));
599 auto it = placeholder_info.find(id.ToString());
600 if (it != placeholder_info.end()) {
601 node_def->set_input(j, it->second.placeholder_name);
602 }
603 }
604 }
605
606 return Status::OK();
607 }
608
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)609 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
610 GraphDef* out) {
611 *out = in;
612 out->clear_node();
613
614 // Tensors needed for feeding.
615 std::set<std::pair<string, int>> feed_tensors;
616 for (const tf2xla::Feed& feed : config.feed()) {
617 feed_tensors.insert(
618 std::make_pair(feed.id().node_name(), feed.id().output_index()));
619 }
620
621 // Maps node name to reachability.
622 std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
623 for (const NodeDef& node : in.node()) {
624 node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
625 }
626
627 // Traverse.
628 std::queue<string> name_queue;
629 for (int i = 0; i < config.fetch_size(); ++i) {
630 name_queue.push(config.fetch(i).id().node_name());
631 }
632 while (!name_queue.empty()) {
633 const string name = name_queue.front();
634 name_queue.pop();
635
636 auto find_it = node_by_name.find(name);
637 if (find_it == node_by_name.end()) {
638 return errors::InvalidArgument("While pruning graph, node ", name,
639 " needed but not found in the graph.");
640 }
641 auto& map_entry = find_it->second;
642 if (map_entry.first) {
643 continue;
644 }
645 map_entry.first = true;
646
647 // Push input nodes of the currently visited node to name_queue.
648 for (const string& in_edge : map_entry.second->input()) {
649 auto id = ParseTensorName(in_edge);
650 const string node_name = string(id.first);
651 if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
652 feed_tensors.end()) {
653 name_queue.push(node_name);
654 } else {
655 // The input tensor is from an edge that is being fed. Therefore,
656 // we skip recursing down that edge, to avoid requiring nodes that
657 // may not be needed (note that the input node may still be added
658 // to name_queue later if one of its output edges is not being fed).
659 }
660 }
661 }
662
663 // Copy over, preserving order of original and only nodes that are reachable
664 // from the fetches.
665 out->mutable_node()->Reserve(in.node_size());
666 for (const NodeDef& node : in.node()) {
667 if (node_by_name[node.name()].first) {
668 *out->add_node() = node;
669 }
670 }
671 return Status::OK();
672 }
673
TensorIdToString(const tf2xla::TensorId & id)674 string TensorIdToString(const tf2xla::TensorId& id) {
675 return absl::StrCat(id.node_name(), ":", id.output_index());
676 }
677
SetNodeShardingFromNeighbors(Node * n,bool out_edges)678 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
679 int core = -1;
680 const Node* matching_node = nullptr;
681 for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
682 if (edge->IsControlEdge()) continue;
683 const Node* possible_match = out_edges ? edge->dst() : edge->src();
684 TF_ASSIGN_OR_RETURN(
685 absl::optional<xla::OpSharding> sharding,
686 ParseShardingFromDevice(
687 *possible_match,
688 /*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
689 /*add_metadata=*/false));
690 if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
691 const int core_annotation = sharding.value().tile_assignment_devices(0);
692 if (core == -1 || core > core_annotation) {
693 core = core_annotation;
694 matching_node = possible_match;
695 }
696 }
697 }
698 if (matching_node != nullptr) {
699 n->set_assigned_device_name(matching_node->assigned_device_name());
700 n->set_requested_device(matching_node->requested_device());
701 }
702 return Status::OK();
703 }
704
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)705 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
706 KernelDef* kdef) {
707 for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
708 if (constraint.name() == name) {
709 constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
710 }
711 }
712 }
713
714 namespace {
InitialRandomSeed()715 uint32 InitialRandomSeed() {
716 // Support plumbing the TF seed through to XLA is being worked on.
717 // If a user wants deterministic behavior, their best option
718 // is to start with a known checkpoint. This also handles issues when
719 // multiple random calls can be invoked in any order by TF executor.
720 // Another option is to use stateless random ops. They have much cleaner
721 // semantics.
722 // If a user really wants to set a deterministic seed for XLA-based
723 // devices, this is the place to do it.
724 std::random_device rd;
725 // Make the starting value odd.
726 return rd() | 1;
727 }
728 } // namespace
729
GetXLARandomSeed()730 uint32 GetXLARandomSeed() {
731 // We initialize counter with an odd number and increment it by two
732 // everytime. This ensures that it will never be zero, even
733 // after an overflow. When seeded with zero, some XLA backends
734 // can return all zeros instead of random numbers.
735 static std::atomic<uint32> counter(InitialRandomSeed());
736 uint32 seed = counter.fetch_add(2);
737 std::srand(seed);
738 return std::rand() | 1;
739 }
740
741 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)742 bool HasAssociatedFunction(const NodeDef& node_def,
743 const FunctionLibraryDefinition* fld) {
744 if (fld->Contains(node_def.op())) {
745 return true;
746 }
747
748 if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
749 // Gradient op has "f" attr, which is set to the function we are getting
750 // gradient for. We need to functionalize the gradient function.
751 return true;
752 }
753
754 if (node_def.op() == "XlaHostCompute") {
755 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
756 // related to graph execution.
757 return false;
758 }
759
760 for (const auto& iter : node_def.attr()) {
761 if (iter.second.has_func()) {
762 return true;
763 }
764 }
765
766 return false;
767 }
768
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)769 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
770 const Node& node, const FunctionLibraryDefinition* fld) {
771 std::vector<AssociatedFunctionInfo> results;
772 const string& op = node.type_string();
773 if (fld->Contains(op)) {
774 // This is a function call node.
775 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
776 results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
777 } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
778 // This is a SymbolicGradient op.
779 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
780 results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
781 } else if (node.type_string() == "XlaHostCompute") {
782 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
783 // related to graph execution.
784 } else {
785 // Collect all function attrs for the node.
786 for (auto& iter : node.attrs()) {
787 if (iter.second.has_func()) {
788 VLOG(2) << "Found function attr for node " << node.name() << ": "
789 << iter.first << " = " << iter.second.func().name();
790 results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
791 iter.second.func().name(), iter.second.func().attr(), iter.first));
792 }
793 }
794 }
795 return results;
796 }
797
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)798 Status RewriteAssociatedFunction(
799 Graph* graph, Node* node, FunctionLibraryDefinition* fld,
800 const AssociatedFunctionInfo& associated_function,
801 const string& rewritten_function_name) {
802 switch (associated_function.type()) {
803 case AssociatedFunctionInfo::kFunctionCallNode: {
804 // Change this node to call the new function.
805 NodeDebugInfo debug_info(*node);
806 NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
807 &debug_info);
808 for (const auto& attr : node->attrs()) {
809 builder.Attr(attr.first, attr.second);
810 }
811 for (int i = 0; i < node->num_inputs(); i++) {
812 Node* input_node;
813 TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
814 builder.Input(input_node->name(), i, node->input_type(i));
815 }
816 builder.Device(node->assigned_device_name().empty()
817 ? node->requested_device()
818 : node->assigned_device_name());
819 NodeDef node_def;
820 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
821 Status s;
822 Node* new_node = graph->AddNode(node_def, &s);
823 TF_RETURN_IF_ERROR(s);
824 for (auto edge : node->in_edges()) {
825 graph->AddEdge(edge->src(), edge->src_output(), new_node,
826 edge->dst_input());
827 }
828 for (auto edge : node->out_edges()) {
829 graph->AddEdge(new_node, edge->src_output(), edge->dst(),
830 edge->dst_input());
831 }
832 graph->RemoveNode(node);
833 break;
834 }
835 case AssociatedFunctionInfo::kSymbolicGradient: {
836 NameAttrList func;
837 TF_RETURN_IF_ERROR(GetNodeAttr(
838 node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
839 GradientDef gradient_def;
840 gradient_def.set_function_name(func.name());
841 gradient_def.set_gradient_func(rewritten_function_name);
842 string original_grad_func = fld->FindGradient(func.name());
843 if (original_grad_func.empty()) {
844 TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
845 } else if (original_grad_func != rewritten_function_name) {
846 TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
847 }
848 break;
849 }
850 case AssociatedFunctionInfo::kFunctionAttr: {
851 // Change function attr to rewritten functions.
852 NameAttrList func;
853 TF_RETURN_IF_ERROR(
854 GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
855 node->ClearAttr(associated_function.attr_name());
856 func.set_name(rewritten_function_name);
857 node->AddAttr(associated_function.attr_name(), func);
858 break;
859 }
860 }
861
862 return Status::OK();
863 }
864
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)865 Status CachedFunctionHandles::GetOrInstantiate(
866 const string& func_name, AttrSlice attrs,
867 FunctionLibraryRuntime::Handle* handle) {
868 string canonicalized_name = Canonicalize(func_name, attrs);
869 auto iter = handles_.find(canonicalized_name);
870 if (iter != handles_.end()) {
871 *handle = iter->second;
872 return Status::OK();
873 }
874
875 TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
876 handles_[canonicalized_name] = *handle;
877 return Status::OK();
878 }
879
ReleaseAllHandles()880 Status CachedFunctionHandles::ReleaseAllHandles() {
881 Status result;
882 for (const auto& iter : handles_) {
883 result.Update(flr_->ReleaseHandle(iter.second));
884 }
885 handles_.clear();
886 return result;
887 }
888
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)889 StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
890 // Create the replacement node.
891 Status s;
892 Node* new_node = g->AddNode(node_def, &s);
893 if (!s.ok()) {
894 return s;
895 }
896
897 // Record original node's output edges and remove them first. This is to avoid
898 // multiple producers for dst nodes' input.
899 std::vector<OutEdgeInfo> out_edge_info;
900 std::vector<const Edge*> out_edges;
901 for (const Edge* edge : n->out_edges()) {
902 out_edges.push_back(edge);
903 out_edge_info.push_back(
904 {edge->dst(), edge->src_output(), edge->dst_input()});
905 }
906 for (const Edge* edge : out_edges) {
907 g->RemoveEdge(edge);
908 }
909
910 // Add original node's input and output edges to the replacement node.
911 for (const Edge* in_edge : n->in_edges()) {
912 g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
913 in_edge->dst_input());
914 }
915 for (const OutEdgeInfo& out_edge : out_edge_info) {
916 g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
917 }
918
919 // Remove the original node.
920 g->RemoveNode(n);
921
922 return new_node;
923 }
924
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)925 StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
926 DataType dtype, const Node* input,
927 absl::optional<string> requested_device) {
928 // Create identity node.
929 NodeDef ndef;
930 ndef.set_name(node_name);
931 ndef.set_op("Identity");
932 if (input) {
933 ndef.add_input(input->name());
934 }
935 if (requested_device) {
936 ndef.set_device(*requested_device);
937 }
938 AddNodeAttr("T", dtype, &ndef);
939 Status s;
940 Node* id_node = graph->AddNode(ndef, &s);
941 TF_RETURN_IF_ERROR(s);
942 return id_node;
943 }
944
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)945 Status PropagateConstIntoFunctionalNodes(
946 Graph* g, const FunctionLibraryDefinition* lookup_fld,
947 FunctionLibraryDefinition* fld) {
948 absl::flat_hash_set<int> done_node_ids;
949
950 // Because we may propagate Const around a while node as well as into it,
951 // we restart the op_nodes() iterator after each pass and keep track of which
952 // nodes we've already dealt with.
953 bool should_continue = true;
954 while (should_continue) {
955 should_continue = false;
956 for (Node* n : g->op_nodes()) {
957 if (!done_node_ids.contains(n->id())) {
958 if (n->IsIfNode()) {
959 VLOG(1) << "PropagateConstIntoIfNode: " << n->name();
960 TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
961 done_node_ids.emplace(n->id());
962 VLOG(1) << "Done PropagateConstIntoIfNode: " << n->name();
963 } else if (n->IsWhileNode()) {
964 VLOG(1) << "PropagateConstIntoWhileNode: " << n->name();
965 TF_RETURN_IF_ERROR(
966 PropagateConstIntoAndAroundWhileNode(g, n, lookup_fld, fld));
967 done_node_ids.emplace(n->id());
968 should_continue = true;
969 VLOG(1) << "Done PropagateConstIntoWhileNode: " << n->name();
970 break;
971 }
972 }
973 }
974 }
975 return Status::OK();
976 }
977
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)978 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
979 FunctionLibraryDefinition* fld) {
980 GraphDef graph_def;
981 g.ToGraphDef(&graph_def);
982 FunctionLibraryDefinition reachable_functions =
983 fld->ReachableDefinitions(graph_def);
984 for (const string& func_name : fld->ListFunctionNames()) {
985 if (!reachable_functions.Find(func_name)) {
986 TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
987 }
988 }
989 return Status::OK();
990 }
991
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)992 Status RewriteTensorListWithConstElement(Graph* g,
993 FunctionLibraryDefinition* fld) {
994 for (Node* n : g->nodes()) {
995 if (n->type_string() != "EmptyTensorList") {
996 continue;
997 }
998
999 // Find the forward While op.
1000 std::vector<const Edge*> fwd_while_edges;
1001 for (const Edge* e : n->out_edges()) {
1002 if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
1003 fwd_while_edges.push_back(e);
1004 }
1005 }
1006 if (fwd_while_edges.size() != 1) {
1007 // No forward While op found, or multiple forward While ops.
1008 continue;
1009 }
1010
1011 // Find the backward While op.
1012 Node* fwd_while = fwd_while_edges[0]->dst();
1013 int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
1014 std::vector<const Edge*> bwd_while_edges;
1015 for (const Edge* e : fwd_while->out_edges()) {
1016 if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
1017 bwd_while_edges.push_back(e);
1018 }
1019 }
1020 if (bwd_while_edges.size() != 1) {
1021 // No backward While op found, or multiple backward While ops.
1022 continue;
1023 }
1024
1025 Node* bwd_while = bwd_while_edges[0]->dst();
1026 int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
1027
1028 // Look into forward While body function and check if TensorListPushBack op
1029 // has a Const input.
1030 NameAttrList fwd_body_attr;
1031 TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
1032 const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
1033 if (!fwd_body) {
1034 return errors::InvalidArgument("Cannot find function ",
1035 fwd_body_attr.name(), " for While node ",
1036 fwd_while->DebugString());
1037 }
1038 std::unique_ptr<FunctionBody> fwd_fbody;
1039 TF_CHECK_OK(FunctionDefToBodyHelper(
1040 *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
1041
1042 // Find the TensorListPushBack node; it's one of fwd_arg's successors.
1043 Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
1044 std::vector<Node*> tl_push_nodes;
1045 for (const Edge* out_edge : fwd_arg->out_edges()) {
1046 if (out_edge->dst()->type_string() == "TensorListPushBack") {
1047 tl_push_nodes.push_back(out_edge->dst());
1048 }
1049 }
1050 if (tl_push_nodes.size() != 1) {
1051 // No TensorListPushBack found, or multiple TensorListPushBack.
1052 continue;
1053 }
1054
1055 // Get input for the TensorListPushBack node.
1056 Node* input_node;
1057 TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
1058 if (input_node->type_string() != "Const") {
1059 // Input for the TensorList is not Const node.
1060 continue;
1061 }
1062
1063 NodeDef const_input_nodedef = input_node->def();
1064
1065 // Rewrite backward While body function, replace usages of
1066 // TensorListPopBack with a Const node.
1067 NameAttrList bwd_body_attr;
1068 TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
1069 const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
1070 if (!bwd_body) {
1071 return errors::InvalidArgument("Cannot find function ",
1072 bwd_body_attr.name(), " for While node ",
1073 bwd_while->DebugString());
1074 }
1075 std::unique_ptr<FunctionBody> bwd_fbody;
1076 TF_CHECK_OK(FunctionDefToBodyHelper(
1077 *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
1078
1079 // Find the TensorListPopBack node; it's one of bwd_arg's successors.
1080 Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
1081 std::vector<Node*> tl_pop_nodes;
1082 for (const Edge* out_edge : bwd_arg->out_edges()) {
1083 if (out_edge->dst()->type_string() == "TensorListPopBack") {
1084 tl_pop_nodes.push_back(out_edge->dst());
1085 }
1086 }
1087 if (tl_pop_nodes.size() != 1) {
1088 // No TensorListPopBack found, or multiple TensorListPopBack.
1089 continue;
1090 }
1091
1092 // Replace TensorListPopBack usages with Const node.
1093 std::vector<const Edge*> edges_to_replace;
1094 for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
1095 if (e->src_output() == 1) {
1096 edges_to_replace.push_back(e);
1097 }
1098 }
1099 if (edges_to_replace.empty()) {
1100 continue;
1101 }
1102 Status s;
1103 const_input_nodedef.set_name(
1104 bwd_fbody->graph->NewName(const_input_nodedef.name()));
1105 Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
1106 TF_RETURN_IF_ERROR(s);
1107 for (const Edge* e : edges_to_replace) {
1108 Node* dst = e->dst();
1109 int dst_input = e->dst_input();
1110 bwd_fbody->graph->RemoveEdge(e);
1111 bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
1112 }
1113
1114 // Add rewritten backward While body function.
1115 FunctionDef new_fdef;
1116 string new_name = fld->UniqueFunctionName(
1117 absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
1118 TF_RETURN_IF_ERROR(
1119 GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
1120 TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
1121
1122 // Change backward While op to use the new body function.
1123 bwd_body_attr.set_name(new_name);
1124 bwd_while->ClearAttr("body");
1125 bwd_while->AddAttr("body", bwd_body_attr);
1126 }
1127 return Status::OK();
1128 }
1129
1130 } // namespace tensorflow
1131