1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <iterator>
22 #include <limits>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/compiler/jit/defs.h"
35 #include "tensorflow/compiler/jit/device_util.h"
36 #include "tensorflow/compiler/jit/flags.h"
37 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
38 #include "tensorflow/compiler/jit/xla_activity.pb.h"
39 #include "tensorflow/compiler/jit/xla_activity_listener.h"
40 #include "tensorflow/compiler/jit/xla_cluster_util.h"
41 #include "tensorflow/compiler/tf2xla/const_analysis.h"
42 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
43 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
44 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/union_find.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/common_runtime/graph_constructor.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/bounds_check.h"
52 #include "tensorflow/core/framework/graph_def_util.h"
53 #include "tensorflow/core/framework/memory_types.h"
54 #include "tensorflow/core/framework/node_def.pb.h"
55 #include "tensorflow/core/framework/op_kernel.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/graph/algorithm.h"
58 #include "tensorflow/core/graph/control_flow.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/lib/strings/stringprintf.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/dump_graph.h"
63
64 namespace tensorflow {
65
66 namespace {
67
68 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
69
HasResourceInput(const Node & node)70 bool HasResourceInput(const Node& node) {
71 return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
72 }
73
LogNotCompilable(const Node & node,absl::string_view reason="")74 void LogNotCompilable(const Node& node, absl::string_view reason = "") {
75 VLOG(3) << "Found uncompilable node " << node.name() << " (op "
76 << node.type_string() << ")" << (reason.empty() ? "" : ": ")
77 << reason;
78 }
79
IsInOutsideCompilationCluster(const Node & n)80 bool IsInOutsideCompilationCluster(const Node& n) {
81 return n.attrs().Find(kXlaOutsideCompilationAttr) != nullptr;
82 }
83
MakeCallNodeFromAttribute(const Node & node,const std::string & attr_name,NodeDef * node_def)84 Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
85 NodeDef* node_def) {
86 const NameAttrList* name_attr;
87 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr));
88 node_def->set_op(name_attr->name());
89 *(node_def->mutable_attr()) = name_attr->attr();
90 return Status::OK();
91 }
92
MakeCallNodesFromAttribute(const Node & node,absl::string_view attr_name,absl::string_view call_name)93 StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
94 const Node& node, absl::string_view attr_name,
95 absl::string_view call_name) {
96 std::vector<NameAttrList> attr_lists;
97 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
98
99 std::vector<NodeDef> out;
100 for (int i = 0; i < attr_lists.size(); i++) {
101 out.emplace_back();
102 NodeDef& inserted = out.back();
103 inserted.set_name(absl::StrCat(call_name, "_", i));
104 inserted.set_op(attr_lists[i].name());
105 *inserted.mutable_attr() = attr_lists[i].attr();
106 }
107 return out;
108 }
109
110 // Utility which searches for values in a sorted list by scanning over it once.
111 // No matter how many times ScanForValue is called, the list is scanned at most
112 // once. However, if a call to ScanForValue skips over a value, that value is
113 // not revisited in future calls to ScanForValue, so callers must take
114 // care to order their calls.
115 //
116 // Useful for merging multiple sorted lists in O(n) time.
117 class SinglePassSearch {
118 public:
119 // Creates a SinglePassSearch object that can be used to search in `values`.
120 // Does not take ownership of `values`. `values` must outlive this.
121 // `values` must be sorted.
SinglePassSearch(absl::Span<int const> values)122 explicit SinglePassSearch(absl::Span<int const> values)
123 : current_index_(0), values_(values) {}
124
125 // Scans forward in the vector looking for "value", updating the internal
126 // position in to the vector.
127 // Returns true iff the vector contains the given value at or after current
128 // position.
129 // Not thread-safe.
ScanForValue(int value)130 bool ScanForValue(int value) {
131 while (current_index_ < values_.size() &&
132 values_[current_index_] <= value) {
133 if (values_[current_index_] == value) {
134 current_index_++;
135 return true;
136 }
137 current_index_++;
138 }
139 return false;
140 }
141
142 private:
143 int current_index_;
144 const absl::Span<int const> values_;
145 };
146
147 } // anonymous namespace
148
149 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const Node & node,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const150 RecursiveCompilabilityChecker::FindUncompilableNodes(
151 const Node& node, FunctionLibraryRuntime* lib_runtime,
152 const std::vector<RecursiveCompilabilityChecker::StackFrame>*
153 node_stack_trace) const {
154 std::vector<StackFrameView> stack_trace;
155 // If `node_stack_trace` is provided, that means `node` is inside
156 // a function body, and therefore, arg nodes and retval nodes are
157 // not considered uncompilable.
158 if (node_stack_trace != nullptr) {
159 for (const auto& frame : *node_stack_trace) {
160 stack_trace.emplace_back(
161 StackFrameView{frame.name, frame.function_name, frame.stack_trace});
162 }
163 }
164 stack_trace.emplace_back(
165 StackFrameView{node.name(), "", node.GetStackTrace()});
166
167 RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
168 IsCompilableNode(node, lib_runtime, &stack_trace,
169 /*encapsulating_function=*/nullptr, &uncompilable_nodes);
170 return uncompilable_nodes;
171 }
172
HasXLAKernel(const Node & node,string * uncompilable_reason) const173 bool RecursiveCompilabilityChecker::HasXLAKernel(
174 const Node& node, string* uncompilable_reason) const {
175 // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
176 // is really a kind of function call and will be handled by
177 // IsCompilableCall().
178 if (node.type_string() == "SymbolicGradient") {
179 *uncompilable_reason =
180 "SymbolicGradient should be handled by IsCompilableCall().";
181 return false;
182 }
183
184 if (node.type_string() == "Const") {
185 const AttrValue* attr = node.attrs().Find("dtype");
186 if (!op_filter_.allow_string_consts && attr != nullptr &&
187 attr->type() == DT_STRING) {
188 *uncompilable_reason =
189 "Const op with type DT_STRING is not supported by XLA.";
190 return false;
191 }
192 }
193
194 // XLA does not offer guaranteed aliasing between the input and output of the
195 // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
196 // such nodes out of XLA clusters.
197 if (HasForwardedRefInput(node)) {
198 VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
199 *uncompilable_reason = "Identity with unsafe cast.";
200 return false;
201 }
202
203 Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
204 if (!s.ok()) {
205 *uncompilable_reason = s.error_message();
206 return false;
207 }
208 return true;
209 }
210
211 // Tests whether 'if_node' is compilable. Every operator in the then_branch and
212 // else_branch functions must be compilable for 'if_node' to be compilable.
IsCompilableIf(const Node & if_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const213 bool RecursiveCompilabilityChecker::IsCompilableIf(
214 const Node& if_node, FunctionLibraryRuntime* lib_runtime,
215 std::vector<StackFrameView>* stack_trace,
216 NameAttrList* encapsulating_function,
217 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
218 const {
219 bool is_compilable = true;
220 is_compilable &= ExtractNodeDefAndCheckCompilability(
221 if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
222 stack_trace, uncompilable_nodes);
223 if (!uncompilable_nodes && !is_compilable) return is_compilable;
224
225 is_compilable &= ExtractNodeDefAndCheckCompilability(
226 if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
227 stack_trace, uncompilable_nodes);
228
229 return is_compilable;
230 }
231
IsCompilableCase(const Node & case_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const232 bool RecursiveCompilabilityChecker::IsCompilableCase(
233 const Node& case_node, FunctionLibraryRuntime* lib_runtime,
234 std::vector<StackFrameView>* stack_trace,
235 NameAttrList* encapsulating_function,
236 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
237 const {
238 StatusOr<std::vector<NodeDef>> calls =
239 MakeCallNodesFromAttribute(case_node, "branches", "branch");
240 if (!calls.ok()) {
241 VLOG(2) << "Rejecting node " << case_node.name() << ": "
242 << "missing attribute 'branches'";
243 return false;
244 }
245
246 bool is_compilable = true;
247
248 for (const NodeDef& call : *calls) {
249 is_compilable &=
250 IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
251 uncompilable_nodes);
252 }
253 return is_compilable;
254 }
255
256 // Tests whether 'while_node' is a completely compilable loop.
257 // Every operator in the condition and body functions must be compilable for a
258 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const259 bool RecursiveCompilabilityChecker::IsCompilableWhile(
260 const Node& while_node, FunctionLibraryRuntime* lib_runtime,
261 std::vector<StackFrameView>* stack_trace,
262 NameAttrList* encapsulating_function,
263 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
264 const {
265 bool is_compilable = true;
266 is_compilable &= ExtractNodeDefAndCheckCompilability(
267 while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
268 stack_trace, uncompilable_nodes);
269
270 if (!uncompilable_nodes && !is_compilable) return is_compilable;
271
272 is_compilable &= ExtractNodeDefAndCheckCompilability(
273 while_node, "body", "while_body", encapsulating_function, lib_runtime,
274 stack_trace, uncompilable_nodes);
275
276 return is_compilable;
277 }
278
ExtractNodeDefAndCheckCompilability(const Node & node,const std::string & attr_name,const std::string & call_name,NameAttrList * encapsulating_function,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const279 bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
280 const Node& node, const std::string& attr_name,
281 const std::string& call_name, NameAttrList* encapsulating_function,
282 FunctionLibraryRuntime* lib_runtime,
283 std::vector<StackFrameView>* stack_trace,
284 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
285 const {
286 NodeDef call;
287 call.set_name(call_name);
288 if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
289 const auto uncompilable_reason = absl::StrCat(
290 "missing '", attr_name, "' attribute from node", node.name());
291 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
292 encapsulating_function, uncompilable_nodes);
293 VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
294 << ".";
295 return false;
296 }
297 if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
298 uncompilable_nodes)) {
299 VLOG(2) << "Rejecting node " << node.name()
300 << ": can't compile : " << call.op();
301 return false;
302 }
303 return true;
304 }
305
306 // Tests whether 'call_def' is a call to a completely compilable function.
307 // Every operator in the function must be compilable for a function to be
308 // compilable.
IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const309 bool RecursiveCompilabilityChecker::IsCompilableCall(
310 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
311 std::vector<StackFrameView>* stack_trace,
312 NameAttrList* encapsulating_function,
313 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
314 const {
315 if (stack_trace->size() > kMaxRecursionDepth) {
316 std::string uncompilable_reason = "function depth limit exceeded";
317 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
318 encapsulating_function, uncompilable_nodes);
319 VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
320 << ".";
321 return false;
322 }
323
324 FunctionLibraryRuntime::Handle handle;
325 Status s;
326 NameAttrList function;
327 s = NameAndAttrsFromFunctionCall(call_def, &function);
328 if (s.ok()) {
329 s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
330 &handle);
331 }
332 if (!s.ok()) {
333 std::string uncompilable_reason =
334 absl::StrCat("could not instantiate call: '", function.name(), "'");
335 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
336 encapsulating_function, uncompilable_nodes);
337 VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
338 << uncompilable_reason << " : " << s;
339 return false;
340 }
341
342 auto release_handle_on_return = gtl::MakeCleanup(
343 [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
344 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
345 bool is_compilable = true;
346 for (const Node* node : fbody->graph->op_nodes()) {
347 stack_trace->emplace_back(
348 StackFrameView{node->name(), function.name(), node->GetStackTrace()});
349 is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
350 &function, uncompilable_nodes);
351 stack_trace->pop_back();
352 if (!uncompilable_nodes && !is_compilable) return is_compilable;
353 }
354
355 return is_compilable;
356 }
357
OpIsInaccurate(const Node & node) const358 bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
359 // b/127344411: SelfAdjointEigV2 and Svd precision issues.
360 return node.type_string() == "SelfAdjointEigV2" ||
361 node.type_string() == "Svd";
362 }
363
OpIsSlow(const Node & node) const364 bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
365 // b/128001705: SelfAdjointEigV2 and Svd performance issues.
366 // b/135640736: MatrixInverse performance issues.
367 // b/111271662: MatrixSolve performance issues.
368 // https://github.com/tensorflow/tensorflow/pull/31012:
369 // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
370 // create convolutions too large for CuDNN to handle.
371 return node.type_string() == "SelfAdjointEigV2" ||
372 node.type_string() == "Svd" || node.type_string() == "Qr" ||
373 node.type_string() == "MatrixInverse" ||
374 node.type_string() == "MatrixSolve" ||
375 node.type_string() == "ResizeNearestNeighbor" ||
376 node.type_string() == "ResizeBilinear" ||
377 node.type_string() == "ResizeBilinearGrad";
378 }
379
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const380 bool RecursiveCompilabilityChecker::IsCompilableNode(
381 const Node& node, FunctionLibraryRuntime* lib_runtime,
382 std::vector<StackFrameView>* stack_trace,
383 NameAttrList* encapsulating_function,
384 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
385 const {
386 auto stack_depth = stack_trace->size();
387
388 if (op_filter_.allow_outside_compiled && IsInOutsideCompilationCluster(node))
389 return true;
390
391 if (node.IsSource() || node.IsSink()) {
392 absl::string_view uncompilable_reason = "source or sink node";
393 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
394 encapsulating_function, uncompilable_nodes);
395 LogNotCompilable(node, uncompilable_reason);
396 return false;
397 }
398
399 // _Arg nodes in a top-level function represent feeds and _Retval nodes in a
400 // top-level function represent fetches.
401 if (stack_depth == 1 &&
402 (node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
403 absl::string_view uncompilable_reason = "top level _Arg or _Retval";
404 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
405 encapsulating_function, uncompilable_nodes);
406 LogNotCompilable(node, uncompilable_reason);
407 return false;
408 }
409
410 if (node.attrs().Find("_scoped_allocator") ||
411 node.attrs().Find("_forward_from")) {
412 // TODO(b/128858118): XLA does not support _scoped_allocator and
413 // _forward_from.
414 absl::string_view uncompilable_reason =
415 "_scoped_allocator or _forward_from attribute";
416 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
417 encapsulating_function, uncompilable_nodes);
418 LogNotCompilable(node, uncompilable_reason);
419 return false;
420 }
421
422 string uncompilable_reason;
423 if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
424 if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
425 encapsulating_function, uncompilable_nodes)) {
426 LogNotCompilable(node, "unsupported function");
427 return false;
428 }
429 } else if (!HasXLAKernel(node, &uncompilable_reason)) {
430 MaybeMarkUncompilableNode(
431 absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
432 encapsulating_function, uncompilable_nodes);
433 LogNotCompilable(node, uncompilable_reason);
434 return false;
435 }
436
437 if (node.IsWhileNode() &&
438 !IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
439 uncompilable_nodes)) {
440 LogNotCompilable(node, "unsupported while");
441 return false;
442 }
443
444 if (node.IsIfNode() &&
445 !IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
446 uncompilable_nodes)) {
447 LogNotCompilable(node, "unsupported if");
448 return false;
449 }
450
451 if (op_filter_.require_always_compilable && node.IsCaseNode() &&
452 !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
453 uncompilable_nodes)) {
454 LogNotCompilable(node, "unsupported case");
455 return false;
456 }
457
458 if (!op_filter_.allow_stateful_rng_ops &&
459 IsStatefulRandomOp(node.type_string())) {
460 absl::string_view uncompilable_reason = "stateful random op";
461 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
462 encapsulating_function, uncompilable_nodes);
463 LogNotCompilable(node, uncompilable_reason);
464 return false;
465 }
466
467 if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
468 absl::string_view uncompilable_reason = "not allowed control trigger";
469 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
470 encapsulating_function, uncompilable_nodes);
471 LogNotCompilable(node, uncompilable_reason);
472 return false;
473 }
474
475 if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
476 IsAssertOrCheckNumerics(node.type_string())) {
477 absl::string_view uncompilable_reason = "Assert or CheckNumerics";
478 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
479 encapsulating_function, uncompilable_nodes);
480 LogNotCompilable(node, uncompilable_reason);
481 return false;
482 }
483
484 if (!op_filter_.allow_collective_reduce_v2 &&
485 node.type_string() == "CollectiveReduceV2") {
486 absl::string_view uncompilable_reason = "Collective op";
487 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
488 encapsulating_function, uncompilable_nodes);
489 LogNotCompilable(node, uncompilable_reason);
490 return false;
491 }
492
493 if (!op_filter_.allow_ops_producing_or_consuming_variant &&
494 OpProducesOrConsumesVariant(node)) {
495 absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
496 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
497 encapsulating_function, uncompilable_nodes);
498 LogNotCompilable(node, uncompilable_reason);
499 return false;
500 }
501
502 if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
503 absl::string_view uncompilable_reason = "Stack op";
504 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
505 encapsulating_function, uncompilable_nodes);
506 LogNotCompilable(node, uncompilable_reason);
507 return false;
508 }
509
510 if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
511 absl::string_view uncompilable_reason = "TensorArray op";
512 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
513 encapsulating_function, uncompilable_nodes);
514 LogNotCompilable(node, uncompilable_reason);
515 return false;
516 }
517
518 if (!op_filter_.allow_resource_ops_in_called_functions && stack_depth > 1 &&
519 HasResourceInput(node)) {
520 absl::string_view uncompilable_reason =
521 "resource variable op in called function";
522 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
523 encapsulating_function, uncompilable_nodes);
524 LogNotCompilable(node, uncompilable_reason);
525 return false;
526 }
527
528 if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
529 absl::string_view uncompilable_reason =
530 "operation with numerical accuracy issues";
531 BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
532 node.DebugString())
533 .IgnoreError();
534 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
535 encapsulating_function, uncompilable_nodes);
536 LogNotCompilable(node, uncompilable_reason);
537 return false;
538 }
539
540 if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
541 absl::string_view uncompilable_reason = "slow operation";
542 BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
543 node.DebugString())
544 .IgnoreError();
545 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
546 encapsulating_function, uncompilable_nodes);
547 LogNotCompilable(node, uncompilable_reason);
548 return false;
549 }
550
551 return true;
552 }
553
CreateOperationFilter(const XlaOpRegistry::DeviceRegistration & registration)554 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
555 const XlaOpRegistry::DeviceRegistration& registration) {
556 RecursiveCompilabilityChecker::OperationFilter op_filter;
557 op_filter.allow_resource_ops_in_called_functions =
558 registration.cluster_resource_variable_ops_unsafely;
559 op_filter.allow_stack_ops = registration.cluster_stack_ops;
560 op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
561 op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
562 op_filter.allow_control_trigger = registration.cluster_control_trigger;
563 op_filter.allow_eliding_assert_and_checknumerics_ops =
564 registration.elide_assert_and_checknumerics;
565 op_filter.allow_ops_producing_or_consuming_variant =
566 registration.cluster_variant_ops;
567 op_filter.allow_slow_ops = registration.cluster_slow_ops;
568 op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
569 return op_filter;
570 }
571
MaybeMarkUncompilableNode(const absl::string_view reason,const std::vector<StackFrameView> & stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes)572 /*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
573 const absl::string_view reason,
574 const std::vector<StackFrameView>& stack_trace,
575 NameAttrList* encapsulating_function,
576 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
577 if (!uncompilable_nodes) return;
578
579 UncompilableNodeInfo node_info;
580 node_info.uncompilable_reason = std::string(reason);
581 absl::c_transform(stack_trace, std::back_inserter(node_info.stack_trace),
582 [](const StackFrameView& stack_element) {
583 return StackFrame{
584 std::string(stack_element.name),
585 std::string(stack_element.function_name),
586 stack_element.stack_trace};
587 });
588
589 node_info.name = std::string(stack_trace.back().name);
590 auto function =
591 encapsulating_function ? *encapsulating_function : NameAttrList();
592 auto function_identifier = function.ShortDebugString();
593
594 auto it = uncompilable_nodes->find(function_identifier);
595 if (it == uncompilable_nodes->end()) {
596 std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
597 uncompilable_node_info{std::move(node_info)};
598 uncompilable_nodes->emplace(
599 std::move(function_identifier),
600 std::make_pair(function, std::move(uncompilable_node_info)));
601 } else {
602 it->second.second.emplace_back(std::move(node_info));
603 }
604 }
605
606 // Returns `true` iff node has a given `attr` set to `true`. Returns `false`
607 // both for the missing attr, and the attr set to `false`.
HasBoolAttr(const NodeDef & node,const char * attr)608 static bool HasBoolAttr(const NodeDef& node, const char* attr) {
609 const auto& it = node.attr().find(attr);
610 return it != node.attr().end() && it->second.b();
611 }
612
CanCreateXlaKernel(const NodeDef & node_def)613 bool CanCreateXlaKernel(const NodeDef& node_def) {
614 return HasBoolAttr(node_def, kXlaMustCompileAttr);
615 }
616
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NameAttrList & function,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)617 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
618 const NameAttrList& function,
619 const FunctionBody** fbody,
620 std::vector<int>* constant_arg_indices,
621 std::vector<int>* resource_arg_indices) {
622 FunctionLibraryRuntime::Handle handle;
623 TF_RETURN_IF_ERROR(
624 flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
625 *fbody = flr->GetFunctionBody(handle);
626 CHECK(*fbody); // Can't be nullptr since we just instantiated it.
627 const DataTypeVector& arg_types = (*fbody)->arg_types;
628 std::vector<bool> const_args(arg_types.size());
629 // If we can't analyze the const args. Bail out.
630 TF_RETURN_IF_ERROR(
631 BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
632 /*compile_time_const_nodes=*/nullptr, flr));
633
634 for (size_t i = 0; i < const_args.size(); ++i) {
635 if (const_args[i]) {
636 constant_arg_indices->push_back(i);
637 }
638 }
639
640 // There can be hundreds of resource variables. Reserve the space for them.
641 // We don't reserve for constants above as they are usually few.
642 resource_arg_indices->reserve(arg_types.size());
643 for (size_t i = 0; i < arg_types.size(); ++i) {
644 if (arg_types[i] == DT_RESOURCE) {
645 resource_arg_indices->push_back(i);
646 }
647 }
648
649 return Status::OK();
650 }
651
GetInputMemoryTypes(const tensorflow::FunctionBody * fbody,absl::Span<int const> constant_arg_indices,absl::Span<int const> resource_arg_indices)652 tensorflow::MemoryTypeVector GetInputMemoryTypes(
653 const tensorflow::FunctionBody* fbody,
654 absl::Span<int const> constant_arg_indices,
655 absl::Span<int const> resource_arg_indices) {
656 // Set input and output memory types.
657 tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
658 tensorflow::DEVICE_MEMORY);
659 // These indices are used only for optimization purposes. They allow us
660 // to loop over constant_arg_indices and resource_arg_indices only once
661 // while iterating over all the function arguments checking if it is a
662 // resource or a constant.
663 // The reason we optimized this code is because functions can have a lot of
664 // captured arguments. For example, the backward pass of ResNet50 takes in all
665 // 214 variables and a similar number of activations.
666 SinglePassSearch constants_search(constant_arg_indices);
667 SinglePassSearch resources_search(resource_arg_indices);
668 for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
669 if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
670 // Compile-time constants and resource handles are expected to be in
671 // host memory.
672 input_memory_types[i] = tensorflow::HOST_MEMORY;
673 }
674 }
675 return input_memory_types;
676 }
677
GetOutputMemoryTypes(const tensorflow::FunctionBody * fbody)678 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
679 const tensorflow::FunctionBody* fbody) {
680 tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
681 tensorflow::DEVICE_MEMORY);
682 for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
683 if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
684 output_memory_types[i] = tensorflow::HOST_MEMORY;
685 }
686 }
687 return output_memory_types;
688 }
689
690 static auto const ops_triggering_xla_compilation =
691 new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
692 "XlaConv",
693 "XlaConvV2",
694 "XlaDequantize",
695 "XlaDot",
696 "XlaDotV2",
697 "XlaDynamicSlice",
698 "XlaDynamicUpdateSlice",
699 "XlaEinsum",
700 "XlaGather",
701 "XlaIf",
702 "XlaKeyValueSort",
703 "XlaPad",
704 "XlaRecv",
705 "XlaReduce",
706 "XlaReduceWindow",
707 "XlaReplicaId",
708 "XlaRngBitGenerator",
709 "XlaScatter",
710 "XlaSelectAndScatter",
711 "XlaSelfAdjointEig",
712 "XlaSend",
713 "XlaSharding",
714 "XlaSort",
715 "XlaSpmdFullToShardShape",
716 "XlaSpmdShardToFullShape",
717 "XlaSvd",
718 "XlaVariadicReduceV2",
719 "XlaVariadicSort",
720 "XlaWhile"};
721
NodeCanTriggerXlaCompilation(const NodeDef & node)722 static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
723 return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
724 HasBoolAttr(node, kXlaMustCompileAttr) ||
725 HasBoolAttr(node, kXlaCompileAttr) ||
726 HasBoolAttr(node, kXlaScopeAttr) ||
727 HasBoolAttr(node, kXlaInternalScopeAttr) ||
728 ops_triggering_xla_compilation->count(node.op());
729 }
730
CanTriggerXlaCompilation(const GraphDef & graph)731 bool CanTriggerXlaCompilation(const GraphDef& graph) {
732 for (const FunctionDef& function : graph.library().function()) {
733 for (const NodeDef& node : function.node_def()) {
734 if (NodeCanTriggerXlaCompilation(node)) {
735 return true;
736 }
737 }
738 }
739
740 for (const NodeDef& node : graph.node()) {
741 if (NodeCanTriggerXlaCompilation(node)) {
742 return true;
743 }
744 }
745
746 return false;
747 }
748
749 } // namespace tensorflow
750