1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 17 #define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 18 19 #include <string> 20 21 #include "absl/algorithm/container.h" 22 #include "absl/strings/string_view.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/jit/defs.h" 25 #include "tensorflow/compiler/jit/device_util.h" 26 #include "tensorflow/compiler/jit/flags.h" 27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" 28 #include "tensorflow/compiler/tf2xla/const_analysis.h" 29 #include "tensorflow/compiler/tf2xla/resource_operation_table.h" 30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 31 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/union_find.h" 34 #include "tensorflow/compiler/xla/util.h" 35 #include "tensorflow/core/common_runtime/function.h" 36 #include "tensorflow/core/common_runtime/graph_constructor.h" 37 #include "tensorflow/core/framework/attr_value.pb.h" 38 #include "tensorflow/core/framework/bounds_check.h" 39 #include "tensorflow/core/framework/function.h" 40 #include "tensorflow/core/framework/graph_def_util.h" 41 #include "tensorflow/core/framework/memory_types.h" 42 #include "tensorflow/core/framework/node_def.pb.h" 43 #include "tensorflow/core/framework/op_kernel.h" 44 #include "tensorflow/core/framework/types.h" 45 #include "tensorflow/core/framework/types.pb.h" 46 #include "tensorflow/core/graph/algorithm.h" 47 #include "tensorflow/core/graph/control_flow.h" 48 #include "tensorflow/core/graph/graph.h" 49 #include "tensorflow/core/lib/gtl/cleanup.h" 50 #include "tensorflow/core/lib/strings/stringprintf.h" 51 #include "tensorflow/core/public/version.h" 52 #include "tensorflow/core/util/dump_graph.h" 53 54 namespace tensorflow { 55 // Checks whether a TF node can be compiled or not. "Recursive" as in for call 56 // and functional while nodes it recursively checks whether the callee functions 57 // can be compiled. 58 class RecursiveCompilabilityChecker { 59 public: 60 // Contains node name and function name. If the node is not inside a function 61 // body, function name is an empty string. 62 struct StackFrame { 63 std::string name; 64 std::string function_name; 65 std::shared_ptr<AbstractStackTrace> stack_trace; 66 }; 67 68 // Contains information about uncompilable node inside a function body. 69 struct UncompilableNodeInfo { 70 std::string name; 71 // A list representing a stacktrace from the highest level node in 72 // increasing call depth to immediate node that fails the 73 // compilability checker. 74 std::vector<StackFrame> stack_trace; 75 std::string uncompilable_reason; 76 }; 77 78 // Aggregates information about what kinds of ops are allowed. 79 struct OperationFilter { // TODO(lzr): Add AllowEverything() helper. 80 // Whether resource variable ops are allowed are allowed in callees. We do 81 // not allow resource variable ops in called functions (either as direct TF 82 // calls or as higher order control flow ops) because we do not yet model 83 // their memory effects in jit/resource_operation_safety_analysis. 84 bool allow_resource_ops_in_called_functions = false; 85 86 // Whether Stack operations are allowed. We avoid auto-clustering Stack 87 // operations in general because we do not support snapshotting them. 88 // 89 // TODO(b/112837194): This restriction can be lifted with some work. 90 bool allow_stack_ops = false; 91 92 // Whether TensorArray operations are allowed. We avoid auto-clustering 93 // TensorArray operations in general because we do not support snapshotting 94 // them. 95 // 96 // TODO(b/112837194): This restriction can be lifted with some work. 97 bool allow_tensor_array_ops = false; 98 99 // Whether stateful RNG ops are allowed. XLA's RNG does not have the same 100 // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid 101 // auto-clustering stateful RNG ops. 102 bool allow_stateful_rng_ops = false; 103 104 // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound 105 // to cluster ControlTrigger because of how we use deadness analysis. 106 bool allow_control_trigger = false; 107 108 // Whether it is okay to "cluster" Assert and CheckNumerics by simply 109 // removing them (they're not removed during clustering, but their 110 // XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so 111 // that the user is not surprised when XLA is implicitly enabled. If the 112 // user explicitly specifies to use XLA, it is fine to resort to a dummy 113 // implementation. Currently Assert and CheckNumerics ops have dummy XLA 114 // implementations. 115 bool allow_eliding_assert_and_checknumerics_ops = false; 116 117 // Whether ops that produce or consume DT_VARIANT values are allowed. We 118 // don't auto-cluster these ops because we don't yet support live-in or 119 // live-out DT_VARIANT values. 120 bool allow_ops_producing_or_consuming_variant = false; 121 122 // Whether ops known to be slow on XLA-GPU should be considered compilable. 123 bool allow_slow_ops = false; 124 125 // Whether ops known to have numerical accuracy issues should be considered 126 // compilable.. 127 bool allow_inaccurate_ops = false; 128 129 // Require the function to be always compilable, regardless whether some 130 // control flow branches might be dead for a given input. 131 bool require_always_compilable = false; 132 133 // Whether string constants are compilable. 134 bool allow_string_consts = true; 135 136 // Whether to allow the compilation of CollectiveReduceV2Op. 137 bool allow_collective_reduce_v2 = true; 138 139 // Whether ops that are marked as outside compiled are always considered 140 // compilable. 141 // TODO(b/191502757): Make this behavior true by default and remove this 142 // option once inference converter supports outside compilation. 143 bool allow_outside_compiled = false; 144 }; 145 RecursiveCompilabilityChecker(OperationFilter op_filter,DeviceType jit_device_type)146 RecursiveCompilabilityChecker(OperationFilter op_filter, 147 DeviceType jit_device_type) 148 : op_filter_(std::move(op_filter)), 149 jit_device_type_(std::move(jit_device_type)) {} 150 151 using UncompilableNodesMap = 152 std::map<std::string, 153 std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>; 154 155 // Returns a map where the key is the function identifier(short debug 156 // string) of the function encapsulating the uncompilable nodes, and the 157 // value is a pair of NameAttrList of the function and a vector of 158 // uncompilable node info. When uncompilable node is not inside any 159 // function call nodes, then key is a ShortDebugString() of an empty 160 // NameAttrList. 161 // 162 // Also, when `node` is inside a function body, users can set 163 // `node_stack_trace` to provide an additional context for `node`'s 164 // placement within the outer most graph. 165 UncompilableNodesMap FindUncompilableNodes( 166 const Node& node, FunctionLibraryRuntime* lib_runtime, 167 const std::vector<StackFrame>* node_stack_trace = nullptr) const; 168 169 // Returns true if `node` can be compiled by XLA. IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime)170 bool IsCompilableNode(const Node& node, 171 FunctionLibraryRuntime* lib_runtime) const { 172 std::vector<StackFrameView> stack_trace; 173 stack_trace.emplace_back(StackFrameView{node.name(), ""}); 174 return IsCompilableNode(node, lib_runtime, &stack_trace); 175 } 176 177 // Returns true if XLA supports this Op, but we don't want to cluster it (ie: 178 // due to performance or correctness concerns). 179 bool OpIsInaccurate(const Node& node) const; 180 bool OpIsSlow(const Node& node) const; 181 182 private: 183 struct StackFrameView { 184 absl::string_view name; 185 absl::string_view function_name; 186 std::shared_ptr<AbstractStackTrace> stack_trace; 187 }; 188 189 bool IsCompilableNode( 190 const Node& node, FunctionLibraryRuntime* lib_runtime, 191 std::vector<StackFrameView>* stack_trace, 192 NameAttrList* encapsulating_function = nullptr, 193 UncompilableNodesMap* uncompilable_nodes = nullptr) const; 194 bool IsCompilableCall( 195 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, 196 std::vector<StackFrameView>* stack_trace, 197 NameAttrList* encapsulating_function = nullptr, 198 UncompilableNodesMap* uncompilable_nodes = nullptr) const; 199 bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime, 200 std::vector<StackFrameView>* stack_trace, 201 NameAttrList* encapsulating_function, 202 UncompilableNodesMap* uncompilable_nodes) const; 203 bool IsCompilableWhile(const Node& while_node, 204 FunctionLibraryRuntime* lib_runtime, 205 std::vector<StackFrameView>* stack_trace, 206 NameAttrList* encapsulating_function, 207 UncompilableNodesMap* uncompilable_nodes) const; 208 209 // Tests whether 'case_node' is compilable. Every operator in all branches 210 // must be compilable. 211 bool IsCompilableCase(const Node& case_node, 212 FunctionLibraryRuntime* lib_runtime, 213 std::vector<StackFrameView>* stack_trace, 214 NameAttrList* encapsulating_function, 215 UncompilableNodesMap* uncompilable_nodes) const; 216 217 // Returns compilability of node def retrieved from `node`'s attribute with 218 // name `attr_name`. 219 bool ExtractNodeDefAndCheckCompilability( 220 const Node& node, const std::string& attr_name, 221 const std::string& call_name, NameAttrList* encapsulating_function, 222 FunctionLibraryRuntime* lib_runtime, 223 std::vector<StackFrameView>* stack_trace, 224 UncompilableNodesMap* uncompilable_nodes) const; 225 IsStackOp(const Node & node)226 bool IsStackOp(const Node& node) const { 227 const XlaResourceOpInfo* op_info = 228 GetResourceOpInfoForOp(node.type_string()); 229 return op_info && op_info->resource_kind() == XlaResourceKind::kStack; 230 } 231 IsTensorArrayOp(const Node & node)232 bool IsTensorArrayOp(const Node& node) const { 233 const XlaResourceOpInfo* op_info = 234 GetResourceOpInfoForOp(node.type_string()); 235 return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray; 236 } 237 IsAssertOrCheckNumerics(absl::string_view op_name)238 bool IsAssertOrCheckNumerics(absl::string_view op_name) const { 239 return op_name == "Assert" || op_name == "CheckNumerics"; 240 } 241 IsStatefulRandomOp(absl::string_view op_name)242 bool IsStatefulRandomOp(absl::string_view op_name) const { 243 return op_name == "RandomUniform" || op_name == "RandomShuffle" || 244 op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || 245 op_name == "TruncatedNormal" || op_name == "Multinomial"; 246 } 247 OpProducesOrConsumesVariant(const Node & node)248 bool OpProducesOrConsumesVariant(const Node& node) const { 249 auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; 250 return absl::c_any_of(node.input_types(), is_variant) || 251 absl::c_any_of(node.output_types(), is_variant); 252 } 253 254 bool HasXLAKernel(const Node& node, 255 string* uncompilable_reason = nullptr) const; 256 257 static void MaybeMarkUncompilableNode( 258 const absl::string_view reason, 259 const std::vector<StackFrameView>& stack_trace, 260 NameAttrList* encapsulating_function, 261 UncompilableNodesMap* uncompilable_nodes_map); 262 263 // Make sure we don't recurse infinitely on recursive functions. 264 const size_t kMaxRecursionDepth = 50; 265 266 const OperationFilter op_filter_; 267 const DeviceType jit_device_type_; 268 }; 269 270 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( 271 const XlaOpRegistry::DeviceRegistration& registration); 272 273 // Given a FunctionLibraryRuntime and a `function`, returns this function's body 274 // in `fbody` as well as the indices of its constant and resource arguments. 275 // `fbody` is owned by `flr`. 276 // `constant_arg_indices` and `resource_arg_indices` should be empty vector. 277 // They are sorted in ascending order on this function's return. 278 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, 279 const NameAttrList& function, 280 const FunctionBody** fbody, 281 std::vector<int>* constant_arg_indices, 282 std::vector<int>* resource_arg_indices); 283 284 // Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr 285 // set. 286 bool CanCreateXlaKernel(const NodeDef& node_def); 287 288 // Returns memory types for the input. 289 // `constant_arg_indices` and `resource_arg_indices` are sorted arrays of 290 // indices corresponding to constant and resource arguments respectively. 291 // 292 // One might wonder, about the case where a compile-time constant argument 293 // (which must be in host memory) is also used as an input into an op, 294 // e.g. `Add`, that expects its inputs in device memory. Here is how it 295 // works now. 296 // First, what do we mean by "op expects an input in XYZ memory"? 297 // There are two types of "ops" here: the tf2xla kernel and the HLO 298 // computation it builds. The tf2xla kernel needs to retrieve the actual 299 // numeric value of the compile-time constant tensors, so it really expects 300 // them to be on in host memory. However, for other inputs, it refers to them 301 // using xla::ComputationDataHandle, which is just a symbolic handle that 302 // xla::ComputationBuilder assigns. How does this handle gets assigned for 303 // constant arguments? Even constant arguments get an _Arg node in the graph 304 // instantiated for Function compilation. The tf2xla kernel for constant _Arg 305 // nodes takes the constant value, converts it to XlaLiteral, and feeds it 306 // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This 307 // constant XlaLiteral is included in the HLO graph, and subsequently, in 308 // the actual executable, which is copied to the device before being 309 // executed. Thus, when this executable runs, the constant is available in 310 // device memory. 311 tensorflow::MemoryTypeVector GetInputMemoryTypes( 312 const tensorflow::FunctionBody* fbody, 313 absl::Span<int const> constant_arg_indices, 314 absl::Span<int const> resource_arg_indices); 315 316 // Returns output memory types. 317 // 318 // XlaLaunch kernel keeps all outputs (including constants, which it copies), 319 // in device memory except for resources. 320 tensorflow::MemoryTypeVector GetOutputMemoryTypes( 321 const tensorflow::FunctionBody* fbody); 322 323 // Check whether graph can trigger XLA compilation. 324 bool CanTriggerXlaCompilation(const GraphDef& graph); 325 326 } // namespace tensorflow 327 328 #endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 329