1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17
18 #include <atomic>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/jit/deadness_analysis.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
31 #include "tensorflow/compiler/jit/union_find.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
35 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/framework/bounds_check.h"
39 #include "tensorflow/core/framework/graph_def_util.h"
40 #include "tensorflow/core/framework/memory_types.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/control_flow.h"
46 #include "tensorflow/core/graph/graph_constructor.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/strings/stringprintf.h"
49 #include "tensorflow/core/public/version.h"
50 #include "tensorflow/core/util/dump_graph.h"
51
52 namespace tensorflow {
53
54 namespace {
55 // The clusters we create here are eventually lowered into an
56 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
57 // PartitionedCall op to execute the cluster in the regular graph executor if
58 // need be. PartitionedCall, however, reruns the entire TF graph optimization
59 // pipeline over the cluster which includes this mark for compilation pass. To
60 // avoid endlessly recursing we tag nodes that we've already visited with this
61 // attribute so that we can bail out if we see them a second time.
62 //
63 // TODO(sanjoy): This method is not robust since it is possible that the
64 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
65 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
66 // The correct fix is to use the ConfigProto to pass in some sort of flag into
67 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
68 // cluster.
69 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
70
71 // Aggregates information about what kinds of ops are allowed.
72 struct OperationFilter {
73 // Whether resource variable ops are allowed. We do not allow resource
74 // variable ops in called functions (either as direct TF calls or as higher
75 // order control flow ops) because we do not yet model their memory effects in
76 // jit/resource_variable_safety_analysis.
77 bool allow_resource_ops;
78
79 // Whether stateful RNG ops are allowed. XLA's RNG does not have the same
80 // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
81 // auto-clustering stateful RNG ops.
82 bool allow_stateful_rng_ops;
83
84 // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound
85 // to cluster ControlTrigger because of how we use deadness analysis.
86 bool allow_control_trigger;
87
88 // Whether ops with dummy implementations are allowed. We avoid
89 // auto-clustering these ops so that the user is not surprised when XLA is
90 // implicitly enabled. If the user explicitly specifies to use XLA, it is fine
91 // to resort to a dummy implementation. Currently Assert and CheckNumerics ops
92 // have dummy XLA implementations.
93 bool allow_dummy_ops;
94
95 // Whether ops that produce or consume DT_VARIANT values are allowed. We
96 // don't auto-cluster these ops because we don't yet support live-in or
97 // live-out DT_VARIANT values.
98 bool allow_ops_producing_or_consuming_variant;
99 };
100
IsDummyImplOp(absl::string_view op_name)101 bool IsDummyImplOp(absl::string_view op_name) {
102 return op_name == "Assert" || op_name == "CheckNumerics";
103 }
104
IsStatefulRandomOp(absl::string_view op_name)105 bool IsStatefulRandomOp(absl::string_view op_name) {
106 return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
107 op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
108 op_name == "TruncatedNormal" || op_name == "Multinomial";
109 }
110
OpProducesOrConsumesVariant(const Node & node)111 bool OpProducesOrConsumesVariant(const Node& node) {
112 auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
113 return absl::c_any_of(node.input_types(), is_variant) ||
114 absl::c_any_of(node.output_types(), is_variant);
115 }
116
HasXLAKernel(const Node & node,const DeviceType & jit_device_type)117 bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
118 // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
119 // is really a kind of function call and will be handled by
120 // IsCompilableCall().
121 if (node.type_string() == "SymbolicGradient") return false;
122 if (node.type_string() == "Const") {
123 // Skip Const op with type DT_STRING, since XLA doesn't support it, but the
124 // registered Const KernelDef says that it does, to support no-op Assert for
125 // tfcompile.
126 const AttrValue* attr = node.attrs().Find("dtype");
127 if (attr != nullptr && attr->type() == DT_STRING) {
128 return false;
129 }
130 }
131
132 // XLA does not offer guaranteed aliasing between the input and output of the
133 // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
134 // such nodes out of XLA clusters.
135 if (HasForwardedRefInput(node)) {
136 VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
137 return false;
138 }
139
140 return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
141 }
142
HasResourceOutput(const Node & node)143 bool HasResourceOutput(const Node& node) {
144 return std::find(node.output_types().begin(), node.output_types().end(),
145 DT_RESOURCE) != node.output_types().end();
146 }
147
HasResourceInput(const Node & node)148 bool HasResourceInput(const Node& node) {
149 return std::find(node.input_types().begin(), node.input_types().end(),
150 DT_RESOURCE) != node.input_types().end();
151 }
152
153 // Returns true if `node` is a resource operation recognized by tf2xla that
154 // operates on something other than resource variables.
IsNonResourceVarResourceOp(const Node & node)155 bool IsNonResourceVarResourceOp(const Node& node) {
156 // TODO(b/112837194): We can't cluster these because we only support
157 // snapshotting resource variables (and we can't e.g. snapshot stacks). This
158 // limitation may be fixable with some work.
159 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string());
160 return op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
161 }
162
163 // Make sure we don't recurse infinitely on recursive functions.
164 const int kMaxRecursionDepth = 10;
165
166 bool IsCompilableCall(const NodeDef& call_def,
167 const DeviceType& jit_device_type,
168 const OperationFilter& op_filter, int depth,
169 FunctionLibraryRuntime* lib_runtime);
170
171 // Tests whether 'while_node' is a completely compilable loop.
172 // Every operator in the condition and body functions must be compilable for a
173 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,const DeviceType & jit_device_type,const OperationFilter & op_filter,int depth,FunctionLibraryRuntime * lib_runtime)174 bool IsCompilableWhile(const Node& while_node,
175 const DeviceType& jit_device_type,
176 const OperationFilter& op_filter, int depth,
177 FunctionLibraryRuntime* lib_runtime) {
178 const NameAttrList* name_attr;
179 NodeDef call;
180 Status status;
181 status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
182 if (!status.ok()) {
183 VLOG(2) << "Rejecting While " << while_node.name()
184 << ": missing 'cond' attribute on While node.";
185 return false;
186 }
187 const string cond_func = name_attr->name();
188 call.set_name("while_cond");
189 call.set_op(cond_func);
190 *call.mutable_attr() = name_attr->attr();
191 if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
192 lib_runtime)) {
193 VLOG(2) << "Rejecting While " << while_node.name()
194 << ": can't compile loop condition: " << cond_func;
195 return false;
196 }
197 status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
198 if (!status.ok()) {
199 VLOG(2) << "Rejecting While " << while_node.name()
200 << ": missing 'body' attribute on While node.";
201 return false;
202 }
203 const string body_func = name_attr->name();
204 call.set_name("while_body");
205 call.set_op(body_func);
206 *call.mutable_attr() = name_attr->attr();
207 if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1,
208 lib_runtime)) {
209 VLOG(2) << "Rejecting While " << while_node.name()
210 << ": can't compile loop body: " << body_func;
211 return false;
212 }
213 return true;
214 }
215
216 // Tests whether 'call_def' is a call to a completely compilable function.
217 // Every operator in the function must be compilable for a function to be
218 // compilable.
IsCompilableCall(const NodeDef & call_def,const DeviceType & jit_device_type,const OperationFilter & op_filter,int depth,FunctionLibraryRuntime * lib_runtime)219 bool IsCompilableCall(const NodeDef& call_def,
220 const DeviceType& jit_device_type,
221 const OperationFilter& op_filter, int depth,
222 FunctionLibraryRuntime* lib_runtime) {
223 if (depth > kMaxRecursionDepth) {
224 VLOG(2) << "Rejecting " << call_def.op()
225 << ": function depth limit exceeded.";
226 return false;
227 }
228
229 FunctionLibraryRuntime::Handle handle;
230 Status status = InstantiateFunctionCall(call_def, *lib_runtime, &handle);
231 if (!status.ok()) {
232 VLOG(2) << "Rejecting " << call_def.DebugString()
233 << ": could not instantiate: " << status;
234 return false;
235 }
236
237 auto release_handle_on_return = gtl::MakeCleanup(
238 [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
239
240 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
241 CHECK(fbody);
242 const FunctionDef& fdef = fbody->fdef;
243 bool noinline = false;
244 if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() &&
245 noinline) {
246 // The underlying mechanism that calls non-inlined functions uses
247 // LocalExecutor, which interacts poorly with the LocalExecutor used by
248 // tf2xla to translate the TF graph into XLA. So we avoid this for now.
249 //
250 // TODO(b/36139787): Create a mechanism to set inlining hints.
251 VLOG(2) << "Rejecting " << call_def.op()
252 << ": can't compile noinline function.";
253 return false;
254 }
255
256 for (Node* node : fbody->graph->op_nodes()) {
257 if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
258 continue;
259 if (node->type_string() == "While") {
260 // Handle functional While loop.
261 return IsCompilableWhile(*node, jit_device_type, op_filter, depth + 1,
262 lib_runtime);
263 }
264 if (!op_filter.allow_resource_ops &&
265 (HasResourceInput(*node) || HasResourceOutput(*node))) {
266 return false;
267 }
268 if (!op_filter.allow_stateful_rng_ops &&
269 IsStatefulRandomOp(node->type_string())) {
270 return false;
271 }
272 if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
273 return false;
274 }
275 if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
276 return false;
277 }
278 if (!op_filter.allow_ops_producing_or_consuming_variant &&
279 OpProducesOrConsumesVariant(*node)) {
280 return false;
281 }
282 if (!HasXLAKernel(*node, jit_device_type) &&
283 !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
284 lib_runtime)) {
285 VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
286 << node->name() << ": " << node->def().ShortDebugString();
287 return false;
288 }
289 }
290 return true;
291 }
292
293 // Returns true if the op can be decomposed into XLA ops for which
294 // there are fusable elemental implementations.
295 //
296 // TODO(hpucha): Remove this code since this functionality is subsumed by
297 // Grappler XlaFusionOptimizer.
IsXlaFusable(const NodeDef & node)298 bool IsXlaFusable(const NodeDef& node) {
299 static const std::unordered_set<std::string>* elementwise_ops =
300 new std::unordered_set<std::string>(
301 {// tf2xla/kernels/aggregate_ops.cc
302 "AddN",
303 // tf2xla/kernels/batchtospace_op.cc
304 "BatchToSpace", "BatchToSpaceND",
305 // tf2xla/kernels/bcast_ops.cc
306 "BroadcastArgs", "BroadcastGradientArgs",
307 // tf2xla/kernels/bias_ops.cc
308 "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
309 // tf2xla/kernels/binary_ops.cc
310 "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
311 "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
312 "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
313 "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
314 "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
315 "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
316 "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
317 // tf2xla/kernels/cast_op.cc
318 "Cast",
319 // tf2xla/kernels/categorical_op.cc
320 "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/,
321 // tf2xla/kernels/concat_op.cc
322 "Concat", "ConcatV2", "ConcatOffset",
323 // tf2xla/kernels/const_op.cc
324 "Const",
325 // tf2xla/kernels/cross_op.cc
326 "Cross",
327 // tf2xla/kernels/depthtospace_op.cc
328 "DepthToSpace",
329 // tf2xla/kernels/diag_op.cc
330 "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart",
331 // tf2xla/kernels/dynamic_stitch_op.cc
332 "DynamicStitch", "ParallelDynamicStitch",
333 // tf2xla/kernels/elu_op.cc
334 "Elu", "EluGrad", "Selu", "SeluGrad",
335 // tf2xla/kernels/fake_quantize_ops.cc
336 "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient",
337 "FakeQuantWithMinMaxVars",
338 "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/,
339 // tf2xla/kernels/fill_op.cc
340 "Fill",
341 // tf2xla/kernels/gather_op.cc
342 "Gather", "GatherV2", "GatherNd",
343 // tf2xla/kernels/identity_op.cc
344 "Identity", "IdentityN", "PreventGradient", "StopGradient",
345 "Snapshot",
346 // tf2xla/kernels/image_ops.cc
347 "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/,
348 "AdjustSaturation", "AdjustHue",
349 // tf2xla/kernels/index_ops.cc
350 "ArgMax", "ArgMin",
351 // tf2xla/kernels/l2loss_op.cc
352 "L2Loss" /*(Reduce)*/,
353 // tf2xla/kernels/lrn_ops.cc (ReduceWindow)
354 "LRN", "LRNGrad",
355 // tf2xla/kernels/matrix_band_part_op.cc
356 "MatrixBandPart",
357 // tf2xla/kernels/matrix_set_diag_op.cc
358 "MatrixSetDiag",
359 // tf2xla/kernels/mirror_pad_op.cc
360 "MirrorPad",
361 // tf2xla/kernels/no_op.cc
362 "NoOp", "ControlTrigger",
363 // tf2xla/kernels/one_hot_op.cc
364 "OneHot",
365 // tf2xla/kernels/pack_op.cc
366 "Pack",
367 // tf2xla/kernels/pad_op.cc
368 "Pad", "PadV2",
369 // tf2xla/kernels/pooling_ops.cc
370 "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool",
371 "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/
372 "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad",
373 "AvgPool3DGrad",
374 // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce)
375 "QuantizeAndDequantizeV2",
376 // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend
377 // currently)
378 "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
379 "TruncatedNormal",
380 // tf2xla/kernels/reduction_ops.cc (Reduce)
381 "Sum", "Prod", "Min", "Max", "Mean", "All", "Any",
382 // tf2xla/kernels/relu_op.cc
383 "Relu", "Relu6", "ReluGrad", "Relu6Grad",
384 // tf2xla/kernels/reshape_op.cc
385 "Reshape",
386 // tf2xla/kernels/reverse_op.cc
387 "Reverse", "ReverseV2",
388 // tf2xla/kernels/reverse_sequence_op.cc
389 "ReverseSequence",
390 // tf2xla/kernels/scan_ops.cc (ReduceWindow)
391 "Cumsum", "Cumprod",
392 // tf2xla/kernels/scatter_nd_op.cc (Reduce)
393 "ScatterNd",
394 // tf2xla/kernels/segment_reduction_ops.cc (Reduce)
395 "UnsortedSegmentSum",
396 // tf2xla/kernels/select_op.cc
397 "Select",
398 // tf2xla/kernels/sequence_ops.cc
399 "Range", "LinSpace",
400 // tf2xla/kernels/shape_op.cc
401 "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
402 "ZerosLike", "OnesLike",
403 // tf2xla/kernels/slice_op.cc
404 "Slice",
405 // tf2xla/kernels/softmax_op.cc (Reduce)
406 "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits",
407 "SparseSoftmaxCrossEntropyWithLogits",
408 // tf2xla/kernels/spacetobatch_op.cc
409 "SpaceToBatchND", "SpaceToBatch",
410 // tf2xla/kernels/spacetodepth_op.cc
411 "SpaceToDepth",
412 // tf2xla/kernels/split_op.cc
413 "Split", "SplitV",
414 // tf2xla/kernels/stack_ops.cc
415 "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2",
416 // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on
417 // GPU
418 // backend currently)
419 "StatelessRandomUniform",
420 "StatelessRandomNormal"
421 // tf2xla/kernels/strided_slice_op.cc
422 "StridedSlice",
423 "StridedSliceGrad", "ResourceStridedSliceAssign",
424 // tf2xla/kernels/tile_ops.cc
425 "Tile",
426 // tf2xla/kernels/training_ops.cc
427 "ResourceApplyGradientDescent", "ResourceApplyMomentum",
428 "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp",
429 "ResourceApplyFtrl", "ResourceApplyFtrlV2",
430 // tf2xla/kernels/transpose_op.cc
431 "Transpose", "InvertPermutation",
432 // tf2xla/kernels/unary_ops.cc
433 "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
434 "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
435 "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
436 "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
437 "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
438 "Square", "Tan", "Tanh", "Real", "Imag",
439 // tf2xla/kernels/unpack_op.cc
440 "Unpack"});
441
442 return elementwise_ops->count(node.op()) > 0;
443 }
444
445 // Nodes that XLA can compile are put in `candidates`. Nodes put in
446 // `isolated_nodes` must either be unclustered or be put in trivial single-node
447 // clusters.
FindCompilationCandidates(const Graph & graph,FunctionLibraryDefinition * flib_def,Env * env,const std::function<bool (const Node *,const DeviceType &)> & is_compilable_fn,OrderedNodeSet * candidates,absl::flat_hash_set<Node * > * isolated_nodes)448 Status FindCompilationCandidates(
449 const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
450 const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
451 OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) {
452 OptimizerOptions opts;
453 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
454 new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
455 flib_def, opts));
456 FunctionLibraryRuntime* lib_runtime =
457 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
458 std::vector<bool> compile_time_const_nodes(graph.num_node_ids(), false);
459 TF_RETURN_IF_ERROR(
460 BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
461 &compile_time_const_nodes, lib_runtime));
462
463 int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
464
465 // Iterate over nodes in sorted order so that compiler fuel is deterministic.
466 // We can't simply pass op_nodes().begin() and op_nodes().end to the
467 // std::vector constructor because they're not proper iterators, with
468 // iterator_traits defined and so on.
469 std::vector<Node*> sorted_nodes;
470 for (Node* node : graph.op_nodes()) {
471 sorted_nodes.push_back(node);
472 }
473 std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
474
475 if (fuel >= std::numeric_limits<int64>::max() / 2) {
476 // The assumption is that if fuel started out as INT64_MAX, it will forever
477 // stay greater than INT64_MAX / 2.
478 VLOG(2) << "Starting fuel: infinity";
479 } else {
480 VLOG(2) << "Starting fuel: " << fuel;
481 }
482
483 for (Node* node : sorted_nodes) {
484 if (fuel <= 0) {
485 VLOG(1)
486 << "Hit fuel limit; not marking any remaining ops as clusterable.";
487 break;
488 }
489
490 DeviceType device_type("");
491 TF_RETURN_IF_ERROR(
492 DeviceToDeviceType(node->assigned_device_name(), &device_type));
493 VLOG(4) << "Device type for " << node->name() << ": "
494 << device_type.type_string();
495
496 if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
497 // is_compilable_fn has already logged the reason if it returned false.
498 continue;
499 }
500
501 const XlaOpRegistry::DeviceRegistration* registration;
502 CHECK(
503 XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
504 DeviceType jit_device_type(registration->compilation_device_name);
505
506 bool always_auto_cluster = registration->autoclustering_policy ==
507 XlaOpRegistry::AutoclusteringPolicy::kAlways;
508
509 OperationFilter op_filter;
510 op_filter.allow_resource_ops = registration->compile_all_resource_ops;
511 op_filter.allow_stateful_rng_ops = always_auto_cluster;
512 op_filter.allow_control_trigger = always_auto_cluster;
513 op_filter.allow_dummy_ops = always_auto_cluster;
514 op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster;
515
516 if (!HasXLAKernel(*node, jit_device_type) &&
517 !IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
518 lib_runtime)) {
519 VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
520 << node->type_string();
521 continue;
522 }
523
524 if (!op_filter.allow_stateful_rng_ops &&
525 IsStatefulRandomOp(node->type_string())) {
526 VLOG(2) << "Rejecting " << node->name() << ": stateful random operation";
527 continue;
528 }
529 if (!op_filter.allow_control_trigger && node->IsControlTrigger()) {
530 VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op";
531 continue;
532 }
533 if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
534 VLOG(2) << "Rejecting " << node->name() << ": dummy op ("
535 << node->type_string() << ")";
536 continue;
537 }
538 if (!op_filter.allow_ops_producing_or_consuming_variant &&
539 OpProducesOrConsumesVariant(*node)) {
540 VLOG(2) << "Rejecting " << node->name()
541 << ": produces or consumes DT_VARIANT";
542 continue;
543 }
544
545 if (!registration->compile_all_resource_ops &&
546 (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
547 // We don't have a way of returning values of type DT_RESOURCE from XLA
548 // computations so we avoid auto-clustering nodes producing DT_RESOURCE.
549 // XlaLaunchOp also cannot snapshot resources that are not resource
550 // variables so we avoid clustering resource operations that operate on
551 // non-resource variables.
552 VLOG(2) << "Rejecting: " << node->name() << ": resource output "
553 << node->type_string();
554 continue;
555 }
556
557 if (compile_time_const_nodes[node->id()]) {
558 const OpDef* op_def;
559 TF_RETURN_IF_ERROR(
560 graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
561 if (op_def->is_stateful()) {
562 // It is easiest to demonstrate the problem we're trying to solve with
563 // an example. Say we have this graph:
564 //
565 // shape = RandomUniformInt();
566 // reshape = Reshape(input, shape)
567 //
568 // Both RandomUniformInt and Reshape are compilable by XLA so, absent
569 // any other reason, we will try to put both shape and reshape in the
570 // same cluster. However, since XLA only supports statically shaped
571 // values, it will expect to be able to constant fold `shape` to get a
572 // static shape for `reshape`. This is a problem because side-effecting
573 // ops like RandomUniformInt() cannot be constant folded. We fix this
574 // by putting `shape` and `reshape` in different clusters, which results
575 // in us recompiling `reshape`'s cluster for every new value of `shape`,
576 // making `reshape` statically sized within each compilation. We
577 // simplify the solution even further by disallowing operations like
578 // `shape` from being part of *any* non-trivial cluster. They're either
579 // not compiled by XLA altogether or, if assigned to an XLA_* device
580 // with "must compile" semantics, compiled into a trivial single-op
581 // cluster. This approach leaves some room for improvement, and we can
582 // consider implementing a more aggressive data-flow-analysis based
583 // solution in the future if needed.
584 //
585 // One ugly problem we have to contend with: certain sets of ops *have*
586 // to be in the same cluster because values flowing between them have
587 // types that can't be live-in or live-out of a cluster. These ops are:
588 //
589 // - TensorArray ops operating on the same TensorArray instance.
590 // - Stack ops operating on the same Stack instance.
591 //
592 // To work around this we avoid isolating these specific ops. Because
593 // of this concession it is unsound to auto-cluster them because then
594 // we'd create clusters we could not compile (because we can't constant
595 // fold, say, a TensorArrayRead or a StackPopV2). But we don't
596 // auto-cluster these operations today so we're good for now.
597 const XlaResourceOpInfo* op_info =
598 GetResourceOpInfoForOp(node->type_string());
599 bool is_tensor_array_or_stack_op =
600 op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
601 if (!is_tensor_array_or_stack_op) {
602 VLOG(2) << "Isolating " << node->name()
603 << ": must-be-constant stateful op";
604 isolated_nodes->insert(node);
605 // Keep going and execute all the other checks.
606 }
607 }
608 }
609 // We don't auto-cluster functional control flow nodes containing resource
610 // operations because safety checks are trickier in this case.
611 // registration->compile_all_resource_ops is true for XLA_CPU/XLA_GPU but
612 // not for CPU/GPU.
613 if (node->type_string() == "While" &&
614 !IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) {
615 continue;
616 }
617 // _Arg nodes in a top-level function represent feeds.
618 // Do not compile them.
619 if (node->type_string() == "_Arg") {
620 continue;
621 }
622 // _Retval nodes in a top-level function represent fetches.
623 // Do not compile them.
624 if (node->type_string() == "_Retval") {
625 continue;
626 }
627 candidates->insert(node);
628 --fuel;
629 }
630 VLOG(2) << "candidates->size() = " << candidates->size();
631 return Status::OK();
632 }
633
634 struct Cluster {
635 // Identifies the node that represents this cluster in the cycle detection
636 // graph.
637 int representative = -1;
638
639 // The set of devices the nodes in this cluster are placed on.
640 absl::flat_hash_set<string> devices;
641
642 // If there are resource operation in the cluster then this is the device that
643 // resource operations are placed on. All resource operations in a cluster
644 // must be placed on the same device.
645 string resource_op_device;
646
647 // True if any node in the cluster has an _XlaCompile attribute set to true.
648 bool has_xla_compile_attr;
649 };
650
651 } // anonymous namespace
652
IsCompilable(FunctionLibraryRuntime * flr,const NodeDef & ndef)653 bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
654 Device* device = flr->device();
655 const XlaOpRegistry::DeviceRegistration* registration;
656 CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
657 ®istration));
658 DeviceType jit_device_type(registration->compilation_device_name);
659
660 // We can always *compile* resource operations, stateful RNGs and dummy ops,
661 // even if we are sometimes unable to auto-cluster them.
662 OperationFilter op_filter;
663 op_filter.allow_resource_ops = true;
664 op_filter.allow_stateful_rng_ops = true;
665 op_filter.allow_control_trigger = true;
666 op_filter.allow_dummy_ops = true;
667 op_filter.allow_ops_producing_or_consuming_variant = true;
668
669 return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
670 }
671
Run(const GraphOptimizationPassOptions & options)672 Status MarkForCompilationPass::Run(
673 const GraphOptimizationPassOptions& options) {
674 // TODO(phawkins): precompute the "GetCompilationDevice" properties of each
675 // device ahead of time.
676 OptimizerOptions::GlobalJitLevel global_jit_level =
677 GetGlobalJitLevel(options);
678 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
679 bool fusion_only = flags->tf_xla_fusion_only;
680
681 VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
682 VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
683 const FunctionLibraryDefinition* fld = options.flib_def;
684
685 // Deadness analysis expects a graph with source and sink edges properly
686 // connected but sometimes the incoming graph does not follow this invariant.
687 // So fix up the source and sink edges before calling into deadness analysis.
688 FixupSourceAndSinkEdges(options.graph->get());
689
690 // See explanation on `kXlaAlreadyClustered`.
691 for (Node* n : options.graph->get()->nodes()) {
692 if (n->attrs().Find(kXlaAlreadyClustered)) {
693 return Status::OK();
694 }
695 }
696
697 std::unique_ptr<DeadnessAnalysis> deadness;
698 {
699 XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
700 TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
701 }
702
703 bool deadness_analysis_disabled =
704 GetMarkForCompilationPassFlags()
705 ->tf_xla_disable_deadness_safety_checks_for_debugging;
706
707 if (deadness_analysis_disabled) {
708 LOG(WARNING) << "Deadness analysis was manually disabled via "
709 "--tf_xla_disable_deadness_safety_checks_for_debugging; "
710 "auto-clustering "
711 "is unsound!";
712 }
713
714 auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
715 const XlaOpRegistry::DeviceRegistration* registration;
716 if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
717 ®istration)) {
718 VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device.";
719 return false;
720 }
721
722 // If there is a _XlaCompile annotation, use its value.
723 bool compile = false;
724 Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
725 if (status.ok()) {
726 if (!compile) {
727 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
728 << kXlaCompileAttr << ") is false.";
729 }
730 return compile;
731 }
732
733 status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
734 if (status.ok()) {
735 if (!compile) {
736 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
737 << kXlaCompileAttr << ") on callee is false.";
738 }
739 return compile;
740 }
741
742 // If inputs to `node` can have conflicting deadness (i.e. some are alive
743 // and some are dead) then don't compile it. XLA cannot represent the
744 // deadness semantics of these nodes correctly and auto-clustering these
745 // nodes can cause deadness to propagate to nodes that should be live.
746 if (!deadness_analysis_disabled) {
747 if (node->IsMerge() ||
748 deadness->HasInputsWithMismatchingDeadness(*node)) {
749 VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness.";
750 return false;
751 }
752 }
753
754 // Check for fusable ops only if requested.
755 if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
756 VLOG(2) << "Rejecting " << node->name()
757 << ": not fusable op but fusion_only enabled.";
758 return false;
759 }
760
761 return true;
762 };
763
764 return RunImpl(options, is_compilable);
765 }
766
RatioToString(int numerator,int denominator)767 static string RatioToString(int numerator, int denominator) {
768 return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
769 (100.0 * numerator) / denominator);
770 }
771
VLogClusteringSummary(const Graph & g)772 static void VLogClusteringSummary(const Graph& g) {
773 if (!VLOG_IS_ON(2)) {
774 return;
775 }
776
777 std::map<absl::string_view, int> cluster_name_to_size;
778 std::map<absl::string_view, std::map<absl::string_view, int>>
779 cluster_name_to_op_histogram;
780 std::map<absl::string_view, int> unclustered_op_histogram;
781 int clustered_node_count = 0;
782
783 for (Node* n : g.nodes()) {
784 absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
785 if (cluster_name) {
786 clustered_node_count++;
787 cluster_name_to_size[*cluster_name]++;
788 cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
789 } else {
790 unclustered_op_histogram[n->type_string()]++;
791 }
792 }
793
794 int unclustered_node_count = g.num_nodes() - clustered_node_count;
795
796 VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes();
797 VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
798 << RatioToString(clustered_node_count, g.num_nodes());
799
800 for (const auto& cluster_name_size_pair : cluster_name_to_size) {
801 absl::string_view cluster_name = cluster_name_size_pair.first;
802 int size = cluster_name_size_pair.second;
803 VLOG(2) << " " << cluster_name << " "
804 << RatioToString(size, g.num_nodes());
805 for (const auto& op_count_pair :
806 cluster_name_to_op_histogram[cluster_name]) {
807 VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
808 << " instances";
809 }
810 }
811
812 if (!unclustered_op_histogram.empty()) {
813 VLOG(2) << " Unclustered nodes: "
814 << RatioToString(unclustered_node_count, g.num_nodes());
815 for (const auto& pair : unclustered_op_histogram) {
816 VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
817 }
818 }
819
820 struct EdgeInfo {
821 absl::string_view node_name;
822 absl::optional<absl::string_view> cluster_name;
823
824 absl::string_view GetClusterName() const {
825 return cluster_name ? *cluster_name : "[none]";
826 }
827
828 std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
829 const {
830 return {node_name, cluster_name};
831 }
832
833 bool operator<(const EdgeInfo& other) const {
834 return AsPair() < other.AsPair();
835 }
836 };
837
838 using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
839
840 EdgeInfoMap incoming_edge_infos;
841 EdgeInfoMap outgoing_edge_infos;
842
843 std::set<absl::string_view> cluster_names_to_print;
844
845 for (const Edge* e : g.edges()) {
846 const Node* from = e->src();
847 absl::optional<absl::string_view> from_cluster_name =
848 GetXlaClusterForNode(*from);
849
850 const Node* to = e->dst();
851 absl::optional<absl::string_view> to_cluster_name =
852 GetXlaClusterForNode(*to);
853
854 if (to_cluster_name == from_cluster_name) {
855 continue;
856 }
857
858 if (to_cluster_name) {
859 incoming_edge_infos[*to_cluster_name]
860 [EdgeInfo{from->name(), from_cluster_name}]++;
861 cluster_names_to_print.insert(*to_cluster_name);
862 }
863
864 if (from_cluster_name) {
865 outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
866 cluster_names_to_print.insert(*from_cluster_name);
867 }
868 }
869
870 VLOG(2) << "*** Inter-Cluster edges:";
871 if (cluster_names_to_print.empty()) {
872 VLOG(2) << " [none]";
873 }
874
875 auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
876 const EdgeInfoMap& edge_info_map,
877 absl::string_view desc) {
878 auto it = edge_info_map.find(cluster_name);
879 if (it != edge_info_map.end()) {
880 VLOG(2) << " " << it->second.size() << " " << desc << " edges";
881 for (const auto& edge_info_count_pair : it->second) {
882 VLOG(2) << " " << edge_info_count_pair.first.GetClusterName() << " "
883 << edge_info_count_pair.first.node_name << " # "
884 << edge_info_count_pair.second;
885 }
886 } else {
887 VLOG(2) << " No " << desc << " edges.";
888 }
889 };
890
891 for (absl::string_view cluster_name : cluster_names_to_print) {
892 VLOG(2) << " ** Cluster " << cluster_name;
893 print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
894 "incoming");
895 print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
896 "outgoing");
897 }
898 }
899
900 // Is 'node' an operator that consumes only the shape of its input, not the
901 // data itself?
IsShapeConsumerOp(const Node & node)902 static bool IsShapeConsumerOp(const Node& node) {
903 return node.type_string() == "Shape" || node.type_string() == "Rank" ||
904 node.type_string() == "Size";
905 }
906
IgnoreResourceOpForSafetyAnalysis(const Node & n,bool * ignore)907 static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) {
908 // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
909 // ignore it during resource operation safety analysis. We need this hack
910 // because of two reasons:
911 //
912 // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
913 // 2. We don't support live-out values of type DT_RESOURCE and live-in values
914 // of type DT_RESOURCE that are not resource variables.
915 //
916 // Together these imply we cannot let resource variable safety analysis
917 // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
918 // clusters: both of them will have to be clustered because of (1) and we
919 // won't be able to keep the edge between the two as neither the input to the
920 // second XLA cluster nor the output from the first XLA cluster are supported
921 // because of (2).
922 //
923 // TODO(b/113100872): This can be fixed if the TensorFlow representation for
924 // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
925 // (2) would no longer hold.
926
927 if (n.assigned_device_name().empty()) {
928 *ignore = false;
929 return Status::OK();
930 }
931 DeviceType device_type("");
932 TF_RETURN_IF_ERROR(
933 DeviceToDeviceType(n.assigned_device_name(), &device_type));
934
935 const XlaOpRegistry::DeviceRegistration* registration;
936 if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
937 *ignore = true;
938 } else {
939 *ignore = registration->compile_all_resource_ops;
940 }
941 return Status::OK();
942 }
943
944 // Sequence number generator to ensure clusters have unique names.
945 static std::atomic<int64> cluster_sequence_num;
946
947 // Returns true if the devices in `cluster_a` and `cluster_b` are compatible and
948 // therefore not a hindrance for combining the two clusters into a larger
949 // cluster.
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b,OptimizerOptions::GlobalJitLevel global_jit_level,bool * result)950 static Status AreDevicesCompatible(
951 const Cluster& cluster_a, const Cluster& cluster_b,
952 OptimizerOptions::GlobalJitLevel global_jit_level, bool* result) {
953 std::vector<string> devices;
954 absl::c_remove_copy(cluster_a.devices, std::back_inserter(devices), "");
955 absl::c_remove_copy(cluster_b.devices, std::back_inserter(devices), "");
956 absl::c_sort(devices);
957
958 if (devices.empty()) {
959 *result = false;
960 return Status::OK();
961 }
962
963 // First check if we will even be able to pick a device for the larger
964 // combined cluster.
965 bool can_pick_device;
966 TF_RETURN_IF_ERROR(CanPickDeviceForXla(
967 devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device));
968 if (!can_pick_device) {
969 *result = false;
970 return Status::OK();
971 }
972
973 string chosen_device;
974 TF_RETURN_IF_ERROR(PickDeviceForXla(
975 devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
976
977 // If we are able to pick a device `chosen_device` for the larger cluster, the
978 // resource operations in `cluster_a` and `cluster_b` must be placed on the
979 // same device as `chosen_device`. This is because the _XlaCompile and
980 // _XlaRun kernels are going to run on and therefore try to access the
981 // resource variables from `chosen_device`, which will be an error if the
982 // resource variables are placed on some other device.
983 auto resource_op_device_ok = [&](const string& resource_op_device) {
984 return resource_op_device.empty() || resource_op_device == chosen_device;
985 };
986
987 *result = resource_op_device_ok(cluster_a.resource_op_device) &&
988 resource_op_device_ok(cluster_b.resource_op_device);
989 if (!*result) {
990 return Status::OK();
991 }
992
993 // We will check this again later, but here we prune out clusters that would
994 // never have been sent to XLA to save compile time. Without this change we
995 // will e.g. create a CPU cluster only to later notice that the user did not
996 // enable the CPU JIT via --tf_xla_cpu_global_jit. With this change we avoid
997 // creating the cluster to begin with.
998 //
999 // TODO(b/126629785): It is possible that this is just papering over O(n^2)
1000 // behavior in our clustering algorithm.
1001 const XlaOpRegistry::DeviceRegistration* registration;
1002 DeviceType device_type("");
1003 TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type));
1004 TF_RET_CHECK(
1005 XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration))
1006 << "chosen device = " << chosen_device
1007 << "; device type = " << device_type.type() << "; devices ("
1008 << devices.size() << ") = " << absl::StrJoin(devices, ", ");
1009
1010 *result = cluster_a.has_xla_compile_attr || cluster_b.has_xla_compile_attr ||
1011 registration->autoclustering_policy ==
1012 XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1013 (registration->autoclustering_policy ==
1014 XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1015 global_jit_level != OptimizerOptions::OFF);
1016
1017 return Status::OK();
1018 }
1019
1020 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster,OptimizerOptions::GlobalJitLevel global_jit_level,bool * should_compile,string * device)1021 static Status ShouldCompileClusterImpl(
1022 const Cluster& cluster, OptimizerOptions::GlobalJitLevel global_jit_level,
1023 bool* should_compile, string* device) {
1024 std::vector<string> devices;
1025 absl::c_remove_copy(cluster.devices, std::back_inserter(devices), "");
1026 absl::c_sort(devices);
1027
1028 string chosen_device;
1029 TF_RETURN_IF_ERROR(PickDeviceForXla(
1030 devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
1031
1032 const XlaOpRegistry::DeviceRegistration* registration;
1033 DeviceType device_type("");
1034 TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type));
1035 TF_RET_CHECK(
1036 XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration))
1037 << "chosen device = " << chosen_device
1038 << "; device type = " << device_type.type() << "; devices ("
1039 << devices.size() << ") = " << absl::StrJoin(devices, ", ");
1040
1041 *should_compile =
1042 cluster.has_xla_compile_attr ||
1043 registration->autoclustering_policy ==
1044 XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1045 (registration->autoclustering_policy ==
1046 XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1047 global_jit_level != OptimizerOptions::OFF);
1048
1049 VLOG(3) << (*should_compile ? "Compiling" : "Not compiling")
1050 << " cluster with device " << chosen_device;
1051
1052 *device = std::move(chosen_device);
1053 return Status::OK();
1054 }
1055
ShouldCompileCluster(absl::flat_hash_map<int,std::pair<bool,string>> * cache,OptimizerOptions::GlobalJitLevel global_jit_level,const Cluster & cluster,bool * should_compile,string * device)1056 static Status ShouldCompileCluster(
1057 absl::flat_hash_map<int, std::pair<bool, string>>* cache,
1058 OptimizerOptions::GlobalJitLevel global_jit_level, const Cluster& cluster,
1059 bool* should_compile, string* device) {
1060 auto it = cache->find(cluster.representative);
1061 if (it != cache->end()) {
1062 *should_compile = it->second.first;
1063 *device = it->second.second;
1064 return Status::OK();
1065 }
1066
1067 string device_s;
1068 TF_RETURN_IF_ERROR(ShouldCompileClusterImpl(cluster, global_jit_level,
1069 should_compile, &device_s));
1070 cache->insert({cluster.representative, {*should_compile, device_s}});
1071 *device = std::move(device_s);
1072 return Status::OK();
1073 }
1074
RunImpl(const GraphOptimizationPassOptions & options,const std::function<bool (const Node *,const DeviceType &)> & is_compilable_fn)1075 Status MarkForCompilationPass::RunImpl(
1076 const GraphOptimizationPassOptions& options,
1077 const std::function<bool(const Node*, const DeviceType&)>&
1078 is_compilable_fn) {
1079 VLOG(1) << "MarkForCompilationPass::Run";
1080
1081 // Make sure that kernels have been registered on the JIT device.
1082 XlaOpRegistry::RegisterCompilationKernels();
1083
1084 Graph* graph = options.graph->get();
1085
1086 OrderedNodeSet compilation_candidates;
1087 absl::flat_hash_set<Node*> isolated_nodes;
1088 TF_RETURN_IF_ERROR(FindCompilationCandidates(
1089 *graph, options.flib_def,
1090 (options.session_options != nullptr) ? options.session_options->env
1091 : Env::Default(),
1092 is_compilable_fn, &compilation_candidates, &isolated_nodes));
1093
1094 if (compilation_candidates.empty()) {
1095 VLOG(2) << "No compilable candidates";
1096 return Status::OK();
1097 }
1098
1099 GraphCycles cycles;
1100 TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
1101 CreateCycleDetectionGraph(graph, &cycles));
1102 if (!cycle_detection_graph_ok) {
1103 return Status::OK();
1104 }
1105 TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
1106 graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));
1107
1108 // Each compilation candidate belongs to a cluster. The cluster's
1109 // representative
1110 // names the node in the 'cycles' graph that represents the cluster.
1111 std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
1112 std::deque<UnionFind<Cluster>*> worklist;
1113 for (Node* node : compilation_candidates) {
1114 Cluster& cluster = clusters[node->id()].Get();
1115 cluster.representative = node->id();
1116 const string& device = !node->assigned_device_name().empty()
1117 ? node->assigned_device_name()
1118 : node->requested_device();
1119 if (HasResourceInput(*node) || HasResourceOutput(*node)) {
1120 cluster.resource_op_device = device;
1121 }
1122 cluster.has_xla_compile_attr = false;
1123 bool xla_compile_attr;
1124 if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) {
1125 cluster.has_xla_compile_attr |= xla_compile_attr;
1126 }
1127 if (options.flib_def->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr)
1128 .ok()) {
1129 cluster.has_xla_compile_attr |= xla_compile_attr;
1130 }
1131
1132 cluster.devices.insert(device);
1133 worklist.push_back(&clusters[node->id()]);
1134 }
1135
1136 OptimizerOptions::GlobalJitLevel global_jit_level =
1137 GetGlobalJitLevel(options);
1138 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1139
1140 // Repeatedly contract edges between clusters that are on the same device,
1141 // provided the contraction would not create a cycle.
1142 //
1143 // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
1144 // example, from the Grappler fusion pass).
1145 while (!worklist.empty()) {
1146 Cluster* cluster_from = &worklist.front()->Get();
1147 int from = cluster_from->representative;
1148 worklist.pop_front();
1149
1150 Node* node_from = graph->FindNodeId(from);
1151 if (node_from->IsControlFlow()) {
1152 // Control flow nodes aren't compilation candidates and should never
1153 // appear.
1154 return errors::Internal(
1155 "Found control flow node in clustering worklist: ",
1156 node_from->type_string());
1157 }
1158
1159 if (isolated_nodes.count(node_from)) {
1160 continue;
1161 }
1162
1163 string from_scope;
1164 string to_scope;
1165 for (int to : cycles.Successors(from)) {
1166 if (to >= graph->num_node_ids()) {
1167 // Node is a fictitious node that is present only in the cycle detection
1168 // graph. No clustering is possible.
1169 continue;
1170 }
1171
1172 const Cluster& cluster_to = clusters[to].Get();
1173 Node* node_to = graph->FindNodeId(to);
1174 if (compilation_candidates.find(node_to) ==
1175 compilation_candidates.cend()) {
1176 continue;
1177 }
1178 bool devices_compatible;
1179 TF_RETURN_IF_ERROR(AreDevicesCompatible(
1180 *cluster_from, cluster_to, global_jit_level, &devices_compatible));
1181 if (!devices_compatible) {
1182 continue;
1183 }
1184 if (isolated_nodes.count(node_to)) {
1185 continue;
1186 }
1187 // Look for an _XlaScope on both nodes. If both nodes have a
1188 // scope and the scopes do not match, do not cluster along this
1189 // edge. This restriction is overridden if the global_jit_level is ON. If
1190 // even one of the nodes lacks an _XlaScope attribute,
1191 // then it is treated as a "bridge" and a cluster may be created
1192 // along it. We may want to restrict this behavior to require
1193 // all nodes marked with _XlaCompile=true to also have a
1194 // _XlaScope property set (and raise an error otherwise); but
1195 // for now we don't do this.
1196 if (global_jit_level == OptimizerOptions::OFF &&
1197 GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
1198 GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
1199 from_scope != to_scope) {
1200 continue;
1201 }
1202
1203 // Ops that consume shapes cannot be the root of a cluster. This is an
1204 // optimization.
1205 if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
1206 continue;
1207 }
1208
1209 // Don't exceed the maximum cluster size.
1210 if (clusters[from].Size() + clusters[to].Size() >
1211 flags->tf_xla_max_cluster_size) {
1212 continue;
1213 }
1214
1215 // If any of the consumer's producers are on a different device, do not
1216 // cluster these nodes. This prevents other work on this device from being
1217 // delayed by work on other devices. We consider predecessors of the
1218 // entire cluster rather than just the inputs to the node to prevent the
1219 // cluster still being combined in cases where the 'to' cluster has
1220 // multiple dependencies on the 'from' cluster and another dependency
1221 // leads to a merging of the clusters.
1222 //
1223 // TODO(b/117085735): We probably want to handle the reciprocal of this
1224 // case where a cluster is producing data for multiple devices.
1225 bool found_split = false;
1226 for (const auto& in_id : cycles.Predecessors(to)) {
1227 if (in_id >= graph->num_node_ids()) continue;
1228
1229 Node* in = graph->FindNodeId(in_id);
1230 const Cluster& cluster_in = clusters[in_id].Get();
1231 if (compilation_candidates.find(in) != compilation_candidates.cend()) {
1232 bool devices_compatible;
1233 TF_RETURN_IF_ERROR(AreDevicesCompatible(
1234 cluster_to, cluster_in, global_jit_level, &devices_compatible));
1235 if (!devices_compatible) {
1236 found_split = true;
1237 }
1238 }
1239 }
1240 if (found_split) continue;
1241
1242 // If contracting the edge would create a cycle, bail out.
1243 // However, just because we can't merge the clusters now does not mean
1244 // we won't be able to merge them in the future.
1245 // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
1246 // 1->3. But if we first contract 1->2 then we can later contract 1->3.
1247 if (!cycles.ContractEdge(from, to)) continue;
1248
1249 // Merge the clusters. ContractEdge uses 'from' as the number of the
1250 // merged node, so make sure 'from' is the chosen representative.
1251 cluster_from->devices.insert(cluster_to.devices.begin(),
1252 cluster_to.devices.end());
1253 if (!cluster_to.resource_op_device.empty()) {
1254 cluster_from->resource_op_device = cluster_to.resource_op_device;
1255 }
1256 cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr;
1257 clusters[from].Merge(&clusters[to]);
1258
1259 worklist.push_back(&clusters[from]);
1260 break;
1261 }
1262 }
1263
1264 // Count the number of non-trivial elements in each cluster.
1265 std::vector<int> effective_cluster_sizes(graph->num_node_ids());
1266
1267 // has_functional_control_flow remembers if a cluster contains a functional
1268 // control flow node.
1269 std::vector<bool> has_functional_control_flow(graph->num_node_ids());
1270
1271 for (const Node* n : compilation_candidates) {
1272 int cluster = clusters[n->id()].Get().representative;
1273 // We want clusters to be big enough that the benefit from XLA's
1274 // optimizations offsets XLA related overhead (for instance we add some
1275 // Switch/Merge nodes into the graph to implement lazy compilation). To
1276 // this end, we don't count Identity and Constant nodes because they do not
1277 // enable interesting optimizations by themselves.
1278 if (!n->IsIdentity() && !n->IsConstant()) {
1279 effective_cluster_sizes[cluster]++;
1280 }
1281 if (n->type_string() == "While" || n->type_string() == "If") {
1282 has_functional_control_flow[cluster] = true;
1283 }
1284 }
1285
1286 // Names for each cluster.
1287 std::unordered_map<int, string> cluster_names;
1288
1289 if (flags->tf_xla_clustering_debug) {
1290 DumpGraphToFile("before_mark_for_compilation", **options.graph,
1291 options.flib_def);
1292 }
1293
1294 absl::flat_hash_map<int, std::pair<bool, string>>
1295 should_compile_cluster_cache;
1296
1297 // Mark clusters for compilation that:
1298 // * are placed on a device that requires compilation (an XlaDevice),
1299 // * are explicitly marked for compilation (_XlaCompile=true), or
1300 // * have more than flags->tf_xla_min_cluster_size elements (applicable only
1301 // if compilation is enabled, otherwise there will be no such candidates).
1302 const int min_cluster_size = flags->tf_xla_min_cluster_size;
1303 for (Node* n : compilation_candidates) {
1304 const Cluster& cluster = clusters[n->id()].Get();
1305 bool should_compile;
1306 string device;
1307 TF_RETURN_IF_ERROR(ShouldCompileCluster(&should_compile_cluster_cache,
1308 global_jit_level, cluster,
1309 &should_compile, &device));
1310 if (!should_compile) {
1311 continue;
1312 }
1313
1314 int cluster_repr = cluster.representative;
1315
1316 // Compile if the user marked this node _XlaCompile=true
1317 bool compile_attr = false;
1318 bool marked_for_compilation = false;
1319 if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
1320 marked_for_compilation = compile_attr;
1321 } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr)
1322 .ok()) {
1323 marked_for_compilation = compile_attr;
1324 }
1325
1326 // We assume that functional If and While nodes have at least
1327 // min_cluster_size non-trivial nodes in them. It would be more principled
1328 // to (recursively) verify this fact, but that's probably not worth the
1329 // trouble.
1330
1331 if (effective_cluster_sizes[cluster_repr] >= min_cluster_size ||
1332 has_functional_control_flow[cluster_repr] || marked_for_compilation) {
1333 string& name = cluster_names[cluster_repr];
1334
1335 if (name.empty()) {
1336 name = absl::StrCat("cluster_", cluster_sequence_num++);
1337 }
1338 n->AddAttr(kXlaClusterAttr, name);
1339 n->AddAttr(kXlaAlreadyClustered, true);
1340 VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
1341 }
1342 }
1343
1344 if (flags->tf_xla_clustering_debug) {
1345 DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def);
1346
1347 // We also dump out an annoated version of the TF graph where the nodes
1348 // names are prefixed with the cluster names. This can help visualizing the
1349 // clustering decisions on TensorBoard.
1350 Graph new_graph((*options.graph)->op_registry());
1351 CopyGraph(**options.graph, &new_graph);
1352
1353 for (Node* n : new_graph.nodes()) {
1354 if (absl::optional<absl::string_view> cluster_name =
1355 GetXlaClusterForNode(*n)) {
1356 n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1357 } else if (n->type_string() == "VarHandleOp") {
1358 n->set_name(absl::StrCat("varhandle/", n->name()));
1359 } else {
1360 // There is room for improvement here. In particular, it may help to
1361 // split these unclustered nodes into classes where every node in a
1362 // specific class has edges to and from the same set of clusters.
1363 n->set_name(absl::StrCat("unclustered/", n->name()));
1364 }
1365 }
1366
1367 DumpGraphToFile("mark_for_compilation_annotated", new_graph,
1368 options.flib_def);
1369 }
1370
1371 VLogClusteringSummary(*graph);
1372
1373 return Status::OK();
1374 }
1375
1376 } // namespace tensorflow
1377