• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17 
18 #include <atomic>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/jit/deadness_analysis.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
31 #include "tensorflow/compiler/jit/union_find.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
35 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/framework/bounds_check.h"
39 #include "tensorflow/core/framework/graph_def_util.h"
40 #include "tensorflow/core/framework/memory_types.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/control_flow.h"
46 #include "tensorflow/core/graph/graph_constructor.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/strings/stringprintf.h"
49 #include "tensorflow/core/public/version.h"
50 #include "tensorflow/core/util/dump_graph.h"
51 
52 namespace tensorflow {
53 
54 namespace {
55 // The clusters we create here are eventually lowered into an
56 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
57 // PartitionedCall op to execute the cluster in the regular graph executor if
58 // need be.  PartitionedCall, however, reruns the entire TF graph optimization
59 // pipeline over the cluster which includes this mark for compilation pass.  To
60 // avoid endlessly recursing we tag nodes that we've already visited with this
61 // attribute so that we can bail out if we see them a second time.
62 //
63 // TODO(sanjoy): This method is not robust since it is possible that the
64 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
65 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
66 // The correct fix is to use the ConfigProto to pass in some sort of flag into
67 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
68 // cluster.
69 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
70 
71 // Aggregates information about what kinds of ops are allowed.
72 struct OperationFilter {
73   // Whether resource variable ops are allowed.  We do not allow resource
74   // variable ops in called functions (either as direct TF calls or as higher
75   // order control flow ops) because we do not yet model their memory effects in
76   // jit/resource_variable_safety_analysis.
77   bool allow_resource_ops;
78 
79   // Whether stateful RNG ops are allowed.  XLA's RNG does not have the same
80   // seeding behavior as TensorFlow's RNG (b/34749654).  So we avoid
81   // auto-clustering stateful RNG ops.
82   bool allow_stateful_rng_ops;
83 
84   // TODO(b/118970344): Whether ControlTrigger ops are allowed.  It is unsound
85   // to cluster ControlTrigger because of how we use deadness analysis.
86   bool allow_control_trigger;
87 
88   // Whether ops with dummy implementations are allowed. We avoid
89   // auto-clustering these ops so that the user is not surprised when XLA is
90   // implicitly enabled. If the user explicitly specifies to use XLA, it is fine
91   // to resort to a dummy implementation. Currently Assert and CheckNumerics ops
92   // have dummy XLA implementations.
93   bool allow_dummy_ops;
94 
95   // Whether ops that produce or consume DT_VARIANT values are allowed.  We
96   // don't auto-cluster these ops because we don't yet support live-in or
97   // live-out DT_VARIANT values.
98   bool allow_ops_producing_or_consuming_variant;
99 };
100 
IsDummyImplOp(absl::string_view op_name)101 bool IsDummyImplOp(absl::string_view op_name) {
102   return op_name == "Assert" || op_name == "CheckNumerics";
103 }
104 
IsStatefulRandomOp(absl::string_view op_name)105 bool IsStatefulRandomOp(absl::string_view op_name) {
106   return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
107          op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
108          op_name == "TruncatedNormal" || op_name == "Multinomial";
109 }
110 
OpProducesOrConsumesVariant(const Node & node)111 bool OpProducesOrConsumesVariant(const Node& node) {
112   auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
113   return absl::c_any_of(node.input_types(), is_variant) ||
114          absl::c_any_of(node.output_types(), is_variant);
115 }
116 
HasXLAKernel(const Node & node,const DeviceType & jit_device_type)117 bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
118   // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
119   // is really a kind of function call and will be handled by
120   // IsCompilableCall().
121   if (node.type_string() == "SymbolicGradient") return false;
122   if (node.type_string() == "Const") {
123     // Skip Const op with type DT_STRING, since XLA doesn't support it, but the
124     // registered Const KernelDef says that it does, to support no-op Assert for
125     // tfcompile.
126     const AttrValue* attr = node.attrs().Find("dtype");
127     if (attr != nullptr && attr->type() == DT_STRING) {
128       return false;
129     }
130   }
131 
132   // XLA does not offer guaranteed aliasing between the input and output of the
133   // XLA cluster so it can't implement the forward-tensor-ref semantic.  Leave
134   // such nodes out of XLA clusters.
135   if (HasForwardedRefInput(node)) {
136     VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
137     return false;
138   }
139 
140   return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
141 }
142 
HasResourceOutput(const Node & node)143 bool HasResourceOutput(const Node& node) {
144   return std::find(node.output_types().begin(), node.output_types().end(),
145                    DT_RESOURCE) != node.output_types().end();
146 }
147 
HasResourceInput(const Node & node)148 bool HasResourceInput(const Node& node) {
149   return std::find(node.input_types().begin(), node.input_types().end(),
150                    DT_RESOURCE) != node.input_types().end();
151 }
152 
153 // Returns true if `node` is a resource operation recognized by tf2xla that
154 // operates on something other than resource variables.
IsNonResourceVarResourceOp(const Node & node)155 bool IsNonResourceVarResourceOp(const Node& node) {
156   // TODO(b/112837194): We can't cluster these because we only support
157   // snapshotting resource variables (and we can't e.g. snapshot stacks).  This
158   // limitation may be fixable with some work.
159   const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string());
160   return op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
161 }
162 
163 // Make sure we don't recurse infinitely on recursive functions.
164 const int kMaxRecursionDepth = 10;
165 
166 bool IsCompilableCall(const NodeDef& call_def,
167                       const DeviceType& jit_device_type,
168                       const OperationFilter& op_filter, int depth,
169                       FunctionLibraryRuntime* lib_runtime);
170 
171 // Tests whether 'while_node' is a completely compilable loop.
172 // Every operator in the condition and body functions must be compilable for a
173 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,const DeviceType & jit_device_type,const OperationFilter & op_filter,int depth,FunctionLibraryRuntime * lib_runtime)174 bool IsCompilableWhile(const Node& while_node,
175                        const DeviceType& jit_device_type,
176                        const OperationFilter& op_filter, int depth,
177                        FunctionLibraryRuntime* lib_runtime) {
178   const NameAttrList* name_attr;
179   NodeDef call;
180   Status status;
181   status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
182   if (!status.ok()) {
183     VLOG(2) << "Rejecting While " << while_node.name()
184             << ": missing 'cond' attribute on While node.";
185     return false;
186   }
187   const string cond_func = name_attr->name();
188   call.set_name("while_cond");
189   call.set_op(cond_func);
190   *call.mutable_attr() = name_attr->attr();
191   if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
192                         lib_runtime)) {
193     VLOG(2) << "Rejecting While " << while_node.name()
194             << ": can't compile loop condition: " << cond_func;
195     return false;
196   }
197   status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
198   if (!status.ok()) {
199     VLOG(2) << "Rejecting While " << while_node.name()
200             << ": missing 'body' attribute on While node.";
201     return false;
202   }
203   const string body_func = name_attr->name();
204   call.set_name("while_body");
205   call.set_op(body_func);
206   *call.mutable_attr() = name_attr->attr();
207   if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
208                         lib_runtime)) {
209     VLOG(2) << "Rejecting While " << while_node.name()
210             << ": can't compile loop body: " << body_func;
211     return false;
212   }
213   return true;
214 }
215 
216 // Tests whether 'call_def' is a call to a completely compilable function.
217 // Every operator in the function must be compilable for a function to be
218 // compilable.
IsCompilableCall(const NodeDef & call_def,const DeviceType & jit_device_type,const OperationFilter & op_filter,int depth,FunctionLibraryRuntime * lib_runtime)219 bool IsCompilableCall(const NodeDef& call_def,
220                       const DeviceType& jit_device_type,
221                       const OperationFilter& op_filter, int depth,
222                       FunctionLibraryRuntime* lib_runtime) {
223   if (depth > kMaxRecursionDepth) {
224     VLOG(2) << "Rejecting " << call_def.op()
225             << ": function depth limit exceeded.";
226     return false;
227   }
228 
229   FunctionLibraryRuntime::Handle handle;
230   Status status = InstantiateFunctionCall(call_def, *lib_runtime, &handle);
231   if (!status.ok()) {
232     VLOG(2) << "Rejecting " << call_def.DebugString()
233             << ": could not instantiate: " << status;
234     return false;
235   }
236 
237   auto release_handle_on_return = gtl::MakeCleanup(
238       [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
239 
240   const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
241   CHECK(fbody);
242   const FunctionDef& fdef = fbody->fdef;
243   bool noinline = false;
244   if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() &&
245       noinline) {
246     // The underlying mechanism that calls non-inlined functions uses
247     // LocalExecutor, which interacts poorly with the LocalExecutor used by
248     // tf2xla to translate the TF graph into XLA.  So we avoid this for now.
249     //
250     // TODO(b/36139787): Create a mechanism to set inlining hints.
251     VLOG(2) << "Rejecting " << call_def.op()
252             << ": can't compile noinline function.";
253     return false;
254   }
255 
256   for (Node* node : fbody->graph->op_nodes()) {
257     if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
258       continue;
259     if (node->type_string() == "While") {
260       // Handle functional While loop.
261       return IsCompilableWhile(*node, jit_device_type, op_filter, depth + 1,
262                                lib_runtime);
263     }
264     if (!op_filter.allow_resource_ops &&
265         (HasResourceInput(*node) || HasResourceOutput(*node))) {
266       return false;
267     }
268     if (!op_filter.allow_stateful_rng_ops &&
269         IsStatefulRandomOp(node->type_string())) {
270       return false;
271     }
272     if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
273       return false;
274     }
275     if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
276       return false;
277     }
278     if (!op_filter.allow_ops_producing_or_consuming_variant &&
279         OpProducesOrConsumesVariant(*node)) {
280       return false;
281     }
282     if (!HasXLAKernel(*node, jit_device_type) &&
283         !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
284                           lib_runtime)) {
285       VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
286               << node->name() << ": " << node->def().ShortDebugString();
287       return false;
288     }
289   }
290   return true;
291 }
292 
293 // Returns true if the op can be decomposed into XLA ops for which
294 // there are fusable elemental implementations.
295 //
296 // TODO(hpucha): Remove this code since this functionality is subsumed by
297 // Grappler XlaFusionOptimizer.
IsXlaFusable(const NodeDef & node)298 bool IsXlaFusable(const NodeDef& node) {
299   static const std::unordered_set<std::string>* elementwise_ops =
300       new std::unordered_set<std::string>(
301           {// tf2xla/kernels/aggregate_ops.cc
302            "AddN",
303            // tf2xla/kernels/batchtospace_op.cc
304            "BatchToSpace", "BatchToSpaceND",
305            // tf2xla/kernels/bcast_ops.cc
306            "BroadcastArgs", "BroadcastGradientArgs",
307            // tf2xla/kernels/bias_ops.cc
308            "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
309            // tf2xla/kernels/binary_ops.cc
310            "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
311            "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
312            "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
313            "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
314            "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
315            "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
316            "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
317            // tf2xla/kernels/cast_op.cc
318            "Cast",
319            // tf2xla/kernels/categorical_op.cc
320            "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/,
321            // tf2xla/kernels/concat_op.cc
322            "Concat", "ConcatV2", "ConcatOffset",
323            // tf2xla/kernels/const_op.cc
324            "Const",
325            // tf2xla/kernels/cross_op.cc
326            "Cross",
327            // tf2xla/kernels/depthtospace_op.cc
328            "DepthToSpace",
329            // tf2xla/kernels/diag_op.cc
330            "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart",
331            // tf2xla/kernels/dynamic_stitch_op.cc
332            "DynamicStitch", "ParallelDynamicStitch",
333            // tf2xla/kernels/elu_op.cc
334            "Elu", "EluGrad", "Selu", "SeluGrad",
335            // tf2xla/kernels/fake_quantize_ops.cc
336            "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient",
337            "FakeQuantWithMinMaxVars",
338            "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/,
339            // tf2xla/kernels/fill_op.cc
340            "Fill",
341            // tf2xla/kernels/gather_op.cc
342            "Gather", "GatherV2", "GatherNd",
343            // tf2xla/kernels/identity_op.cc
344            "Identity", "IdentityN", "PreventGradient", "StopGradient",
345            "Snapshot",
346            // tf2xla/kernels/image_ops.cc
347            "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/,
348            "AdjustSaturation", "AdjustHue",
349            // tf2xla/kernels/index_ops.cc
350            "ArgMax", "ArgMin",
351            // tf2xla/kernels/l2loss_op.cc
352            "L2Loss" /*(Reduce)*/,
353            // tf2xla/kernels/lrn_ops.cc (ReduceWindow)
354            "LRN", "LRNGrad",
355            // tf2xla/kernels/matrix_band_part_op.cc
356            "MatrixBandPart",
357            // tf2xla/kernels/matrix_set_diag_op.cc
358            "MatrixSetDiag",
359            // tf2xla/kernels/mirror_pad_op.cc
360            "MirrorPad",
361            // tf2xla/kernels/no_op.cc
362            "NoOp", "ControlTrigger",
363            // tf2xla/kernels/one_hot_op.cc
364            "OneHot",
365            // tf2xla/kernels/pack_op.cc
366            "Pack",
367            // tf2xla/kernels/pad_op.cc
368            "Pad", "PadV2",
369            // tf2xla/kernels/pooling_ops.cc
370            "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool",
371            "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/
372            "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad",
373            "AvgPool3DGrad",
374            // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce)
375            "QuantizeAndDequantizeV2",
376            // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend
377            // currently)
378            "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
379            "TruncatedNormal",
380            // tf2xla/kernels/reduction_ops.cc (Reduce)
381            "Sum", "Prod", "Min", "Max", "Mean", "All", "Any",
382            // tf2xla/kernels/relu_op.cc
383            "Relu", "Relu6", "ReluGrad", "Relu6Grad",
384            // tf2xla/kernels/reshape_op.cc
385            "Reshape",
386            // tf2xla/kernels/reverse_op.cc
387            "Reverse", "ReverseV2",
388            // tf2xla/kernels/reverse_sequence_op.cc
389            "ReverseSequence",
390            // tf2xla/kernels/scan_ops.cc (ReduceWindow)
391            "Cumsum", "Cumprod",
392            // tf2xla/kernels/scatter_nd_op.cc (Reduce)
393            "ScatterNd",
394            // tf2xla/kernels/segment_reduction_ops.cc (Reduce)
395            "UnsortedSegmentSum",
396            // tf2xla/kernels/select_op.cc
397            "Select",
398            // tf2xla/kernels/sequence_ops.cc
399            "Range", "LinSpace",
400            // tf2xla/kernels/shape_op.cc
401            "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
402            "ZerosLike", "OnesLike",
403            // tf2xla/kernels/slice_op.cc
404            "Slice",
405            // tf2xla/kernels/softmax_op.cc (Reduce)
406            "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits",
407            "SparseSoftmaxCrossEntropyWithLogits",
408            // tf2xla/kernels/spacetobatch_op.cc
409            "SpaceToBatchND", "SpaceToBatch",
410            // tf2xla/kernels/spacetodepth_op.cc
411            "SpaceToDepth",
412            // tf2xla/kernels/split_op.cc
413            "Split", "SplitV",
414            // tf2xla/kernels/stack_ops.cc
415            "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2",
416            // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on
417            // GPU
418            // backend currently)
419            "StatelessRandomUniform",
420            "StatelessRandomNormal"
421            // tf2xla/kernels/strided_slice_op.cc
422            "StridedSlice",
423            "StridedSliceGrad", "ResourceStridedSliceAssign",
424            // tf2xla/kernels/tile_ops.cc
425            "Tile",
426            // tf2xla/kernels/training_ops.cc
427            "ResourceApplyGradientDescent", "ResourceApplyMomentum",
428            "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp",
429            "ResourceApplyFtrl", "ResourceApplyFtrlV2",
430            // tf2xla/kernels/transpose_op.cc
431            "Transpose", "InvertPermutation",
432            // tf2xla/kernels/unary_ops.cc
433            "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
434            "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
435            "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
436            "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
437            "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
438            "Square", "Tan", "Tanh", "Real", "Imag",
439            // tf2xla/kernels/unpack_op.cc
440            "Unpack"});
441 
442   return elementwise_ops->count(node.op()) > 0;
443 }
444 
445 // Nodes that XLA can compile are put in `candidates`.  Nodes put in
446 // `isolated_nodes` must either be unclustered or be put in trivial single-node
447 // clusters.
FindCompilationCandidates(const Graph & graph,FunctionLibraryDefinition * flib_def,Env * env,const std::function<bool (const Node *,const DeviceType &)> & is_compilable_fn,OrderedNodeSet * candidates,absl::flat_hash_set<Node * > * isolated_nodes)448 Status FindCompilationCandidates(
449     const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
450     const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
451     OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) {
452   OptimizerOptions opts;
453   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
454       new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
455                                         flib_def, opts));
456   FunctionLibraryRuntime* lib_runtime =
457       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
458   std::vector<bool> compile_time_const_nodes(graph.num_node_ids(), false);
459   TF_RETURN_IF_ERROR(
460       BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
461                              &compile_time_const_nodes, lib_runtime));
462 
463   int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
464 
465   // Iterate over nodes in sorted order so that compiler fuel is deterministic.
466   // We can't simply pass op_nodes().begin() and op_nodes().end to the
467   // std::vector constructor because they're not proper iterators, with
468   // iterator_traits defined and so on.
469   std::vector<Node*> sorted_nodes;
470   for (Node* node : graph.op_nodes()) {
471     sorted_nodes.push_back(node);
472   }
473   std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
474 
475   if (fuel >= std::numeric_limits<int64>::max() / 2) {
476     // The assumption is that if fuel started out as INT64_MAX, it will forever
477     // stay greater than INT64_MAX / 2.
478     VLOG(2) << "Starting fuel: infinity";
479   } else {
480     VLOG(2) << "Starting fuel: " << fuel;
481   }
482 
483   for (Node* node : sorted_nodes) {
484     if (fuel <= 0) {
485       VLOG(1)
486           << "Hit fuel limit; not marking any remaining ops as clusterable.";
487       break;
488     }
489 
490     DeviceType device_type("");
491     TF_RETURN_IF_ERROR(
492         DeviceToDeviceType(node->assigned_device_name(), &device_type));
493     VLOG(4) << "Device type for " << node->name() << ": "
494             << device_type.type_string();
495 
496     if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
497       // is_compilable_fn has already logged the reason if it returned false.
498       continue;
499     }
500 
501     const XlaOpRegistry::DeviceRegistration* registration;
502     CHECK(
503         XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
504     DeviceType jit_device_type(registration->compilation_device_name);
505 
506     bool always_auto_cluster = registration->autoclustering_policy ==
507                                XlaOpRegistry::AutoclusteringPolicy::kAlways;
508 
509     OperationFilter op_filter;
510     op_filter.allow_resource_ops = registration->compile_all_resource_ops;
511     op_filter.allow_stateful_rng_ops = always_auto_cluster;
512     op_filter.allow_control_trigger = always_auto_cluster;
513     op_filter.allow_dummy_ops = always_auto_cluster;
514     op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster;
515 
516     if (!HasXLAKernel(*node, jit_device_type) &&
517         !IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
518                           lib_runtime)) {
519       VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
520               << node->type_string();
521       continue;
522     }
523 
524     if (!op_filter.allow_stateful_rng_ops &&
525         IsStatefulRandomOp(node->type_string())) {
526       VLOG(2) << "Rejecting " << node->name() << ": stateful random operation";
527       continue;
528     }
529     if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
530       VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op";
531       continue;
532     }
533     if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
534       VLOG(2) << "Rejecting " << node->name() << ": dummy op ("
535               << node->type_string() << ")";
536       continue;
537     }
538     if (!op_filter.allow_ops_producing_or_consuming_variant &&
539         OpProducesOrConsumesVariant(*node)) {
540       VLOG(2) << "Rejecting " << node->name()
541               << ": produces or consumes DT_VARIANT";
542       continue;
543     }
544 
545     if (!registration->compile_all_resource_ops &&
546         (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
547       // We don't have a way of returning values of type DT_RESOURCE from XLA
548       // computations so we avoid auto-clustering nodes producing DT_RESOURCE.
549       // XlaLaunchOp also cannot snapshot resources that are not resource
550       // variables so we avoid clustering resource operations that operate on
551       // non-resource variables.
552       VLOG(2) << "Rejecting: " << node->name() << ": resource output "
553               << node->type_string();
554       continue;
555     }
556 
557     if (compile_time_const_nodes[node->id()]) {
558       const OpDef* op_def;
559       TF_RETURN_IF_ERROR(
560           graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
561       if (op_def->is_stateful()) {
562         // It is easiest to demonstrate the problem we're trying to solve with
563         // an example.  Say we have this graph:
564         //
565         //   shape = RandomUniformInt();
566         //   reshape = Reshape(input, shape)
567         //
568         // Both RandomUniformInt and Reshape are compilable by XLA so, absent
569         // any other reason, we will try to put both shape and reshape in the
570         // same cluster.  However, since XLA only supports statically shaped
571         // values, it will expect to be able to constant fold `shape` to get a
572         // static shape for `reshape`.  This is a problem because side-effecting
573         // ops like RandomUniformInt() cannot be constant folded.  We fix this
574         // by putting `shape` and `reshape` in different clusters, which results
575         // in us recompiling `reshape`'s cluster for every new value of `shape`,
576         // making `reshape` statically sized within each compilation.  We
577         // simplify the solution even further by disallowing operations like
578         // `shape` from being part of *any* non-trivial cluster.  They're either
579         // not compiled by XLA altogether or, if assigned to an XLA_* device
580         // with "must compile" semantics, compiled into a trivial single-op
581         // cluster.  This approach leaves some room for improvement, and we can
582         // consider implementing a more aggressive data-flow-analysis based
583         // solution in the future if needed.
584         //
585         // One ugly problem we have to contend with: certain sets of ops *have*
586         // to be in the same cluster because values flowing between them have
587         // types that can't be live-in or live-out of a cluster.  These ops are:
588         //
589         //  - TensorArray ops operating on the same TensorArray instance.
590         //  - Stack ops operating on the same Stack instance.
591         //
592         // To work around this we avoid isolating these specific ops.  Because
593         // of this concession it is unsound to auto-cluster them because then
594         // we'd create clusters we could not compile (because we can't constant
595         // fold, say, a TensorArrayRead or a StackPopV2).  But we don't
596         // auto-cluster these operations today so we're good for now.
597         const XlaResourceOpInfo* op_info =
598             GetResourceOpInfoForOp(node->type_string());
599         bool is_tensor_array_or_stack_op =
600             op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
601         if (!is_tensor_array_or_stack_op) {
602           VLOG(2) << "Isolating " << node->name()
603                   << ": must-be-constant stateful op";
604           isolated_nodes->insert(node);
605           // Keep going and execute all the other checks.
606         }
607       }
608     }
609     // We don't auto-cluster functional control flow nodes containing resource
610     // operations because safety checks are trickier in this case.
611     // registration->compile_all_resource_ops is true for XLA_CPU/XLA_GPU but
612     // not for CPU/GPU.
613     if (node->type_string() == "While" &&
614         !IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) {
615       continue;
616     }
617     // _Arg nodes in a top-level function represent feeds.
618     // Do not compile them.
619     if (node->type_string() == "_Arg") {
620       continue;
621     }
622     // _Retval nodes in a top-level function represent fetches.
623     // Do not compile them.
624     if (node->type_string() == "_Retval") {
625       continue;
626     }
627     candidates->insert(node);
628     --fuel;
629   }
630   VLOG(2) << "candidates->size() = " << candidates->size();
631   return Status::OK();
632 }
633 
634 struct Cluster {
635   // Identifies the node that represents this cluster in the cycle detection
636   // graph.
637   int representative = -1;
638 
639   // The set of devices the nodes in this cluster are placed on.
640   absl::flat_hash_set<string> devices;
641 
642   // If there are resource operation in the cluster then this is the device that
643   // resource operations are placed on.  All resource operations in a cluster
644   // must be placed on the same device.
645   string resource_op_device;
646 
647   // True if any node in the cluster has an _XlaCompile attribute set to true.
648   bool has_xla_compile_attr;
649 };
650 
651 }  // anonymous namespace
652 
IsCompilable(FunctionLibraryRuntime * flr,const NodeDef & ndef)653 bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
654   Device* device = flr->device();
655   const XlaOpRegistry::DeviceRegistration* registration;
656   CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
657                                             &registration));
658   DeviceType jit_device_type(registration->compilation_device_name);
659 
660   // We can always *compile* resource operations, stateful RNGs and dummy ops,
661   // even if we are sometimes unable to auto-cluster them.
662   OperationFilter op_filter;
663   op_filter.allow_resource_ops = true;
664   op_filter.allow_stateful_rng_ops = true;
665   op_filter.allow_control_trigger = true;
666   op_filter.allow_dummy_ops = true;
667   op_filter.allow_ops_producing_or_consuming_variant = true;
668 
669   return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
670 }
671 
Run(const GraphOptimizationPassOptions & options)672 Status MarkForCompilationPass::Run(
673     const GraphOptimizationPassOptions& options) {
674   // TODO(phawkins): precompute the "GetCompilationDevice" properties of each
675   // device ahead of time.
676   OptimizerOptions::GlobalJitLevel global_jit_level =
677       GetGlobalJitLevel(options);
678   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
679   bool fusion_only = flags->tf_xla_fusion_only;
680 
681   VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
682   VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
683   const FunctionLibraryDefinition* fld = options.flib_def;
684 
685   // Deadness analysis expects a graph with source and sink edges properly
686   // connected but sometimes the incoming graph does not follow this invariant.
687   // So fix up the source and sink edges before calling into deadness analysis.
688   FixupSourceAndSinkEdges(options.graph->get());
689 
690   // See explanation on `kXlaAlreadyClustered`.
691   for (Node* n : options.graph->get()->nodes()) {
692     if (n->attrs().Find(kXlaAlreadyClustered)) {
693       return Status::OK();
694     }
695   }
696 
697   std::unique_ptr<DeadnessAnalysis> deadness;
698   {
699     XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
700     TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
701   }
702 
703   bool deadness_analysis_disabled =
704       GetMarkForCompilationPassFlags()
705           ->tf_xla_disable_deadness_safety_checks_for_debugging;
706 
707   if (deadness_analysis_disabled) {
708     LOG(WARNING) << "Deadness analysis was manually disabled via "
709                     "--tf_xla_disable_deadness_safety_checks_for_debugging; "
710                     "auto-clustering "
711                     "is unsound!";
712   }
713 
714   auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
715     const XlaOpRegistry::DeviceRegistration* registration;
716     if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
717                                              &registration)) {
718       VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device.";
719       return false;
720     }
721 
722     // If there is a _XlaCompile annotation, use its value.
723     bool compile = false;
724     Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
725     if (status.ok()) {
726       if (!compile) {
727         VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
728                 << kXlaCompileAttr << ") is false.";
729       }
730       return compile;
731     }
732 
733     status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
734     if (status.ok()) {
735       if (!compile) {
736         VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
737                 << kXlaCompileAttr << ") on callee is false.";
738       }
739       return compile;
740     }
741 
742     // If inputs to `node` can have conflicting deadness (i.e. some are alive
743     // and some are dead) then don't compile it.  XLA cannot represent the
744     // deadness semantics of these nodes correctly and auto-clustering these
745     // nodes can cause deadness to propagate to nodes that should be live.
746     if (!deadness_analysis_disabled) {
747       if (node->IsMerge() ||
748           deadness->HasInputsWithMismatchingDeadness(*node)) {
749         VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness.";
750         return false;
751       }
752     }
753 
754     // Check for fusable ops only if requested.
755     if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
756       VLOG(2) << "Rejecting " << node->name()
757               << ": not fusable op but fusion_only enabled.";
758       return false;
759     }
760 
761     return true;
762   };
763 
764   return RunImpl(options, is_compilable);
765 }
766 
RatioToString(int numerator,int denominator)767 static string RatioToString(int numerator, int denominator) {
768   return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
769                          (100.0 * numerator) / denominator);
770 }
771 
VLogClusteringSummary(const Graph & g)772 static void VLogClusteringSummary(const Graph& g) {
773   if (!VLOG_IS_ON(2)) {
774     return;
775   }
776 
777   std::map<absl::string_view, int> cluster_name_to_size;
778   std::map<absl::string_view, std::map<absl::string_view, int>>
779       cluster_name_to_op_histogram;
780   std::map<absl::string_view, int> unclustered_op_histogram;
781   int clustered_node_count = 0;
782 
783   for (Node* n : g.nodes()) {
784     absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
785     if (cluster_name) {
786       clustered_node_count++;
787       cluster_name_to_size[*cluster_name]++;
788       cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
789     } else {
790       unclustered_op_histogram[n->type_string()]++;
791     }
792   }
793 
794   int unclustered_node_count = g.num_nodes() - clustered_node_count;
795 
796   VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes();
797   VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
798           << RatioToString(clustered_node_count, g.num_nodes());
799 
800   for (const auto& cluster_name_size_pair : cluster_name_to_size) {
801     absl::string_view cluster_name = cluster_name_size_pair.first;
802     int size = cluster_name_size_pair.second;
803     VLOG(2) << "  " << cluster_name << " "
804             << RatioToString(size, g.num_nodes());
805     for (const auto& op_count_pair :
806          cluster_name_to_op_histogram[cluster_name]) {
807       VLOG(3) << "   " << op_count_pair.first << ": " << op_count_pair.second
808               << " instances";
809     }
810   }
811 
812   if (!unclustered_op_histogram.empty()) {
813     VLOG(2) << " Unclustered nodes: "
814             << RatioToString(unclustered_node_count, g.num_nodes());
815     for (const auto& pair : unclustered_op_histogram) {
816       VLOG(3) << "  " << pair.first << ": " << pair.second << " instances";
817     }
818   }
819 
820   struct EdgeInfo {
821     absl::string_view node_name;
822     absl::optional<absl::string_view> cluster_name;
823 
824     absl::string_view GetClusterName() const {
825       return cluster_name ? *cluster_name : "[none]";
826     }
827 
828     std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
829         const {
830       return {node_name, cluster_name};
831     }
832 
833     bool operator<(const EdgeInfo& other) const {
834       return AsPair() < other.AsPair();
835     }
836   };
837 
838   using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
839 
840   EdgeInfoMap incoming_edge_infos;
841   EdgeInfoMap outgoing_edge_infos;
842 
843   std::set<absl::string_view> cluster_names_to_print;
844 
845   for (const Edge* e : g.edges()) {
846     const Node* from = e->src();
847     absl::optional<absl::string_view> from_cluster_name =
848         GetXlaClusterForNode(*from);
849 
850     const Node* to = e->dst();
851     absl::optional<absl::string_view> to_cluster_name =
852         GetXlaClusterForNode(*to);
853 
854     if (to_cluster_name == from_cluster_name) {
855       continue;
856     }
857 
858     if (to_cluster_name) {
859       incoming_edge_infos[*to_cluster_name]
860                          [EdgeInfo{from->name(), from_cluster_name}]++;
861       cluster_names_to_print.insert(*to_cluster_name);
862     }
863 
864     if (from_cluster_name) {
865       outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
866       cluster_names_to_print.insert(*from_cluster_name);
867     }
868   }
869 
870   VLOG(2) << "*** Inter-Cluster edges:";
871   if (cluster_names_to_print.empty()) {
872     VLOG(2) << "   [none]";
873   }
874 
875   auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
876                                              const EdgeInfoMap& edge_info_map,
877                                              absl::string_view desc) {
878     auto it = edge_info_map.find(cluster_name);
879     if (it != edge_info_map.end()) {
880       VLOG(2) << "  " << it->second.size() << " " << desc << " edges";
881       for (const auto& edge_info_count_pair : it->second) {
882         VLOG(2) << "   " << edge_info_count_pair.first.GetClusterName() << " "
883                 << edge_info_count_pair.first.node_name << " # "
884                 << edge_info_count_pair.second;
885       }
886     } else {
887       VLOG(2) << "  No " << desc << " edges.";
888     }
889   };
890 
891   for (absl::string_view cluster_name : cluster_names_to_print) {
892     VLOG(2) << " ** Cluster " << cluster_name;
893     print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
894                                     "incoming");
895     print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
896                                     "outgoing");
897   }
898 }
899 
900 // Is 'node' an operator that consumes only the shape of its input, not the
901 // data itself?
IsShapeConsumerOp(const Node & node)902 static bool IsShapeConsumerOp(const Node& node) {
903   return node.type_string() == "Shape" || node.type_string() == "Rank" ||
904          node.type_string() == "Size";
905 }
906 
IgnoreResourceOpForSafetyAnalysis(const Node & n,bool * ignore)907 static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) {
908   // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
909   // ignore it during resource operation safety analysis.  We need this hack
910   // because of two reasons:
911   //
912   //  1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
913   //  2. We don't support live-out values of type DT_RESOURCE and live-in values
914   //     of type DT_RESOURCE that are not resource variables.
915   //
916   // Together these imply we cannot let resource variable safety analysis
917   // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
918   // clusters: both of them will have to be clustered because of (1) and we
919   // won't be able to keep the edge between the two as neither the input to the
920   // second XLA cluster nor the output from the first XLA cluster are supported
921   // because of (2).
922   //
923   // TODO(b/113100872): This can be fixed if the TensorFlow representation for
924   // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
925   // (2) would no longer hold.
926 
927   if (n.assigned_device_name().empty()) {
928     *ignore = false;
929     return Status::OK();
930   }
931   DeviceType device_type("");
932   TF_RETURN_IF_ERROR(
933       DeviceToDeviceType(n.assigned_device_name(), &device_type));
934 
935   const XlaOpRegistry::DeviceRegistration* registration;
936   if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
937     *ignore = true;
938   } else {
939     *ignore = registration->compile_all_resource_ops;
940   }
941   return Status::OK();
942 }
943 
944 // Sequence number generator to ensure clusters have unique names.
945 static std::atomic<int64> cluster_sequence_num;
946 
947 // Returns true if the devices in `cluster_a` and `cluster_b` are compatible and
948 // therefore not a hindrance for combining the two clusters into a larger
949 // cluster.
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b,OptimizerOptions::GlobalJitLevel global_jit_level,bool * result)950 static Status AreDevicesCompatible(
951     const Cluster& cluster_a, const Cluster& cluster_b,
952     OptimizerOptions::GlobalJitLevel global_jit_level, bool* result) {
953   std::vector<string> devices;
954   absl::c_remove_copy(cluster_a.devices, std::back_inserter(devices), "");
955   absl::c_remove_copy(cluster_b.devices, std::back_inserter(devices), "");
956   absl::c_sort(devices);
957 
958   if (devices.empty()) {
959     *result = false;
960     return Status::OK();
961   }
962 
963   // First check if we will even be able to pick a device for the larger
964   // combined cluster.
965   bool can_pick_device;
966   TF_RETURN_IF_ERROR(CanPickDeviceForXla(
967       devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device));
968   if (!can_pick_device) {
969     *result = false;
970     return Status::OK();
971   }
972 
973   string chosen_device;
974   TF_RETURN_IF_ERROR(PickDeviceForXla(
975       devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
976 
977   // If we are able to pick a device `chosen_device` for the larger cluster, the
978   // resource operations in `cluster_a` and `cluster_b` must be placed on the
979   // same device as `chosen_device`.  This is because the _XlaCompile and
980   // _XlaRun kernels are going to run on and therefore try to access the
981   // resource variables from `chosen_device`, which will be an error if the
982   // resource variables are placed on some other device.
983   auto resource_op_device_ok = [&](const string& resource_op_device) {
984     return resource_op_device.empty() || resource_op_device == chosen_device;
985   };
986 
987   *result = resource_op_device_ok(cluster_a.resource_op_device) &&
988             resource_op_device_ok(cluster_b.resource_op_device);
989   if (!*result) {
990     return Status::OK();
991   }
992 
993   // We will check this again later, but here we prune out clusters that would
994   // never have been sent to XLA to save compile time.  Without this change we
995   // will e.g. create a CPU cluster only to later notice that the user did not
996   // enable the CPU JIT via --tf_xla_cpu_global_jit.  With this change we avoid
997   // creating the cluster to begin with.
998   //
999   // TODO(b/126629785): It is possible that this is just papering over O(n^2)
1000   // behavior in our clustering algorithm.
1001   const XlaOpRegistry::DeviceRegistration* registration;
1002   DeviceType device_type("");
1003   TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type));
1004   TF_RET_CHECK(
1005       XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration))
1006       << "chosen device = " << chosen_device
1007       << "; device type = " << device_type.type() << "; devices ("
1008       << devices.size() << ") = " << absl::StrJoin(devices, ", ");
1009 
1010   *result = cluster_a.has_xla_compile_attr || cluster_b.has_xla_compile_attr ||
1011             registration->autoclustering_policy ==
1012                 XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1013             (registration->autoclustering_policy ==
1014                  XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1015              global_jit_level != OptimizerOptions::OFF);
1016 
1017   return Status::OK();
1018 }
1019 
1020 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster,OptimizerOptions::GlobalJitLevel global_jit_level,bool * should_compile,string * device)1021 static Status ShouldCompileClusterImpl(
1022     const Cluster& cluster, OptimizerOptions::GlobalJitLevel global_jit_level,
1023     bool* should_compile, string* device) {
1024   std::vector<string> devices;
1025   absl::c_remove_copy(cluster.devices, std::back_inserter(devices), "");
1026   absl::c_sort(devices);
1027 
1028   string chosen_device;
1029   TF_RETURN_IF_ERROR(PickDeviceForXla(
1030       devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
1031 
1032   const XlaOpRegistry::DeviceRegistration* registration;
1033   DeviceType device_type("");
1034   TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type));
1035   TF_RET_CHECK(
1036       XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration))
1037       << "chosen device = " << chosen_device
1038       << "; device type = " << device_type.type() << "; devices ("
1039       << devices.size() << ") = " << absl::StrJoin(devices, ", ");
1040 
1041   *should_compile =
1042       cluster.has_xla_compile_attr ||
1043       registration->autoclustering_policy ==
1044           XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1045       (registration->autoclustering_policy ==
1046            XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1047        global_jit_level != OptimizerOptions::OFF);
1048 
1049   VLOG(3) << (*should_compile ? "Compiling" : "Not compiling")
1050           << " cluster with device " << chosen_device;
1051 
1052   *device = std::move(chosen_device);
1053   return Status::OK();
1054 }
1055 
ShouldCompileCluster(absl::flat_hash_map<int,std::pair<bool,string>> * cache,OptimizerOptions::GlobalJitLevel global_jit_level,const Cluster & cluster,bool * should_compile,string * device)1056 static Status ShouldCompileCluster(
1057     absl::flat_hash_map<int, std::pair<bool, string>>* cache,
1058     OptimizerOptions::GlobalJitLevel global_jit_level, const Cluster& cluster,
1059     bool* should_compile, string* device) {
1060   auto it = cache->find(cluster.representative);
1061   if (it != cache->end()) {
1062     *should_compile = it->second.first;
1063     *device = it->second.second;
1064     return Status::OK();
1065   }
1066 
1067   string device_s;
1068   TF_RETURN_IF_ERROR(ShouldCompileClusterImpl(cluster, global_jit_level,
1069                                               should_compile, &device_s));
1070   cache->insert({cluster.representative, {*should_compile, device_s}});
1071   *device = std::move(device_s);
1072   return Status::OK();
1073 }
1074 
RunImpl(const GraphOptimizationPassOptions & options,const std::function<bool (const Node *,const DeviceType &)> & is_compilable_fn)1075 Status MarkForCompilationPass::RunImpl(
1076     const GraphOptimizationPassOptions& options,
1077     const std::function<bool(const Node*, const DeviceType&)>&
1078         is_compilable_fn) {
1079   VLOG(1) << "MarkForCompilationPass::Run";
1080 
1081   // Make sure that kernels have been registered on the JIT device.
1082   XlaOpRegistry::RegisterCompilationKernels();
1083 
1084   Graph* graph = options.graph->get();
1085 
1086   OrderedNodeSet compilation_candidates;
1087   absl::flat_hash_set<Node*> isolated_nodes;
1088   TF_RETURN_IF_ERROR(FindCompilationCandidates(
1089       *graph, options.flib_def,
1090       (options.session_options != nullptr) ? options.session_options->env
1091                                            : Env::Default(),
1092       is_compilable_fn, &compilation_candidates, &isolated_nodes));
1093 
1094   if (compilation_candidates.empty()) {
1095     VLOG(2) << "No compilable candidates";
1096     return Status::OK();
1097   }
1098 
1099   GraphCycles cycles;
1100   TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
1101                       CreateCycleDetectionGraph(graph, &cycles));
1102   if (!cycle_detection_graph_ok) {
1103     return Status::OK();
1104   }
1105   TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
1106       graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));
1107 
1108   // Each compilation candidate belongs to a cluster. The cluster's
1109   // representative
1110   // names the node in the 'cycles' graph that represents the cluster.
1111   std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
1112   std::deque<UnionFind<Cluster>*> worklist;
1113   for (Node* node : compilation_candidates) {
1114     Cluster& cluster = clusters[node->id()].Get();
1115     cluster.representative = node->id();
1116     const string& device = !node->assigned_device_name().empty()
1117                                ? node->assigned_device_name()
1118                                : node->requested_device();
1119     if (HasResourceInput(*node) || HasResourceOutput(*node)) {
1120       cluster.resource_op_device = device;
1121     }
1122     cluster.has_xla_compile_attr = false;
1123     bool xla_compile_attr;
1124     if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) {
1125       cluster.has_xla_compile_attr |= xla_compile_attr;
1126     }
1127     if (options.flib_def->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr)
1128             .ok()) {
1129       cluster.has_xla_compile_attr |= xla_compile_attr;
1130     }
1131 
1132     cluster.devices.insert(device);
1133     worklist.push_back(&clusters[node->id()]);
1134   }
1135 
1136   OptimizerOptions::GlobalJitLevel global_jit_level =
1137       GetGlobalJitLevel(options);
1138   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1139 
1140   // Repeatedly contract edges between clusters that are on the same device,
1141   // provided the contraction would not create a cycle.
1142   //
1143   // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
1144   // example, from the Grappler fusion pass).
1145   while (!worklist.empty()) {
1146     Cluster* cluster_from = &worklist.front()->Get();
1147     int from = cluster_from->representative;
1148     worklist.pop_front();
1149 
1150     Node* node_from = graph->FindNodeId(from);
1151     if (node_from->IsControlFlow()) {
1152       // Control flow nodes aren't compilation candidates and should never
1153       // appear.
1154       return errors::Internal(
1155           "Found control flow node in clustering worklist: ",
1156           node_from->type_string());
1157     }
1158 
1159     if (isolated_nodes.count(node_from)) {
1160       continue;
1161     }
1162 
1163     string from_scope;
1164     string to_scope;
1165     for (int to : cycles.Successors(from)) {
1166       if (to >= graph->num_node_ids()) {
1167         // Node is a fictitious node that is present only in the cycle detection
1168         // graph. No clustering is possible.
1169         continue;
1170       }
1171 
1172       const Cluster& cluster_to = clusters[to].Get();
1173       Node* node_to = graph->FindNodeId(to);
1174       if (compilation_candidates.find(node_to) ==
1175           compilation_candidates.cend()) {
1176         continue;
1177       }
1178       bool devices_compatible;
1179       TF_RETURN_IF_ERROR(AreDevicesCompatible(
1180           *cluster_from, cluster_to, global_jit_level, &devices_compatible));
1181       if (!devices_compatible) {
1182         continue;
1183       }
1184       if (isolated_nodes.count(node_to)) {
1185         continue;
1186       }
1187       // Look for an _XlaScope on both nodes.  If both nodes have a
1188       // scope and the scopes do not match, do not cluster along this
1189       // edge. This restriction is overridden if the global_jit_level is ON. If
1190       // even one of the nodes lacks an _XlaScope attribute,
1191       // then it is treated as a "bridge" and a cluster may be created
1192       // along it.  We may want to restrict this behavior to require
1193       // all nodes marked with _XlaCompile=true to also have a
1194       // _XlaScope property set (and raise an error otherwise); but
1195       // for now we don't do this.
1196       if (global_jit_level == OptimizerOptions::OFF &&
1197           GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
1198           GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
1199           from_scope != to_scope) {
1200         continue;
1201       }
1202 
1203       // Ops that consume shapes cannot be the root of a cluster. This is an
1204       // optimization.
1205       if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
1206         continue;
1207       }
1208 
1209       // Don't exceed the maximum cluster size.
1210       if (clusters[from].Size() + clusters[to].Size() >
1211           flags->tf_xla_max_cluster_size) {
1212         continue;
1213       }
1214 
1215       // If any of the consumer's producers are on a different device, do not
1216       // cluster these nodes. This prevents other work on this device from being
1217       // delayed by work on other devices. We consider predecessors of the
1218       // entire cluster rather than just the inputs to the node to prevent the
1219       // cluster still being combined in cases where the 'to' cluster has
1220       // multiple dependencies on the 'from' cluster and another dependency
1221       // leads to a merging of the clusters.
1222       //
1223       // TODO(b/117085735): We probably want to handle the reciprocal of this
1224       // case where a cluster is producing data for multiple devices.
1225       bool found_split = false;
1226       for (const auto& in_id : cycles.Predecessors(to)) {
1227         if (in_id >= graph->num_node_ids()) continue;
1228 
1229         Node* in = graph->FindNodeId(in_id);
1230         const Cluster& cluster_in = clusters[in_id].Get();
1231         if (compilation_candidates.find(in) != compilation_candidates.cend()) {
1232           bool devices_compatible;
1233           TF_RETURN_IF_ERROR(AreDevicesCompatible(
1234               cluster_to, cluster_in, global_jit_level, &devices_compatible));
1235           if (!devices_compatible) {
1236             found_split = true;
1237           }
1238         }
1239       }
1240       if (found_split) continue;
1241 
1242       // If contracting the edge would create a cycle, bail out.
1243       // However, just because we can't merge the clusters now does not mean
1244       // we won't be able to merge them in the future.
1245       // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
1246       // 1->3. But if we first contract 1->2 then we can later contract 1->3.
1247       if (!cycles.ContractEdge(from, to)) continue;
1248 
1249       // Merge the clusters. ContractEdge uses 'from' as the number of the
1250       // merged node, so make sure 'from' is the chosen representative.
1251       cluster_from->devices.insert(cluster_to.devices.begin(),
1252                                    cluster_to.devices.end());
1253       if (!cluster_to.resource_op_device.empty()) {
1254         cluster_from->resource_op_device = cluster_to.resource_op_device;
1255       }
1256       cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr;
1257       clusters[from].Merge(&clusters[to]);
1258 
1259       worklist.push_back(&clusters[from]);
1260       break;
1261     }
1262   }
1263 
1264   // Count the number of non-trivial elements in each cluster.
1265   std::vector<int> effective_cluster_sizes(graph->num_node_ids());
1266 
1267   // has_functional_control_flow remembers if a cluster contains a functional
1268   // control flow node.
1269   std::vector<bool> has_functional_control_flow(graph->num_node_ids());
1270 
1271   for (const Node* n : compilation_candidates) {
1272     int cluster = clusters[n->id()].Get().representative;
1273     // We want clusters to be big enough that the benefit from XLA's
1274     // optimizations offsets XLA related overhead (for instance we add some
1275     // Switch/Merge nodes into the graph to implement lazy compilation).  To
1276     // this end, we don't count Identity and Constant nodes because they do not
1277     // enable interesting optimizations by themselves.
1278     if (!n->IsIdentity() && !n->IsConstant()) {
1279       effective_cluster_sizes[cluster]++;
1280     }
1281     if (n->type_string() == "While" || n->type_string() == "If") {
1282       has_functional_control_flow[cluster] = true;
1283     }
1284   }
1285 
1286   // Names for each cluster.
1287   std::unordered_map<int, string> cluster_names;
1288 
1289   if (flags->tf_xla_clustering_debug) {
1290     DumpGraphToFile("before_mark_for_compilation", **options.graph,
1291                     options.flib_def);
1292   }
1293 
1294   absl::flat_hash_map<int, std::pair<bool, string>>
1295       should_compile_cluster_cache;
1296 
1297   // Mark clusters for compilation that:
1298   // * are placed on a device that requires compilation (an XlaDevice),
1299   // * are explicitly marked for compilation (_XlaCompile=true), or
1300   // * have more than flags->tf_xla_min_cluster_size elements (applicable only
1301   //   if compilation is enabled, otherwise there will be no such candidates).
1302   const int min_cluster_size = flags->tf_xla_min_cluster_size;
1303   for (Node* n : compilation_candidates) {
1304     const Cluster& cluster = clusters[n->id()].Get();
1305     bool should_compile;
1306     string device;
1307     TF_RETURN_IF_ERROR(ShouldCompileCluster(&should_compile_cluster_cache,
1308                                             global_jit_level, cluster,
1309                                             &should_compile, &device));
1310     if (!should_compile) {
1311       continue;
1312     }
1313 
1314     int cluster_repr = cluster.representative;
1315 
1316     // Compile if the user marked this node _XlaCompile=true
1317     bool compile_attr = false;
1318     bool marked_for_compilation = false;
1319     if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
1320       marked_for_compilation = compile_attr;
1321     } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr)
1322                    .ok()) {
1323       marked_for_compilation = compile_attr;
1324     }
1325 
1326     // We assume that functional If and While nodes have at least
1327     // min_cluster_size non-trivial nodes in them.  It would be more principled
1328     // to (recursively) verify this fact, but that's probably not worth the
1329     // trouble.
1330 
1331     if (effective_cluster_sizes[cluster_repr] >= min_cluster_size ||
1332         has_functional_control_flow[cluster_repr] || marked_for_compilation) {
1333       string& name = cluster_names[cluster_repr];
1334 
1335       if (name.empty()) {
1336         name = absl::StrCat("cluster_", cluster_sequence_num++);
1337       }
1338       n->AddAttr(kXlaClusterAttr, name);
1339       n->AddAttr(kXlaAlreadyClustered, true);
1340       VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
1341     }
1342   }
1343 
1344   if (flags->tf_xla_clustering_debug) {
1345     DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def);
1346 
1347     // We also dump out an annoated version of the TF graph where the nodes
1348     // names are prefixed with the cluster names.  This can help visualizing the
1349     // clustering decisions on TensorBoard.
1350     Graph new_graph((*options.graph)->op_registry());
1351     CopyGraph(**options.graph, &new_graph);
1352 
1353     for (Node* n : new_graph.nodes()) {
1354       if (absl::optional<absl::string_view> cluster_name =
1355               GetXlaClusterForNode(*n)) {
1356         n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1357       } else if (n->type_string() == "VarHandleOp") {
1358         n->set_name(absl::StrCat("varhandle/", n->name()));
1359       } else {
1360         // There is room for improvement here.  In particular, it may help to
1361         // split these unclustered nodes into classes where every node in a
1362         // specific class has edges to and from the same set of clusters.
1363         n->set_name(absl::StrCat("unclustered/", n->name()));
1364       }
1365     }
1366 
1367     DumpGraphToFile("mark_for_compilation_annotated", new_graph,
1368                     options.flib_def);
1369   }
1370 
1371   VLogClusteringSummary(*graph);
1372 
1373   return Status::OK();
1374 }
1375 
1376 }  // namespace tensorflow
1377