• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6 
7      http://www.apache.org/licenses/LICENSE-2.0
8 
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15 // Rewrites computations generated by the xla.compile() Python code into
16 // XlaLaunch nodes.
17 //
18 // xla.compile() does two main things:
19 // a) marks operators that make up an XLA computation with the attribute
20 //    _xla_compile_id=XYZ, where XYZ is a unique key.
21 // b) adds XlaClusterOutput nodes to represent outputs of the computation.
22 //    These nodes are not marked with the _xla_compile_id attribute.
23 
24 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
25 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
26 
27 #include <functional>
28 #include <string>
29 
30 #include "tensorflow/core/common_runtime/optimization_registry.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/statusor.h"
34 
35 namespace tensorflow {
36 
37 // Encapsulates nodes marked with the _xla_compile_id attribute into
38 // XlaLaunch operators.
39 class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
40  public:
41   Status Run(const GraphOptimizationPassOptions& options) override;
42 
43   // The following methods are public only for unit tests.
44 
45   // This pass has two stages:
46   // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes
47   //    marked with the same _xla_compile_id attribute into functions. These
48   //    functions contain the computations to be passed to XlaLaunch. During
49   //    encapsulation, we sort the arguments into the order expected by
50   //    XlaLaunch.
51   static Status Encapsulate(std::unique_ptr<Graph>* graph,
52                             FunctionLibraryDefinition* flib_def);
53 
54   // b) we rewrite the function calls generated in phase (a) into XlaLaunch
55   //    operators. We also convert the XlaClusterOutput output nodes of the
56   //    function call into the outputs of the XlaLaunch operator.
57   static Status BuildXlaLaunchOps(Graph* graph);
58 
59   struct XlaFunctionInfo {
60     int variable_start_index = -1;
61     std::string function_name;
62   };
63 
64   // We need to introduce this version to adapt to the output of gpu inference
65   // converter. The single argument overload version calls this function.
66   //
67   // When add_edges_to_output_of_downstream_nodes is true, the output edges of
68   // the xla_launch_node's immediate downstream nodes would be attached to the
69   // generated xla node. For example, if the original graph is
70   // StatefulPartitionedCall{_xla_compile_id=1} -> XlaClusterOutput -> NodeA
71   // The output graph of this function would look like the following when
72   // add_edges_to_output_of_downstream_nodes is true:
73   // XlaLaunch -> NodeA
74   static Status BuildXlaLaunchOps(
75       Graph* graph,
76       const std::function<StatusOr<bool>(const Node&)>& is_xla_launch_node,
77       const std::function<StatusOr<XlaFunctionInfo>(const Node&)>&
78           get_xla_function_info,
79       bool add_edges_to_output_of_downstream_nodes);
80 };
81 
82 }  // namespace tensorflow
83 
84 #endif  // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
85