1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Rewrites computations generated by the xla.compile() Python code into 16 // XlaLaunch nodes. 17 // 18 // xla.compile() does two main things: 19 // a) marks operators that make up an XLA computation with the attribute 20 // _xla_compile_id=XYZ, where XYZ is a unique key. 21 // b) adds XlaClusterOutput nodes to represent outputs of the computation. 22 // These nodes are not marked with the _xla_compile_id attribute. 23 24 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ 25 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ 26 27 #include <functional> 28 #include <string> 29 30 #include "tensorflow/core/common_runtime/optimization_registry.h" 31 #include "tensorflow/core/graph/graph.h" 32 #include "tensorflow/core/platform/env.h" 33 #include "tensorflow/core/platform/statusor.h" 34 35 namespace tensorflow { 36 37 // Encapsulates nodes marked with the _xla_compile_id attribute into 38 // XlaLaunch operators. 39 class EncapsulateXlaComputationsPass : public GraphOptimizationPass { 40 public: 41 Status Run(const GraphOptimizationPassOptions& options) override; 42 43 // The following methods are public only for unit tests. 44 45 // This pass has two stages: 46 // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes 47 // marked with the same _xla_compile_id attribute into functions. These 48 // functions contain the computations to be passed to XlaLaunch. During 49 // encapsulation, we sort the arguments into the order expected by 50 // XlaLaunch. 51 static Status Encapsulate(std::unique_ptr<Graph>* graph, 52 FunctionLibraryDefinition* flib_def); 53 54 // b) we rewrite the function calls generated in phase (a) into XlaLaunch 55 // operators. We also convert the XlaClusterOutput output nodes of the 56 // function call into the outputs of the XlaLaunch operator. 57 static Status BuildXlaLaunchOps(Graph* graph); 58 59 struct XlaFunctionInfo { 60 int variable_start_index = -1; 61 std::string function_name; 62 }; 63 64 // We need to introduce this version to adapt to the output of gpu inference 65 // converter. The single argument overload version calls this function. 66 // 67 // When add_edges_to_output_of_downstream_nodes is true, the output edges of 68 // the xla_launch_node's immediate downstream nodes would be attached to the 69 // generated xla node. For example, if the original graph is 70 // StatefulPartitionedCall{_xla_compile_id=1} -> XlaClusterOutput -> NodeA 71 // The output graph of this function would look like the following when 72 // add_edges_to_output_of_downstream_nodes is true: 73 // XlaLaunch -> NodeA 74 static Status BuildXlaLaunchOps( 75 Graph* graph, 76 const std::function<StatusOr<bool>(const Node&)>& is_xla_launch_node, 77 const std::function<StatusOr<XlaFunctionInfo>(const Node&)>& 78 get_xla_function_info, 79 bool add_edges_to_output_of_downstream_nodes); 80 }; 81 82 } // namespace tensorflow 83 84 #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ 85