1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file contains some utility functions for encapsulating XLA computation 17 // in host graph and encapsulating outside compilation in XLA computation. 18 19 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ 20 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ 21 22 #include "absl/container/flat_hash_map.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/stream_executor/lib/statusor.h" 25 26 namespace tensorflow { 27 28 // Attribute marking output tensor shapes inferred by XLA. Attribute value is 29 // a list of PartialTensorShape objects. 30 extern const char kXlaInferredShapesAttrName[]; 31 32 // Infers output shapes for all nodes in graph `g`. The output shapes will be 33 // stored in node attribute `kXlaInferredShapesAttrName`. 34 // 35 // We have to perform shape inference before encapsulation because after 36 // encapsulation, some nodes will be encapsulated into function call, and shape 37 // inference does not handle function call at the moment. 38 Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); 39 40 // Attribute indicating that some ops in this node's XLA computation has control 41 // dependency on this node. Attribute value will always be "true". 42 extern const char kXlaConnectedToXlaComputationAttrName[]; 43 44 // Attribute indicating that this node has control dependency on some ops in 45 // this node's XLA computation. Attribute value will always be "true". 46 extern const char kXlaConnectedFromXlaComputationAttrName[]; 47 48 // Attribute indicating that this is an Placeholder node added to act as a 49 // temporary input node for an outside compilation node. Attribute value will be 50 // string (original input node name). 51 extern const char kOutsideCompilationOriginalNodeAttrName[]; 52 53 // Attribute indicating that this is an Placeholder node added to act as a 54 // temporary input node for an outside compilation node. Attribute value will be 55 // int (src_output for original edge). 56 extern const char kOutsideCompilationSrcOutputAttrName[]; 57 58 // Attribute indicating that this node has control dependencies on some other 59 // nodes within the same XLA cluster. Attribute value will be a list of string 60 // (node names). 61 extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; 62 63 // Attribute indicating that this node is an outside compilation node which is 64 // lifted out of If/While/function node. Attribute value will always be boolean 65 // value "true". 66 extern const char kXlaIsLiftedArgAttrName[]; 67 68 // Attribute indicating that this node is a Placeholder node for an _Arg node 69 // lifted out of If/While/function node. Attribute value will be a string, which 70 // is the outside compilation cluster name sending the lifted arg node to host. 71 extern const char kXlaLiftedArgOutsideCompilationAttrName[]; 72 73 // Attribute indicating that this is an IdentityN node receiving inputs for a 74 // outside compilation Placeholder node (the original outside compilation node 75 // is moved out of TPU computation, and we left a Placeholder node there). 76 // Attribute value will be a string, which is the outside compilation cluster 77 // name for the outside compilation Placeholder node. 78 extern const char kXlaOutsideCompilationInputsAttrName[]; 79 80 // Attribute indicating that this is a Placeholder node for an _Arg node used in 81 // outside compilation. We should not move this node out of XLA computation. 82 // Attribute value will always be boolean value "true". 83 extern const char kXlaIsPlaceholderForArg[]; 84 85 // Information for XLA computation. 86 struct XlaClusterInfo { 87 // Add an explicitly-defined default constructor for this class. 88 // 89 // The compiler may delete the default constructor here because 90 // host_compute_core is a const member whose type (std::map) doesn't 91 // necessarily have a user provided constructor -- while libc++ and 92 // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at 93 // least >= 7.3 does not. See also c++11 [class.ctor] p5. 94 // 95 // TODO(klimek): In c++17 we'll be able to initialize host_compute_core 96 // without losing aggregate initialization, which allows us to get rid of 97 // the constructor definitions again. XlaClusterInfoXlaClusterInfo98 XlaClusterInfo() {} XlaClusterInfoXlaClusterInfo99 XlaClusterInfo(const string& cluster_name, 100 const NameAttrList& func_name_attrs, Node* node, 101 const std::map<string, int>& host_compute_core) 102 : cluster_name(cluster_name), 103 func_name_attrs(func_name_attrs), 104 node(node), 105 host_compute_core(host_compute_core) {} 106 // XLA cluster name. It might be different from `func_name`. 107 const string cluster_name; 108 // Name and attributes of XLA computation function. 109 const NameAttrList func_name_attrs; 110 // The XLA computation node in the graph. 111 Node* node; 112 // A mapping from outside compilation cluster name to its device assignment. 113 const std::map<string, int> host_compute_core; 114 }; 115 116 // Finds dependencies between outside compilation clusters, including both data 117 // dependencies and control dependencies. cluster_deps maps the name name of an 118 // outside compilation cluster to a set of names of outside compilation clusters 119 // that it depends on. 120 stream_executor::port::StatusOr< 121 std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>> 122 OutsideCompilationClusterDependencies( 123 const Graph* g, const string& outside_compilation_attr_name); 124 125 // Preprocesses edges within the same XLA cluster. It will perform the following 126 // operations in order: 127 // 128 // 0. Remove edges from source node to outside compilation nodes, and edges 129 // from outside compilation nodes to sink node. 130 // 1a. For edges between different outside compilation clusters, remove the edge 131 // and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node 132 // name" to dst node. 133 // 1b. For control edges between outside compilation and its XLA computation, 134 // add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the 135 // outside compilation node. 136 // 2. For data edges between different outside compilations, remove the edge 137 // and create a Placeholder node as dst node's input. 138 Status PreprocessEdgesBetweenOutsideCompilations( 139 Graph* g, const string& outside_compilation_attr_name); 140 141 // Postprocesses edges within the same XLA cluster. This function reverts what 142 // `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the 143 // following operations in order: 144 // 145 // 1. Remove Placeholder nodes between different outside compilations (created 146 // in `PreprocessEdgesBetweenOutsideCompilations` step 2). 147 // 2a. Reconnect control edges between different outside compilations (marked by 148 // `PreprocessEdgesBetweenOutsideCompilations` step 1a). 149 // Notice that control edges marked by 150 // `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. 151 // They are handled in `RewriteOutsideCompilationSubgraphFn`. 152 Status PostprocessEdgesBetweenOutsideCompilations( 153 Graph* g, const string& outside_compilation_attr_name); 154 } // namespace tensorflow 155 156 #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ 157