• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <iterator>
22 #include <limits>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/compiler/jit/defs.h"
35 #include "tensorflow/compiler/jit/device_util.h"
36 #include "tensorflow/compiler/jit/flags.h"
37 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
38 #include "tensorflow/compiler/jit/xla_activity.pb.h"
39 #include "tensorflow/compiler/jit/xla_activity_listener.h"
40 #include "tensorflow/compiler/jit/xla_cluster_util.h"
41 #include "tensorflow/compiler/tf2xla/const_analysis.h"
42 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
43 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
44 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/union_find.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/common_runtime/graph_constructor.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/bounds_check.h"
52 #include "tensorflow/core/framework/graph_def_util.h"
53 #include "tensorflow/core/framework/memory_types.h"
54 #include "tensorflow/core/framework/node_def.pb.h"
55 #include "tensorflow/core/framework/op_kernel.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/graph/algorithm.h"
58 #include "tensorflow/core/graph/control_flow.h"
59 #include "tensorflow/core/lib/gtl/cleanup.h"
60 #include "tensorflow/core/lib/strings/stringprintf.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/dump_graph.h"
63 
64 namespace tensorflow {
65 
66 namespace {
67 
68 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
69 
HasResourceInput(const Node & node)70 bool HasResourceInput(const Node& node) {
71   return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
72 }
73 
LogNotCompilable(const Node & node,absl::string_view reason="")74 void LogNotCompilable(const Node& node, absl::string_view reason = "") {
75   VLOG(3) << "Found uncompilable node " << node.name() << " (op "
76           << node.type_string() << ")" << (reason.empty() ? "" : ": ")
77           << reason;
78 }
79 
IsInOutsideCompilationCluster(const Node & n)80 bool IsInOutsideCompilationCluster(const Node& n) {
81   return n.attrs().Find(kXlaOutsideCompilationAttr) != nullptr;
82 }
83 
MakeCallNodeFromAttribute(const Node & node,const std::string & attr_name,NodeDef * node_def)84 Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
85                                  NodeDef* node_def) {
86   const NameAttrList* name_attr;
87   TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr));
88   node_def->set_op(name_attr->name());
89   *(node_def->mutable_attr()) = name_attr->attr();
90   return Status::OK();
91 }
92 
MakeCallNodesFromAttribute(const Node & node,absl::string_view attr_name,absl::string_view call_name)93 StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
94     const Node& node, absl::string_view attr_name,
95     absl::string_view call_name) {
96   std::vector<NameAttrList> attr_lists;
97   TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
98 
99   std::vector<NodeDef> out;
100   for (int i = 0; i < attr_lists.size(); i++) {
101     out.emplace_back();
102     NodeDef& inserted = out.back();
103     inserted.set_name(absl::StrCat(call_name, "_", i));
104     inserted.set_op(attr_lists[i].name());
105     *inserted.mutable_attr() = attr_lists[i].attr();
106   }
107   return out;
108 }
109 
110 // Utility which searches for values in a sorted list by scanning over it once.
111 // No matter how many times ScanForValue is called, the list is scanned at most
112 // once. However, if a call to ScanForValue skips over a value, that value is
113 // not revisited in future calls to ScanForValue, so callers must take
114 // care to order their calls.
115 //
116 // Useful for merging multiple sorted lists in O(n) time.
117 class SinglePassSearch {
118  public:
119   // Creates a SinglePassSearch object that can be used to search in `values`.
120   // Does not take ownership of `values`. `values` must outlive this.
121   // `values` must be sorted.
SinglePassSearch(absl::Span<int const> values)122   explicit SinglePassSearch(absl::Span<int const> values)
123       : current_index_(0), values_(values) {}
124 
125   // Scans forward in the vector looking for "value", updating the internal
126   // position in to the vector.
127   // Returns true iff the vector contains the given value at or after current
128   // position.
129   // Not thread-safe.
ScanForValue(int value)130   bool ScanForValue(int value) {
131     while (current_index_ < values_.size() &&
132            values_[current_index_] <= value) {
133       if (values_[current_index_] == value) {
134         current_index_++;
135         return true;
136       }
137       current_index_++;
138     }
139     return false;
140   }
141 
142  private:
143   int current_index_;
144   const absl::Span<int const> values_;
145 };
146 
147 }  // anonymous namespace
148 
149 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const Node & node,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const150 RecursiveCompilabilityChecker::FindUncompilableNodes(
151     const Node& node, FunctionLibraryRuntime* lib_runtime,
152     const std::vector<RecursiveCompilabilityChecker::StackFrame>*
153         node_stack_trace) const {
154   std::vector<StackFrameView> stack_trace;
155   // If `node_stack_trace` is provided, that means `node` is inside
156   // a function body, and therefore, arg nodes and retval nodes are
157   // not considered uncompilable.
158   if (node_stack_trace != nullptr) {
159     for (const auto& frame : *node_stack_trace) {
160       stack_trace.emplace_back(
161           StackFrameView{frame.name, frame.function_name, frame.stack_trace});
162     }
163   }
164   stack_trace.emplace_back(
165       StackFrameView{node.name(), "", node.GetStackTrace()});
166 
167   RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
168   IsCompilableNode(node, lib_runtime, &stack_trace,
169                    /*encapsulating_function=*/nullptr, &uncompilable_nodes);
170   return uncompilable_nodes;
171 }
172 
HasXLAKernel(const Node & node,string * uncompilable_reason) const173 bool RecursiveCompilabilityChecker::HasXLAKernel(
174     const Node& node, string* uncompilable_reason) const {
175   // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
176   // is really a kind of function call and will be handled by
177   // IsCompilableCall().
178   if (node.type_string() == "SymbolicGradient") {
179     *uncompilable_reason =
180         "SymbolicGradient should be handled by IsCompilableCall().";
181     return false;
182   }
183 
184   if (node.type_string() == "Const") {
185     const AttrValue* attr = node.attrs().Find("dtype");
186     if (!op_filter_.allow_string_consts && attr != nullptr &&
187         attr->type() == DT_STRING) {
188       *uncompilable_reason =
189           "Const op with type DT_STRING is not supported by XLA.";
190       return false;
191     }
192   }
193 
194   // XLA does not offer guaranteed aliasing between the input and output of the
195   // XLA cluster so it can't implement the forward-tensor-ref semantic.  Leave
196   // such nodes out of XLA clusters.
197   if (HasForwardedRefInput(node)) {
198     VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
199     *uncompilable_reason = "Identity with unsafe cast.";
200     return false;
201   }
202 
203   Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
204   if (!s.ok()) {
205     *uncompilable_reason = s.error_message();
206     return false;
207   }
208   return true;
209 }
210 
211 // Tests whether 'if_node' is compilable. Every operator in the then_branch and
212 // else_branch functions must be compilable for 'if_node' to be compilable.
IsCompilableIf(const Node & if_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const213 bool RecursiveCompilabilityChecker::IsCompilableIf(
214     const Node& if_node, FunctionLibraryRuntime* lib_runtime,
215     std::vector<StackFrameView>* stack_trace,
216     NameAttrList* encapsulating_function,
217     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
218     const {
219   bool is_compilable = true;
220   is_compilable &= ExtractNodeDefAndCheckCompilability(
221       if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
222       stack_trace, uncompilable_nodes);
223   if (!uncompilable_nodes && !is_compilable) return is_compilable;
224 
225   is_compilable &= ExtractNodeDefAndCheckCompilability(
226       if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
227       stack_trace, uncompilable_nodes);
228 
229   return is_compilable;
230 }
231 
IsCompilableCase(const Node & case_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const232 bool RecursiveCompilabilityChecker::IsCompilableCase(
233     const Node& case_node, FunctionLibraryRuntime* lib_runtime,
234     std::vector<StackFrameView>* stack_trace,
235     NameAttrList* encapsulating_function,
236     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
237     const {
238   StatusOr<std::vector<NodeDef>> calls =
239       MakeCallNodesFromAttribute(case_node, "branches", "branch");
240   if (!calls.ok()) {
241     VLOG(2) << "Rejecting node " << case_node.name() << ": "
242             << "missing attribute 'branches'";
243     return false;
244   }
245 
246   bool is_compilable = true;
247 
248   for (const NodeDef& call : *calls) {
249     is_compilable &=
250         IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
251                          uncompilable_nodes);
252   }
253   return is_compilable;
254 }
255 
256 // Tests whether 'while_node' is a completely compilable loop.
257 // Every operator in the condition and body functions must be compilable for a
258 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const259 bool RecursiveCompilabilityChecker::IsCompilableWhile(
260     const Node& while_node, FunctionLibraryRuntime* lib_runtime,
261     std::vector<StackFrameView>* stack_trace,
262     NameAttrList* encapsulating_function,
263     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
264     const {
265   bool is_compilable = true;
266   is_compilable &= ExtractNodeDefAndCheckCompilability(
267       while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
268       stack_trace, uncompilable_nodes);
269 
270   if (!uncompilable_nodes && !is_compilable) return is_compilable;
271 
272   is_compilable &= ExtractNodeDefAndCheckCompilability(
273       while_node, "body", "while_body", encapsulating_function, lib_runtime,
274       stack_trace, uncompilable_nodes);
275 
276   return is_compilable;
277 }
278 
ExtractNodeDefAndCheckCompilability(const Node & node,const std::string & attr_name,const std::string & call_name,NameAttrList * encapsulating_function,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const279 bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
280     const Node& node, const std::string& attr_name,
281     const std::string& call_name, NameAttrList* encapsulating_function,
282     FunctionLibraryRuntime* lib_runtime,
283     std::vector<StackFrameView>* stack_trace,
284     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
285     const {
286   NodeDef call;
287   call.set_name(call_name);
288   if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
289     const auto uncompilable_reason = absl::StrCat(
290         "missing '", attr_name, "' attribute from node", node.name());
291     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
292                               encapsulating_function, uncompilable_nodes);
293     VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
294             << ".";
295     return false;
296   }
297   if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
298                         uncompilable_nodes)) {
299     VLOG(2) << "Rejecting node " << node.name()
300             << ": can't compile : " << call.op();
301     return false;
302   }
303   return true;
304 }
305 
306 // Tests whether 'call_def' is a call to a completely compilable function.
307 // Every operator in the function must be compilable for a function to be
308 // compilable.
IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const309 bool RecursiveCompilabilityChecker::IsCompilableCall(
310     const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
311     std::vector<StackFrameView>* stack_trace,
312     NameAttrList* encapsulating_function,
313     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
314     const {
315   if (stack_trace->size() > kMaxRecursionDepth) {
316     std::string uncompilable_reason = "function depth limit exceeded";
317     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
318                               encapsulating_function, uncompilable_nodes);
319     VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
320             << ".";
321     return false;
322   }
323 
324   FunctionLibraryRuntime::Handle handle;
325   Status s;
326   NameAttrList function;
327   s = NameAndAttrsFromFunctionCall(call_def, &function);
328   if (s.ok()) {
329     s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
330                                  &handle);
331   }
332   if (!s.ok()) {
333     std::string uncompilable_reason =
334         absl::StrCat("could not instantiate call: '", function.name(), "'");
335     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
336                               encapsulating_function, uncompilable_nodes);
337     VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
338             << uncompilable_reason << " : " << s;
339     return false;
340   }
341 
342   auto release_handle_on_return = gtl::MakeCleanup(
343       [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
344   const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
345   bool is_compilable = true;
346   for (const Node* node : fbody->graph->op_nodes()) {
347     stack_trace->emplace_back(
348         StackFrameView{node->name(), function.name(), node->GetStackTrace()});
349     is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
350                                       &function, uncompilable_nodes);
351     stack_trace->pop_back();
352     if (!uncompilable_nodes && !is_compilable) return is_compilable;
353   }
354 
355   return is_compilable;
356 }
357 
OpIsInaccurate(const Node & node) const358 bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
359   // b/127344411: SelfAdjointEigV2 and Svd precision issues.
360   return node.type_string() == "SelfAdjointEigV2" ||
361          node.type_string() == "Svd";
362 }
363 
OpIsSlow(const Node & node) const364 bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
365   // b/128001705: SelfAdjointEigV2 and Svd performance issues.
366   // b/135640736: MatrixInverse performance issues.
367   // b/111271662: MatrixSolve performance issues.
368   // https://github.com/tensorflow/tensorflow/pull/31012:
369   //    ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
370   //    create convolutions too large for CuDNN to handle.
371   return node.type_string() == "SelfAdjointEigV2" ||
372          node.type_string() == "Svd" || node.type_string() == "Qr" ||
373          node.type_string() == "MatrixInverse" ||
374          node.type_string() == "MatrixSolve" ||
375          node.type_string() == "ResizeNearestNeighbor" ||
376          node.type_string() == "ResizeBilinear" ||
377          node.type_string() == "ResizeBilinearGrad";
378 }
379 
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const380 bool RecursiveCompilabilityChecker::IsCompilableNode(
381     const Node& node, FunctionLibraryRuntime* lib_runtime,
382     std::vector<StackFrameView>* stack_trace,
383     NameAttrList* encapsulating_function,
384     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
385     const {
386   auto stack_depth = stack_trace->size();
387 
388   if (op_filter_.allow_outside_compiled && IsInOutsideCompilationCluster(node))
389     return true;
390 
391   if (node.IsSource() || node.IsSink()) {
392     absl::string_view uncompilable_reason = "source or sink node";
393     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
394                               encapsulating_function, uncompilable_nodes);
395     LogNotCompilable(node, uncompilable_reason);
396     return false;
397   }
398 
399   // _Arg nodes in a top-level function represent feeds and _Retval nodes in a
400   // top-level function represent fetches.
401   if (stack_depth == 1 &&
402       (node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
403     absl::string_view uncompilable_reason = "top level _Arg or _Retval";
404     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
405                               encapsulating_function, uncompilable_nodes);
406     LogNotCompilable(node, uncompilable_reason);
407     return false;
408   }
409 
410   if (node.attrs().Find("_scoped_allocator") ||
411       node.attrs().Find("_forward_from")) {
412     // TODO(b/128858118): XLA does not support _scoped_allocator and
413     // _forward_from.
414     absl::string_view uncompilable_reason =
415         "_scoped_allocator or _forward_from attribute";
416     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
417                               encapsulating_function, uncompilable_nodes);
418     LogNotCompilable(node, uncompilable_reason);
419     return false;
420   }
421 
422   string uncompilable_reason;
423   if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
424     if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
425                           encapsulating_function, uncompilable_nodes)) {
426       LogNotCompilable(node, "unsupported function");
427       return false;
428     }
429   } else if (!HasXLAKernel(node, &uncompilable_reason)) {
430     MaybeMarkUncompilableNode(
431         absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
432         encapsulating_function, uncompilable_nodes);
433     LogNotCompilable(node, uncompilable_reason);
434     return false;
435   }
436 
437   if (node.IsWhileNode() &&
438       !IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
439                          uncompilable_nodes)) {
440     LogNotCompilable(node, "unsupported while");
441     return false;
442   }
443 
444   if (node.IsIfNode() &&
445       !IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
446                       uncompilable_nodes)) {
447     LogNotCompilable(node, "unsupported if");
448     return false;
449   }
450 
451   if (op_filter_.require_always_compilable && node.IsCaseNode() &&
452       !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
453                         uncompilable_nodes)) {
454     LogNotCompilable(node, "unsupported case");
455     return false;
456   }
457 
458   if (!op_filter_.allow_stateful_rng_ops &&
459       IsStatefulRandomOp(node.type_string())) {
460     absl::string_view uncompilable_reason = "stateful random op";
461     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
462                               encapsulating_function, uncompilable_nodes);
463     LogNotCompilable(node, uncompilable_reason);
464     return false;
465   }
466 
467   if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
468     absl::string_view uncompilable_reason = "not allowed control trigger";
469     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
470                               encapsulating_function, uncompilable_nodes);
471     LogNotCompilable(node, uncompilable_reason);
472     return false;
473   }
474 
475   if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
476       IsAssertOrCheckNumerics(node.type_string())) {
477     absl::string_view uncompilable_reason = "Assert or CheckNumerics";
478     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
479                               encapsulating_function, uncompilable_nodes);
480     LogNotCompilable(node, uncompilable_reason);
481     return false;
482   }
483 
484   if (!op_filter_.allow_collective_reduce_v2 &&
485       node.type_string() == "CollectiveReduceV2") {
486     absl::string_view uncompilable_reason = "Collective op";
487     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
488                               encapsulating_function, uncompilable_nodes);
489     LogNotCompilable(node, uncompilable_reason);
490     return false;
491   }
492 
493   if (!op_filter_.allow_ops_producing_or_consuming_variant &&
494       OpProducesOrConsumesVariant(node)) {
495     absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
496     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
497                               encapsulating_function, uncompilable_nodes);
498     LogNotCompilable(node, uncompilable_reason);
499     return false;
500   }
501 
502   if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
503     absl::string_view uncompilable_reason = "Stack op";
504     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
505                               encapsulating_function, uncompilable_nodes);
506     LogNotCompilable(node, uncompilable_reason);
507     return false;
508   }
509 
510   if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
511     absl::string_view uncompilable_reason = "TensorArray op";
512     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
513                               encapsulating_function, uncompilable_nodes);
514     LogNotCompilable(node, uncompilable_reason);
515     return false;
516   }
517 
518   if (!op_filter_.allow_resource_ops_in_called_functions && stack_depth > 1 &&
519       HasResourceInput(node)) {
520     absl::string_view uncompilable_reason =
521         "resource variable op in called function";
522     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
523                               encapsulating_function, uncompilable_nodes);
524     LogNotCompilable(node, uncompilable_reason);
525     return false;
526   }
527 
528   if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
529     absl::string_view uncompilable_reason =
530         "operation with numerical accuracy issues";
531     BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
532                                 node.DebugString())
533         .IgnoreError();
534     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
535                               encapsulating_function, uncompilable_nodes);
536     LogNotCompilable(node, uncompilable_reason);
537     return false;
538   }
539 
540   if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
541     absl::string_view uncompilable_reason = "slow operation";
542     BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
543                                 node.DebugString())
544         .IgnoreError();
545     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
546                               encapsulating_function, uncompilable_nodes);
547     LogNotCompilable(node, uncompilable_reason);
548     return false;
549   }
550 
551   return true;
552 }
553 
CreateOperationFilter(const XlaOpRegistry::DeviceRegistration & registration)554 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
555     const XlaOpRegistry::DeviceRegistration& registration) {
556   RecursiveCompilabilityChecker::OperationFilter op_filter;
557   op_filter.allow_resource_ops_in_called_functions =
558       registration.cluster_resource_variable_ops_unsafely;
559   op_filter.allow_stack_ops = registration.cluster_stack_ops;
560   op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
561   op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
562   op_filter.allow_control_trigger = registration.cluster_control_trigger;
563   op_filter.allow_eliding_assert_and_checknumerics_ops =
564       registration.elide_assert_and_checknumerics;
565   op_filter.allow_ops_producing_or_consuming_variant =
566       registration.cluster_variant_ops;
567   op_filter.allow_slow_ops = registration.cluster_slow_ops;
568   op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
569   return op_filter;
570 }
571 
MaybeMarkUncompilableNode(const absl::string_view reason,const std::vector<StackFrameView> & stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes)572 /*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
573     const absl::string_view reason,
574     const std::vector<StackFrameView>& stack_trace,
575     NameAttrList* encapsulating_function,
576     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
577   if (!uncompilable_nodes) return;
578 
579   UncompilableNodeInfo node_info;
580   node_info.uncompilable_reason = std::string(reason);
581   absl::c_transform(stack_trace, std::back_inserter(node_info.stack_trace),
582                     [](const StackFrameView& stack_element) {
583                       return StackFrame{
584                           std::string(stack_element.name),
585                           std::string(stack_element.function_name),
586                           stack_element.stack_trace};
587                     });
588 
589   node_info.name = std::string(stack_trace.back().name);
590   auto function =
591       encapsulating_function ? *encapsulating_function : NameAttrList();
592   auto function_identifier = function.ShortDebugString();
593 
594   auto it = uncompilable_nodes->find(function_identifier);
595   if (it == uncompilable_nodes->end()) {
596     std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
597         uncompilable_node_info{std::move(node_info)};
598     uncompilable_nodes->emplace(
599         std::move(function_identifier),
600         std::make_pair(function, std::move(uncompilable_node_info)));
601   } else {
602     it->second.second.emplace_back(std::move(node_info));
603   }
604 }
605 
606 // Returns `true` iff node has a given `attr` set to `true`. Returns `false`
607 // both for the missing attr, and the attr set to `false`.
HasBoolAttr(const NodeDef & node,const char * attr)608 static bool HasBoolAttr(const NodeDef& node, const char* attr) {
609   const auto& it = node.attr().find(attr);
610   return it != node.attr().end() && it->second.b();
611 }
612 
CanCreateXlaKernel(const NodeDef & node_def)613 bool CanCreateXlaKernel(const NodeDef& node_def) {
614   return HasBoolAttr(node_def, kXlaMustCompileAttr);
615 }
616 
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NameAttrList & function,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)617 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
618                                        const NameAttrList& function,
619                                        const FunctionBody** fbody,
620                                        std::vector<int>* constant_arg_indices,
621                                        std::vector<int>* resource_arg_indices) {
622   FunctionLibraryRuntime::Handle handle;
623   TF_RETURN_IF_ERROR(
624       flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
625   *fbody = flr->GetFunctionBody(handle);
626   CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
627   const DataTypeVector& arg_types = (*fbody)->arg_types;
628   std::vector<bool> const_args(arg_types.size());
629   // If we can't analyze the const args. Bail out.
630   TF_RETURN_IF_ERROR(
631       BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
632                              /*compile_time_const_nodes=*/nullptr, flr));
633 
634   for (size_t i = 0; i < const_args.size(); ++i) {
635     if (const_args[i]) {
636       constant_arg_indices->push_back(i);
637     }
638   }
639 
640   // There can be hundreds of resource variables. Reserve the space for them.
641   // We don't reserve for constants above as they are usually few.
642   resource_arg_indices->reserve(arg_types.size());
643   for (size_t i = 0; i < arg_types.size(); ++i) {
644     if (arg_types[i] == DT_RESOURCE) {
645       resource_arg_indices->push_back(i);
646     }
647   }
648 
649   return Status::OK();
650 }
651 
GetInputMemoryTypes(const tensorflow::FunctionBody * fbody,absl::Span<int const> constant_arg_indices,absl::Span<int const> resource_arg_indices)652 tensorflow::MemoryTypeVector GetInputMemoryTypes(
653     const tensorflow::FunctionBody* fbody,
654     absl::Span<int const> constant_arg_indices,
655     absl::Span<int const> resource_arg_indices) {
656   // Set input and output memory types.
657   tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
658                                                   tensorflow::DEVICE_MEMORY);
659   // These indices are used only for optimization purposes. They allow us
660   // to loop over constant_arg_indices and resource_arg_indices only once
661   // while iterating over all the function arguments checking if it is a
662   // resource or a constant.
663   // The reason we optimized this code is because functions can have a lot of
664   // captured arguments. For example, the backward pass of ResNet50 takes in all
665   // 214 variables and a similar number of activations.
666   SinglePassSearch constants_search(constant_arg_indices);
667   SinglePassSearch resources_search(resource_arg_indices);
668   for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
669     if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
670       // Compile-time constants and resource handles are expected to be in
671       // host memory.
672       input_memory_types[i] = tensorflow::HOST_MEMORY;
673     }
674   }
675   return input_memory_types;
676 }
677 
GetOutputMemoryTypes(const tensorflow::FunctionBody * fbody)678 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
679     const tensorflow::FunctionBody* fbody) {
680   tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
681                                                    tensorflow::DEVICE_MEMORY);
682   for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
683     if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
684       output_memory_types[i] = tensorflow::HOST_MEMORY;
685     }
686   }
687   return output_memory_types;
688 }
689 
690 static auto const ops_triggering_xla_compilation =
691     new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
692                                          "XlaConv",
693                                          "XlaConvV2",
694                                          "XlaDequantize",
695                                          "XlaDot",
696                                          "XlaDotV2",
697                                          "XlaDynamicSlice",
698                                          "XlaDynamicUpdateSlice",
699                                          "XlaEinsum",
700                                          "XlaGather",
701                                          "XlaIf",
702                                          "XlaKeyValueSort",
703                                          "XlaPad",
704                                          "XlaRecv",
705                                          "XlaReduce",
706                                          "XlaReduceWindow",
707                                          "XlaReplicaId",
708                                          "XlaRngBitGenerator",
709                                          "XlaScatter",
710                                          "XlaSelectAndScatter",
711                                          "XlaSelfAdjointEig",
712                                          "XlaSend",
713                                          "XlaSharding",
714                                          "XlaSort",
715                                          "XlaSpmdFullToShardShape",
716                                          "XlaSpmdShardToFullShape",
717                                          "XlaSvd",
718                                          "XlaVariadicReduceV2",
719                                          "XlaVariadicSort",
720                                          "XlaWhile"};
721 
NodeCanTriggerXlaCompilation(const NodeDef & node)722 static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
723   return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
724          HasBoolAttr(node, kXlaMustCompileAttr) ||
725          HasBoolAttr(node, kXlaCompileAttr) ||
726          HasBoolAttr(node, kXlaScopeAttr) ||
727          HasBoolAttr(node, kXlaInternalScopeAttr) ||
728          ops_triggering_xla_compilation->count(node.op());
729 }
730 
CanTriggerXlaCompilation(const GraphDef & graph)731 bool CanTriggerXlaCompilation(const GraphDef& graph) {
732   for (const FunctionDef& function : graph.library().function()) {
733     for (const NodeDef& node : function.node_def()) {
734       if (NodeCanTriggerXlaCompilation(node)) {
735         return true;
736       }
737     }
738   }
739 
740   for (const NodeDef& node : graph.node()) {
741     if (NodeCanTriggerXlaCompilation(node)) {
742       return true;
743     }
744   }
745 
746   return false;
747 }
748 
749 }  // namespace tensorflow
750