• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
31 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/tensor.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_def_builder.h"
50 #include "tensorflow/core/graph/tensor_id.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 
58 namespace tensorflow {
59 
60 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
61 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
62 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
63 const char* const kXlaHostTransferSequencerAttr =
64     "_xla_host_transfer_sequencer";
65 const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
66 
SortControlInputs(GraphDef * gdef)67 void SortControlInputs(GraphDef* gdef) {
68   int64_t num_nodes = gdef->node_size();
69   for (int64_t i = 0; i < num_nodes; ++i) {
70     NodeDef* node = gdef->mutable_node(i);
71     // Stable sort control inputs and leave the order of data inputs unchanged.
72     std::stable_sort(node->mutable_input()->begin(),
73                      node->mutable_input()->end(),
74                      [](const string& a, const string& b) {
75                        bool a_is_control = absl::StartsWith(a, "^");
76                        bool b_is_control = absl::StartsWith(b, "^");
77                        return (!a_is_control && b_is_control) ||
78                               (a_is_control && b_is_control && a < b);
79                      });
80   }
81 }
82 
83 namespace {
84 
AreAllParentsGuaranteedConst(const Node & n,const absl::flat_hash_set<const Node * > & runtime_const_nodes)85 bool AreAllParentsGuaranteedConst(
86     const Node& n,
87     const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
88   if (n.type_string() == "GuaranteeConst") {
89     // If the current node is itself a cast-to-const, no need
90     // to look at the incoming edges.
91     return true;
92   }
93 
94   bool all_parents_const = true;
95   bool atleast_one_non_control_edge = false;
96   for (const Edge* in : n.in_edges()) {
97     atleast_one_non_control_edge =
98         atleast_one_non_control_edge || !in->IsControlEdge();
99     if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
100       all_parents_const = false;
101       break;
102     }
103   }
104   return all_parents_const && atleast_one_non_control_edge;
105 }
106 
MarkGuaranteedConstants(const Graph & graph,const std::vector<std::pair<const Node *,Node * >> & src_arg_pairs)107 void MarkGuaranteedConstants(
108     const Graph& graph,
109     const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
110   absl::flat_hash_set<const Node*> guaranteed_const_nodes;
111   std::vector<const Node*> srcs;
112   srcs.reserve(src_arg_pairs.size());
113   for (const auto& src_arg : src_arg_pairs) {
114     srcs.push_back(src_arg.first);
115   }
116   ReverseDFSFrom(
117       graph, srcs, /*enter=*/nullptr,
118       /*leave=*/[&guaranteed_const_nodes](const Node* n) {
119         // TODO(vinuraja): Doesn't work in the presence of loops.
120         if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
121           guaranteed_const_nodes.insert(n);
122         }
123       });
124 
125   for (auto& src_arg : src_arg_pairs) {
126     if (guaranteed_const_nodes.count(src_arg.first) != 0) {
127       VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
128       src_arg.second->AddAttr("_is_guaranteed_constant", true);
129     }
130   }
131 }
132 
133 struct OutputInputTensorPairHasher {
operator ()tensorflow::__anon5401122e0211::OutputInputTensorPairHasher134   uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
135     return Hash64Combine(OutputTensor::Hash()(s.first),
136                          InputTensor::Hash()(s.second));
137   }
138 };
139 
140 // TODO(phawkins) add a canonical copy of these operator names and refactor
141 // everything to use it.
142 static const char* const kArgOp = "_Arg";
143 static const char* const kRetValOp = "_Retval";
144 
145 class Encapsulator {
146  public:
Encapsulator(string group_attribute,Graph const * graph_in)147   Encapsulator(string group_attribute, Graph const* graph_in)
148       : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
149 
150   // Find subgraphs marked with 'group_attribute', and build a new
151   // subgraph, one for each value of 'group_attribute'.
152   Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
153 
154   // Build a FunctionDef for each subgraph, and add it 'library'. The values of
155   // the 'group_attribute' annotations become the function names.
156   // If 'reuse_existing_functions' is set, use an existing function with the
157   // same name, if any.
158   // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
159   // function conversion.
160   Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
161                            bool reuse_existing_functions,
162                            FunctionLibraryDefinition* library);
163 
164   // Write a copy of the input graph to 'graph_out', where the subgraphs are
165   // replaced with calls to the new functions.
166   Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
167 
168  private:
169   // A subgraph of the input, all marked with a common 'group_attribute'
170   // value.
171   //
172   // In the following simple example, A, B, ..., E are nodes in the original
173   // graph. The group attributes g are each shown as either 0 or empty.
174   //
175   //  A  -->  B  -->  C  -->  D  -->  E
176   //  g:      g:0     g:0     g:0     g:
177   //
178   // The example is rewritten to two graphs; one on the host and one to be
179   // compiled. The host graph is as follows.
180   //
181   //  A  -->  Call  -->  E
182   //
183   // The compiled cluster is as follows.
184   //
185   //  Arg  --> B  --> C  --> D --> Retval
186   class Subgraph {
187    public:
188     // Creates a graph to build the subgraph in, if it doesn't already exist,
189     // using the same op registry and versions as graph_in.
190     Node* MakeNodeImage(const Graph* graph_in, Node* node);
191 
192     // Returns the graph the subgraph is being built in.
193     Graph* GetGraph() const;
194 
195     // Builds a FunctionDef, and adds it to 'library'. The value of the
196     // 'group_attribute' annotations becomes the function name.  If
197     // 'reuse_existing_functions' is set, use an existing function with the same
198     // name, if any.  If 'rewrite_subgraph_fn' is set, it is applied to the
199     // subgraph before function conversion.
200     Status BuildFunctionDef(const string& name_in,
201                             const RewriteSubgraphFn& rewrite_subgraph_fn,
202                             bool reuse_existing_functions,
203                             FunctionLibraryDefinition* library);
204 
205     // Adds the function call node to graph_out.
206     Status AddFunctionCallNode(
207         const std::unordered_map<const Node*, Node*>& node_images,
208         Graph* graph_out);
209 
210     // Returns the Node that the inputs and outputs of the function should be
211     // wired up to.
212     Node* GetCallNode() const;
213 
214     // Returns the index of the arg that the dst of edge should connect to.
215     int GetArgIndexForEdge(const Edge* edge) const;
216 
217     // Returns the index of the result that the src of edge should connect to.
218     int GetResultIndexForEdge(const Edge* edge) const;
219 
220     // Creates an _Arg node for the src node of edge, and add its index to
221     // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
222     // and adds the edge within the subgraph from the _Arg node to the image of
223     // the dst node.
224     Status RecordArg(const Edge* edge,
225                      const std::unordered_map<const Node*, Node*>& node_images,
226                      std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
227 
228     // Records the src of the given edge as a control result of the graph.
229     // Used during graph to function conversion to tie control results to
230     // the function signature.
231     Status RecordControlResult(
232         const Edge* edge,
233         const std::unordered_map<const Node*, Node*>& node_images);
234 
235     // Creates a _Retval node for the src node of edge, and add it to results_,
236     // if none exists yet. If a new _Retval node is created, also adds the edge
237     // within the subgraph from the src to the _Retval node.
238     Status RecordResult(
239         const Edge* edge,
240         const std::unordered_map<const Node*, Node*>& node_images);
241 
242     // Creates the sequencer node if it doesn't exist, adding it to graph_out.
243     Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
244 
245     // If there is a sequencer node, adds a control edge from the sequencer to
246     // the call node.
247     void ConnectSequencerToCallNode(Graph* graph_out);
248 
249     Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
250 
251    private:
252     // The subgraph extracted from the input graph, suitable for being turned
253     // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
254     // returned by _Retval nodes.
255     std::unique_ptr<Graph> graph_;
256 
257     // Which device are these nodes on? Used to assign a device to the call
258     // node.
259     string device_;
260 
261     // NodeDef for the function call node.
262     NodeDef call_node_def_;
263 
264     // Name that is used for the call node. This may not be
265     // call_node_def_.name() if the client supplies a rewrite lambda.
266     string function_def_name_;
267 
268     // Placeholder node simulating the host compute key in the output graph.
269     // Not owned.
270     Node* host_compute_key_placeholder_ = nullptr;
271 
272     // Function call node in the output graph. Not owned.
273     Node* call_node_;
274 
275     // Maps from source (producer node/slot) and destination
276     // (consumer node/slot) tensors in the input graph to _Arg numbers in
277     // the subgraph. The source map is one-to-one, whereas the dest map may be
278     // many-to-one.
279     std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
280     std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
281 
282     // The arguments to the subgraph, in order.
283     std::vector<Node*> args_;
284 
285     // Map from source tensor in the input graph to result #.
286     std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
287 
288     // Set of node names that are the source of a control output of the
289     // subgraph. We store strings here so that we can tolerate nodes being
290     // removed from the graph.
291     absl::flat_hash_set<string> control_output_nodes_;
292 
293     // NoOp node in the output graph that is sequenced after the call node.
294     Node* sequencer_ = nullptr;
295   };
296 
297   // Returns the key attribute associated with a node in attr. Sets either
298   // result to the empty string if the respective attribute is not found.
299   Status GetFunctionNameAttr(Node const* node, string* attr) const;
300 
301   // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
302   // subgraphs for data edges that cross subgraph boundaries.
303   Status CopySubgraphEdges(
304       const std::unordered_map<const Node*, Node*>& node_images,
305       std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
306 
307   // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
308   Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
309 
310   // Copies all nodes that aren't in a compiled subgraph to the output graph.
311   Status CopyNodesToOutputGraph(
312       Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
313 
314   // Adds function call nodes for each compiled subgraph.
315   Status AddFunctionCallNodes(
316       const std::unordered_map<const Node*, Node*>& node_images,
317       Graph* graph_out);
318 
319   // Finds the image of an edge source in the output graph. If the edge crosses
320   // a subgraph boundary it is the output of a call node, otherwise it is a node
321   // in the output graph.
322   Status FindOutputImageOfEdgeSrc(
323       const string& src_func_id, const string& dst_func_id,
324       const std::unordered_map<const Node*, Node*>& node_images,
325       const Node* original_src_node, Node** src_image);
326 
327   // Finds an edge source slot in the output graph. If the edge crosses a
328   // subgraph boundary it is a slot on the output of a call node, otherwise it
329   // is a slot on a node in the output graph.
330   int FindOutputSlotOfEdgeSrc(const string& src_func_id,
331                               const string& dst_func_id,
332                               const Edge* edge);
333 
334   // Finds the image of an edge destination in the output graph. If the edge
335   // crosses a subgraph boundary it is the input of a call node, otherwise it is
336   // a node in the output graph.
337   Status FindOutputImageOfEdgeDst(
338       const string& src_func_id, const string& dst_func_id,
339       const std::unordered_map<const Node*, Node*>& node_images,
340       const Node* original_dst_node, Node** dst_image);
341 
342   // Finds an edge destination slot in the output graph. If the edge crosses a
343   // subgraph boundary it is a slot on the input of a call node, otherwise it is
344   // a slot on a node in the output graph.
345   int FindOutputSlotOfEdgeDst(const string& src_func_id,
346                               const string& dst_func_id,
347                               const Edge* edge);
348 
349   // Copies a single edge to the output graph. The edge is either entirely
350   // within the output graph, or crosses into or out of a compiled subgraph.
351   Status CopyEdgeToOutputGraph(
352       const Edge* edge, const string& src_func_id, const string& dst_func_id,
353       const std::unordered_map<const Node*, Node*>& node_images,
354       Graph* graph_out,
355       std::unordered_set<std::pair<OutputTensor, InputTensor>,
356                          OutputInputTensorPairHasher>* edges_added);
357 
358   // Adds all edges to the output graph.
359   Status AddEdgesToOutputGraph(
360       const std::unordered_map<const Node*, Node*>& node_images,
361       Graph* graph_out);
362 
363   // Makes a copy of graph containing only nodes that are ancestors of at least
364   // one node in send_from_host_nodes and store it in pruned_graph. On exit
365   // nodes_images contains a mapping from nodes in graph to nodes in
366   // pruned_graph. All functions in the copied graph are inlined.
367   Status MakePrunedGraphCopyAndInline(
368       const Graph& graph, const std::vector<Node*>& sink_nodes,
369       std::unique_ptr<Graph>* pruned_graph,
370       std::unordered_map<const Node*, Node*>* node_images,
371       FunctionLibraryDefinition* library);
372 
373   const string group_attribute_;
374   const Graph* graph_in_;
375 
376   std::unordered_map<string, Subgraph> subgraphs_;
377 
378   TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
379 };
380 
381 namespace {
382 
383 // Return in 'sorted' a topological sort of clusters according to the
384 // dependencies encoded in ancestors. clusters is the list of all clusters
385 // including clusters that are not present in the ancestors map. has_successors
386 // is the set of clusters that are ancestors of some other cluster.
TopologicalClusterSort(const std::unordered_set<string> & clusters,const std::unordered_set<string> & has_successors,const std::unordered_map<string,std::unordered_set<string>> & ancestors,std::vector<string> * sorted)387 void TopologicalClusterSort(
388     const std::unordered_set<string>& clusters,
389     const std::unordered_set<string>& has_successors,
390     const std::unordered_map<string, std::unordered_set<string>>& ancestors,
391     std::vector<string>* sorted) {
392   // The nodes are placed in 'sorted' in topological order.
393   sorted->clear();
394   // We don't use the standard DFS because we are not operating on Node*
395   // objects.
396   struct Work {
397     string cluster;
398     bool leave;
399   };
400   std::set<string> visited;
401   std::vector<Work> stack;
402   // Seed the processing list with clusters that have no successors.
403   for (const auto& cluster : clusters) {
404     if (has_successors.find(cluster) == has_successors.end()) {
405       stack.push_back({cluster, false});
406     }
407   }
408   while (!stack.empty()) {
409     const Work item = stack.back();
410     stack.pop_back();
411     if (item.leave) {
412       sorted->push_back(item.cluster);
413       continue;
414     }
415 
416     if (visited.find(item.cluster) != visited.end()) continue;
417     visited.insert(item.cluster);
418 
419     stack.push_back({item.cluster, true});
420     const auto& iter = ancestors.find(item.cluster);
421     if (iter != ancestors.end()) {
422       for (const auto& ancestor : iter->second) {
423         stack.push_back({ancestor, false});
424       }
425     }
426   }
427   CHECK(sorted->size() == clusters.size());
428 }
429 
430 }  // namespace
431 
GetCallNode() const432 Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
433 
GetArgIndexForEdge(const Edge * edge) const434 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
435   return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
436 }
437 
GetResultIndexForEdge(const Edge * edge) const438 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
439   return results_.at(OutputTensor(edge->src(), edge->src_output()));
440 }
441 
MakeNodeImage(const Graph * graph_in,Node * node)442 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
443   if (!graph_) {
444     graph_.reset(new Graph(graph_in->op_registry()));
445     graph_->set_versions(graph_in->versions());
446   }
447 
448   // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
449   // determined. In case of hard placement, ensure all the encapsulated nodes
450   // have the same requested device, which in turn will be the requested device
451   // for the entire encapsulated subgraph. In case of soft placement, use a
452   // deterministic approach to fill in the requested device. Handle co-location
453   // constraints similarly if they exist.
454   if (device_.empty()) {
455     device_ = node->assigned_device_name().empty()
456                   ? node->requested_device()
457                   : node->assigned_device_name();
458   }
459 
460   return graph_->CopyNode(node);
461 }
462 
GetGraph() const463 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
464 
RecordArg(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)465 Status Encapsulator::Subgraph::RecordArg(
466     const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
467     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
468   Node* src_node = edge->src();
469   int src_slot = edge->src_output();
470   std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
471   bool inserted;
472   std::tie(iter, inserted) = args_by_src_.emplace(
473       OutputTensor(src_node, src_slot), args_by_src_.size());
474   int arg_index = iter->second;
475   if (inserted) {
476     NodeDef arg_def;
477     NodeDefBuilder builder(
478         absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
479         NodeDebugInfo(src_node->def()));
480     DataType dtype = edge->dst()->input_type(edge->dst_input());
481     builder.Attr("T", dtype);
482     builder.Attr("index", arg_index);
483     Status s = builder.Finalize(&arg_def);
484     if (!s.ok()) return s;
485 
486     Node* arg = graph_->AddNode(arg_def, &s);
487     if (!s.ok()) return s;
488 
489     src_arg_pairs->push_back({src_node, arg});
490     args_.push_back(arg);
491   }
492   Node* dst_node = edge->dst();
493   Node* dst_image = node_images.at(dst_node);
494   int dst_slot = edge->dst_input();
495   args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
496   graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
497   return Status::OK();
498 }
499 
RecordControlResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)500 Status Encapsulator::Subgraph::RecordControlResult(
501     const Edge* edge,
502     const std::unordered_map<const Node*, Node*>& node_images) {
503   Node* src_node = edge->src();
504   Node* src_image = node_images.at(src_node);
505   control_output_nodes_.insert(src_image->name());
506   return Status::OK();
507 }
508 
RecordResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)509 Status Encapsulator::Subgraph::RecordResult(
510     const Edge* edge,
511     const std::unordered_map<const Node*, Node*>& node_images) {
512   Node* src_node = edge->src();
513   Node* src_image = node_images.at(src_node);
514   int src_slot = edge->src_output();
515   std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
516   bool inserted;
517   std::tie(iter, inserted) =
518       results_.emplace(OutputTensor(src_node, src_slot), results_.size());
519   int ret_index = iter->second;
520   if (inserted) {
521     NodeDef ret_def;
522     NodeDefBuilder builder(
523         absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
524         NodeDebugInfo(src_node->def()));
525     DataType dtype = src_node->output_type(src_slot);
526     builder.Attr("T", dtype);
527     builder.Attr("index", ret_index);
528     builder.Input(src_image->name(), src_slot, dtype);
529     Status s = builder.Finalize(&ret_def);
530     if (!s.ok()) return s;
531     Node* ret = graph_->AddNode(ret_def, &s);
532     if (!s.ok()) return s;
533 
534     graph_->AddEdge(src_image, src_slot, ret, 0);
535   }
536   return Status::OK();
537 }
538 
MakeSequencingNode(const string & subgraph_name,Graph * graph_out)539 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
540                                                   Graph* graph_out) {
541   if (sequencer_ == nullptr) {
542     NodeDef seq_def;
543     // TODO(shikharagarwal): What source node should we use for errors?
544     NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
545     builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
546     builder.Device(device_);
547     Status s = builder.Finalize(&seq_def);
548     if (!s.ok()) return s;
549 
550     sequencer_ = graph_out->AddNode(seq_def, &s);
551     if (!s.ok()) return s;
552   }
553   return Status::OK();
554 }
555 
ConnectSequencerToCallNode(Graph * graph_out)556 void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
557   if (sequencer_ != nullptr) {
558     VLOG(2) << "ConnectSequencerToCallNode";
559     graph_out->AddControlEdge(sequencer_, call_node_,
560                               /* allow_duplicates= */ true);
561   }
562 }
563 
BuildFunctionDef(const string & name_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)564 Status Encapsulator::Subgraph::BuildFunctionDef(
565     const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
566     bool reuse_existing_functions, FunctionLibraryDefinition* library) {
567   // name_in is copied here because name may be modified below if
568   // rewrite_subgraph_fn is true.
569   string name = name_in;
570   call_node_def_.set_op(name);
571   call_node_def_.set_name(name);
572   call_node_def_.set_device(device_);
573 
574   if (rewrite_subgraph_fn) {
575     std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
576     for (const auto& arg : args_by_src_) {
577       arg_source_tensors.at(arg.second) = arg.first;
578     }
579     // Initialize the input and output permutations to the identity.
580     std::vector<int> input_permutation(args_by_src_.size());
581     std::iota(input_permutation.begin(), input_permutation.end(), 0);
582     std::vector<int> output_permutation(results_.size());
583     std::iota(output_permutation.begin(), output_permutation.end(), 0);
584 
585     TF_RETURN_IF_ERROR(
586         rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
587                             &output_permutation, &call_node_def_));
588 
589     // Apply the input/output permutations to the 'args_by_...' and 'results_'
590     // mappings, so when we build edges in BuildOutputGraph() we
591     // connect them to the right input/output positions.
592     if (input_permutation.size() != args_by_src_.size()) {
593       return errors::InvalidArgument("Input permutation has incorrect size.");
594     }
595     if (output_permutation.size() != results_.size()) {
596       return errors::InvalidArgument("Output permutation has incorrect size.");
597     }
598     for (auto& arg : args_by_src_) {
599       arg.second = input_permutation[arg.second];
600     }
601     for (auto& arg : args_by_dst_) {
602       arg.second = input_permutation[arg.second];
603     }
604     for (auto& result : results_) {
605       result.second = output_permutation[result.second];
606     }
607 
608     name = call_node_def_.op();
609   }
610 
611   function_def_name_ = name;
612 
613   FunctionDef fdef;
614   auto lookup = [this](const Node* node) -> absl::optional<string> {
615     if (control_output_nodes_.contains(node->name())) {
616       return absl::make_optional(node->name());
617     }
618     return absl::nullopt;
619   };
620   // Verify that the graph has well-formed control flow structure.
621   std::vector<ControlFlowInfo> dummy;
622   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
623   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
624 
625   if (VLOG_IS_ON(1)) {
626     VLOG(2) << "Build function def " << name;
627     DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
628                     library);
629     DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
630   }
631 
632   const FunctionDef* original_fdef = library->Find(name);
633   if (!reuse_existing_functions || original_fdef == nullptr) {
634     TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
635   } else if (!FunctionDefsEqual(*original_fdef, fdef)) {
636     TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
637   }
638   return Status::OK();
639 }
640 
ReplaceFunctionDef(FunctionLibraryDefinition * library)641 Status Encapsulator::Subgraph::ReplaceFunctionDef(
642     FunctionLibraryDefinition* library) {
643   const string& name = function_def_name_;
644 
645   FunctionDef fdef;
646   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
647 
648   if (VLOG_IS_ON(1)) {
649     VLOG(2) << "Replace function def " << name;
650     DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
651                     *graph_, library);
652     DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
653                           fdef);
654   }
655 
656   TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
657   return Status::OK();
658 }
659 
AddFunctionCallNode(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)660 Status Encapsulator::Subgraph::AddFunctionCallNode(
661     const std::unordered_map<const Node*, Node*>& node_images,
662     Graph* graph_out) {
663   Status s;
664   call_node_ = graph_out->AddNode(call_node_def_, &s);
665   if (!s.ok()) return s;
666 
667   // Copy the assigned device and the key_annotation over.
668   call_node_->set_assigned_device_name(device_);
669 
670   return Status::OK();
671 }
672 
GetFunctionNameAttr(Node const * node,string * attr) const673 Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
674   AttrSlice attrs = node->attrs();
675   attr->clear();
676   for (const auto& node_attr : attrs) {
677     if (node_attr.first == group_attribute_) {
678       TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
679       *attr = node_attr.second.s();
680       break;
681     }
682   }
683   return Status::OK();
684 }
685 
IsInSubgraph(const string & func_id)686 bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
687 
CopySubgraphNodes(std::unordered_map<const Node *,Node * > * node_images)688 Status Encapsulator::CopySubgraphNodes(
689     std::unordered_map<const Node*, Node*>* node_images) {
690   for (Node* node : graph_in_->op_nodes()) {
691     string func_id;
692     TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
693     if (!IsInSubgraph(func_id)) continue;
694 
695     Subgraph& subgraph = subgraphs_[func_id];
696     Node* image = subgraph.MakeNodeImage(graph_in_, node);
697     image->ClearAttr(group_attribute_);
698     (*node_images)[node] = image;
699   }
700   return Status::OK();
701 }
702 
CopySubgraphEdges(const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)703 Status Encapsulator::CopySubgraphEdges(
704     const std::unordered_map<const Node*, Node*>& node_images,
705     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
706   for (const Edge* edge : graph_in_->edges()) {
707     string src_func_id;
708     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
709     string dst_func_id;
710     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
711     Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
712     Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
713 
714     // Copy edges that are local to a subgraph.
715     if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
716         src_func_id == dst_func_id) {
717       Graph* g = subgraphs_[src_func_id].GetGraph();
718       if (edge->IsControlEdge()) {
719         g->AddControlEdge(src_image, dst_image,
720                           /* allow_duplicates= */ true);
721       } else {
722         g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
723       }
724       continue;
725     }
726 
727     // Record 'src' as an output of its subgraph, if applicable.
728     if (IsInSubgraph(src_func_id)) {
729       if (!edge->IsControlEdge()) {
730         DataType dtype = edge->src()->output_type(edge->src_output());
731         if (IsRefType(dtype)) {
732           return errors::InvalidArgument(
733               "Ref Tensors (e.g., Variables) are not supported as results: "
734               "tensor ",
735               edge->src()->name(), ":", edge->src_output());
736         }
737       }
738 
739       Subgraph& src_subgraph = subgraphs_[src_func_id];
740       if (edge->IsControlEdge()) {
741         TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images));
742       } else {
743         TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
744       }
745     }
746 
747     // Record 'dst' as an input of its subgraph, if applicable.
748     if (IsInSubgraph(dst_func_id)) {
749       // Look at the type of the destination not the source, since Ref output
750       // Tensors can be automatically cast to non-Ref Tensors at the
751       // destination.
752       if (!edge->IsControlEdge()) {
753         DataType dtype = edge->dst()->input_type(edge->dst_input());
754         if (IsRefType(dtype)) {
755           return errors::InvalidArgument(
756               "Ref Tensors (e.g., Variables) are not supported as args: "
757               "tensor ",
758               edge->src()->name(), ":", edge->src_output());
759         }
760       }
761 
762       Subgraph& dst_subgraph = subgraphs_[dst_func_id];
763       // Ignore control edges entering the subgraph. We will lift them onto
764       // the enclosing call operators in BuildOutputGraph().
765       if (!edge->IsControlEdge()) {
766         TF_RETURN_IF_ERROR(
767             dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
768       }
769     }
770   }
771   return Status::OK();
772 }
773 
SplitIntoSubgraphs(FunctionLibraryDefinition * library)774 Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
775   Status s;
776 
777   // Map from input graph nodes to subgraph nodes.
778   std::unordered_map<const Node*, Node*> node_images;
779 
780   // Each entry of src_arg_pairs is a pair whose first element is a node in the
781   // original graph that has an output edge in the subgraph, and whose second
782   // element is the arg node in the subgraph that it sends to. The vector will
783   // be filled in below in AddArgs.
784   std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
785 
786   TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
787   TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
788   MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
789 
790   for (auto& entry : subgraphs_) {
791     Subgraph& subgraph = entry.second;
792     FixupSourceAndSinkEdges(subgraph.GetGraph());
793   }
794 
795   if (VLOG_IS_ON(1)) {
796     // Dump subgraphs.
797     for (auto& entry : subgraphs_) {
798       DumpGraphToFile(
799           absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
800           *entry.second.GetGraph(), library);
801     }
802   }
803 
804   return s;
805 }
806 
BuildFunctionDefs(const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)807 Status Encapsulator::BuildFunctionDefs(
808     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
809     FunctionLibraryDefinition* library) {
810   for (auto& subgraph_entry : subgraphs_) {
811     string name = subgraph_entry.first;
812     Subgraph& subgraph = subgraph_entry.second;
813     TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
814         name, rewrite_subgraph_fn, reuse_existing_functions, library));
815   }
816   return Status::OK();
817 }
818 
CopyNodesToOutputGraph(Graph * graph_out,std::unordered_map<const Node *,Node * > * node_images)819 Status Encapsulator::CopyNodesToOutputGraph(
820     Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
821   for (Node* node : graph_in_->op_nodes()) {
822     string func_id;
823     TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
824 
825     // Don't copy nodes that are going to be encapsulated.
826     if (IsInSubgraph(func_id)) continue;
827 
828     Node* image = graph_out->CopyNode(node);
829     (*node_images)[node] = image;
830   }
831   (*node_images)[graph_in_->source_node()] = graph_out->source_node();
832   (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
833   return Status::OK();
834 }
835 
AddFunctionCallNodes(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)836 Status Encapsulator::AddFunctionCallNodes(
837     const std::unordered_map<const Node*, Node*>& node_images,
838     Graph* graph_out) {
839   for (auto& subgraph_entry : subgraphs_) {
840     TF_RETURN_IF_ERROR(
841         subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
842   }
843   return Status::OK();
844 }
845 
FindOutputImageOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_src_node,Node ** src_image)846 Status Encapsulator::FindOutputImageOfEdgeSrc(
847     const string& src_func_id, const string& dst_func_id,
848     const std::unordered_map<const Node*, Node*>& node_images,
849     const Node* original_src_node, Node** src_image) {
850   if (IsInSubgraph(src_func_id)) {
851     // The edge is from a subgraph to a regular node in the output graph so
852     // use the subgraph's call node output.
853     *src_image = subgraphs_.at(src_func_id).GetCallNode();
854   } else {
855     // The source of the edge is in the output graph so use the node image in
856     // the output graph.
857     *src_image = node_images.at(original_src_node);
858   }
859   return Status::OK();
860 }
861 
FindOutputSlotOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const Edge * edge)862 int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
863                                           const string& dst_func_id,
864                                           const Edge* edge) {
865   if (IsInSubgraph(src_func_id)) {
866     const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
867     // 'src' is in a subgraph and 'dst' is a regular node in the output
868     // graph. Use the corresponding call output instead.
869     return src_subgraph.GetResultIndexForEdge(edge);
870   } else {
871     // The source of the edge is in the output graph so use the regular edge
872     // slot.
873     return edge->src_output();
874   }
875 }
876 
FindOutputImageOfEdgeDst(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_dst_node,Node ** dst_image)877 Status Encapsulator::FindOutputImageOfEdgeDst(
878     const string& src_func_id, const string& dst_func_id,
879     const std::unordered_map<const Node*, Node*>& node_images,
880     const Node* original_dst_node, Node** dst_image) {
881   if (IsInSubgraph(dst_func_id)) {
882     // The edge is to a subgraph from a regular node in the output graph so
883     // use the subgraph's call node input.
884     *dst_image = subgraphs_.at(dst_func_id).GetCallNode();
885   } else {
886     // The destination of the edge is in the output graph so use the node image
887     // in the output graph.
888     *dst_image = node_images.at(original_dst_node);
889   }
890   return Status::OK();
891 }
892 
FindOutputSlotOfEdgeDst(const string & src_func_id,const string & dst_func_id,const Edge * edge)893 int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
894                                           const string& dst_func_id,
895                                           const Edge* edge) {
896   if (IsInSubgraph(dst_func_id)) {
897     const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
898       // 'dst' is in a subgraph and 'src' is a regular node in the output
899       // graph. Use the corresponding call input instead.
900       return dst_subgraph.GetArgIndexForEdge(edge);
901   } else {
902     // The destination of the edge is in the output graph so use the regular
903     // edge slot.
904     return edge->dst_input();
905   }
906 }
907 
CopyEdgeToOutputGraph(const Edge * edge,const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out,std::unordered_set<std::pair<OutputTensor,InputTensor>,OutputInputTensorPairHasher> * edges_added)908 Status Encapsulator::CopyEdgeToOutputGraph(
909     const Edge* edge, const string& src_func_id, const string& dst_func_id,
910     const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
911     std::unordered_set<std::pair<OutputTensor, InputTensor>,
912                        OutputInputTensorPairHasher>* edges_added) {
913   Node* src_image;
914   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
915       src_func_id, dst_func_id, node_images, edge->src(), &src_image));
916   Node* dst_image;
917   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
918       src_func_id, dst_func_id, node_images, edge->dst(), &dst_image));
919 
920   // If this is a control edge then copy it and return. Lift control edges onto
921   // the enclosing call operator.
922   if (edge->IsControlEdge()) {
923     // Add the control edge, if we have not already added it, using the images
924     // determined above (potentially call operators or RecvAtHost/SendFromHost).
925     if (edges_added
926             ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
927             .second) {
928       graph_out->AddControlEdge(src_image, dst_image,
929                                 /* allow_duplicates= */ true);
930     }
931 
932     return Status::OK();
933   }
934 
935   int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge);
936 
937   int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge);
938 
939   // Add the edge, if we have not already added it.
940   if (edges_added
941           ->emplace(OutputTensor(src_image, src_output),
942                     InputTensor(dst_image, dst_input))
943           .second) {
944     graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
945   }
946   return Status::OK();
947 }
948 
AddEdgesToOutputGraph(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)949 Status Encapsulator::AddEdgesToOutputGraph(
950     const std::unordered_map<const Node*, Node*>& node_images,
951     Graph* graph_out) {
952   // Set of edges already added to the output graph, represented as (src, dst)
953   // pairs. We use the set to deduplicate edges; multiple edges in the input
954   // graph may map to one edge in the output graph.
955   std::unordered_set<std::pair<OutputTensor, InputTensor>,
956                      OutputInputTensorPairHasher>
957       edges_added;
958 
959   for (const Edge* edge : graph_in_->edges()) {
960     string src_func_id;
961     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
962     string dst_func_id;
963     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
964 
965     // Ignore edges that are strictly contained within one subgraph, unless
966     // we are constructing parallel check graphs.
967     if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
968         src_func_id == dst_func_id) {
969       continue;
970     }
971 
972     // We have an edge that crosses a cluster boundary or is entirely within the
973     // unclustered graph.
974     TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
975         edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added));
976   }
977 
978   for (auto& subgraph_entry : subgraphs_) {
979     Subgraph& subgraph = subgraph_entry.second;
980     subgraph.ConnectSequencerToCallNode(graph_out);
981   }
982 
983   return Status::OK();
984 }
985 
986 namespace {
987 
988 // Adds a dummy Const node to graph_out. The "constant" has the type of
989 // data_type and the shape indicated in 'shape'. The dummy node is not a valid
990 // Const node because it does not have any value defined, but this doesn't
991 // matter because it will only be used subsequently for shape inference. (It
992 // would be possible to add a switch statement over data_type to create a value
993 // for the constant, but that would entail maintaining the logic as new types
994 // are added, and is not necessary.) If the node being replaced was within a
995 // control flow frame, adds appropriate Enter nodes so that the use of the Const
996 // is well-formed.
AddDummyShapedNode(const Node * src_node,int src_port,const std::vector<ControlFlowInfo> & control_flow_info,const TensorShapeProto & shape,Graph * graph_out)997 Node* AddDummyShapedNode(const Node* src_node, int src_port,
998                          const std::vector<ControlFlowInfo>& control_flow_info,
999                          const TensorShapeProto& shape, Graph* graph_out) {
1000   DataType data_type = src_node->output_type(src_port);
1001   TensorProto dummy_proto;
1002   dummy_proto.set_dtype(data_type);
1003   *dummy_proto.mutable_tensor_shape() = shape;
1004   // Don't set any value field in the proto, since it is only going to be used
1005   // for shape inference.
1006 
1007   GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
1008   NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
1009                            options.op_registry());
1010   node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
1011   Node* node = options.FinalizeBuilder(&node_builder);
1012   // Add any Enter nodes required to bring the constant to the correct control
1013   // flow frame.
1014   while (!control_flow_info[src_node->id()].frame_name.empty()) {
1015     NodeDebugInfo debug_info(*src_node);
1016     NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
1017                               options.op_registry(), &debug_info);
1018     enter_builder.Attr("frame_name",
1019                        control_flow_info[src_node->id()].frame_name);
1020     enter_builder.Attr("is_constant", true);
1021     enter_builder.Input(node, 0);
1022     Node* enter_node = options.FinalizeBuilder(&enter_builder);
1023     // Adopt the new Enter node as the value in the current frame.
1024     node = enter_node;
1025     // Recurse to the parent frame to see if more Enter nodes need to be added.
1026     src_node = control_flow_info[src_node->id()].parent_frame;
1027   }
1028   return node;
1029 }
1030 
1031 }  // namespace
1032 
MakePrunedGraphCopyAndInline(const Graph & graph,const std::vector<Node * > & sink_nodes,std::unique_ptr<Graph> * pruned_graph,std::unordered_map<const Node *,Node * > * node_images,FunctionLibraryDefinition * library)1033 Status Encapsulator::MakePrunedGraphCopyAndInline(
1034     const Graph& graph, const std::vector<Node*>& sink_nodes,
1035     std::unique_ptr<Graph>* pruned_graph,
1036     std::unordered_map<const Node*, Node*>* node_images,
1037     FunctionLibraryDefinition* library) {
1038   // First copy all ancestor nodes of sink_nodes into a new graph.
1039   pruned_graph->reset(new Graph(library));
1040   (*pruned_graph)->set_versions(graph.versions());
1041   ReverseDFSFrom(graph, sink_nodes,
1042                  /*enter=*/nullptr,
1043                  /*leave=*/[&](Node* n) {
1044                    if (!n->IsSource()) {
1045                      Node* copied = (*pruned_graph)->CopyNode(n);
1046                      node_images->emplace(n, copied);
1047                    }
1048                  });
1049 
1050   // Add all the edges between copied nodes.
1051   for (auto entry : *node_images) {
1052     const Node* orig = entry.first;
1053     Node* image = entry.second;
1054     for (const Edge* out_edge : orig->out_edges()) {
1055       auto iter = node_images->find(out_edge->dst());
1056       if (iter != node_images->end()) {
1057         // The source and destination are both in the copied graph.
1058         (*pruned_graph)
1059             ->AddEdge(image, out_edge->src_output(), iter->second,
1060                       out_edge->dst_input());
1061       }
1062     }
1063   }
1064 
1065   // Find all the function call nodes, and inline them.
1066   std::vector<Node*> function_nodes;
1067   for (auto node : (*pruned_graph)->nodes()) {
1068     const OpRegistrationData* op_reg_data;
1069     TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1070     if (op_reg_data->is_function_op) {
1071       function_nodes.push_back(node);
1072     }
1073   }
1074   for (auto node : function_nodes) {
1075     VLOG(2) << "Inlining function " << node->name();
1076     const FunctionDef* fdef = library->Find(node->type_string());
1077     if (fdef == nullptr) {
1078       return errors::Internal("Failed to find function ", node->type_string(),
1079                               " in function library.");
1080     }
1081     std::unique_ptr<FunctionBody> fbody;
1082     TF_RETURN_IF_ERROR(
1083         FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
1084 
1085     InlineFunctionBodyOptions inline_opts;
1086     TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
1087                                           fbody.get(), inline_opts));
1088   }
1089 
1090   return Status::OK();
1091 }
1092 
BuildOutputGraph(Graph * graph_out,FunctionLibraryDefinition * library)1093 Status Encapsulator::BuildOutputGraph(Graph* graph_out,
1094                                       FunctionLibraryDefinition* library) {
1095   // Map from nodes in the input graph to nodes in the output graph.
1096   std::unordered_map<const Node*, Node*> node_images;
1097 
1098   TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
1099   TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
1100   TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
1101 
1102   return Status::OK();
1103 }
1104 
1105 }  // anonymous namespace
1106 
EncapsulateSubgraphsInFunctions(string group_attribute,const Graph & graph_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,std::unique_ptr<Graph> * graph_out,FunctionLibraryDefinition * library)1107 Status EncapsulateSubgraphsInFunctions(
1108     string group_attribute, const Graph& graph_in,
1109     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1110     std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1111   Encapsulator encapsulator(std::move(group_attribute),
1112                             &graph_in);
1113   TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
1114 
1115   TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1116       rewrite_subgraph_fn, reuse_existing_functions, library));
1117 
1118   std::unique_ptr<Graph> out(new Graph(library));
1119   out->set_versions(graph_in.versions());
1120   TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
1121 
1122   *graph_out = std::move(out);
1123   return Status::OK();
1124 }
1125 
1126 // Finds the types of the _Arg nodes, indexed by position.
GetArgTypes(const Graph & graph,DataTypeVector * types)1127 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
1128   for (Node* n : graph.op_nodes()) {
1129     if (n->type_string() == kArgOp) {
1130       int index;
1131       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1132       const int num_types = types->size();
1133       if (index < 0 || index >= num_types) {
1134         return errors::InvalidArgument("Invalid argument number");
1135       }
1136       (*types)[index] = n->output_type(0);
1137     }
1138   }
1139   return Status::OK();
1140 }
1141 
1142 // Renumber the indices of _Arg nodes in a graph, according to
1143 // 'permutation' that maps old indices to new indices.
RenumberArguments(Graph * graph,const std::vector<int> & permutation)1144 static Status RenumberArguments(Graph* graph,
1145                                 const std::vector<int>& permutation) {
1146   for (Node* n : graph->op_nodes()) {
1147     if (n->type_string() == kArgOp) {
1148       int index;
1149       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1150       const int permutation_size = permutation.size();
1151       if (index < 0 || index >= permutation_size) {
1152         return errors::InvalidArgument("Invalid argument number");
1153       }
1154       n->AddAttr("index", permutation[index]);
1155     }
1156   }
1157   return Status::OK();
1158 }
1159 
Run(const GraphOptimizationPassOptions & options)1160 Status EncapsulateSubgraphsPass::Run(
1161     const GraphOptimizationPassOptions& options) {
1162   VLOG(1) << "EncapsulateSubgraphsPass::Run";
1163   if (VLOG_IS_ON(1)) {
1164     DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
1165                     options.flib_def);
1166   }
1167 
1168   std::unique_ptr<Graph> graph_out;
1169   FunctionLibraryDefinition* const library = options.flib_def;
1170 
1171   // Constant folding below might need to run part of the function to compute
1172   // constants. Create an FunctionLibraryRuntime with a single CPU device
1173   // that can run the part of the function.
1174   // NOTE: If this turns out to be slow, we can cache the FLRs keyed by
1175   // `options`.
1176   SessionOptions session_options;
1177   auto* device_count = session_options.config.mutable_device_count();
1178   device_count->insert({"CPU", 1});
1179   std::vector<std::unique_ptr<Device>> devices;
1180 
1181   DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
1182   if (!cpu_factory) {
1183     return errors::NotFound(
1184         "CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
1185   }
1186   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
1187       session_options, "/job:localhost/replica:0/task:0", &devices));
1188   if (devices.empty()) {
1189     return errors::NotFound(
1190         "Failed to create a CPU device for EncapsulateSubgraphsPass");
1191   }
1192 
1193   std::unique_ptr<DeviceMgr> device_mgr =
1194       absl::make_unique<StaticDeviceMgr>(std::move(devices));
1195   const auto* config = &options.session_options->config;
1196   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1197       new ProcessFunctionLibraryRuntime(
1198           device_mgr.get(), options.session_options->env,
1199           /*config=*/config, TF_GRAPH_DEF_VERSION, library,
1200           config->graph_options().optimizer_options()));
1201   FunctionLibraryRuntime* flr =
1202       pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
1203   if (flr == nullptr) {
1204     return errors::Internal(
1205         "Failed to create and retrieve function library runtime to run "
1206         "constant folding");
1207   }
1208 
1209   auto rewrite_subgraph =
1210       [flr](const std::vector<OutputTensor>& arg_source_tensors,
1211             std::unique_ptr<Graph>* subgraph,
1212             std::vector<int>* input_permutation,
1213             std::vector<int>* output_permutation, NodeDef* node) {
1214         // Optimize the subgraph.
1215         // Do not constant fold nodes that output DT_VARIANT type tensors.
1216         // XLA does not support Const nodes of Variant type since it needs
1217         // to know the original ops to be able to compile them to the relevant
1218         // XLA form.
1219         // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
1220         // the form:
1221         //                          Const
1222         //                            |
1223         // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
1224         //                                                  |
1225         //                                        (Discard popped list)
1226         //
1227         // Would have been reduced to "Const -> Op" without this filter.
1228         // However since we are only allowed to specify the filter at the "Node"
1229         // level there is no good way to allow the above behavior. So we
1230         // disallow any sort of constant folding on Variant nodes for now.
1231         bool disable_constant_folding =
1232             GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding;
1233         auto cf_consider_fn = [disable_constant_folding](const Node* n) {
1234           if (disable_constant_folding) return false;
1235           for (const auto& output_arg : n->op_def().output_arg()) {
1236             if (output_arg.type() == DT_VARIANT) {
1237               return false;
1238             }
1239           }
1240           return true;
1241         };
1242         GraphOptimizer::Options graph_optimizer_options;
1243         graph_optimizer_options.cf_consider_fn = cf_consider_fn;
1244         OptimizeGraph(flr, subgraph, graph_optimizer_options);
1245 
1246         const int num_args = input_permutation->size();
1247         std::vector<bool> const_args(num_args);
1248         TF_RETURN_IF_ERROR(
1249             BackwardsConstAnalysis(**subgraph, &const_args,
1250                                    /*compile_time_const_nodes=*/nullptr, flr));
1251 
1252         DataTypeVector arg_types(num_args);
1253         TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
1254 
1255         // Compute a permutation of the arguments such that the constant
1256         // arguments are first.
1257         const int num_consts =
1258             std::count(const_args.begin(), const_args.end(), true);
1259 
1260         const int num_resources =
1261             std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
1262         const int num_nonconsts = num_args - num_resources - num_consts;
1263         if (num_nonconsts < 0) {
1264           return errors::Internal("num_nonconsts should be >= 0, was ",
1265                                   num_nonconsts);
1266         }
1267 
1268         int const_pos = 0;
1269         int arg_pos = num_consts;
1270         int resource_pos = num_consts + num_nonconsts;
1271         for (int i = 0; i < num_args; ++i) {
1272           if (const_args[i]) {
1273             if (arg_types[i] == DT_RESOURCE) {
1274               return errors::Internal(
1275                   "Resource arguments cannot be constant (argument ", i, ")");
1276             }
1277             (*input_permutation)[i] = const_pos;
1278             ++const_pos;
1279           } else if (arg_types[i] == DT_RESOURCE) {
1280             (*input_permutation)[i] = resource_pos;
1281             ++resource_pos;
1282           } else {
1283             (*input_permutation)[i] = arg_pos;
1284             ++arg_pos;
1285           }
1286         }
1287 
1288         // Renumber argument nodes in the graph.
1289         TF_RETURN_IF_ERROR(
1290             RenumberArguments(subgraph->get(), *input_permutation));
1291 
1292         // TODO(phawkins): add a forward is-constant analysis, similarly split
1293         // outputs into host-memory constants and device-memory non-constants.
1294 
1295         AddNodeAttr(kXlaCompiledKernelAttr, true, node);
1296         AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
1297         AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
1298         return Status::OK();
1299       };
1300 
1301   // Don't EncapsulateSubgraphs if graph doesn't contain nodes with
1302   // kXlaClusterAttr.
1303   bool has_xla_cluster_attribute = false;
1304   for (Node* node : (*options.graph)->nodes()) {
1305     if (HasNodeAttr(node->def(), kXlaClusterAttr)) {
1306       has_xla_cluster_attribute = true;
1307       break;
1308     }
1309   }
1310 
1311   if (has_xla_cluster_attribute) {
1312     TF_RETURN_WITH_CONTEXT_IF_ERROR(
1313         EncapsulateSubgraphsInFunctions(
1314             kXlaClusterAttr, **options.graph, rewrite_subgraph,
1315             /*reuse_existing_functions=*/false, &graph_out, library),
1316         "EncapsulateSubgraphsPass failed");
1317     if (VLOG_IS_ON(1)) {
1318       DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
1319                       options.flib_def);
1320     }
1321 
1322     *options.graph = std::move(graph_out);
1323   }
1324 
1325   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
1326                       GetNodesRelatedToRefVariables(**options.graph, flr));
1327   for (Node* node : (*options.graph)->nodes()) {
1328     bool has_ref_vars = ref_related_nodes.contains(node);
1329     node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
1330     VLOG(3) << "Has ref vars = " << has_ref_vars
1331             << ", node: " << node->def().SerializeAsString();
1332   }
1333   return Status::OK();
1334 }
1335 
IsXlaCompiledKernel(const Node & node)1336 bool IsXlaCompiledKernel(const Node& node) {
1337   bool is_compiled = false;
1338   bool has_compilation_attr =
1339       TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
1340       is_compiled;
1341   return has_compilation_attr ? is_compiled : false;
1342 }
1343 
1344 }  // namespace tensorflow
1345