• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17 
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/function_body.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/graph_def_util.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op_def_builder.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/platform/errors.h"
45 
46 namespace tensorflow {
47 
48 namespace {
49 
ValidateTensorId(const tf2xla::TensorId & id)50 Status ValidateTensorId(const tf2xla::TensorId& id) {
51   if (id.node_name().empty()) {
52     return errors::InvalidArgument("TensorId node_name must be non-empty");
53   }
54   if (id.output_index() < 0) {
55     return errors::InvalidArgument("TensorId output_index must be positive");
56   }
57   return Status::OK();
58 }
59 
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)60 Status CheckNameDuplicates(const string& kind, const string& name,
61                            std::set<string>* names) {
62   if (!name.empty()) {
63     if (!names->insert(name).second) {
64       return errors::InvalidArgument("duplicate ", kind, " name: ", name);
65     }
66   }
67   return Status::OK();
68 }
69 
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)70 Status CheckFeedFetchNameConflicts(const string& kind,
71                                    const std::set<string>& names) {
72   // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
73   // since that will cause a collision in codegen symbols.
74   for (const string& name : names) {
75     const string name_data(name + "_data");
76     if (names.find(name_data) != names.end()) {
77       return errors::InvalidArgument("conflicting ", kind, " name: ", name,
78                                      " and ", name_data);
79     }
80   }
81   return Status::OK();
82 }
83 
84 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
85 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)86 Status CopyAssociatedFunctions(Graph* g,
87                                const FunctionLibraryDefinition* lookup_fld,
88                                FunctionLibraryDefinition* fld) {
89   for (Node* n : g->op_nodes()) {
90     for (const auto& associated_function :
91          GetAssociatedFunctions(*n, lookup_fld)) {
92       switch (associated_function.type()) {
93         case AssociatedFunctionInfo::kFunctionCallNode: {
94           const FunctionDef* fdef =
95               lookup_fld->Find(associated_function.func_name());
96           if (!fdef) {
97             return errors::Internal(
98                 "Cannot find function ", associated_function.func_name(),
99                 " for function call node ", n->DebugString());
100           }
101           TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
102           break;
103         }
104         case AssociatedFunctionInfo::kSymbolicGradient:
105         case AssociatedFunctionInfo::kFunctionAttr:
106           break;
107       }
108     }
109   }
110   return Status::OK();
111 }
112 
113 // Replaces the single edge feeding into {dst,dst_input} with a new
114 // src/src_output specified by {with,with_output}.
ReplaceEdge(Graph * g,Node * dst,int dst_input,Node * with,int with_output)115 StatusOr<Node*> ReplaceEdge(Graph* g, Node* dst, int dst_input, Node* with,
116                             int with_output) {
117   NodeDef replace_def = dst->def();
118   *replace_def.mutable_input(dst_input) = with->name();
119   TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, replace_def));
120   const Edge* usage_edge;
121   TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &usage_edge));
122   g->RemoveEdge(usage_edge);
123   g->AddEdge(with, with_output, replace_node, dst_input);
124   return replace_node;
125 }
126 
127 // Replaces usages of the given `src_output` index of the given `src` node with
128 // the given `replacement` node (assumes the :0 output of `replacement`).
ReplaceSrcOutputUsageWithNode(Graph * g,Node * src,int src_output,Node * replacement)129 Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output,
130                                      Node* replacement) {
131   VLOG(1) << "Replace usages of output " << src_output << " of node "
132           << (VLOG_IS_ON(3) ? src->DebugString() : src->name()) << " with "
133           << (VLOG_IS_ON(3) ? replacement->DebugString() : replacement->name());
134   // Collect all usages of the specified src output (src->out_edges() iterator
135   // will not be stable under modifications).
136   struct OutEdgeInfo {
137     int dst_node_id, dst_input;
138   };
139   std::vector<OutEdgeInfo> usages;
140   for (const Edge* e : src->out_edges()) {
141     if (e->IsControlEdge() || e->src_output() != src_output) {
142       continue;
143     }
144     usages.push_back({e->dst()->id(), e->dst_input()});
145   }
146 
147   // Now, replace each usage.
148   for (int i = 0, end = usages.size(); i < end; i++) {
149     // Make a copy of `usage_node`, and change its input to const node.
150     Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
151     VLOG(2) << "  Replace usage by " << usage_node->DebugString();
152     // Note: Replacement output index is presumed to be 0.
153     TF_ASSIGN_OR_RETURN(
154         Node * replace_node,
155         ReplaceEdge(g, usage_node, usages[i].dst_input, replacement, 0));
156     // Later entries in `usages` might have `usage_node` as dst node, but
157     // `usage_node` is removed. Replace such entries with `replace_node`.
158     for (int j = i + 1, end = usages.size(); j < end; j++) {
159       if (usages[j].dst_node_id == usages[i].dst_node_id) {
160         usages[j].dst_node_id = replace_node->id();
161       }
162     }
163   }
164   return Status::OK();
165 }
166 
167 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
168 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)169 Status ReplaceArgUsageWithConstNode(
170     Graph* g,
171     const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
172   // Collect all _Arg nodes.
173   absl::flat_hash_map<int, Node*> arg_nodes;
174   for (Node* n : g->op_nodes()) {
175     if (n->IsArg()) {
176       int index;
177       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
178       arg_nodes[index] = n;
179     }
180   }
181 
182   for (const auto& iter : const_input_index_to_node) {
183     int arg_index = iter.first;
184     VLOG(2) << "Replace usages of _Arg " << arg_index;
185     NodeDef const_def = iter.second->def();
186     const_def.set_name(g->NewName(const_def.name()));
187     Status s;
188     Node* const_node = g->AddNode(const_def, &s);
189     TF_RETURN_IF_ERROR(s);
190     Node* arg_node = arg_nodes[arg_index];
191     TF_RETURN_IF_ERROR(
192         ReplaceSrcOutputUsageWithNode(g, arg_node, 0, const_node));
193   }
194   return Status::OK();
195 }
196 
197 // Replaces the single input to _Retval nodes with an index in the keys of
198 // const_input_index_to_node with the single output of the corresponding _Arg
199 // node.
ReplaceRetvalInputWithArg(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)200 Status ReplaceRetvalInputWithArg(
201     Graph* g,
202     const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
203   absl::flat_hash_map<int, Node*> arg_nodes;
204   absl::flat_hash_map<int, Node*> ret_nodes;
205   for (Node* n : g->op_nodes()) {
206     if (n->IsRetval() || n->IsArg()) {
207       int index;
208       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
209       if (n->IsRetval()) {
210         ret_nodes[index] = n;
211       } else {
212         arg_nodes[index] = n;
213       }
214     }
215   }
216 
217   for (const auto& iter : const_input_index_to_node) {
218     int arg_index = iter.first;
219     VLOG(2) << "Bind _Retval " << arg_index << " to _Arg " << arg_index;
220     TF_RETURN_IF_ERROR(
221         ReplaceEdge(g, ret_nodes[arg_index], 0, arg_nodes[arg_index], 0)
222             .status());
223   }
224   return Status::OK();
225 }
226 
227 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
228 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
229 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld,bool passthrough_arg_to_retval=false)230 Status PropagateConstIntoFuncAttr(
231     Node* n, const string& attr_name,
232     const absl::flat_hash_map<int, const Node*>& const_input_index_to_node,
233     const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld,
234     bool passthrough_arg_to_retval = false) {
235   VLOG(1) << "Propagate const into " << attr_name << " of node " << n->name();
236   // Instantiate the function.
237   NameAttrList func_attr;
238   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
239   const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
240   if (!fdef) {
241     return errors::Internal("Cannot find function ", func_attr.name(),
242                             " for node ", n->name());
243   }
244   std::unique_ptr<FunctionBody> fbody;
245   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
246       *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
247 
248   // Rewrite _Arg usages with Const node.
249   Graph* func_graph = fbody->graph;
250   TF_RETURN_IF_ERROR(
251       ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
252   if (passthrough_arg_to_retval) {
253     TF_RETURN_IF_ERROR(
254         ReplaceRetvalInputWithArg(func_graph, const_input_index_to_node));
255   }
256 
257   // Save rewritten function.
258   FunctionDef replace_fdef;
259   string new_func_name =
260       fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
261   TF_RETURN_IF_ERROR(
262       GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
263   TF_RETURN_IF_ERROR(fld->AddFunctionDef(
264       replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
265 
266   VLOG(1) << "replace func " << func_attr.name() << " with " << new_func_name;
267   // Change the node to use rewritten function.
268   func_attr.set_name(new_func_name);
269   n->ClearAttr(attr_name);
270   n->AddAttr(attr_name, func_attr);
271 
272   TF_RETURN_IF_ERROR(fld->AddFunctionDef(
273       replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
274 
275   // Copy associated functions.
276   TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
277 
278   return Status::OK();
279 }
280 
281 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
282 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)283 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
284                                 const FunctionLibraryDefinition* lookup_fld,
285                                 FunctionLibraryDefinition* fld) {
286   // Notice that first input for If node is predicate; other inputs are function
287   // inputs.
288   absl::flat_hash_map<int, const Node*> const_input_index_to_node;
289   for (int i = 1; i < if_node->num_inputs(); i++) {
290     const Node* input_node;
291     TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
292     if (input_node->type_string() == "Const") {
293       const_input_index_to_node[i - 1] = input_node;
294     }
295   }
296   if (const_input_index_to_node.empty()) {
297     return Status::OK();
298   }
299 
300   // Rewrite "then_branch" and "else_branch" function, replace usage of those
301   // _Arg nodes with corresponding const node.
302   for (const auto& attr_name :
303        std::vector<string>{"then_branch", "else_branch"}) {
304     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
305         if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
306   }
307 
308   return Status::OK();
309 }
310 
311 using GraphCache = absl::flat_hash_map<string, std::unique_ptr<FunctionBody>>;
312 
FindOrInsert(GraphCache * cache,const NameAttrList & body_attr,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld)313 StatusOr<FunctionBody*> FindOrInsert(
314     GraphCache* cache, const NameAttrList& body_attr,
315     const FunctionLibraryDefinition* lookup_fld,
316     const FunctionLibraryDefinition* fallback_fld) {
317   const string name = body_attr.name();
318   std::unique_ptr<FunctionBody>& value = (*cache)[name];
319   if (!value) {
320     const FunctionDef* body_func = lookup_fld->Find(name);
321     if (!body_func && fallback_fld != nullptr) {
322       body_func = fallback_fld->Find(name);
323     }
324     if (!body_func) {
325       return errors::Internal("Traverse: Cannot find body function ", name);
326     }
327     std::unique_ptr<FunctionBody> fbody;
328     Status s = FunctionDefToBodyHelper(*body_func, AttrSlice(&body_attr.attr()),
329                                        lookup_fld, &fbody);
330     if (!s.ok() && fallback_fld != nullptr) {
331       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
332           *body_func, AttrSlice(&body_attr.attr()), fallback_fld, &fbody));
333     }
334     value = std::move(fbody);
335   }
336   return value.get();
337 }
338 // Determines whether a loop body is invariant for the given argument index.
339 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
340                                const FunctionLibraryDefinition* lookup_fld,
341                                const FunctionLibraryDefinition* fallback_fld,
342                                GraphCache* cache);
343 
344 // Traces backward through non-modifying ops such as Identity and loop-invariant
345 // While, to find a preceding source edge.
TraverseUnmodifiedPathBackward(const Edge * src,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)346 StatusOr<const Edge*> TraverseUnmodifiedPathBackward(
347     const Edge* src, const FunctionLibraryDefinition* lookup_fld,
348     const FunctionLibraryDefinition* fallback_fld, GraphCache* cache) {
349   const Edge* e = src;
350   VLOG(2) << "Traverse: Begin at " << e->DebugString();
351   // TODO(b/184727356): Also traverse If/Case nodes.
352   // Begin walking back from the output node.
353   while (IsConstTraversableOpType(e->src())) {
354     VLOG(3) << e->DebugString();
355 
356     if (e->src()->IsWhileNode()) {
357       NameAttrList body_attr;
358       TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "body", &body_attr));
359       TF_ASSIGN_OR_RETURN(
360           FunctionBody * fbody,
361           FindOrInsert(cache, body_attr, lookup_fld, fallback_fld));
362       TF_ASSIGN_OR_RETURN(bool is_loop_invariant,
363                           IsLoopInvariant(fbody, e->src_output(), lookup_fld,
364                                           fallback_fld, cache));
365       if (!is_loop_invariant) {
366         VLOG(2) << "Non-loop-invariant: index " << e->src_output() << " of "
367                 << body_attr.name();
368         break;
369       }
370     }  // if While|StatelessWhile
371     // Proceed backward to the src's input corresponding with the output index.
372     TF_RETURN_IF_ERROR(e->src()->input_edge(e->src_output(), &e));
373   }
374   VLOG(2) << "Traverse: Finish at " << e->DebugString();
375 
376   return e;
377 }
378 
379 // Determines whether a loop body is invariant for the given argument index.
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)380 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
381                                const FunctionLibraryDefinition* lookup_fld,
382                                const FunctionLibraryDefinition* fallback_fld,
383                                GraphCache* cache) {
384   const Edge* e;
385   TF_RETURN_IF_ERROR(loop_body->ret_nodes[index]->input_edge(0, &e));
386   TF_ASSIGN_OR_RETURN(
387       const Edge* reachable,
388       TraverseUnmodifiedPathBackward(e, lookup_fld, fallback_fld, cache));
389   if (reachable->src()->id() == loop_body->arg_nodes[index]->id()) {
390     VLOG(2) << "Index " << index << " is loop invariant.";
391     return true;
392   }
393   VLOG(2) << "Index " << index << " not loop invariant: "
394           << "walk backward from " << e->src()->DebugString() << " to "
395           << reachable->src()->DebugString() << " did not reach "
396           << loop_body->arg_nodes[index]->DebugString();
397   return false;
398 }
399 
400 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
401 // cond/body function to replace _Arg nodes with those Const inputs. Then,
402 // propagate these Const to consumers of the relevant outputs of the while loop.
PropagateConstIntoAndAroundWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)403 Status PropagateConstIntoAndAroundWhileNode(
404     Graph* g, Node* while_node, const FunctionLibraryDefinition* lookup_fld,
405     FunctionLibraryDefinition* fld) {
406   VLOG(1) << "Propagate const into " << while_node->name();
407 
408   // For "While" node, we should only replace _Arg nodes which are loop
409   // invariants. For such _Arg nodes, the return value's input will come
410   // directly from the corresponding arg.
411   absl::flat_hash_map<int, const Node*> const_input_index_to_node;
412   absl::flat_hash_map<int, Node*> const_input_index_to_mutable_node;
413   NameAttrList body_attr;
414   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
415   const string fn_name = body_attr.name();
416   const FunctionDef* body_func = lookup_fld->Find(fn_name);
417   if (!body_func) {
418     return errors::Internal("Propagate: Cannot find body function ", fn_name,
419                             " for While node ", while_node->name());
420   }
421   std::unique_ptr<FunctionBody> fbody;
422   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
423       *body_func, AttrSlice(&body_attr.attr()), lookup_fld, &fbody));
424   GraphCache cache;
425   for (int i = 0; i < while_node->num_inputs(); i++) {
426     // Check if i-th retval's input comes from i-th arg directly.
427     // For resource variable input of While nodes, TF2XLA convention is to place
428     // them at the end of all inputs (after all data inputs), and *not* return
429     // them. So number of While node inputs might be larger than number of its
430     // outputs.
431     if (i >= body_func->signature().output_arg_size()) {
432       break;
433     }
434 
435     const Edge* input_edge;
436     TF_RETURN_IF_ERROR(while_node->input_edge(i, &input_edge));
437     TF_ASSIGN_OR_RETURN(input_edge, TraverseUnmodifiedPathBackward(
438                                         input_edge, lookup_fld, fld, &cache));
439     if (!input_edge->src()->IsConstant()) {
440       VLOG(2) << "Input " << i << " is not Const; is "
441               << input_edge->src()->type_string();
442       continue;
443     }
444 
445     TF_ASSIGN_OR_RETURN(
446         bool is_loop_invariant,
447         IsLoopInvariant(fbody.get(), i, lookup_fld, fld, &cache));
448     if (!is_loop_invariant) {
449       VLOG(2) << "While state not loop-invariant; not propagating Const " << i;
450       continue;
451     }
452     VLOG(2) << "While state is loop-invariant; propagating Const " << i;
453 
454     const_input_index_to_mutable_node[i] = input_edge->src();
455     const_input_index_to_node[i] = input_edge->src();
456   }
457   if (const_input_index_to_node.empty()) {
458     return Status::OK();
459   }
460 
461   // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
462   // corresponding const node.
463   for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
464     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
465         while_node, attr_name, const_input_index_to_node, lookup_fld, fld,
466         /*passthrough_arg_to_retval=*/attr_name == "body"));
467   }
468 
469   // Rewrite usages of the output edges corresponding to loop-invariant const
470   // inputs to refer instead to the Const node.
471   for (const auto& it : const_input_index_to_mutable_node) {
472     TF_RETURN_IF_ERROR(
473         ReplaceSrcOutputUsageWithNode(g, while_node, it.first, it.second));
474   }
475   return Status::OK();
476 }
477 
478 }  // namespace
479 
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld)480 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
481                                const FunctionLibraryDefinition* lookup_fld) {
482   GraphCache cache;
483   return IsLoopInvariant(loop_body, index, lookup_fld,
484                          /*fallback_fld=*/nullptr, &cache);
485 }
486 
487 const char kTpuReplicateAttrName[] = "_tpu_replicate";
488 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
489 
ValidateConfig(const tf2xla::Config & config)490 Status ValidateConfig(const tf2xla::Config& config) {
491   std::set<string> names;
492   for (const tf2xla::Feed& feed : config.feed()) {
493     TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
494     TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
495     TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
496   }
497   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
498   names.clear();
499   for (const tf2xla::Fetch& fetch : config.fetch()) {
500     TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
501     TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
502   }
503   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
504   if (config.fetch().empty()) {
505     return errors::InvalidArgument("fetches must be specified");
506   }
507   return Status::OK();
508 }
509 
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)510 Status AddPlaceholdersForFeeds(
511     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
512     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
513   struct PlaceholderInfo {
514     const tf2xla::Feed* feed = nullptr;  // point to Feed in <config>.
515     string placeholder_name;
516     DataType data_type = DT_INVALID;
517   };
518 
519   // Put each fed tensor into a map by name:port. A map is used for determinism
520   // when creating placeholders (genrules want deterministic output).
521   std::map<string, PlaceholderInfo> placeholder_info;
522   for (int i = 0; i < config.feed_size(); ++i) {
523     const tf2xla::Feed* feed = &config.feed(i);
524     const string name_port = TensorIdToString(feed->id());
525     PlaceholderInfo& info = placeholder_info[name_port];
526     info.feed = feed;
527     info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
528                                          "/", feed->id().node_name());
529     (*feed_remapping)[name_port] = info.placeholder_name;
530   }
531 
532   // Verify node exists and determine data type.
533   std::unordered_map<string, const NodeDef*> name_to_node;
534   for (int i = 0; i < graph_def->node_size(); ++i) {
535     name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
536   }
537   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
538     PlaceholderInfo& info = it->second;
539     const tf2xla::TensorId& feed_id = info.feed->id();
540 
541     // Find the existing node and determine data type.
542     auto node_it = name_to_node.find(feed_id.node_name());
543     if (node_it == name_to_node.end()) {
544       return errors::NotFound("Can't find feed node: ",
545                               TensorIdToString(feed_id));
546     }
547     const NodeDef* existing = node_it->second;
548 
549     if (info.feed->type() != DT_INVALID) {
550       info.data_type = info.feed->type();
551     } else {
552       // Build the node in order to infer its type.
553 
554       // Must first add default attrs as well, so do this in a copied GraphDef.
555       GraphDef gd;
556       *gd.mutable_versions() = graph_def->versions();
557       *gd.add_node() = *existing;
558       MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
559       TF_RETURN_IF_ERROR(
560           AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
561 
562       // Now build the node from the copied node def.
563       Graph g(op_registry);
564       g.set_versions(graph_def->versions());
565       Status status;
566       Node* feed_node = g.AddNode(gd.node(0), &status);
567       TF_RETURN_IF_ERROR(status);
568 
569       if (info.feed->id().output_index() < feed_node->num_outputs()) {
570         info.data_type =
571             BaseType(feed_node->output_type(info.feed->id().output_index()));
572       } else {
573         return errors::InvalidArgument(
574             "Invalid output_index ", info.feed->id().output_index(),
575             " for feed node ", info.feed->id().node_name());
576       }
577     }
578   }
579 
580   // Create placeholders. Note that we could avoid creating a placeholder for
581   // feeds which are already placeholders, but we omit that to avoid more cases
582   // in this code.
583   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
584     const PlaceholderInfo& info = it->second;
585     // TODO(shikharagarwal): Add original node information.
586     NodeDef* d = graph_def->add_node();
587     d->set_name(info.placeholder_name);
588     d->set_op("Placeholder");
589     auto& attr_map = *d->mutable_attr();
590     attr_map["dtype"].set_type(info.data_type);
591     *attr_map["shape"].mutable_shape() = info.feed->shape();
592   }
593 
594   // Rewrite references to the fed tensors to refer to the placeholder.
595   for (int i = 0; i < graph_def->node_size(); ++i) {
596     NodeDef* node_def = graph_def->mutable_node(i);
597     for (int j = 0; j < node_def->input_size(); ++j) {
598       auto id = ParseTensorName(node_def->input(j));
599       auto it = placeholder_info.find(id.ToString());
600       if (it != placeholder_info.end()) {
601         node_def->set_input(j, it->second.placeholder_name);
602       }
603     }
604   }
605 
606   return Status::OK();
607 }
608 
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)609 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
610                          GraphDef* out) {
611   *out = in;
612   out->clear_node();
613 
614   // Tensors needed for feeding.
615   std::set<std::pair<string, int>> feed_tensors;
616   for (const tf2xla::Feed& feed : config.feed()) {
617     feed_tensors.insert(
618         std::make_pair(feed.id().node_name(), feed.id().output_index()));
619   }
620 
621   // Maps node name to reachability.
622   std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
623   for (const NodeDef& node : in.node()) {
624     node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
625   }
626 
627   // Traverse.
628   std::queue<string> name_queue;
629   for (int i = 0; i < config.fetch_size(); ++i) {
630     name_queue.push(config.fetch(i).id().node_name());
631   }
632   while (!name_queue.empty()) {
633     const string name = name_queue.front();
634     name_queue.pop();
635 
636     auto find_it = node_by_name.find(name);
637     if (find_it == node_by_name.end()) {
638       return errors::InvalidArgument("While pruning graph, node ", name,
639                                      " needed but not found in the graph.");
640     }
641     auto& map_entry = find_it->second;
642     if (map_entry.first) {
643       continue;
644     }
645     map_entry.first = true;
646 
647     // Push input nodes of the currently visited node to name_queue.
648     for (const string& in_edge : map_entry.second->input()) {
649       auto id = ParseTensorName(in_edge);
650       const string node_name = string(id.first);
651       if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
652           feed_tensors.end()) {
653         name_queue.push(node_name);
654       } else {
655         // The input tensor is from an edge that is being fed. Therefore,
656         // we skip recursing down that edge, to avoid requiring nodes that
657         // may not be needed (note that the input node may still be added
658         // to name_queue later if one of its output edges is not being fed).
659       }
660     }
661   }
662 
663   // Copy over, preserving order of original and only nodes that are reachable
664   // from the fetches.
665   out->mutable_node()->Reserve(in.node_size());
666   for (const NodeDef& node : in.node()) {
667     if (node_by_name[node.name()].first) {
668       *out->add_node() = node;
669     }
670   }
671   return Status::OK();
672 }
673 
TensorIdToString(const tf2xla::TensorId & id)674 string TensorIdToString(const tf2xla::TensorId& id) {
675   return absl::StrCat(id.node_name(), ":", id.output_index());
676 }
677 
SetNodeShardingFromNeighbors(Node * n,bool out_edges)678 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
679   int core = -1;
680   const Node* matching_node = nullptr;
681   for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
682     if (edge->IsControlEdge()) continue;
683     const Node* possible_match = out_edges ? edge->dst() : edge->src();
684     TF_ASSIGN_OR_RETURN(
685         absl::optional<xla::OpSharding> sharding,
686         ParseShardingFromDevice(
687             *possible_match,
688             /*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
689             /*add_metadata=*/false));
690     if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
691       const int core_annotation = sharding.value().tile_assignment_devices(0);
692       if (core == -1 || core > core_annotation) {
693         core = core_annotation;
694         matching_node = possible_match;
695       }
696     }
697   }
698   if (matching_node != nullptr) {
699     n->set_assigned_device_name(matching_node->assigned_device_name());
700     n->set_requested_device(matching_node->requested_device());
701   }
702   return Status::OK();
703 }
704 
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)705 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
706                                    KernelDef* kdef) {
707   for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
708     if (constraint.name() == name) {
709       constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
710     }
711   }
712 }
713 
714 namespace {
InitialRandomSeed()715 uint32 InitialRandomSeed() {
716   // Support plumbing the TF seed through to XLA is being worked on.
717   // If a user wants deterministic behavior, their best option
718   // is to start with a known checkpoint. This also handles issues when
719   // multiple random calls can be invoked in any order by TF executor.
720   // Another option is to use stateless random ops. They have much cleaner
721   // semantics.
722   // If a user really wants to set a deterministic seed for XLA-based
723   // devices, this is the place to do it.
724   std::random_device rd;
725   // Make the starting value odd.
726   return rd() | 1;
727 }
728 }  // namespace
729 
GetXLARandomSeed()730 uint32 GetXLARandomSeed() {
731   // We initialize counter with an odd number and increment it by two
732   // everytime. This ensures that it will never be zero, even
733   // after an overflow. When seeded with zero, some XLA backends
734   // can return all zeros instead of random numbers.
735   static std::atomic<uint32> counter(InitialRandomSeed());
736   uint32 seed = counter.fetch_add(2);
737   std::srand(seed);
738   return std::rand() | 1;
739 }
740 
741 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)742 bool HasAssociatedFunction(const NodeDef& node_def,
743                            const FunctionLibraryDefinition* fld) {
744   if (fld->Contains(node_def.op())) {
745     return true;
746   }
747 
748   if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
749     // Gradient op has "f" attr, which is set to the function we are getting
750     // gradient for. We need to functionalize the gradient function.
751     return true;
752   }
753 
754   if (node_def.op() == "XlaHostCompute") {
755     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
756     // related to graph execution.
757     return false;
758   }
759 
760   for (const auto& iter : node_def.attr()) {
761     if (iter.second.has_func()) {
762       return true;
763     }
764   }
765 
766   return false;
767 }
768 
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)769 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
770     const Node& node, const FunctionLibraryDefinition* fld) {
771   std::vector<AssociatedFunctionInfo> results;
772   const string& op = node.type_string();
773   if (fld->Contains(op)) {
774     // This is a function call node.
775     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
776     results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
777   } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
778     // This is a SymbolicGradient op.
779     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
780     results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
781   } else if (node.type_string() == "XlaHostCompute") {
782     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
783     // related to graph execution.
784   } else {
785     // Collect all function attrs for the node.
786     for (auto& iter : node.attrs()) {
787       if (iter.second.has_func()) {
788         VLOG(2) << "Found function attr for node " << node.name() << ": "
789                 << iter.first << " = " << iter.second.func().name();
790         results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
791             iter.second.func().name(), iter.second.func().attr(), iter.first));
792       }
793     }
794   }
795   return results;
796 }
797 
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)798 Status RewriteAssociatedFunction(
799     Graph* graph, Node* node, FunctionLibraryDefinition* fld,
800     const AssociatedFunctionInfo& associated_function,
801     const string& rewritten_function_name) {
802   switch (associated_function.type()) {
803     case AssociatedFunctionInfo::kFunctionCallNode: {
804       // Change this node to call the new function.
805       NodeDebugInfo debug_info(*node);
806       NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
807                              &debug_info);
808       for (const auto& attr : node->attrs()) {
809         builder.Attr(attr.first, attr.second);
810       }
811       for (int i = 0; i < node->num_inputs(); i++) {
812         Node* input_node;
813         TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
814         builder.Input(input_node->name(), i, node->input_type(i));
815       }
816       builder.Device(node->assigned_device_name().empty()
817                          ? node->requested_device()
818                          : node->assigned_device_name());
819       NodeDef node_def;
820       TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
821       Status s;
822       Node* new_node = graph->AddNode(node_def, &s);
823       TF_RETURN_IF_ERROR(s);
824       for (auto edge : node->in_edges()) {
825         graph->AddEdge(edge->src(), edge->src_output(), new_node,
826                        edge->dst_input());
827       }
828       for (auto edge : node->out_edges()) {
829         graph->AddEdge(new_node, edge->src_output(), edge->dst(),
830                        edge->dst_input());
831       }
832       graph->RemoveNode(node);
833       break;
834     }
835     case AssociatedFunctionInfo::kSymbolicGradient: {
836       NameAttrList func;
837       TF_RETURN_IF_ERROR(GetNodeAttr(
838           node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
839       GradientDef gradient_def;
840       gradient_def.set_function_name(func.name());
841       gradient_def.set_gradient_func(rewritten_function_name);
842       string original_grad_func = fld->FindGradient(func.name());
843       if (original_grad_func.empty()) {
844         TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
845       } else if (original_grad_func != rewritten_function_name) {
846         TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
847       }
848       break;
849     }
850     case AssociatedFunctionInfo::kFunctionAttr: {
851       // Change function attr to rewritten functions.
852       NameAttrList func;
853       TF_RETURN_IF_ERROR(
854           GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
855       node->ClearAttr(associated_function.attr_name());
856       func.set_name(rewritten_function_name);
857       node->AddAttr(associated_function.attr_name(), func);
858       break;
859     }
860   }
861 
862   return Status::OK();
863 }
864 
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)865 Status CachedFunctionHandles::GetOrInstantiate(
866     const string& func_name, AttrSlice attrs,
867     FunctionLibraryRuntime::Handle* handle) {
868   string canonicalized_name = Canonicalize(func_name, attrs);
869   auto iter = handles_.find(canonicalized_name);
870   if (iter != handles_.end()) {
871     *handle = iter->second;
872     return Status::OK();
873   }
874 
875   TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
876   handles_[canonicalized_name] = *handle;
877   return Status::OK();
878 }
879 
ReleaseAllHandles()880 Status CachedFunctionHandles::ReleaseAllHandles() {
881   Status result;
882   for (const auto& iter : handles_) {
883     result.Update(flr_->ReleaseHandle(iter.second));
884   }
885   handles_.clear();
886   return result;
887 }
888 
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)889 StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
890   // Create the replacement node.
891   Status s;
892   Node* new_node = g->AddNode(node_def, &s);
893   if (!s.ok()) {
894     return s;
895   }
896 
897   // Record original node's output edges and remove them first. This is to avoid
898   // multiple producers for dst nodes' input.
899   std::vector<OutEdgeInfo> out_edge_info;
900   std::vector<const Edge*> out_edges;
901   for (const Edge* edge : n->out_edges()) {
902     out_edges.push_back(edge);
903     out_edge_info.push_back(
904         {edge->dst(), edge->src_output(), edge->dst_input()});
905   }
906   for (const Edge* edge : out_edges) {
907     g->RemoveEdge(edge);
908   }
909 
910   // Add original node's input and output edges to the replacement node.
911   for (const Edge* in_edge : n->in_edges()) {
912     g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
913                in_edge->dst_input());
914   }
915   for (const OutEdgeInfo& out_edge : out_edge_info) {
916     g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
917   }
918 
919   // Remove the original node.
920   g->RemoveNode(n);
921 
922   return new_node;
923 }
924 
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)925 StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
926                                   DataType dtype, const Node* input,
927                                   absl::optional<string> requested_device) {
928   // Create identity node.
929   NodeDef ndef;
930   ndef.set_name(node_name);
931   ndef.set_op("Identity");
932   if (input) {
933     ndef.add_input(input->name());
934   }
935   if (requested_device) {
936     ndef.set_device(*requested_device);
937   }
938   AddNodeAttr("T", dtype, &ndef);
939   Status s;
940   Node* id_node = graph->AddNode(ndef, &s);
941   TF_RETURN_IF_ERROR(s);
942   return id_node;
943 }
944 
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)945 Status PropagateConstIntoFunctionalNodes(
946     Graph* g, const FunctionLibraryDefinition* lookup_fld,
947     FunctionLibraryDefinition* fld) {
948   absl::flat_hash_set<int> done_node_ids;
949 
950   // Because we may propagate Const around a while node as well as into it,
951   // we restart the op_nodes() iterator after each pass and keep track of which
952   // nodes we've already dealt with.
953   bool should_continue = true;
954   while (should_continue) {
955     should_continue = false;
956     for (Node* n : g->op_nodes()) {
957       if (!done_node_ids.contains(n->id())) {
958         if (n->IsIfNode()) {
959           VLOG(1) << "PropagateConstIntoIfNode: " << n->name();
960           TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
961           done_node_ids.emplace(n->id());
962           VLOG(1) << "Done PropagateConstIntoIfNode: " << n->name();
963         } else if (n->IsWhileNode()) {
964           VLOG(1) << "PropagateConstIntoWhileNode: " << n->name();
965           TF_RETURN_IF_ERROR(
966               PropagateConstIntoAndAroundWhileNode(g, n, lookup_fld, fld));
967           done_node_ids.emplace(n->id());
968           should_continue = true;
969           VLOG(1) << "Done PropagateConstIntoWhileNode: " << n->name();
970           break;
971         }
972       }
973     }
974   }
975   return Status::OK();
976 }
977 
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)978 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
979                                           FunctionLibraryDefinition* fld) {
980   GraphDef graph_def;
981   g.ToGraphDef(&graph_def);
982   FunctionLibraryDefinition reachable_functions =
983       fld->ReachableDefinitions(graph_def);
984   for (const string& func_name : fld->ListFunctionNames()) {
985     if (!reachable_functions.Find(func_name)) {
986       TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
987     }
988   }
989   return Status::OK();
990 }
991 
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)992 Status RewriteTensorListWithConstElement(Graph* g,
993                                          FunctionLibraryDefinition* fld) {
994   for (Node* n : g->nodes()) {
995     if (n->type_string() != "EmptyTensorList") {
996       continue;
997     }
998 
999     // Find the forward While op.
1000     std::vector<const Edge*> fwd_while_edges;
1001     for (const Edge* e : n->out_edges()) {
1002       if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
1003         fwd_while_edges.push_back(e);
1004       }
1005     }
1006     if (fwd_while_edges.size() != 1) {
1007       // No forward While op found, or multiple forward While ops.
1008       continue;
1009     }
1010 
1011     // Find the backward While op.
1012     Node* fwd_while = fwd_while_edges[0]->dst();
1013     int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
1014     std::vector<const Edge*> bwd_while_edges;
1015     for (const Edge* e : fwd_while->out_edges()) {
1016       if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
1017         bwd_while_edges.push_back(e);
1018       }
1019     }
1020     if (bwd_while_edges.size() != 1) {
1021       // No backward While op found, or multiple backward While ops.
1022       continue;
1023     }
1024 
1025     Node* bwd_while = bwd_while_edges[0]->dst();
1026     int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
1027 
1028     // Look into forward While body function and check if TensorListPushBack op
1029     // has a Const input.
1030     NameAttrList fwd_body_attr;
1031     TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
1032     const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
1033     if (!fwd_body) {
1034       return errors::InvalidArgument("Cannot find function ",
1035                                      fwd_body_attr.name(), " for While node ",
1036                                      fwd_while->DebugString());
1037     }
1038     std::unique_ptr<FunctionBody> fwd_fbody;
1039     TF_CHECK_OK(FunctionDefToBodyHelper(
1040         *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
1041 
1042     // Find the TensorListPushBack node; it's one of fwd_arg's successors.
1043     Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
1044     std::vector<Node*> tl_push_nodes;
1045     for (const Edge* out_edge : fwd_arg->out_edges()) {
1046       if (out_edge->dst()->type_string() == "TensorListPushBack") {
1047         tl_push_nodes.push_back(out_edge->dst());
1048       }
1049     }
1050     if (tl_push_nodes.size() != 1) {
1051       // No TensorListPushBack found, or multiple TensorListPushBack.
1052       continue;
1053     }
1054 
1055     // Get input for the TensorListPushBack node.
1056     Node* input_node;
1057     TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
1058     if (input_node->type_string() != "Const") {
1059       // Input for the TensorList is not Const node.
1060       continue;
1061     }
1062 
1063     NodeDef const_input_nodedef = input_node->def();
1064 
1065     // Rewrite backward While body function, replace usages of
1066     // TensorListPopBack with a Const node.
1067     NameAttrList bwd_body_attr;
1068     TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
1069     const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
1070     if (!bwd_body) {
1071       return errors::InvalidArgument("Cannot find function ",
1072                                      bwd_body_attr.name(), " for While node ",
1073                                      bwd_while->DebugString());
1074     }
1075     std::unique_ptr<FunctionBody> bwd_fbody;
1076     TF_CHECK_OK(FunctionDefToBodyHelper(
1077         *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
1078 
1079     // Find the TensorListPopBack node; it's one of bwd_arg's successors.
1080     Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
1081     std::vector<Node*> tl_pop_nodes;
1082     for (const Edge* out_edge : bwd_arg->out_edges()) {
1083       if (out_edge->dst()->type_string() == "TensorListPopBack") {
1084         tl_pop_nodes.push_back(out_edge->dst());
1085       }
1086     }
1087     if (tl_pop_nodes.size() != 1) {
1088       // No TensorListPopBack found, or multiple TensorListPopBack.
1089       continue;
1090     }
1091 
1092     // Replace TensorListPopBack usages with Const node.
1093     std::vector<const Edge*> edges_to_replace;
1094     for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
1095       if (e->src_output() == 1) {
1096         edges_to_replace.push_back(e);
1097       }
1098     }
1099     if (edges_to_replace.empty()) {
1100       continue;
1101     }
1102     Status s;
1103     const_input_nodedef.set_name(
1104         bwd_fbody->graph->NewName(const_input_nodedef.name()));
1105     Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
1106     TF_RETURN_IF_ERROR(s);
1107     for (const Edge* e : edges_to_replace) {
1108       Node* dst = e->dst();
1109       int dst_input = e->dst_input();
1110       bwd_fbody->graph->RemoveEdge(e);
1111       bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
1112     }
1113 
1114     // Add rewritten backward While body function.
1115     FunctionDef new_fdef;
1116     string new_name = fld->UniqueFunctionName(
1117         absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
1118     TF_RETURN_IF_ERROR(
1119         GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
1120     TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
1121 
1122     // Change backward While op to use the new body function.
1123     bwd_body_attr.set_name(new_name);
1124     bwd_while->ClearAttr("body");
1125     bwd_while->AddAttr("body", bwd_body_attr);
1126   }
1127   return Status::OK();
1128 }
1129 
1130 }  // namespace tensorflow
1131