• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17 
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23 
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/tf2xla/sharding_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_builder.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 
42 namespace tensorflow {
43 
44 namespace {
45 
ValidateTensorId(const tf2xla::TensorId & id)46 Status ValidateTensorId(const tf2xla::TensorId& id) {
47   if (id.node_name().empty()) {
48     return errors::InvalidArgument("TensorId node_name must be non-empty");
49   }
50   if (id.output_index() < 0) {
51     return errors::InvalidArgument("TensorId output_index must be positive");
52   }
53   return Status::OK();
54 }
55 
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)56 Status CheckNameDuplicates(const string& kind, const string& name,
57                            std::set<string>* names) {
58   if (!name.empty()) {
59     if (!names->insert(name).second) {
60       return errors::InvalidArgument("duplicate ", kind, " name: ", name);
61     }
62   }
63   return Status::OK();
64 }
65 
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)66 Status CheckFeedFetchNameConflicts(const string& kind,
67                                    const std::set<string>& names) {
68   // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
69   // since that will cause a collision in codegen symbols.
70   for (const string& name : names) {
71     const string name_data(name + "_data");
72     if (names.find(name_data) != names.end()) {
73       return errors::InvalidArgument("conflicting ", kind, " name: ", name,
74                                      " and ", name_data);
75     }
76   }
77   return Status::OK();
78 }
79 
80 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
81 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)82 Status CopyAssociatedFunctions(Graph* g,
83                                const FunctionLibraryDefinition* lookup_fld,
84                                FunctionLibraryDefinition* fld) {
85   for (Node* n : g->op_nodes()) {
86     for (const auto& associated_function :
87          GetAssociatedFunctions(*n, lookup_fld)) {
88       switch (associated_function.type()) {
89         case AssociatedFunctionInfo::kFunctionCallNode: {
90           const FunctionDef* fdef =
91               lookup_fld->Find(associated_function.func_name());
92           if (!fdef) {
93             return errors::Internal(
94                 "Cannot find function ", associated_function.func_name(),
95                 " for function call node ", n->DebugString());
96           }
97           TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
98           break;
99         }
100         case AssociatedFunctionInfo::kSymbolicGradient:
101         case AssociatedFunctionInfo::kFunctionAttr:
102           break;
103       }
104     }
105   }
106   return Status::OK();
107 }
108 
109 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
110 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const std::unordered_map<int,const Node * > & const_input_index_to_node)111 Status ReplaceArgUsageWithConstNode(
112     Graph* g,
113     const std::unordered_map<int, const Node*>& const_input_index_to_node) {
114   // Collect all _Arg nodes.
115   std::unordered_map<int, Node*> arg_nodes;
116   for (Node* n : g->op_nodes()) {
117     if (n->IsArg()) {
118       int index;
119       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
120       arg_nodes[index] = n;
121     }
122   }
123 
124   for (const auto& iter : const_input_index_to_node) {
125     int arg_index = iter.first;
126     NodeDef const_def = iter.second->def();
127     const_def.set_name(g->NewName(const_def.name()));
128     Status s;
129     Node* const_node = g->AddNode(const_def, &s);
130     TF_RETURN_IF_ERROR(s);
131 
132     Node* arg_node = arg_nodes[arg_index];
133 
134     // Collect all usages of the _Arg node.
135     struct OutEdgeInfo {
136       int dst_node_id, dst_input;
137     };
138     std::vector<OutEdgeInfo> usages;
139     for (const Edge* e : arg_node->out_edges()) {
140       if (e->IsControlEdge()) {
141         continue;
142       }
143       usages.push_back({e->dst()->id(), e->dst_input()});
144     }
145 
146     for (int i = 0; i < usages.size(); i++) {
147       // Make a copy of `usage_node`, and change its input to const node.
148       Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
149       NodeDef replace_def = usage_node->def();
150       *replace_def.mutable_input(usages[i].dst_input) = const_node->name();
151       TF_ASSIGN_OR_RETURN(Node * replace_node,
152                           ReplaceNode(g, usage_node, replace_def));
153       const Edge* usage_edge;
154       TF_RETURN_IF_ERROR(
155           replace_node->input_edge(usages[i].dst_input, &usage_edge));
156       g->RemoveEdge(usage_edge);
157       g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
158 
159       // Later entries in `usages` might have `usage_node` as dst node, but
160       // `usage_node` is removed. Replace such entries with `replace_node`.
161       for (int j = i + 1; j < usages.size(); j++) {
162         if (usages[j].dst_node_id == usages[i].dst_node_id) {
163           usages[j].dst_node_id = replace_node->id();
164         }
165       }
166     }
167   }
168   return Status::OK();
169 }
170 
171 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
172 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
173 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const std::unordered_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)174 Status PropagateConstIntoFuncAttr(
175     Node* n, const string& attr_name,
176     const std::unordered_map<int, const Node*>& const_input_index_to_node,
177     const FunctionLibraryDefinition* lookup_fld,
178     FunctionLibraryDefinition* fld) {
179   // Instantiate the function.
180   NameAttrList func_attr;
181   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
182   const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
183   if (!fdef) {
184     return errors::Internal("Cannot find function ", func_attr.name(),
185                             " for node ", n->name());
186   }
187   std::unique_ptr<FunctionBody> fbody;
188   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
189       *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
190 
191   // Rewrite _Arg usages with Const node.
192   Graph* func_graph = fbody->graph;
193   TF_RETURN_IF_ERROR(
194       ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
195 
196   // Save rewritten function.
197   FunctionDef replace_fdef;
198   string new_func_name =
199       fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
200   TF_RETURN_IF_ERROR(
201       GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
202   TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
203 
204   // Change the node to use rewritten function.
205   func_attr.set_name(new_func_name);
206   n->ClearAttr(attr_name);
207   n->AddAttr(attr_name, func_attr);
208 
209   // Copy associated functions.
210   TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
211 
212   return Status::OK();
213 }
214 
215 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
216 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)217 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
218                                 const FunctionLibraryDefinition* lookup_fld,
219                                 FunctionLibraryDefinition* fld) {
220   // Notice that first input for If node is predicate; other inputs are function
221   // inputs.
222   std::unordered_map<int, const Node*> const_input_index_to_node;
223   for (int i = 1; i < if_node->num_inputs(); i++) {
224     const Node* input_node;
225     TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
226     if (input_node->type_string() == "Const") {
227       const_input_index_to_node[i - 1] = input_node;
228     }
229   }
230   if (const_input_index_to_node.empty()) {
231     return Status::OK();
232   }
233 
234   // Rewrite "then_branch" and "else_branch" function, replace usage of those
235   // _Arg nodes with corresponding const node.
236   for (const auto& attr_name :
237        std::vector<string>{"then_branch", "else_branch"}) {
238     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
239         if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
240   }
241 
242   return Status::OK();
243 }
244 
245 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
246 // cond/body function to replace _Arg nodes with those Const inputs.
PropagateConstIntoWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)247 Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
248                                    const FunctionLibraryDefinition* lookup_fld,
249                                    FunctionLibraryDefinition* fld) {
250   // For "While" node, we should only replace _Arg nodes which are loop
251   // invariants. For such _Arg nodes, the return value's input will come
252   // directly from the corresponding arg.
253   std::unordered_map<int, const Node*> const_input_index_to_node;
254   NameAttrList body_attr;
255   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
256   const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
257   if (!body_func) {
258     return errors::Internal("Cannot find body function ", body_attr.name(),
259                             " for While node ", while_node->name());
260   }
261   for (int i = 0; i < while_node->num_inputs(); i++) {
262     const Node* input_node;
263     TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
264     if (input_node->type_string() != "Const") {
265       continue;
266     }
267 
268     // Check if i-th retval's input comes from i-th arg directly.
269     // For resource variable input of While nodes, TF2XLA convention is to place
270     // them at the end of all inputs (after all data inputs), and *not* return
271     // them. So number of While node inputs might be larger than number of its
272     // outputs.
273     if (i >= body_func->signature().output_arg_size()) {
274       continue;
275     }
276     const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
277     auto output_arg_input = body_func->ret().find(output_arg.name());
278     if (output_arg_input == body_func->ret().end()) {
279       return errors::Internal("Cannot find input for output arg ",
280                               output_arg.name(), " in function ",
281                               body_attr.name());
282     }
283     const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
284     if (output_arg_input->second != input_arg.name()) {
285       continue;
286     }
287 
288     const_input_index_to_node[i] = input_node;
289   }
290   if (const_input_index_to_node.empty()) {
291     return Status::OK();
292   }
293 
294   // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
295   // corresponding const node.
296   for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
297     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
298         while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
299   }
300   return Status::OK();
301 }
302 
303 }  // namespace
304 
305 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
306 
ValidateConfig(const tf2xla::Config & config)307 Status ValidateConfig(const tf2xla::Config& config) {
308   std::set<string> names;
309   for (const tf2xla::Feed& feed : config.feed()) {
310     TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
311     TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
312     TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
313   }
314   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
315   names.clear();
316   for (const tf2xla::Fetch& fetch : config.fetch()) {
317     TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
318     TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
319   }
320   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
321   if (config.fetch().empty()) {
322     return errors::InvalidArgument("fetches must be specified");
323   }
324   return Status::OK();
325 }
326 
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)327 Status AddPlaceholdersForFeeds(
328     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
329     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
330   struct PlaceholderInfo {
331     const tf2xla::Feed* feed = nullptr;  // point to Feed in <config>.
332     string placeholder_name;
333     DataType data_type = DT_INVALID;
334   };
335 
336   // Put each fed tensor into a map by name:port. A map is used for determinism
337   // when creating placeholders (genrules want deterministic output).
338   std::map<string, PlaceholderInfo> placeholder_info;
339   for (int i = 0; i < config.feed_size(); ++i) {
340     const tf2xla::Feed* feed = &config.feed(i);
341     const string name_port = TensorIdToString(feed->id());
342     PlaceholderInfo& info = placeholder_info[name_port];
343     info.feed = feed;
344     info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
345                                          "/", feed->id().node_name());
346     (*feed_remapping)[name_port] = info.placeholder_name;
347   }
348 
349   // Verify node exists and determine data type.
350   std::unordered_map<string, const NodeDef*> name_to_node;
351   for (int i = 0; i < graph_def->node_size(); ++i) {
352     name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
353   }
354   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
355     PlaceholderInfo& info = it->second;
356     const tf2xla::TensorId& feed_id = info.feed->id();
357 
358     // Find the existing node and determine data type.
359     auto node_it = name_to_node.find(feed_id.node_name());
360     if (node_it == name_to_node.end()) {
361       return errors::NotFound("Can't find feed node: ",
362                               TensorIdToString(feed_id));
363     }
364     const NodeDef* existing = node_it->second;
365 
366     if (info.feed->type() != DT_INVALID) {
367       info.data_type = info.feed->type();
368     } else {
369       // Build the node in order to infer its type.
370 
371       // Must first add default attrs as well, so do this in a copied GraphDef.
372       GraphDef gd;
373       *gd.mutable_versions() = graph_def->versions();
374       *gd.add_node() = *existing;
375       MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
376       TF_RETURN_IF_ERROR(
377           AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
378 
379       // Now build the node from the copied node def.
380       Graph g(op_registry);
381       g.set_versions(graph_def->versions());
382       Status status;
383       Node* feed_node = g.AddNode(gd.node(0), &status);
384       TF_RETURN_IF_ERROR(status);
385 
386       if (info.feed->id().output_index() < feed_node->num_outputs()) {
387         info.data_type =
388             BaseType(feed_node->output_type(info.feed->id().output_index()));
389       } else {
390         return errors::InvalidArgument(
391             "Invalid output_index ", info.feed->id().output_index(),
392             " for feed node ", info.feed->id().node_name());
393       }
394     }
395   }
396 
397   // Create placeholders. Note that we could avoid creating a placeholder for
398   // feeds which are already placeholders, but we omit that to avoid more cases
399   // in this code.
400   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
401     const PlaceholderInfo& info = it->second;
402     // TODO(shikharagarwal): Add original node information.
403     NodeDef* d = graph_def->add_node();
404     d->set_name(info.placeholder_name);
405     d->set_op("PlaceholderV2");
406     auto& attr_map = *d->mutable_attr();
407     attr_map["dtype"].set_type(info.data_type);
408     *attr_map["shape"].mutable_shape() = info.feed->shape();
409   }
410 
411   // Rewrite references to the fed tensors to refer to the placeholder.
412   for (int i = 0; i < graph_def->node_size(); ++i) {
413     NodeDef* node_def = graph_def->mutable_node(i);
414     for (int j = 0; j < node_def->input_size(); ++j) {
415       auto id = ParseTensorName(node_def->input(j));
416       auto it = placeholder_info.find(id.ToString());
417       if (it != placeholder_info.end()) {
418         node_def->set_input(j, it->second.placeholder_name);
419       }
420     }
421   }
422 
423   return Status::OK();
424 }
425 
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)426 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
427                          GraphDef* out) {
428   *out = in;
429   out->clear_node();
430 
431   // Tensors needed for feeding.
432   std::set<std::pair<string, int>> feed_tensors;
433   for (const tf2xla::Feed& feed : config.feed()) {
434     feed_tensors.insert(
435         std::make_pair(feed.id().node_name(), feed.id().output_index()));
436   }
437 
438   // Maps node name to reachability.
439   std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
440   for (const NodeDef& node : in.node()) {
441     node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
442   }
443 
444   // Traverse.
445   std::queue<string> name_queue;
446   for (int i = 0; i < config.fetch_size(); ++i) {
447     name_queue.push(config.fetch(i).id().node_name());
448   }
449   while (!name_queue.empty()) {
450     const string name = name_queue.front();
451     name_queue.pop();
452 
453     auto find_it = node_by_name.find(name);
454     if (find_it == node_by_name.end()) {
455       return errors::InvalidArgument("While pruning graph, node ", name,
456                                      " needed but not found in the graph.");
457     }
458     auto& map_entry = find_it->second;
459     if (map_entry.first) {
460       continue;
461     }
462     map_entry.first = true;
463 
464     // Push input nodes of the currently visited node to name_queue.
465     for (const string& in_edge : map_entry.second->input()) {
466       auto id = ParseTensorName(in_edge);
467       const string node_name = string(id.first);
468       if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
469           feed_tensors.end()) {
470         name_queue.push(node_name);
471       } else {
472         // The input tensor is from an edge that is being fed. Therefore,
473         // we skip recursing down that edge, to avoid requiring nodes that
474         // may not be needed (note that the input node may still be added
475         // to name_queue later if one of its output edges is not being fed).
476       }
477     }
478   }
479 
480   // Copy over, preserving order of original and only nodes that are reachable
481   // from the fetches.
482   out->mutable_node()->Reserve(in.node_size());
483   for (const NodeDef& node : in.node()) {
484     if (node_by_name[node.name()].first) {
485       *out->add_node() = node;
486     }
487   }
488   return Status::OK();
489 }
490 
TensorIdToString(const tf2xla::TensorId & id)491 string TensorIdToString(const tf2xla::TensorId& id) {
492   return absl::StrCat(id.node_name(), ":", id.output_index());
493 }
494 
SetNodeShardingFromNeighbors(Node * n,bool out_edges)495 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
496   int core = -1;
497   const Node* matching_node = nullptr;
498   for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
499     if (edge->IsControlEdge()) continue;
500     const Node* possible_match = out_edges ? edge->dst() : edge->src();
501     TF_ASSIGN_OR_RETURN(
502         absl::optional<xla::OpSharding> sharding,
503         ParseShardingFromDevice(
504             *possible_match,
505             /*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
506     if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
507       const int core_annotation = sharding.value().tile_assignment_devices(0);
508       if (core == -1 || core > core_annotation) {
509         core = core_annotation;
510         matching_node = possible_match;
511       }
512     }
513   }
514   if (matching_node != nullptr) {
515     n->set_assigned_device_name(matching_node->assigned_device_name());
516     n->set_requested_device(matching_node->requested_device());
517   }
518   return Status::OK();
519 }
520 
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)521 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
522                                    KernelDef* kdef) {
523   for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
524     if (constraint.name() == name) {
525       constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
526     }
527   }
528 }
529 
530 namespace {
InitialRandomSeed()531 uint32 InitialRandomSeed() {
532   // Support plumbing the TF seed through to XLA is being worked on.
533   // If a user wants deterministic behavior, their best option
534   // is to start with a known checkpoint. This also handles issues when
535   // multiple random calls can be invoked in any order by TF executor.
536   // Another option is to use stateless random ops. They have much cleaner
537   // semantics.
538   // If a user really wants to set a deterministic seed for XLA-based
539   // devices, this is the place to do it.
540   std::random_device rd;
541   // Make the starting value odd.
542   return rd() | 1;
543 }
544 }  // namespace
545 
GetXLARandomSeed()546 uint32 GetXLARandomSeed() {
547   // We initialize counter with an odd number and increment it by two
548   // everytime. This ensures that it will never be zero, even
549   // after an overflow. When seeded with zero, some XLA backends
550   // can return all zeros instead of random numbers.
551   static std::atomic<uint32> counter(InitialRandomSeed());
552   uint32 seed = counter.fetch_add(2);
553   std::srand(seed);
554   return std::rand() | 1;
555 }
556 
557 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)558 bool HasAssociatedFunction(const NodeDef& node_def,
559                            const FunctionLibraryDefinition* fld) {
560   if (fld->Contains(node_def.op())) {
561     return true;
562   }
563 
564   if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
565     // Gradient op has "f" attr, which is set to the function we are getting
566     // gradient for. We need to functionalize the gradient function.
567     return true;
568   }
569 
570   if (node_def.op() == "XlaHostCompute") {
571     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
572     // related to graph execution.
573     return false;
574   }
575 
576   for (const auto& iter : node_def.attr()) {
577     if (iter.second.has_func()) {
578       return true;
579     }
580   }
581 
582   return false;
583 }
584 
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)585 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
586     const Node& node, const FunctionLibraryDefinition* fld) {
587   std::vector<AssociatedFunctionInfo> results;
588   const string& op = node.type_string();
589   if (fld->Contains(op)) {
590     // This is a function call node.
591     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
592     results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
593   } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
594     // This is a SymbolicGradient op.
595     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
596     results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
597   } else if (node.type_string() == "XlaHostCompute") {
598     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
599     // related to graph execution.
600   } else {
601     // Collect all function attrs for the node.
602     for (auto& iter : node.attrs()) {
603       if (iter.second.has_func()) {
604         VLOG(2) << "Found function attr for node " << node.name() << ": "
605                 << iter.first << " = " << iter.second.func().name();
606         results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
607             iter.second.func().name(), iter.second.func().attr(), iter.first));
608       }
609     }
610   }
611   return results;
612 }
613 
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)614 Status RewriteAssociatedFunction(
615     Graph* graph, Node* node, FunctionLibraryDefinition* fld,
616     const AssociatedFunctionInfo& associated_function,
617     const string& rewritten_function_name) {
618   switch (associated_function.type()) {
619     case AssociatedFunctionInfo::kFunctionCallNode: {
620       // Change this node to call the new function.
621       NodeDebugInfo debug_info(*node);
622       NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
623                              &debug_info);
624       for (auto attr : node->attrs()) {
625         builder.Attr(attr.first, attr.second);
626       }
627       for (int i = 0; i < node->num_inputs(); i++) {
628         Node* input_node;
629         TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
630         builder.Input(input_node->name(), i, node->input_type(i));
631       }
632       builder.Device(node->assigned_device_name().empty()
633                          ? node->requested_device()
634                          : node->assigned_device_name());
635       NodeDef node_def;
636       TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
637       Status s;
638       Node* new_node = graph->AddNode(node_def, &s);
639       TF_RETURN_IF_ERROR(s);
640       for (auto edge : node->in_edges()) {
641         graph->AddEdge(edge->src(), edge->src_output(), new_node,
642                        edge->dst_input());
643       }
644       for (auto edge : node->out_edges()) {
645         graph->AddEdge(new_node, edge->src_output(), edge->dst(),
646                        edge->dst_input());
647       }
648       graph->RemoveNode(node);
649       break;
650     }
651     case AssociatedFunctionInfo::kSymbolicGradient: {
652       NameAttrList func;
653       TF_RETURN_IF_ERROR(GetNodeAttr(
654           node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
655       GradientDef gradient_def;
656       gradient_def.set_function_name(func.name());
657       gradient_def.set_gradient_func(rewritten_function_name);
658       string original_grad_func = fld->FindGradient(func.name());
659       if (original_grad_func.empty()) {
660         TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
661       } else if (original_grad_func != rewritten_function_name) {
662         TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
663       }
664       break;
665     }
666     case AssociatedFunctionInfo::kFunctionAttr: {
667       // Change function attr to rewritten functions.
668       NameAttrList func;
669       TF_RETURN_IF_ERROR(
670           GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
671       node->ClearAttr(associated_function.attr_name());
672       func.set_name(rewritten_function_name);
673       node->AddAttr(associated_function.attr_name(), func);
674       break;
675     }
676   }
677 
678   return Status::OK();
679 }
680 
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)681 Status CachedFunctionHandles::GetOrInstantiate(
682     const string& func_name, AttrSlice attrs,
683     FunctionLibraryRuntime::Handle* handle) {
684   string canonicalized_name = Canonicalize(func_name, attrs);
685   auto iter = handles_.find(canonicalized_name);
686   if (iter != handles_.end()) {
687     *handle = iter->second;
688     return Status::OK();
689   }
690 
691   TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
692   handles_[canonicalized_name] = *handle;
693   return Status::OK();
694 }
695 
ReleaseAllHandles()696 Status CachedFunctionHandles::ReleaseAllHandles() {
697   Status result;
698   for (auto iter : handles_) {
699     result.Update(flr_->ReleaseHandle(iter.second));
700   }
701   handles_.clear();
702   return result;
703 }
704 
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)705 xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
706   // Create the replacement node.
707   Status s;
708   Node* new_node = g->AddNode(node_def, &s);
709   if (!s.ok()) {
710     return s;
711   }
712 
713   // Record original node's output edges and remove them first. This is to avoid
714   // multiple producers for dst nodes' input.
715   std::vector<OutEdgeInfo> out_edge_info;
716   std::vector<const Edge*> out_edges;
717   for (const Edge* edge : n->out_edges()) {
718     out_edges.push_back(edge);
719     out_edge_info.push_back(
720         {edge->dst(), edge->src_output(), edge->dst_input()});
721   }
722   for (const Edge* edge : out_edges) {
723     g->RemoveEdge(edge);
724   }
725 
726   // Add original node's input and output edges to the replacement node.
727   for (const Edge* in_edge : n->in_edges()) {
728     g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
729                in_edge->dst_input());
730   }
731   for (const OutEdgeInfo& out_edge : out_edge_info) {
732     g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
733   }
734 
735   // Remove the original node.
736   g->RemoveNode(n);
737 
738   return new_node;
739 }
740 
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)741 xla::StatusOr<Node*> BuildIdentityNode(
742     Graph* graph, const string& node_name, DataType dtype, const Node* input,
743     absl::optional<string> requested_device) {
744   // Create identity node.
745   NodeDef ndef;
746   ndef.set_name(node_name);
747   ndef.set_op("Identity");
748   if (input) {
749     ndef.add_input(input->name());
750   }
751   if (requested_device) {
752     ndef.set_device(*requested_device);
753   }
754   AddNodeAttr("T", dtype, &ndef);
755   Status s;
756   Node* id_node = graph->AddNode(ndef, &s);
757   TF_RETURN_IF_ERROR(s);
758   return id_node;
759 }
760 
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)761 Status PropagateConstIntoFunctionalNodes(
762     Graph* g, const FunctionLibraryDefinition* lookup_fld,
763     FunctionLibraryDefinition* fld) {
764   for (Node* n : g->op_nodes()) {
765     if (n->IsIfNode()) {
766       TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
767     } else if (n->IsWhileNode()) {
768       TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
769     }
770   }
771   return Status::OK();
772 }
773 
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)774 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
775                                           FunctionLibraryDefinition* fld) {
776   GraphDef graph_def;
777   g.ToGraphDef(&graph_def);
778   FunctionLibraryDefinition reachable_functions =
779       fld->ReachableDefinitions(graph_def);
780   for (const string& func_name : fld->ListFunctionNames()) {
781     if (!reachable_functions.Find(func_name)) {
782       TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
783     }
784   }
785   return Status::OK();
786 }
787 
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)788 Status RewriteTensorListWithConstElement(Graph* g,
789                                          FunctionLibraryDefinition* fld) {
790   for (Node* n : g->nodes()) {
791     if (n->type_string() != "EmptyTensorList") {
792       continue;
793     }
794 
795     // Find the forward While op.
796     std::vector<const Edge*> fwd_while_edges;
797     for (const Edge* e : n->out_edges()) {
798       if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
799         fwd_while_edges.push_back(e);
800       }
801     }
802     if (fwd_while_edges.size() != 1) {
803       // No forward While op found, or multiple forward While ops.
804       continue;
805     }
806 
807     // Find the backward While op.
808     Node* fwd_while = fwd_while_edges[0]->dst();
809     int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
810     std::vector<const Edge*> bwd_while_edges;
811     for (const Edge* e : fwd_while->out_edges()) {
812       if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
813         bwd_while_edges.push_back(e);
814       }
815     }
816     if (bwd_while_edges.size() != 1) {
817       // No backward While op found, or multiple backward While ops.
818       continue;
819     }
820 
821     Node* bwd_while = bwd_while_edges[0]->dst();
822     int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
823 
824     // Look into forward While body function and check if TensorListPushBack op
825     // has a Const input.
826     NameAttrList fwd_body_attr;
827     TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
828     const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
829     if (!fwd_body) {
830       return errors::InvalidArgument("Cannot find function ",
831                                      fwd_body_attr.name(), " for While node ",
832                                      fwd_while->DebugString());
833     }
834     std::unique_ptr<FunctionBody> fwd_fbody;
835     TF_CHECK_OK(FunctionDefToBodyHelper(
836         *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
837 
838     // Find the TensorListPushBack node; it's one of fwd_arg's successors.
839     Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
840     std::vector<Node*> tl_push_nodes;
841     for (const Edge* out_edge : fwd_arg->out_edges()) {
842       if (out_edge->dst()->type_string() == "TensorListPushBack") {
843         tl_push_nodes.push_back(out_edge->dst());
844       }
845     }
846     if (tl_push_nodes.size() != 1) {
847       // No TensorListPushBack found, or multiple TensorListPushBack.
848       continue;
849     }
850 
851     // Get input for the TensorListPushBack node.
852     Node* input_node;
853     TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
854     if (input_node->type_string() != "Const") {
855       // Input for the TensorList is not Const node.
856       continue;
857     }
858 
859     NodeDef const_input_nodedef = input_node->def();
860 
861     // Rewrite backward While body function, replace usages of
862     // TensorListPopBack with a Const node.
863     NameAttrList bwd_body_attr;
864     TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
865     const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
866     if (!bwd_body) {
867       return errors::InvalidArgument("Cannot find function ",
868                                      bwd_body_attr.name(), " for While node ",
869                                      bwd_while->DebugString());
870     }
871     std::unique_ptr<FunctionBody> bwd_fbody;
872     TF_CHECK_OK(FunctionDefToBodyHelper(
873         *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
874 
875     // Find the TensorListPopBack node; it's one of bwd_arg's successors.
876     Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
877     std::vector<Node*> tl_pop_nodes;
878     for (const Edge* out_edge : bwd_arg->out_edges()) {
879       if (out_edge->dst()->type_string() == "TensorListPopBack") {
880         tl_pop_nodes.push_back(out_edge->dst());
881       }
882     }
883     if (tl_pop_nodes.size() != 1) {
884       // No TensorListPopBack found, or multiple TensorListPopBack.
885       continue;
886     }
887 
888     // Replace TensorListPopBack usages with Const node.
889     std::vector<const Edge*> edges_to_replace;
890     for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
891       if (e->src_output() == 1) {
892         edges_to_replace.push_back(e);
893       }
894     }
895     if (edges_to_replace.empty()) {
896       continue;
897     }
898     Status s;
899     const_input_nodedef.set_name(
900         bwd_fbody->graph->NewName(const_input_nodedef.name()));
901     Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
902     TF_RETURN_IF_ERROR(s);
903     for (const Edge* e : edges_to_replace) {
904       Node* dst = e->dst();
905       int dst_input = e->dst_input();
906       bwd_fbody->graph->RemoveEdge(e);
907       bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
908     }
909 
910     // Add rewritten backward While body function.
911     FunctionDef new_fdef;
912     string new_name = fld->UniqueFunctionName(
913         absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
914     TF_RETURN_IF_ERROR(
915         GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
916     TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
917 
918     // Change backward While op to use the new body function.
919     bwd_body_attr.set_name(new_name);
920     bwd_while->ClearAttr("body");
921     bwd_while->AddAttr("body", bwd_body_attr);
922   }
923   return Status::OK();
924 }
925 
926 }  // namespace tensorflow
927