• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
17 #define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
18 
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/jit/defs.h"
25 #include "tensorflow/compiler/jit/device_util.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
28 #include "tensorflow/compiler/tf2xla/const_analysis.h"
29 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
31 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/union_find.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/framework/attr_value.pb.h"
38 #include "tensorflow/core/framework/bounds_check.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/memory_types.h"
42 #include "tensorflow/core/framework/node_def.pb.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/strings/stringprintf.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 
54 namespace tensorflow {
55 // Checks whether a TF node can be compiled or not.  "Recursive" as in for call
56 // and functional while nodes it recursively checks whether the callee functions
57 // can be compiled.
58 class RecursiveCompilabilityChecker {
59  public:
60   // Contains node name and function name. If the node is not inside a function
61   // body, function name is an empty string.
62   struct StackFrame {
63     std::string name;
64     std::string function_name;
65     std::shared_ptr<AbstractStackTrace> stack_trace;
66   };
67 
68   // Contains information about uncompilable node inside a function body.
69   struct UncompilableNodeInfo {
70     std::string name;
71     // A list representing a stacktrace from the highest level node in
72     // increasing call depth to immediate node that fails the
73     // compilability checker.
74     std::vector<StackFrame> stack_trace;
75     std::string uncompilable_reason;
76   };
77 
78   // Aggregates information about what kinds of ops are allowed.
79   struct OperationFilter {  // TODO(lzr): Add AllowEverything() helper.
80     // Whether resource variable ops are allowed are allowed in callees.  We do
81     // not allow resource variable ops in called functions (either as direct TF
82     // calls or as higher order control flow ops) because we do not yet model
83     // their memory effects in jit/resource_operation_safety_analysis.
84     bool allow_resource_ops_in_called_functions = false;
85 
86     // Whether Stack operations are allowed.  We avoid auto-clustering Stack
87     // operations in general because we do not support snapshotting them.
88     //
89     // TODO(b/112837194): This restriction can be lifted with some work.
90     bool allow_stack_ops = false;
91 
92     // Whether TensorArray operations are allowed.  We avoid auto-clustering
93     // TensorArray operations in general because we do not support snapshotting
94     // them.
95     //
96     // TODO(b/112837194): This restriction can be lifted with some work.
97     bool allow_tensor_array_ops = false;
98 
99     // Whether stateful RNG ops are allowed.  XLA's RNG does not have the same
100     // seeding behavior as TensorFlow's RNG (b/34749654).  So we avoid
101     // auto-clustering stateful RNG ops.
102     bool allow_stateful_rng_ops = false;
103 
104     // TODO(b/118970344): Whether ControlTrigger ops are allowed.  It is unsound
105     // to cluster ControlTrigger because of how we use deadness analysis.
106     bool allow_control_trigger = false;
107 
108     // Whether it is okay to "cluster" Assert and CheckNumerics by simply
109     // removing them (they're not removed during clustering, but their
110     // XlaOpKernel is a no-op kernel).  We avoid auto-clustering these ops so
111     // that the user is not surprised when XLA is implicitly enabled. If the
112     // user explicitly specifies to use XLA, it is fine to resort to a dummy
113     // implementation. Currently Assert and CheckNumerics ops have dummy XLA
114     // implementations.
115     bool allow_eliding_assert_and_checknumerics_ops = false;
116 
117     // Whether ops that produce or consume DT_VARIANT values are allowed.  We
118     // don't auto-cluster these ops because we don't yet support live-in or
119     // live-out DT_VARIANT values.
120     bool allow_ops_producing_or_consuming_variant = false;
121 
122     // Whether ops known to be slow on XLA-GPU should be considered compilable.
123     bool allow_slow_ops = false;
124 
125     // Whether ops known to have numerical accuracy issues should be considered
126     // compilable..
127     bool allow_inaccurate_ops = false;
128 
129     // Require the function to be always compilable, regardless whether some
130     // control flow branches might be dead for a given input.
131     bool require_always_compilable = false;
132 
133     // Whether string constants are compilable.
134     bool allow_string_consts = true;
135 
136     // Whether to allow the compilation of CollectiveReduceV2Op.
137     bool allow_collective_reduce_v2 = true;
138 
139     // Whether ops that are marked as outside compiled are always considered
140     // compilable.
141     // TODO(b/191502757):  Make this behavior true by default and remove this
142     // option once inference converter supports outside compilation.
143     bool allow_outside_compiled = false;
144   };
145 
RecursiveCompilabilityChecker(OperationFilter op_filter,DeviceType jit_device_type)146   RecursiveCompilabilityChecker(OperationFilter op_filter,
147                                 DeviceType jit_device_type)
148       : op_filter_(std::move(op_filter)),
149         jit_device_type_(std::move(jit_device_type)) {}
150 
151   using UncompilableNodesMap =
152       std::map<std::string,
153                std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>;
154 
155   // Returns a map where the key is the function identifier(short debug
156   // string) of the function encapsulating the uncompilable nodes, and the
157   // value is a pair of NameAttrList of the function and a vector of
158   // uncompilable node info. When uncompilable node is not inside any
159   // function call nodes, then key is a ShortDebugString() of an empty
160   // NameAttrList.
161   //
162   // Also, when `node` is inside a function body, users can set
163   // `node_stack_trace` to provide an additional context for `node`'s
164   // placement within the outer most graph.
165   UncompilableNodesMap FindUncompilableNodes(
166       const Node& node, FunctionLibraryRuntime* lib_runtime,
167       const std::vector<StackFrame>* node_stack_trace = nullptr) const;
168 
169   // Returns true if `node` can be compiled by XLA.
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime)170   bool IsCompilableNode(const Node& node,
171                         FunctionLibraryRuntime* lib_runtime) const {
172     std::vector<StackFrameView> stack_trace;
173     stack_trace.emplace_back(StackFrameView{node.name(), ""});
174     return IsCompilableNode(node, lib_runtime, &stack_trace);
175   }
176 
177   // Returns true if XLA supports this Op, but we don't want to cluster it (ie:
178   // due to performance or correctness concerns).
179   bool OpIsInaccurate(const Node& node) const;
180   bool OpIsSlow(const Node& node) const;
181 
182  private:
183   struct StackFrameView {
184     absl::string_view name;
185     absl::string_view function_name;
186     std::shared_ptr<AbstractStackTrace> stack_trace;
187   };
188 
189   bool IsCompilableNode(
190       const Node& node, FunctionLibraryRuntime* lib_runtime,
191       std::vector<StackFrameView>* stack_trace,
192       NameAttrList* encapsulating_function = nullptr,
193       UncompilableNodesMap* uncompilable_nodes = nullptr) const;
194   bool IsCompilableCall(
195       const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
196       std::vector<StackFrameView>* stack_trace,
197       NameAttrList* encapsulating_function = nullptr,
198       UncompilableNodesMap* uncompilable_nodes = nullptr) const;
199   bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime,
200                       std::vector<StackFrameView>* stack_trace,
201                       NameAttrList* encapsulating_function,
202                       UncompilableNodesMap* uncompilable_nodes) const;
203   bool IsCompilableWhile(const Node& while_node,
204                          FunctionLibraryRuntime* lib_runtime,
205                          std::vector<StackFrameView>* stack_trace,
206                          NameAttrList* encapsulating_function,
207                          UncompilableNodesMap* uncompilable_nodes) const;
208 
209   // Tests whether 'case_node' is compilable. Every operator in all branches
210   // must be compilable.
211   bool IsCompilableCase(const Node& case_node,
212                         FunctionLibraryRuntime* lib_runtime,
213                         std::vector<StackFrameView>* stack_trace,
214                         NameAttrList* encapsulating_function,
215                         UncompilableNodesMap* uncompilable_nodes) const;
216 
217   // Returns compilability of node def retrieved from `node`'s attribute with
218   // name `attr_name`.
219   bool ExtractNodeDefAndCheckCompilability(
220       const Node& node, const std::string& attr_name,
221       const std::string& call_name, NameAttrList* encapsulating_function,
222       FunctionLibraryRuntime* lib_runtime,
223       std::vector<StackFrameView>* stack_trace,
224       UncompilableNodesMap* uncompilable_nodes) const;
225 
IsStackOp(const Node & node)226   bool IsStackOp(const Node& node) const {
227     const XlaResourceOpInfo* op_info =
228         GetResourceOpInfoForOp(node.type_string());
229     return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
230   }
231 
IsTensorArrayOp(const Node & node)232   bool IsTensorArrayOp(const Node& node) const {
233     const XlaResourceOpInfo* op_info =
234         GetResourceOpInfoForOp(node.type_string());
235     return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
236   }
237 
IsAssertOrCheckNumerics(absl::string_view op_name)238   bool IsAssertOrCheckNumerics(absl::string_view op_name) const {
239     return op_name == "Assert" || op_name == "CheckNumerics";
240   }
241 
IsStatefulRandomOp(absl::string_view op_name)242   bool IsStatefulRandomOp(absl::string_view op_name) const {
243     return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
244            op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
245            op_name == "TruncatedNormal" || op_name == "Multinomial";
246   }
247 
OpProducesOrConsumesVariant(const Node & node)248   bool OpProducesOrConsumesVariant(const Node& node) const {
249     auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
250     return absl::c_any_of(node.input_types(), is_variant) ||
251            absl::c_any_of(node.output_types(), is_variant);
252   }
253 
254   bool HasXLAKernel(const Node& node,
255                     string* uncompilable_reason = nullptr) const;
256 
257   static void MaybeMarkUncompilableNode(
258       const absl::string_view reason,
259       const std::vector<StackFrameView>& stack_trace,
260       NameAttrList* encapsulating_function,
261       UncompilableNodesMap* uncompilable_nodes_map);
262 
263   // Make sure we don't recurse infinitely on recursive functions.
264   const size_t kMaxRecursionDepth = 50;
265 
266   const OperationFilter op_filter_;
267   const DeviceType jit_device_type_;
268 };
269 
270 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
271     const XlaOpRegistry::DeviceRegistration& registration);
272 
273 // Given a FunctionLibraryRuntime and a `function`, returns this function's body
274 // in `fbody` as well as the indices of its constant and resource arguments.
275 // `fbody` is owned by `flr`.
276 // `constant_arg_indices` and `resource_arg_indices` should be empty vector.
277 // They are sorted in ascending order on this function's return.
278 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
279                                        const NameAttrList& function,
280                                        const FunctionBody** fbody,
281                                        std::vector<int>* constant_arg_indices,
282                                        std::vector<int>* resource_arg_indices);
283 
284 // Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
285 // set.
286 bool CanCreateXlaKernel(const NodeDef& node_def);
287 
288 // Returns memory types for the input.
289 // `constant_arg_indices` and `resource_arg_indices` are sorted arrays of
290 // indices corresponding to constant and resource arguments respectively.
291 //
292 // One might wonder, about the case where a compile-time constant argument
293 // (which must be in host memory) is also used as an input into an op,
294 // e.g. `Add`, that expects its inputs in device memory. Here is how it
295 // works now.
296 // First, what do we mean by "op expects an input in XYZ memory"?
297 // There are two types of "ops" here: the tf2xla kernel and the HLO
298 // computation it builds. The tf2xla kernel needs to retrieve the actual
299 // numeric value of the compile-time constant tensors, so it really expects
300 // them to be on in host memory. However, for other inputs, it refers to them
301 // using xla::ComputationDataHandle, which is just a symbolic handle that
302 // xla::ComputationBuilder assigns. How does this handle gets assigned for
303 // constant arguments? Even constant arguments get an _Arg node in the graph
304 // instantiated for Function compilation. The tf2xla kernel for constant _Arg
305 // nodes takes the constant value, converts it to XlaLiteral, and feeds it
306 // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
307 // constant XlaLiteral is included in the HLO graph, and subsequently, in
308 // the actual executable, which is copied to the device before being
309 // executed. Thus, when this executable runs, the constant is available in
310 // device memory.
311 tensorflow::MemoryTypeVector GetInputMemoryTypes(
312     const tensorflow::FunctionBody* fbody,
313     absl::Span<int const> constant_arg_indices,
314     absl::Span<int const> resource_arg_indices);
315 
316 // Returns output memory types.
317 //
318 // XlaLaunch kernel keeps all outputs (including constants, which it copies),
319 // in device memory except for resources.
320 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
321     const tensorflow::FunctionBody* fbody);
322 
323 // Check whether graph can trigger XLA compilation.
324 bool CanTriggerXlaCompilation(const GraphDef& graph);
325 
326 }  // namespace tensorflow
327 
328 #endif  // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
329