1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
31 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/tensor.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_def_builder.h"
50 #include "tensorflow/core/graph/tensor_id.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #include "tensorflow/core/util/dump_graph.h"
57
58 namespace tensorflow {
59
60 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
61 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
62 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
63 const char* const kXlaHostTransferSequencerAttr =
64 "_xla_host_transfer_sequencer";
65 const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
66
SortControlInputs(GraphDef * gdef)67 void SortControlInputs(GraphDef* gdef) {
68 int64_t num_nodes = gdef->node_size();
69 for (int64_t i = 0; i < num_nodes; ++i) {
70 NodeDef* node = gdef->mutable_node(i);
71 // Stable sort control inputs and leave the order of data inputs unchanged.
72 std::stable_sort(node->mutable_input()->begin(),
73 node->mutable_input()->end(),
74 [](const string& a, const string& b) {
75 bool a_is_control = absl::StartsWith(a, "^");
76 bool b_is_control = absl::StartsWith(b, "^");
77 return (!a_is_control && b_is_control) ||
78 (a_is_control && b_is_control && a < b);
79 });
80 }
81 }
82
83 namespace {
84
AreAllParentsGuaranteedConst(const Node & n,const absl::flat_hash_set<const Node * > & runtime_const_nodes)85 bool AreAllParentsGuaranteedConst(
86 const Node& n,
87 const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
88 if (n.type_string() == "GuaranteeConst") {
89 // If the current node is itself a cast-to-const, no need
90 // to look at the incoming edges.
91 return true;
92 }
93
94 bool all_parents_const = true;
95 bool atleast_one_non_control_edge = false;
96 for (const Edge* in : n.in_edges()) {
97 atleast_one_non_control_edge =
98 atleast_one_non_control_edge || !in->IsControlEdge();
99 if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
100 all_parents_const = false;
101 break;
102 }
103 }
104 return all_parents_const && atleast_one_non_control_edge;
105 }
106
MarkGuaranteedConstants(const Graph & graph,const std::vector<std::pair<const Node *,Node * >> & src_arg_pairs)107 void MarkGuaranteedConstants(
108 const Graph& graph,
109 const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
110 absl::flat_hash_set<const Node*> guaranteed_const_nodes;
111 std::vector<const Node*> srcs;
112 srcs.reserve(src_arg_pairs.size());
113 for (const auto& src_arg : src_arg_pairs) {
114 srcs.push_back(src_arg.first);
115 }
116 ReverseDFSFrom(
117 graph, srcs, /*enter=*/nullptr,
118 /*leave=*/[&guaranteed_const_nodes](const Node* n) {
119 // TODO(vinuraja): Doesn't work in the presence of loops.
120 if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
121 guaranteed_const_nodes.insert(n);
122 }
123 });
124
125 for (auto& src_arg : src_arg_pairs) {
126 if (guaranteed_const_nodes.count(src_arg.first) != 0) {
127 VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
128 src_arg.second->AddAttr("_is_guaranteed_constant", true);
129 }
130 }
131 }
132
133 struct OutputInputTensorPairHasher {
operator ()tensorflow::__anon5401122e0211::OutputInputTensorPairHasher134 uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
135 return Hash64Combine(OutputTensor::Hash()(s.first),
136 InputTensor::Hash()(s.second));
137 }
138 };
139
140 // TODO(phawkins) add a canonical copy of these operator names and refactor
141 // everything to use it.
142 static const char* const kArgOp = "_Arg";
143 static const char* const kRetValOp = "_Retval";
144
145 class Encapsulator {
146 public:
Encapsulator(string group_attribute,Graph const * graph_in)147 Encapsulator(string group_attribute, Graph const* graph_in)
148 : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
149
150 // Find subgraphs marked with 'group_attribute', and build a new
151 // subgraph, one for each value of 'group_attribute'.
152 Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
153
154 // Build a FunctionDef for each subgraph, and add it 'library'. The values of
155 // the 'group_attribute' annotations become the function names.
156 // If 'reuse_existing_functions' is set, use an existing function with the
157 // same name, if any.
158 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
159 // function conversion.
160 Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
161 bool reuse_existing_functions,
162 FunctionLibraryDefinition* library);
163
164 // Write a copy of the input graph to 'graph_out', where the subgraphs are
165 // replaced with calls to the new functions.
166 Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
167
168 private:
169 // A subgraph of the input, all marked with a common 'group_attribute'
170 // value.
171 //
172 // In the following simple example, A, B, ..., E are nodes in the original
173 // graph. The group attributes g are each shown as either 0 or empty.
174 //
175 // A --> B --> C --> D --> E
176 // g: g:0 g:0 g:0 g:
177 //
178 // The example is rewritten to two graphs; one on the host and one to be
179 // compiled. The host graph is as follows.
180 //
181 // A --> Call --> E
182 //
183 // The compiled cluster is as follows.
184 //
185 // Arg --> B --> C --> D --> Retval
186 class Subgraph {
187 public:
188 // Creates a graph to build the subgraph in, if it doesn't already exist,
189 // using the same op registry and versions as graph_in.
190 Node* MakeNodeImage(const Graph* graph_in, Node* node);
191
192 // Returns the graph the subgraph is being built in.
193 Graph* GetGraph() const;
194
195 // Builds a FunctionDef, and adds it to 'library'. The value of the
196 // 'group_attribute' annotations becomes the function name. If
197 // 'reuse_existing_functions' is set, use an existing function with the same
198 // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the
199 // subgraph before function conversion.
200 Status BuildFunctionDef(const string& name_in,
201 const RewriteSubgraphFn& rewrite_subgraph_fn,
202 bool reuse_existing_functions,
203 FunctionLibraryDefinition* library);
204
205 // Adds the function call node to graph_out.
206 Status AddFunctionCallNode(
207 const std::unordered_map<const Node*, Node*>& node_images,
208 Graph* graph_out);
209
210 // Returns the Node that the inputs and outputs of the function should be
211 // wired up to.
212 Node* GetCallNode() const;
213
214 // Returns the index of the arg that the dst of edge should connect to.
215 int GetArgIndexForEdge(const Edge* edge) const;
216
217 // Returns the index of the result that the src of edge should connect to.
218 int GetResultIndexForEdge(const Edge* edge) const;
219
220 // Creates an _Arg node for the src node of edge, and add its index to
221 // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
222 // and adds the edge within the subgraph from the _Arg node to the image of
223 // the dst node.
224 Status RecordArg(const Edge* edge,
225 const std::unordered_map<const Node*, Node*>& node_images,
226 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
227
228 // Records the src of the given edge as a control result of the graph.
229 // Used during graph to function conversion to tie control results to
230 // the function signature.
231 Status RecordControlResult(
232 const Edge* edge,
233 const std::unordered_map<const Node*, Node*>& node_images);
234
235 // Creates a _Retval node for the src node of edge, and add it to results_,
236 // if none exists yet. If a new _Retval node is created, also adds the edge
237 // within the subgraph from the src to the _Retval node.
238 Status RecordResult(
239 const Edge* edge,
240 const std::unordered_map<const Node*, Node*>& node_images);
241
242 // Creates the sequencer node if it doesn't exist, adding it to graph_out.
243 Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
244
245 // If there is a sequencer node, adds a control edge from the sequencer to
246 // the call node.
247 void ConnectSequencerToCallNode(Graph* graph_out);
248
249 Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
250
251 private:
252 // The subgraph extracted from the input graph, suitable for being turned
253 // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
254 // returned by _Retval nodes.
255 std::unique_ptr<Graph> graph_;
256
257 // Which device are these nodes on? Used to assign a device to the call
258 // node.
259 string device_;
260
261 // NodeDef for the function call node.
262 NodeDef call_node_def_;
263
264 // Name that is used for the call node. This may not be
265 // call_node_def_.name() if the client supplies a rewrite lambda.
266 string function_def_name_;
267
268 // Placeholder node simulating the host compute key in the output graph.
269 // Not owned.
270 Node* host_compute_key_placeholder_ = nullptr;
271
272 // Function call node in the output graph. Not owned.
273 Node* call_node_;
274
275 // Maps from source (producer node/slot) and destination
276 // (consumer node/slot) tensors in the input graph to _Arg numbers in
277 // the subgraph. The source map is one-to-one, whereas the dest map may be
278 // many-to-one.
279 std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
280 std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
281
282 // The arguments to the subgraph, in order.
283 std::vector<Node*> args_;
284
285 // Map from source tensor in the input graph to result #.
286 std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
287
288 // Set of node names that are the source of a control output of the
289 // subgraph. We store strings here so that we can tolerate nodes being
290 // removed from the graph.
291 absl::flat_hash_set<string> control_output_nodes_;
292
293 // NoOp node in the output graph that is sequenced after the call node.
294 Node* sequencer_ = nullptr;
295 };
296
297 // Returns the key attribute associated with a node in attr. Sets either
298 // result to the empty string if the respective attribute is not found.
299 Status GetFunctionNameAttr(Node const* node, string* attr) const;
300
301 // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
302 // subgraphs for data edges that cross subgraph boundaries.
303 Status CopySubgraphEdges(
304 const std::unordered_map<const Node*, Node*>& node_images,
305 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
306
307 // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
308 Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
309
310 // Copies all nodes that aren't in a compiled subgraph to the output graph.
311 Status CopyNodesToOutputGraph(
312 Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
313
314 // Adds function call nodes for each compiled subgraph.
315 Status AddFunctionCallNodes(
316 const std::unordered_map<const Node*, Node*>& node_images,
317 Graph* graph_out);
318
319 // Finds the image of an edge source in the output graph. If the edge crosses
320 // a subgraph boundary it is the output of a call node, otherwise it is a node
321 // in the output graph.
322 Status FindOutputImageOfEdgeSrc(
323 const string& src_func_id, const string& dst_func_id,
324 const std::unordered_map<const Node*, Node*>& node_images,
325 const Node* original_src_node, Node** src_image);
326
327 // Finds an edge source slot in the output graph. If the edge crosses a
328 // subgraph boundary it is a slot on the output of a call node, otherwise it
329 // is a slot on a node in the output graph.
330 int FindOutputSlotOfEdgeSrc(const string& src_func_id,
331 const string& dst_func_id,
332 const Edge* edge);
333
334 // Finds the image of an edge destination in the output graph. If the edge
335 // crosses a subgraph boundary it is the input of a call node, otherwise it is
336 // a node in the output graph.
337 Status FindOutputImageOfEdgeDst(
338 const string& src_func_id, const string& dst_func_id,
339 const std::unordered_map<const Node*, Node*>& node_images,
340 const Node* original_dst_node, Node** dst_image);
341
342 // Finds an edge destination slot in the output graph. If the edge crosses a
343 // subgraph boundary it is a slot on the input of a call node, otherwise it is
344 // a slot on a node in the output graph.
345 int FindOutputSlotOfEdgeDst(const string& src_func_id,
346 const string& dst_func_id,
347 const Edge* edge);
348
349 // Copies a single edge to the output graph. The edge is either entirely
350 // within the output graph, or crosses into or out of a compiled subgraph.
351 Status CopyEdgeToOutputGraph(
352 const Edge* edge, const string& src_func_id, const string& dst_func_id,
353 const std::unordered_map<const Node*, Node*>& node_images,
354 Graph* graph_out,
355 std::unordered_set<std::pair<OutputTensor, InputTensor>,
356 OutputInputTensorPairHasher>* edges_added);
357
358 // Adds all edges to the output graph.
359 Status AddEdgesToOutputGraph(
360 const std::unordered_map<const Node*, Node*>& node_images,
361 Graph* graph_out);
362
363 // Makes a copy of graph containing only nodes that are ancestors of at least
364 // one node in send_from_host_nodes and store it in pruned_graph. On exit
365 // nodes_images contains a mapping from nodes in graph to nodes in
366 // pruned_graph. All functions in the copied graph are inlined.
367 Status MakePrunedGraphCopyAndInline(
368 const Graph& graph, const std::vector<Node*>& sink_nodes,
369 std::unique_ptr<Graph>* pruned_graph,
370 std::unordered_map<const Node*, Node*>* node_images,
371 FunctionLibraryDefinition* library);
372
373 const string group_attribute_;
374 const Graph* graph_in_;
375
376 std::unordered_map<string, Subgraph> subgraphs_;
377
378 TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
379 };
380
381 namespace {
382
383 // Return in 'sorted' a topological sort of clusters according to the
384 // dependencies encoded in ancestors. clusters is the list of all clusters
385 // including clusters that are not present in the ancestors map. has_successors
386 // is the set of clusters that are ancestors of some other cluster.
TopologicalClusterSort(const std::unordered_set<string> & clusters,const std::unordered_set<string> & has_successors,const std::unordered_map<string,std::unordered_set<string>> & ancestors,std::vector<string> * sorted)387 void TopologicalClusterSort(
388 const std::unordered_set<string>& clusters,
389 const std::unordered_set<string>& has_successors,
390 const std::unordered_map<string, std::unordered_set<string>>& ancestors,
391 std::vector<string>* sorted) {
392 // The nodes are placed in 'sorted' in topological order.
393 sorted->clear();
394 // We don't use the standard DFS because we are not operating on Node*
395 // objects.
396 struct Work {
397 string cluster;
398 bool leave;
399 };
400 std::set<string> visited;
401 std::vector<Work> stack;
402 // Seed the processing list with clusters that have no successors.
403 for (const auto& cluster : clusters) {
404 if (has_successors.find(cluster) == has_successors.end()) {
405 stack.push_back({cluster, false});
406 }
407 }
408 while (!stack.empty()) {
409 const Work item = stack.back();
410 stack.pop_back();
411 if (item.leave) {
412 sorted->push_back(item.cluster);
413 continue;
414 }
415
416 if (visited.find(item.cluster) != visited.end()) continue;
417 visited.insert(item.cluster);
418
419 stack.push_back({item.cluster, true});
420 const auto& iter = ancestors.find(item.cluster);
421 if (iter != ancestors.end()) {
422 for (const auto& ancestor : iter->second) {
423 stack.push_back({ancestor, false});
424 }
425 }
426 }
427 CHECK(sorted->size() == clusters.size());
428 }
429
430 } // namespace
431
GetCallNode() const432 Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
433
GetArgIndexForEdge(const Edge * edge) const434 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
435 return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
436 }
437
GetResultIndexForEdge(const Edge * edge) const438 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
439 return results_.at(OutputTensor(edge->src(), edge->src_output()));
440 }
441
MakeNodeImage(const Graph * graph_in,Node * node)442 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
443 if (!graph_) {
444 graph_.reset(new Graph(graph_in->op_registry()));
445 graph_->set_versions(graph_in->versions());
446 }
447
448 // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
449 // determined. In case of hard placement, ensure all the encapsulated nodes
450 // have the same requested device, which in turn will be the requested device
451 // for the entire encapsulated subgraph. In case of soft placement, use a
452 // deterministic approach to fill in the requested device. Handle co-location
453 // constraints similarly if they exist.
454 if (device_.empty()) {
455 device_ = node->assigned_device_name().empty()
456 ? node->requested_device()
457 : node->assigned_device_name();
458 }
459
460 return graph_->CopyNode(node);
461 }
462
GetGraph() const463 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
464
RecordArg(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)465 Status Encapsulator::Subgraph::RecordArg(
466 const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
467 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
468 Node* src_node = edge->src();
469 int src_slot = edge->src_output();
470 std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
471 bool inserted;
472 std::tie(iter, inserted) = args_by_src_.emplace(
473 OutputTensor(src_node, src_slot), args_by_src_.size());
474 int arg_index = iter->second;
475 if (inserted) {
476 NodeDef arg_def;
477 NodeDefBuilder builder(
478 absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
479 NodeDebugInfo(src_node->def()));
480 DataType dtype = edge->dst()->input_type(edge->dst_input());
481 builder.Attr("T", dtype);
482 builder.Attr("index", arg_index);
483 Status s = builder.Finalize(&arg_def);
484 if (!s.ok()) return s;
485
486 Node* arg = graph_->AddNode(arg_def, &s);
487 if (!s.ok()) return s;
488
489 src_arg_pairs->push_back({src_node, arg});
490 args_.push_back(arg);
491 }
492 Node* dst_node = edge->dst();
493 Node* dst_image = node_images.at(dst_node);
494 int dst_slot = edge->dst_input();
495 args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
496 graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
497 return Status::OK();
498 }
499
RecordControlResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)500 Status Encapsulator::Subgraph::RecordControlResult(
501 const Edge* edge,
502 const std::unordered_map<const Node*, Node*>& node_images) {
503 Node* src_node = edge->src();
504 Node* src_image = node_images.at(src_node);
505 control_output_nodes_.insert(src_image->name());
506 return Status::OK();
507 }
508
RecordResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)509 Status Encapsulator::Subgraph::RecordResult(
510 const Edge* edge,
511 const std::unordered_map<const Node*, Node*>& node_images) {
512 Node* src_node = edge->src();
513 Node* src_image = node_images.at(src_node);
514 int src_slot = edge->src_output();
515 std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
516 bool inserted;
517 std::tie(iter, inserted) =
518 results_.emplace(OutputTensor(src_node, src_slot), results_.size());
519 int ret_index = iter->second;
520 if (inserted) {
521 NodeDef ret_def;
522 NodeDefBuilder builder(
523 absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
524 NodeDebugInfo(src_node->def()));
525 DataType dtype = src_node->output_type(src_slot);
526 builder.Attr("T", dtype);
527 builder.Attr("index", ret_index);
528 builder.Input(src_image->name(), src_slot, dtype);
529 Status s = builder.Finalize(&ret_def);
530 if (!s.ok()) return s;
531 Node* ret = graph_->AddNode(ret_def, &s);
532 if (!s.ok()) return s;
533
534 graph_->AddEdge(src_image, src_slot, ret, 0);
535 }
536 return Status::OK();
537 }
538
MakeSequencingNode(const string & subgraph_name,Graph * graph_out)539 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
540 Graph* graph_out) {
541 if (sequencer_ == nullptr) {
542 NodeDef seq_def;
543 // TODO(shikharagarwal): What source node should we use for errors?
544 NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
545 builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
546 builder.Device(device_);
547 Status s = builder.Finalize(&seq_def);
548 if (!s.ok()) return s;
549
550 sequencer_ = graph_out->AddNode(seq_def, &s);
551 if (!s.ok()) return s;
552 }
553 return Status::OK();
554 }
555
ConnectSequencerToCallNode(Graph * graph_out)556 void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
557 if (sequencer_ != nullptr) {
558 VLOG(2) << "ConnectSequencerToCallNode";
559 graph_out->AddControlEdge(sequencer_, call_node_,
560 /* allow_duplicates= */ true);
561 }
562 }
563
BuildFunctionDef(const string & name_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)564 Status Encapsulator::Subgraph::BuildFunctionDef(
565 const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
566 bool reuse_existing_functions, FunctionLibraryDefinition* library) {
567 // name_in is copied here because name may be modified below if
568 // rewrite_subgraph_fn is true.
569 string name = name_in;
570 call_node_def_.set_op(name);
571 call_node_def_.set_name(name);
572 call_node_def_.set_device(device_);
573
574 if (rewrite_subgraph_fn) {
575 std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
576 for (const auto& arg : args_by_src_) {
577 arg_source_tensors.at(arg.second) = arg.first;
578 }
579 // Initialize the input and output permutations to the identity.
580 std::vector<int> input_permutation(args_by_src_.size());
581 std::iota(input_permutation.begin(), input_permutation.end(), 0);
582 std::vector<int> output_permutation(results_.size());
583 std::iota(output_permutation.begin(), output_permutation.end(), 0);
584
585 TF_RETURN_IF_ERROR(
586 rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
587 &output_permutation, &call_node_def_));
588
589 // Apply the input/output permutations to the 'args_by_...' and 'results_'
590 // mappings, so when we build edges in BuildOutputGraph() we
591 // connect them to the right input/output positions.
592 if (input_permutation.size() != args_by_src_.size()) {
593 return errors::InvalidArgument("Input permutation has incorrect size.");
594 }
595 if (output_permutation.size() != results_.size()) {
596 return errors::InvalidArgument("Output permutation has incorrect size.");
597 }
598 for (auto& arg : args_by_src_) {
599 arg.second = input_permutation[arg.second];
600 }
601 for (auto& arg : args_by_dst_) {
602 arg.second = input_permutation[arg.second];
603 }
604 for (auto& result : results_) {
605 result.second = output_permutation[result.second];
606 }
607
608 name = call_node_def_.op();
609 }
610
611 function_def_name_ = name;
612
613 FunctionDef fdef;
614 auto lookup = [this](const Node* node) -> absl::optional<string> {
615 if (control_output_nodes_.contains(node->name())) {
616 return absl::make_optional(node->name());
617 }
618 return absl::nullopt;
619 };
620 // Verify that the graph has well-formed control flow structure.
621 std::vector<ControlFlowInfo> dummy;
622 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
623 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
624
625 if (VLOG_IS_ON(1)) {
626 VLOG(2) << "Build function def " << name;
627 DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
628 library);
629 DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
630 }
631
632 const FunctionDef* original_fdef = library->Find(name);
633 if (!reuse_existing_functions || original_fdef == nullptr) {
634 TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
635 } else if (!FunctionDefsEqual(*original_fdef, fdef)) {
636 TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
637 }
638 return Status::OK();
639 }
640
ReplaceFunctionDef(FunctionLibraryDefinition * library)641 Status Encapsulator::Subgraph::ReplaceFunctionDef(
642 FunctionLibraryDefinition* library) {
643 const string& name = function_def_name_;
644
645 FunctionDef fdef;
646 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
647
648 if (VLOG_IS_ON(1)) {
649 VLOG(2) << "Replace function def " << name;
650 DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
651 *graph_, library);
652 DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
653 fdef);
654 }
655
656 TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
657 return Status::OK();
658 }
659
AddFunctionCallNode(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)660 Status Encapsulator::Subgraph::AddFunctionCallNode(
661 const std::unordered_map<const Node*, Node*>& node_images,
662 Graph* graph_out) {
663 Status s;
664 call_node_ = graph_out->AddNode(call_node_def_, &s);
665 if (!s.ok()) return s;
666
667 // Copy the assigned device and the key_annotation over.
668 call_node_->set_assigned_device_name(device_);
669
670 return Status::OK();
671 }
672
GetFunctionNameAttr(Node const * node,string * attr) const673 Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
674 AttrSlice attrs = node->attrs();
675 attr->clear();
676 for (const auto& node_attr : attrs) {
677 if (node_attr.first == group_attribute_) {
678 TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
679 *attr = node_attr.second.s();
680 break;
681 }
682 }
683 return Status::OK();
684 }
685
IsInSubgraph(const string & func_id)686 bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
687
CopySubgraphNodes(std::unordered_map<const Node *,Node * > * node_images)688 Status Encapsulator::CopySubgraphNodes(
689 std::unordered_map<const Node*, Node*>* node_images) {
690 for (Node* node : graph_in_->op_nodes()) {
691 string func_id;
692 TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
693 if (!IsInSubgraph(func_id)) continue;
694
695 Subgraph& subgraph = subgraphs_[func_id];
696 Node* image = subgraph.MakeNodeImage(graph_in_, node);
697 image->ClearAttr(group_attribute_);
698 (*node_images)[node] = image;
699 }
700 return Status::OK();
701 }
702
CopySubgraphEdges(const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)703 Status Encapsulator::CopySubgraphEdges(
704 const std::unordered_map<const Node*, Node*>& node_images,
705 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
706 for (const Edge* edge : graph_in_->edges()) {
707 string src_func_id;
708 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
709 string dst_func_id;
710 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
711 Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
712 Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
713
714 // Copy edges that are local to a subgraph.
715 if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
716 src_func_id == dst_func_id) {
717 Graph* g = subgraphs_[src_func_id].GetGraph();
718 if (edge->IsControlEdge()) {
719 g->AddControlEdge(src_image, dst_image,
720 /* allow_duplicates= */ true);
721 } else {
722 g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
723 }
724 continue;
725 }
726
727 // Record 'src' as an output of its subgraph, if applicable.
728 if (IsInSubgraph(src_func_id)) {
729 if (!edge->IsControlEdge()) {
730 DataType dtype = edge->src()->output_type(edge->src_output());
731 if (IsRefType(dtype)) {
732 return errors::InvalidArgument(
733 "Ref Tensors (e.g., Variables) are not supported as results: "
734 "tensor ",
735 edge->src()->name(), ":", edge->src_output());
736 }
737 }
738
739 Subgraph& src_subgraph = subgraphs_[src_func_id];
740 if (edge->IsControlEdge()) {
741 TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images));
742 } else {
743 TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
744 }
745 }
746
747 // Record 'dst' as an input of its subgraph, if applicable.
748 if (IsInSubgraph(dst_func_id)) {
749 // Look at the type of the destination not the source, since Ref output
750 // Tensors can be automatically cast to non-Ref Tensors at the
751 // destination.
752 if (!edge->IsControlEdge()) {
753 DataType dtype = edge->dst()->input_type(edge->dst_input());
754 if (IsRefType(dtype)) {
755 return errors::InvalidArgument(
756 "Ref Tensors (e.g., Variables) are not supported as args: "
757 "tensor ",
758 edge->src()->name(), ":", edge->src_output());
759 }
760 }
761
762 Subgraph& dst_subgraph = subgraphs_[dst_func_id];
763 // Ignore control edges entering the subgraph. We will lift them onto
764 // the enclosing call operators in BuildOutputGraph().
765 if (!edge->IsControlEdge()) {
766 TF_RETURN_IF_ERROR(
767 dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
768 }
769 }
770 }
771 return Status::OK();
772 }
773
SplitIntoSubgraphs(FunctionLibraryDefinition * library)774 Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
775 Status s;
776
777 // Map from input graph nodes to subgraph nodes.
778 std::unordered_map<const Node*, Node*> node_images;
779
780 // Each entry of src_arg_pairs is a pair whose first element is a node in the
781 // original graph that has an output edge in the subgraph, and whose second
782 // element is the arg node in the subgraph that it sends to. The vector will
783 // be filled in below in AddArgs.
784 std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
785
786 TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
787 TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
788 MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
789
790 for (auto& entry : subgraphs_) {
791 Subgraph& subgraph = entry.second;
792 FixupSourceAndSinkEdges(subgraph.GetGraph());
793 }
794
795 if (VLOG_IS_ON(1)) {
796 // Dump subgraphs.
797 for (auto& entry : subgraphs_) {
798 DumpGraphToFile(
799 absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
800 *entry.second.GetGraph(), library);
801 }
802 }
803
804 return s;
805 }
806
BuildFunctionDefs(const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)807 Status Encapsulator::BuildFunctionDefs(
808 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
809 FunctionLibraryDefinition* library) {
810 for (auto& subgraph_entry : subgraphs_) {
811 string name = subgraph_entry.first;
812 Subgraph& subgraph = subgraph_entry.second;
813 TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
814 name, rewrite_subgraph_fn, reuse_existing_functions, library));
815 }
816 return Status::OK();
817 }
818
CopyNodesToOutputGraph(Graph * graph_out,std::unordered_map<const Node *,Node * > * node_images)819 Status Encapsulator::CopyNodesToOutputGraph(
820 Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
821 for (Node* node : graph_in_->op_nodes()) {
822 string func_id;
823 TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
824
825 // Don't copy nodes that are going to be encapsulated.
826 if (IsInSubgraph(func_id)) continue;
827
828 Node* image = graph_out->CopyNode(node);
829 (*node_images)[node] = image;
830 }
831 (*node_images)[graph_in_->source_node()] = graph_out->source_node();
832 (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
833 return Status::OK();
834 }
835
AddFunctionCallNodes(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)836 Status Encapsulator::AddFunctionCallNodes(
837 const std::unordered_map<const Node*, Node*>& node_images,
838 Graph* graph_out) {
839 for (auto& subgraph_entry : subgraphs_) {
840 TF_RETURN_IF_ERROR(
841 subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
842 }
843 return Status::OK();
844 }
845
FindOutputImageOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_src_node,Node ** src_image)846 Status Encapsulator::FindOutputImageOfEdgeSrc(
847 const string& src_func_id, const string& dst_func_id,
848 const std::unordered_map<const Node*, Node*>& node_images,
849 const Node* original_src_node, Node** src_image) {
850 if (IsInSubgraph(src_func_id)) {
851 // The edge is from a subgraph to a regular node in the output graph so
852 // use the subgraph's call node output.
853 *src_image = subgraphs_.at(src_func_id).GetCallNode();
854 } else {
855 // The source of the edge is in the output graph so use the node image in
856 // the output graph.
857 *src_image = node_images.at(original_src_node);
858 }
859 return Status::OK();
860 }
861
FindOutputSlotOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const Edge * edge)862 int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
863 const string& dst_func_id,
864 const Edge* edge) {
865 if (IsInSubgraph(src_func_id)) {
866 const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
867 // 'src' is in a subgraph and 'dst' is a regular node in the output
868 // graph. Use the corresponding call output instead.
869 return src_subgraph.GetResultIndexForEdge(edge);
870 } else {
871 // The source of the edge is in the output graph so use the regular edge
872 // slot.
873 return edge->src_output();
874 }
875 }
876
FindOutputImageOfEdgeDst(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_dst_node,Node ** dst_image)877 Status Encapsulator::FindOutputImageOfEdgeDst(
878 const string& src_func_id, const string& dst_func_id,
879 const std::unordered_map<const Node*, Node*>& node_images,
880 const Node* original_dst_node, Node** dst_image) {
881 if (IsInSubgraph(dst_func_id)) {
882 // The edge is to a subgraph from a regular node in the output graph so
883 // use the subgraph's call node input.
884 *dst_image = subgraphs_.at(dst_func_id).GetCallNode();
885 } else {
886 // The destination of the edge is in the output graph so use the node image
887 // in the output graph.
888 *dst_image = node_images.at(original_dst_node);
889 }
890 return Status::OK();
891 }
892
FindOutputSlotOfEdgeDst(const string & src_func_id,const string & dst_func_id,const Edge * edge)893 int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
894 const string& dst_func_id,
895 const Edge* edge) {
896 if (IsInSubgraph(dst_func_id)) {
897 const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
898 // 'dst' is in a subgraph and 'src' is a regular node in the output
899 // graph. Use the corresponding call input instead.
900 return dst_subgraph.GetArgIndexForEdge(edge);
901 } else {
902 // The destination of the edge is in the output graph so use the regular
903 // edge slot.
904 return edge->dst_input();
905 }
906 }
907
CopyEdgeToOutputGraph(const Edge * edge,const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out,std::unordered_set<std::pair<OutputTensor,InputTensor>,OutputInputTensorPairHasher> * edges_added)908 Status Encapsulator::CopyEdgeToOutputGraph(
909 const Edge* edge, const string& src_func_id, const string& dst_func_id,
910 const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
911 std::unordered_set<std::pair<OutputTensor, InputTensor>,
912 OutputInputTensorPairHasher>* edges_added) {
913 Node* src_image;
914 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
915 src_func_id, dst_func_id, node_images, edge->src(), &src_image));
916 Node* dst_image;
917 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
918 src_func_id, dst_func_id, node_images, edge->dst(), &dst_image));
919
920 // If this is a control edge then copy it and return. Lift control edges onto
921 // the enclosing call operator.
922 if (edge->IsControlEdge()) {
923 // Add the control edge, if we have not already added it, using the images
924 // determined above (potentially call operators or RecvAtHost/SendFromHost).
925 if (edges_added
926 ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
927 .second) {
928 graph_out->AddControlEdge(src_image, dst_image,
929 /* allow_duplicates= */ true);
930 }
931
932 return Status::OK();
933 }
934
935 int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge);
936
937 int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge);
938
939 // Add the edge, if we have not already added it.
940 if (edges_added
941 ->emplace(OutputTensor(src_image, src_output),
942 InputTensor(dst_image, dst_input))
943 .second) {
944 graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
945 }
946 return Status::OK();
947 }
948
AddEdgesToOutputGraph(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)949 Status Encapsulator::AddEdgesToOutputGraph(
950 const std::unordered_map<const Node*, Node*>& node_images,
951 Graph* graph_out) {
952 // Set of edges already added to the output graph, represented as (src, dst)
953 // pairs. We use the set to deduplicate edges; multiple edges in the input
954 // graph may map to one edge in the output graph.
955 std::unordered_set<std::pair<OutputTensor, InputTensor>,
956 OutputInputTensorPairHasher>
957 edges_added;
958
959 for (const Edge* edge : graph_in_->edges()) {
960 string src_func_id;
961 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
962 string dst_func_id;
963 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
964
965 // Ignore edges that are strictly contained within one subgraph, unless
966 // we are constructing parallel check graphs.
967 if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
968 src_func_id == dst_func_id) {
969 continue;
970 }
971
972 // We have an edge that crosses a cluster boundary or is entirely within the
973 // unclustered graph.
974 TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
975 edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added));
976 }
977
978 for (auto& subgraph_entry : subgraphs_) {
979 Subgraph& subgraph = subgraph_entry.second;
980 subgraph.ConnectSequencerToCallNode(graph_out);
981 }
982
983 return Status::OK();
984 }
985
986 namespace {
987
988 // Adds a dummy Const node to graph_out. The "constant" has the type of
989 // data_type and the shape indicated in 'shape'. The dummy node is not a valid
990 // Const node because it does not have any value defined, but this doesn't
991 // matter because it will only be used subsequently for shape inference. (It
992 // would be possible to add a switch statement over data_type to create a value
993 // for the constant, but that would entail maintaining the logic as new types
994 // are added, and is not necessary.) If the node being replaced was within a
995 // control flow frame, adds appropriate Enter nodes so that the use of the Const
996 // is well-formed.
AddDummyShapedNode(const Node * src_node,int src_port,const std::vector<ControlFlowInfo> & control_flow_info,const TensorShapeProto & shape,Graph * graph_out)997 Node* AddDummyShapedNode(const Node* src_node, int src_port,
998 const std::vector<ControlFlowInfo>& control_flow_info,
999 const TensorShapeProto& shape, Graph* graph_out) {
1000 DataType data_type = src_node->output_type(src_port);
1001 TensorProto dummy_proto;
1002 dummy_proto.set_dtype(data_type);
1003 *dummy_proto.mutable_tensor_shape() = shape;
1004 // Don't set any value field in the proto, since it is only going to be used
1005 // for shape inference.
1006
1007 GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
1008 NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
1009 options.op_registry());
1010 node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
1011 Node* node = options.FinalizeBuilder(&node_builder);
1012 // Add any Enter nodes required to bring the constant to the correct control
1013 // flow frame.
1014 while (!control_flow_info[src_node->id()].frame_name.empty()) {
1015 NodeDebugInfo debug_info(*src_node);
1016 NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
1017 options.op_registry(), &debug_info);
1018 enter_builder.Attr("frame_name",
1019 control_flow_info[src_node->id()].frame_name);
1020 enter_builder.Attr("is_constant", true);
1021 enter_builder.Input(node, 0);
1022 Node* enter_node = options.FinalizeBuilder(&enter_builder);
1023 // Adopt the new Enter node as the value in the current frame.
1024 node = enter_node;
1025 // Recurse to the parent frame to see if more Enter nodes need to be added.
1026 src_node = control_flow_info[src_node->id()].parent_frame;
1027 }
1028 return node;
1029 }
1030
1031 } // namespace
1032
MakePrunedGraphCopyAndInline(const Graph & graph,const std::vector<Node * > & sink_nodes,std::unique_ptr<Graph> * pruned_graph,std::unordered_map<const Node *,Node * > * node_images,FunctionLibraryDefinition * library)1033 Status Encapsulator::MakePrunedGraphCopyAndInline(
1034 const Graph& graph, const std::vector<Node*>& sink_nodes,
1035 std::unique_ptr<Graph>* pruned_graph,
1036 std::unordered_map<const Node*, Node*>* node_images,
1037 FunctionLibraryDefinition* library) {
1038 // First copy all ancestor nodes of sink_nodes into a new graph.
1039 pruned_graph->reset(new Graph(library));
1040 (*pruned_graph)->set_versions(graph.versions());
1041 ReverseDFSFrom(graph, sink_nodes,
1042 /*enter=*/nullptr,
1043 /*leave=*/[&](Node* n) {
1044 if (!n->IsSource()) {
1045 Node* copied = (*pruned_graph)->CopyNode(n);
1046 node_images->emplace(n, copied);
1047 }
1048 });
1049
1050 // Add all the edges between copied nodes.
1051 for (auto entry : *node_images) {
1052 const Node* orig = entry.first;
1053 Node* image = entry.second;
1054 for (const Edge* out_edge : orig->out_edges()) {
1055 auto iter = node_images->find(out_edge->dst());
1056 if (iter != node_images->end()) {
1057 // The source and destination are both in the copied graph.
1058 (*pruned_graph)
1059 ->AddEdge(image, out_edge->src_output(), iter->second,
1060 out_edge->dst_input());
1061 }
1062 }
1063 }
1064
1065 // Find all the function call nodes, and inline them.
1066 std::vector<Node*> function_nodes;
1067 for (auto node : (*pruned_graph)->nodes()) {
1068 const OpRegistrationData* op_reg_data;
1069 TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1070 if (op_reg_data->is_function_op) {
1071 function_nodes.push_back(node);
1072 }
1073 }
1074 for (auto node : function_nodes) {
1075 VLOG(2) << "Inlining function " << node->name();
1076 const FunctionDef* fdef = library->Find(node->type_string());
1077 if (fdef == nullptr) {
1078 return errors::Internal("Failed to find function ", node->type_string(),
1079 " in function library.");
1080 }
1081 std::unique_ptr<FunctionBody> fbody;
1082 TF_RETURN_IF_ERROR(
1083 FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
1084
1085 InlineFunctionBodyOptions inline_opts;
1086 TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
1087 fbody.get(), inline_opts));
1088 }
1089
1090 return Status::OK();
1091 }
1092
BuildOutputGraph(Graph * graph_out,FunctionLibraryDefinition * library)1093 Status Encapsulator::BuildOutputGraph(Graph* graph_out,
1094 FunctionLibraryDefinition* library) {
1095 // Map from nodes in the input graph to nodes in the output graph.
1096 std::unordered_map<const Node*, Node*> node_images;
1097
1098 TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
1099 TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
1100 TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
1101
1102 return Status::OK();
1103 }
1104
1105 } // anonymous namespace
1106
EncapsulateSubgraphsInFunctions(string group_attribute,const Graph & graph_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,std::unique_ptr<Graph> * graph_out,FunctionLibraryDefinition * library)1107 Status EncapsulateSubgraphsInFunctions(
1108 string group_attribute, const Graph& graph_in,
1109 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1110 std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1111 Encapsulator encapsulator(std::move(group_attribute),
1112 &graph_in);
1113 TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
1114
1115 TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1116 rewrite_subgraph_fn, reuse_existing_functions, library));
1117
1118 std::unique_ptr<Graph> out(new Graph(library));
1119 out->set_versions(graph_in.versions());
1120 TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
1121
1122 *graph_out = std::move(out);
1123 return Status::OK();
1124 }
1125
1126 // Finds the types of the _Arg nodes, indexed by position.
GetArgTypes(const Graph & graph,DataTypeVector * types)1127 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
1128 for (Node* n : graph.op_nodes()) {
1129 if (n->type_string() == kArgOp) {
1130 int index;
1131 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1132 const int num_types = types->size();
1133 if (index < 0 || index >= num_types) {
1134 return errors::InvalidArgument("Invalid argument number");
1135 }
1136 (*types)[index] = n->output_type(0);
1137 }
1138 }
1139 return Status::OK();
1140 }
1141
1142 // Renumber the indices of _Arg nodes in a graph, according to
1143 // 'permutation' that maps old indices to new indices.
RenumberArguments(Graph * graph,const std::vector<int> & permutation)1144 static Status RenumberArguments(Graph* graph,
1145 const std::vector<int>& permutation) {
1146 for (Node* n : graph->op_nodes()) {
1147 if (n->type_string() == kArgOp) {
1148 int index;
1149 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1150 const int permutation_size = permutation.size();
1151 if (index < 0 || index >= permutation_size) {
1152 return errors::InvalidArgument("Invalid argument number");
1153 }
1154 n->AddAttr("index", permutation[index]);
1155 }
1156 }
1157 return Status::OK();
1158 }
1159
Run(const GraphOptimizationPassOptions & options)1160 Status EncapsulateSubgraphsPass::Run(
1161 const GraphOptimizationPassOptions& options) {
1162 VLOG(1) << "EncapsulateSubgraphsPass::Run";
1163 if (VLOG_IS_ON(1)) {
1164 DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
1165 options.flib_def);
1166 }
1167
1168 std::unique_ptr<Graph> graph_out;
1169 FunctionLibraryDefinition* const library = options.flib_def;
1170
1171 // Constant folding below might need to run part of the function to compute
1172 // constants. Create an FunctionLibraryRuntime with a single CPU device
1173 // that can run the part of the function.
1174 // NOTE: If this turns out to be slow, we can cache the FLRs keyed by
1175 // `options`.
1176 SessionOptions session_options;
1177 auto* device_count = session_options.config.mutable_device_count();
1178 device_count->insert({"CPU", 1});
1179 std::vector<std::unique_ptr<Device>> devices;
1180
1181 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
1182 if (!cpu_factory) {
1183 return errors::NotFound(
1184 "CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
1185 }
1186 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
1187 session_options, "/job:localhost/replica:0/task:0", &devices));
1188 if (devices.empty()) {
1189 return errors::NotFound(
1190 "Failed to create a CPU device for EncapsulateSubgraphsPass");
1191 }
1192
1193 std::unique_ptr<DeviceMgr> device_mgr =
1194 absl::make_unique<StaticDeviceMgr>(std::move(devices));
1195 const auto* config = &options.session_options->config;
1196 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1197 new ProcessFunctionLibraryRuntime(
1198 device_mgr.get(), options.session_options->env,
1199 /*config=*/config, TF_GRAPH_DEF_VERSION, library,
1200 config->graph_options().optimizer_options()));
1201 FunctionLibraryRuntime* flr =
1202 pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
1203 if (flr == nullptr) {
1204 return errors::Internal(
1205 "Failed to create and retrieve function library runtime to run "
1206 "constant folding");
1207 }
1208
1209 auto rewrite_subgraph =
1210 [flr](const std::vector<OutputTensor>& arg_source_tensors,
1211 std::unique_ptr<Graph>* subgraph,
1212 std::vector<int>* input_permutation,
1213 std::vector<int>* output_permutation, NodeDef* node) {
1214 // Optimize the subgraph.
1215 // Do not constant fold nodes that output DT_VARIANT type tensors.
1216 // XLA does not support Const nodes of Variant type since it needs
1217 // to know the original ops to be able to compile them to the relevant
1218 // XLA form.
1219 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
1220 // the form:
1221 // Const
1222 // |
1223 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
1224 // |
1225 // (Discard popped list)
1226 //
1227 // Would have been reduced to "Const -> Op" without this filter.
1228 // However since we are only allowed to specify the filter at the "Node"
1229 // level there is no good way to allow the above behavior. So we
1230 // disallow any sort of constant folding on Variant nodes for now.
1231 bool disable_constant_folding =
1232 GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding;
1233 auto cf_consider_fn = [disable_constant_folding](const Node* n) {
1234 if (disable_constant_folding) return false;
1235 for (const auto& output_arg : n->op_def().output_arg()) {
1236 if (output_arg.type() == DT_VARIANT) {
1237 return false;
1238 }
1239 }
1240 return true;
1241 };
1242 GraphOptimizer::Options graph_optimizer_options;
1243 graph_optimizer_options.cf_consider_fn = cf_consider_fn;
1244 OptimizeGraph(flr, subgraph, graph_optimizer_options);
1245
1246 const int num_args = input_permutation->size();
1247 std::vector<bool> const_args(num_args);
1248 TF_RETURN_IF_ERROR(
1249 BackwardsConstAnalysis(**subgraph, &const_args,
1250 /*compile_time_const_nodes=*/nullptr, flr));
1251
1252 DataTypeVector arg_types(num_args);
1253 TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
1254
1255 // Compute a permutation of the arguments such that the constant
1256 // arguments are first.
1257 const int num_consts =
1258 std::count(const_args.begin(), const_args.end(), true);
1259
1260 const int num_resources =
1261 std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
1262 const int num_nonconsts = num_args - num_resources - num_consts;
1263 if (num_nonconsts < 0) {
1264 return errors::Internal("num_nonconsts should be >= 0, was ",
1265 num_nonconsts);
1266 }
1267
1268 int const_pos = 0;
1269 int arg_pos = num_consts;
1270 int resource_pos = num_consts + num_nonconsts;
1271 for (int i = 0; i < num_args; ++i) {
1272 if (const_args[i]) {
1273 if (arg_types[i] == DT_RESOURCE) {
1274 return errors::Internal(
1275 "Resource arguments cannot be constant (argument ", i, ")");
1276 }
1277 (*input_permutation)[i] = const_pos;
1278 ++const_pos;
1279 } else if (arg_types[i] == DT_RESOURCE) {
1280 (*input_permutation)[i] = resource_pos;
1281 ++resource_pos;
1282 } else {
1283 (*input_permutation)[i] = arg_pos;
1284 ++arg_pos;
1285 }
1286 }
1287
1288 // Renumber argument nodes in the graph.
1289 TF_RETURN_IF_ERROR(
1290 RenumberArguments(subgraph->get(), *input_permutation));
1291
1292 // TODO(phawkins): add a forward is-constant analysis, similarly split
1293 // outputs into host-memory constants and device-memory non-constants.
1294
1295 AddNodeAttr(kXlaCompiledKernelAttr, true, node);
1296 AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
1297 AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
1298 return Status::OK();
1299 };
1300
1301 // Don't EncapsulateSubgraphs if graph doesn't contain nodes with
1302 // kXlaClusterAttr.
1303 bool has_xla_cluster_attribute = false;
1304 for (Node* node : (*options.graph)->nodes()) {
1305 if (HasNodeAttr(node->def(), kXlaClusterAttr)) {
1306 has_xla_cluster_attribute = true;
1307 break;
1308 }
1309 }
1310
1311 if (has_xla_cluster_attribute) {
1312 TF_RETURN_WITH_CONTEXT_IF_ERROR(
1313 EncapsulateSubgraphsInFunctions(
1314 kXlaClusterAttr, **options.graph, rewrite_subgraph,
1315 /*reuse_existing_functions=*/false, &graph_out, library),
1316 "EncapsulateSubgraphsPass failed");
1317 if (VLOG_IS_ON(1)) {
1318 DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
1319 options.flib_def);
1320 }
1321
1322 *options.graph = std::move(graph_out);
1323 }
1324
1325 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
1326 GetNodesRelatedToRefVariables(**options.graph, flr));
1327 for (Node* node : (*options.graph)->nodes()) {
1328 bool has_ref_vars = ref_related_nodes.contains(node);
1329 node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
1330 VLOG(3) << "Has ref vars = " << has_ref_vars
1331 << ", node: " << node->def().SerializeAsString();
1332 }
1333 return Status::OK();
1334 }
1335
IsXlaCompiledKernel(const Node & node)1336 bool IsXlaCompiledKernel(const Node& node) {
1337 bool is_compiled = false;
1338 bool has_compilation_attr =
1339 TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
1340 is_compiled;
1341 return has_compilation_attr ? is_compiled : false;
1342 }
1343
1344 } // namespace tensorflow
1345