1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // An optimization pass that groups nodes marked with a common 17 // kXlaClusterAttr into functions, and replaces the original nodes by 18 // calls. The calls are annotated with kXlaCompiledKernelAttr. 19 20 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ 21 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ 22 23 #include "tensorflow/core/common_runtime/optimization_registry.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/graph/graph.h" 26 #include "tensorflow/core/lib/core/status.h" 27 28 namespace tensorflow { 29 30 // EncapsulateSubgraphs pass takes all the nodes with the same cluster ID 31 // (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into 32 // a TF function, and replaces the subgraph in the main graph with a call to 33 // that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel). 34 class EncapsulateSubgraphsPass : public GraphOptimizationPass { 35 public: 36 Status Run(const GraphOptimizationPassOptions& options) override; 37 }; 38 39 // A rewriting function to apply to each subgraph during encapsulation. 40 // 'arg_source_tensors' are the tensors corresponding to the arguments in the 41 // original source graph (*not* 'graph'). 42 // 43 // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; 44 // 'input_permutation' is a mapping from old argument numbers to new argument 45 // numbers, whereas 'output_permutation' is the same for outputs. Both 46 // 'input_permutation' and 'output_permutation' are initialized to the identity 47 // permutation. 'nodedef' is the NodeDef for the call to the function under 48 // construction, provided to allow additional attributes to be set. 49 // The rewrite may also change the NodeDef's operator name, and that 50 // name will be used as the name of the generated function. 51 typedef std::function<Status( 52 const std::vector<OutputTensor>& arg_source_tensors, 53 std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation, 54 std::vector<int>* output_permutation, NodeDef* node_def)> 55 RewriteSubgraphFn; 56 57 // Transformation that finds subgraphs whose nodes are marked with 58 // 'group_attribute', splits those subgraphs into functions, and replaces 59 // the originals with function calls. 60 // 61 // 'group_attribute' must be a string valued-attribute that names the new 62 // functions to introduce. 63 // 64 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before 65 // function conversion. 66 // 67 // If 'reuse_existing_functions' is set, use an existing function with the 68 // same name, if any. 69 // 70 // TODO(phawkins): currently, some information in control edges 71 // is not preserved. Suppose you have A and B in the main 72 // graph, C and D in a subgraph. B and C have control deps from A, D has control 73 // dep from B. Originally D must run after C, post-transformation this 74 // dependency is lost. 75 Status EncapsulateSubgraphsInFunctions( 76 string group_attribute, const Graph& graph_in, 77 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, 78 std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library); 79 80 // The attribute that marks function calls produced by the encapsulate 81 // subgraphs pass and that should in turn be compiled via XlaLaunch operators. 82 extern const char* const kXlaCompiledKernelAttr; 83 84 // Does `node` have the kXlaCompiledKernelAttr attribute? 85 bool IsXlaCompiledKernel(const Node& node); 86 87 // Functions produced by the EncapsulateSubgraphs pass have their arguments in 88 // the order: 89 // 1) compile-time constant arguments, in host memory, 90 // 2) other arguments, in device memory. 91 // 3) resource variable arguments, in host memory. Note that only the resource 92 // Tensor itself is in host memory; the underlying value may be in device 93 // memory. 94 // The functions are annotated with the following attributes that describe how 95 // many constant and resource arguments there are: 96 97 // Name of the attribute containing the number of constant arguments. 98 extern const char* const kXlaNumConstantArgsAttr; 99 100 // Name of the attribute containing the number of resource variable arguments. 101 extern const char* const kXlaNumResourceArgsAttr; 102 103 // Name of the attribute defining whether the cluster has reference variables. 104 extern const char* const kXlaHasReferenceVarsAttr; 105 106 // Sorts each node's control inputs by their names. This guarantees that for two 107 // structurally equivalent GraphDefs, we get the same traversal ordering on 108 // node's control input fields. 109 // TODO(hpucha): Move the utilities to a more appropriate place. 110 void SortControlInputs(GraphDef* gdef); 111 112 } // namespace tensorflow 113 114 #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ 115