1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17
18 #include <atomic>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23
24 #include "absl/base/call_once.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/jit/compilability_check_util.h"
29 #include "tensorflow/compiler/jit/deadness_analysis.h"
30 #include "tensorflow/compiler/jit/defs.h"
31 #include "tensorflow/compiler/jit/device_util.h"
32 #include "tensorflow/compiler/jit/flags.h"
33 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
34 #include "tensorflow/compiler/jit/xla_cluster_util.h"
35 #include "tensorflow/compiler/tf2xla/const_analysis.h"
36 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
37 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
38 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/union_find.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/common_runtime/function.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/framework/bounds_check.h"
45 #include "tensorflow/core/framework/graph_def_util.h"
46 #include "tensorflow/core/framework/memory_types.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/tensor.pb.h"
50 #include "tensorflow/core/framework/types.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/control_flow.h"
53 #include "tensorflow/core/lib/gtl/cleanup.h"
54 #include "tensorflow/core/lib/strings/stringprintf.h"
55 #include "tensorflow/core/public/version.h"
56 #include "tensorflow/core/util/dump_graph.h"
57
58 namespace tensorflow {
59
60 namespace {
61 using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
62 using jit::DeviceId;
63 using jit::DeviceSet;
64
65 // The clusters we create here are eventually lowered into an
66 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
67 // PartitionedCall op to execute the cluster in the regular graph executor if
68 // need be. PartitionedCall, however, reruns the entire TF graph optimization
69 // pipeline over the cluster which includes this mark for compilation pass. To
70 // avoid endlessly recursing we tag nodes that we've already visited with this
71 // attribute so that we can bail out if we see them a second time.
72 //
73 // TODO(sanjoy): This method is not robust since it is possible that the
74 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
75 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
76 // The correct fix is to use the ConfigProto to pass in some sort of flag into
77 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
78 // cluster.
79 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
80
81 class MarkForCompilationPassImpl {
82 public:
83 struct DebugOptions {
84 // If true, do not respect the results of deadness analysis.
85 bool ignore_deadness_checks;
86
87 // If true, do not do safety checks to preserve TensorFlow's resource
88 // variable concurrency semantics.
89 bool ignore_resource_variable_checks;
90
91 // If true, do not respect the _XlaCompile=false attribute.
92 bool ignore_xla_compile_attr;
93
94 int max_cluster_size;
95 int min_cluster_size;
96
97 // Compiler fuel for the auto-clustering algorithm.
98 //
99 // We decrement this value by one on every time we choose a compilation
100 // candidate and we stop clustering when it hits zero. This means the
101 // initial value for this variable (via --tf_xla_clustering_fuel=N)
102 // effectively acts as a "cap" for how much we cluster and we can bisect
103 // over this initial value to discover clustering decisions that cause a
104 // miscompile or a performance regression.
105 std::atomic<int64>* fuel;
106
107 bool dump_graphs;
108 };
109
MarkForCompilationPassImpl(DebugOptions debug_options,Graph * graph,FunctionLibraryDefinition * flib_def,Env * env,OptimizerOptions::GlobalJitLevel global_jit_level)110 MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph,
111 FunctionLibraryDefinition* flib_def, Env* env,
112 OptimizerOptions::GlobalJitLevel global_jit_level)
113 : debug_options_(debug_options),
114 graph_(graph),
115 flib_def_(flib_def),
116 env_(env),
117 global_jit_level_(global_jit_level) {}
118
119 Status Run();
120
121 private:
122 // Represents a "cluster" or a connected subgraph of a TensorFlow graph.
123 class Cluster {
124 public:
125 // Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,DeviceSet devices,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)126 Cluster(int tf_graph_node_id, int effective_cluster_size,
127 bool has_functional_control_flow, DeviceSet devices,
128 absl::optional<DeviceId> resource_op_device,
129 absl::optional<int> resource_var_operation_node_id,
130 absl::optional<DeadnessPredicate> deadness_predicate,
131 bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
132 : cycles_graph_node_id_(tf_graph_node_id),
133 effective_cluster_size_(effective_cluster_size),
134 has_functional_control_flow_(has_functional_control_flow),
135 devices_(std::move(devices)),
136 resource_op_device_(resource_op_device),
137 deadness_predicate_(deadness_predicate),
138 is_xla_compile_attr_true_(is_xla_compile_attr_true),
139 xla_scope_(std::move(xla_scope)) {
140 if (resource_var_operation_node_id.has_value()) {
141 resource_var_operation_node_ids_.push_back(
142 *resource_var_operation_node_id);
143 }
144 }
145
146 // Merges `other` into this cluster, and clears `other`. This method is
147 // closely tied with the implementation of `MarkForCompilationPassImpl`.
148 void Merge(Cluster* other);
149
150 // If this is a trivial cluster containing only one node then return the ID
151 // of that node. May not be called otherwise.
GetIdOfOnlyNode() const152 int GetIdOfOnlyNode() const {
153 DCHECK_EQ(cluster_size(), 1);
154 return cycles_graph_node_id();
155 }
156
157 // The number of TF nodes in this cluster.
cluster_size() const158 int cluster_size() const { return cluster_size_; }
159
160 // The ID of the cluster as represented in `cycles_graph_`.
cycles_graph_node_id() const161 int cycles_graph_node_id() const { return cycles_graph_node_id_; }
162
163 // Sets the ID of the cluster as represented in `cycles_graph_`.
set_cycles_graph_node_id(int cycles_graph_node_id)164 void set_cycles_graph_node_id(int cycles_graph_node_id) {
165 cycles_graph_node_id_ = cycles_graph_node_id;
166 }
167
168 // The size of the cluster excluding constant and identity nodes.
effective_cluster_size() const169 int effective_cluster_size() const { return effective_cluster_size_; }
170
171 // True if the cluster has functional control flow like `If` and `While`.
has_functional_control_flow() const172 bool has_functional_control_flow() const {
173 return has_functional_control_flow_;
174 }
175
176 // The set of devices nodes in the cluster are placed on.
devices() const177 const DeviceSet& devices() const { return devices_; }
178
179 // If the cluster has a resource operation then the device the resource
180 // operation is placed on. A cluster may have resource ops placed only on a
181 // single device.
resource_op_device() const182 const absl::optional<DeviceId>& resource_op_device() const {
183 return resource_op_device_;
184 }
185
186 // If not nullopt the a predicate that is true iff the cluster is alive.
187 // Otherwise the user has (unsafely) disabled deadness analysis. If this is
188 // unset on a single Cluster instance then it is unset on all Cluster
189 // instances.
deadness_predicate() const190 const absl::optional<DeadnessPredicate>& deadness_predicate() const {
191 return deadness_predicate_;
192 }
193
194 // If true then the cluster has a XlaCompile=true attribute on one of its
195 // nodes.
is_xla_compile_attr_true() const196 bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; }
197
198 // If not nullopt then the all nodes in the cluster either do not have the
199 // XlaScope attribute set or have it set to the value returned.
xla_scope() const200 const absl::optional<string>& xla_scope() const { return xla_scope_; }
201
202 // Returns the TF graph node IDs for the resource variable operations in
203 // this cluster.
resource_var_operation_node_ids() const204 absl::Span<const int> resource_var_operation_node_ids() const {
205 return resource_var_operation_node_ids_;
206 }
207
DebugString(const Graph & graph) const208 string DebugString(const Graph& graph) const {
209 Node* node = graph.FindNodeId(cycles_graph_node_id());
210 if (!node) {
211 // This should never happen but we try to be resilient because this is a
212 // debugging aid.
213 return absl::StrCat("NULL NODE IN #", cycles_graph_node_id());
214 }
215
216 if (cluster_size() == 1) {
217 return absl::StrCat("<", node->name(), " #", cycles_graph_node_id(),
218 ">");
219 }
220
221 return absl::StrCat("<", node->name(), " + ", cluster_size() - 1,
222 " others #", cycles_graph_node_id(), ">");
223 }
224
225 private:
226 int cluster_size_ = 1;
227 int cycles_graph_node_id_;
228 int effective_cluster_size_;
229 bool has_functional_control_flow_;
230 DeviceSet devices_;
231 absl::optional<DeviceId> resource_op_device_;
232 absl::optional<DeadnessPredicate> deadness_predicate_;
233 bool is_xla_compile_attr_true_;
234 absl::optional<string> xla_scope_;
235 std::vector<int> resource_var_operation_node_ids_;
236
237 TF_DISALLOW_COPY_AND_ASSIGN(Cluster);
238 };
239
240 // If `cluster` has only a single node then returns that, otherwise returns
241 // nullptr.
242 Node* GetOnlyNodeIn(const Cluster& cluster);
243
244 // Returns true if `cluster` is a trivial cluster containing a "sink like"
245 // node -- a NoOp node that only the Sink node control depends on.
246 bool IsSinkLike(const Cluster& cluster);
247
248 // Returns true if `cluster` looks like an "i++" operation on an integer
249 // scalar resource variable.
250 bool IsScalarIntegerResourceOperation(const Cluster& cluster);
251
252 // ---------------------------------------------------------------------------
253 // The pass proceeds in four steps, out of which `RunEdgeContractionLoop` and
254 // `CreateClusters` do most of the heavy lifting.
255
256 // Initializes some internal data structures.
257 //
258 // If this returns false then Initialize exited early (either because there is
259 // nothing to do or we saw a graph that we can't handle) and not all the
260 // fields in this MarkForCompilationPassImpl instance are set up.
261 StatusOr<bool> Initialize();
262
263 // Runs through the entire cluster graph in post-order and calls `fn(from,
264 // to)` on each edge. `fn(from, to)` is expected to return true if it was
265 // able to contract `from`->`to`.
266 //
267 // Returns true if `fn` returned true for any edge.
268 template <typename FnTy>
269 StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
270
271 // Contracts as many edges as possible to create XLA clusters. After this
272 // finishes the clustering decisions made are implicitly stored in
273 // `clusters_`.
274 Status RunEdgeContractionLoop();
275
276 // Manifests the clustering decisions into the TF graph by tagging nodes with
277 // an `_XlaCluster` attribute. Also some basic filter logic, like
278 // tf_xla_min_cluster_size, are applied here.
279 Status CreateClusters();
280
281 Status DumpDebugInfo();
282
IsCompilationCandidate(Node * n) const283 bool IsCompilationCandidate(Node* n) const {
284 return compilation_candidates_.find(n) != compilation_candidates_.end();
285 }
286
287 // Tries to contract the edge from cluster `from` to cluster `to`. Returns
288 // true if successful.
289 StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
290
291 // Nodes that XLA can compile are put in `compilation_candidates_`.
292 Status FindCompilationCandidates();
293
294 bool CompilationDisallowedByXlaCompileAttr(Node* node);
295
296 // Populates `clusters_`.
297 Status BuildInitialClusterSet();
298
299 StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
300
301 StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
302
303 StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
304 const Cluster& from, const Cluster& to);
305
306 // Returns true if the devices in `cluster_a` and `cluster_b` are compatible
307 // and therefore not a hindrance for combining the two clusters into a larger
308 // cluster.
309 StatusOr<bool> AreDevicesCompatible(const Cluster& cluster_a,
310 const Cluster& cluster_b);
311
312 void DumpPostClusteringGraphs();
313 void VLogClusteringSummary();
314
MakeNewCluster(int cycles_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,const DeviceSet & device_set,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)315 Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
316 bool has_functional_control_flow,
317 const DeviceSet& device_set,
318 absl::optional<DeviceId> resource_op_device,
319 absl::optional<int> resource_var_operation_node_id,
320 absl::optional<DeadnessPredicate> deadness_predicate,
321 bool is_xla_compile_attr_true,
322 absl::optional<string> xla_scope) {
323 cluster_storage_.push_back(absl::make_unique<Cluster>(
324 cycles_graph_node_id, effective_cluster_size,
325 has_functional_control_flow, device_set, resource_op_device,
326 resource_var_operation_node_id, deadness_predicate,
327 is_xla_compile_attr_true, xla_scope));
328 return cluster_storage_.back().get();
329 }
330
331 absl::optional<string> GetXlaScope(Node* n);
332
333 // Returns the cluster for node `n`. If two nodes, N1 and N2, are placed in
334 // the same cluster by the clustering algorithm then this function will return
335 // the same Cluster instance for N1 and N2.
336 //
337 // Returns nullptr if `n` is not a compilation candidate.
GetClusterForNode(Node * n)338 Cluster* GetClusterForNode(Node* n) {
339 return cluster_for_node_[n->id()].Get();
340 }
341
342 // Returns the cluster for a node in `cycles_graph_`. This uses the same
343 // underlying map because of how we set things up, but we can do an additional
344 // CHECK in this accessor.
345 //
346 // Returns nullptr if `node_id` is not a compilation candidate.
GetClusterForCyclesGraphNode(int node_id)347 Cluster* GetClusterForCyclesGraphNode(int node_id) {
348 // We have to check `graph_->FindNodeId(node) == nullptr` because we add all
349 // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the
350 // TF graph may be missing some node ids.
351 if (node_id >= graph_->num_node_ids() ||
352 graph_->FindNodeId(node_id) == nullptr) {
353 return nullptr;
354 }
355 Cluster* cluster = cluster_for_node_[node_id].Get();
356 if (cluster) {
357 DCHECK_EQ(cluster->cycles_graph_node_id(), node_id);
358 }
359 return cluster;
360 }
361
362 bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to,
363 absl::string_view reason);
364
365 // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct
366 // edge from `from` to `to`.
367 //
368 // Tries to find a path that contains at least one unclusterable node.
369 std::vector<int> FindAlternatePathForDebugging(int from, int to);
370
371 // Returns a string representing `cycles_graph_node_id`. If the node is
372 // unclusterable (either it is a phatom "frame" node or is not a compilation
373 // candidate) then set `*found_unclustered` to true.
374 string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered);
375
376 // We could not contract the edge from `from` to `to`. Return a string
377 // describing an alternate path from `from` to `to` (besides the direct edge
378 // from `from` to `to`) which would have created a cycle had we contracted the
379 // edge.
380 //
381 // Tries (if possible) to find a path that contains at least one unclusterable
382 // node as it is surprising to the user if we print "A->B could not be
383 // contracted because of the path [P,Q,R]" where P, Q and R are all clusters
384 // since in that case a natural question is why we could not form a {A, P, Q,
385 // R, B} cluster.
386 string DescribePotentialCycle(int from, int to);
387
388 // Merge the clusters `cluster_from` and `cluster_to`. After this step the
389 // larger combined cluster is represented by `cluster_from`, but can have
390 // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
391 // which way will require less operations.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)392 bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
393 int from = cluster_from->cycles_graph_node_id();
394 int to = cluster_to->cycles_graph_node_id();
395
396 auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
397 if (!optional_merged_node.has_value()) {
398 VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
399 << " -> " << cluster_to->DebugString(*graph_)
400 << " because contracting the edge would create a cycle via "
401 << DescribePotentialCycle(from, to) << ".";
402 return false;
403 }
404
405 // Merge the clusters.
406 cluster_from->Merge(cluster_to);
407 // Update `cycle_graph_`'s ID.
408 cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
409
410 // Merge the UnionFind<Cluster*>.
411 cluster_for_node_[from].Merge(&cluster_for_node_[to]);
412
413 return true;
414 }
415
EdgeContractionFailureMsg(Cluster * from,Cluster * to,absl::string_view reason)416 string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
417 absl::string_view reason) {
418 return absl::StrCat("Could not contract ", from->DebugString(*graph_),
419 " -> ", to->DebugString(*graph_), " because ", reason,
420 ".");
421 }
422
423 DebugOptions debug_options_;
424 Graph* graph_;
425 FunctionLibraryDefinition* flib_def_;
426 Env* env_;
427 OptimizerOptions::GlobalJitLevel global_jit_level_;
428 absl::flat_hash_map<const Cluster*, bool> should_compile_cluster_cache_;
429 jit::DeviceInfoCache device_info_cache_;
430
431 bool initialized_ = false;
432 bool edges_contracted_ = false;
433 bool clusters_created_ = false;
434
435 std::vector<std::unique_ptr<Cluster>> cluster_storage_;
436 std::vector<UnionFind<Cluster*>> cluster_for_node_;
437 GraphCycles cycles_graph_;
438 OrderedNodeSet compilation_candidates_;
439 std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
440 int64 iteration_count_ = 0;
441 absl::flat_hash_set<std::pair<int, int>> unsafe_resource_deps_;
442 };
443
FindAlternatePathForDebugging(int from,int to)444 std::vector<int> MarkForCompilationPassImpl::FindAlternatePathForDebugging(
445 int from, int to) {
446 std::vector<int> rpo = cycles_graph_.AllNodesInPostOrder();
447 absl::c_reverse(rpo);
448
449 // best_pred_for_node[n] contains a predecessor of `n` that has an
450 // unclusterable node in some path from `from` to itself.
451 // best_pred_for_node[n] is unpopulated for nodes that are not reachable from
452 // `from`. We build this table up inductively by traversing the cycles graph
453 // in RPO.
454 absl::flat_hash_map<int, int> best_pred_for_node;
455 best_pred_for_node[from] = -1;
456
457 int rpo_index = 0, current_rpo_node;
458 do {
459 current_rpo_node = rpo[rpo_index++];
460 absl::optional<int> some_pred, preferred_pred;
461 for (int pred : cycles_graph_.Predecessors(current_rpo_node)) {
462 if (!best_pred_for_node.contains(pred)) {
463 continue;
464 }
465
466 // Ignore the from->to edge since we're trying to find an alternate path.
467 if (current_rpo_node == to && pred == from) {
468 continue;
469 }
470
471 some_pred = pred;
472 if (GetClusterForCyclesGraphNode(pred) == nullptr) {
473 preferred_pred = pred;
474 }
475 }
476
477 if (some_pred || preferred_pred) {
478 best_pred_for_node[current_rpo_node] =
479 preferred_pred.has_value() ? *preferred_pred : *some_pred;
480 }
481 } while (current_rpo_node != to);
482
483 auto get_best_pred = [&](int n) {
484 auto it = best_pred_for_node.find(n);
485 CHECK(it != best_pred_for_node.end());
486 return it->second;
487 };
488
489 std::vector<int> path;
490 int current_path_node = get_best_pred(to);
491 while (current_path_node != from) {
492 path.push_back(current_path_node);
493 current_path_node = get_best_pred(current_path_node);
494 }
495
496 absl::c_reverse(path);
497 return path;
498 }
499
DebugStringForCyclesGraphNode(int cycles_graph_node_id,bool * found_unclustered)500 string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode(
501 int cycles_graph_node_id, bool* found_unclustered) {
502 Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id);
503 if (cluster) {
504 return cluster->DebugString(*graph_);
505 }
506
507 *found_unclustered = true;
508 if (cycles_graph_node_id >= graph_->num_node_ids()) {
509 return absl::StrCat("<oob #", cycles_graph_node_id, ">");
510 }
511
512 Node* node = graph_->FindNodeId(cycles_graph_node_id);
513 if (!node) {
514 return absl::StrCat("<bad #", cycles_graph_node_id, ">");
515 }
516
517 return node->name();
518 }
519
DescribePotentialCycle(int from,int to)520 string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) {
521 std::vector<string> path_str;
522 bool found_unclustered = false;
523 absl::c_transform(FindAlternatePathForDebugging(from, to),
524 std::back_inserter(path_str), [&](int node_id) {
525 return DebugStringForCyclesGraphNode(node_id,
526 &found_unclustered);
527 });
528 return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[",
529 absl::StrJoin(path_str, ","), "]");
530 }
531
Merge(Cluster * other)532 void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
533 // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does.
534
535 // Clearing out data structures in `other` is just a memory saving
536 // optimization and not needed for correctness.
537
538 cluster_size_ += other->cluster_size_;
539 effective_cluster_size_ += other->effective_cluster_size_;
540 has_functional_control_flow_ |= other->has_functional_control_flow_;
541
542 devices_.UnionWith(other->devices_);
543
544 DCHECK(!(resource_op_device_.has_value() &&
545 other->resource_op_device_.has_value()) ||
546 *resource_op_device_ == *other->resource_op_device_)
547 << "AreDevicesCompatible should have returned false otherwise!";
548
549 if (!resource_op_device_.has_value()) {
550 resource_op_device_ = other->resource_op_device_;
551 }
552
553 is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
554
555 if (!xla_scope_.has_value()) {
556 xla_scope_ = std::move(other->xla_scope_);
557 }
558
559 resource_var_operation_node_ids_.reserve(
560 resource_var_operation_node_ids_.size() +
561 other->resource_var_operation_node_ids_.size());
562 absl::c_copy(other->resource_var_operation_node_ids_,
563 std::back_inserter(resource_var_operation_node_ids_));
564 other->resource_var_operation_node_ids_.clear();
565 }
566
IgnoreResourceOpForSafetyAnalysis(jit::DeviceInfoCache * device_info_cache,const Node & n,bool * ignore)567 Status IgnoreResourceOpForSafetyAnalysis(
568 jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) {
569 // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
570 // ignore it during resource operation safety analysis. We need this hack
571 // because of two reasons:
572 //
573 // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
574 // 2. We don't support live-out values of type DT_RESOURCE and live-in values
575 // of type DT_RESOURCE that are not resource variables.
576 //
577 // Together these imply we cannot let resource variable safety analysis
578 // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
579 // clusters: both of them will have to be clustered because of (1) and we
580 // won't be able to keep the edge between the two as neither the input to the
581 // second XLA cluster nor the output from the first XLA cluster are supported
582 // because of (2).
583 //
584 // TODO(b/113100872): This can be fixed if the TensorFlow representation for
585 // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
586 // (2) would no longer hold.
587
588 if (n.assigned_device_name().empty()) {
589 *ignore = false;
590 return Status::OK();
591 }
592
593 TF_ASSIGN_OR_RETURN(
594 const XlaOpRegistry::DeviceRegistration* registration,
595 device_info_cache->GetCompilationDevice(n.assigned_device_name()));
596
597 if (!registration) {
598 *ignore = true;
599 } else {
600 *ignore = registration->cluster_resource_variable_ops_unsafely;
601 }
602 return Status::OK();
603 }
604
Initialize()605 StatusOr<bool> MarkForCompilationPassImpl::Initialize() {
606 TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
607 initialized_ = true;
608
609 TF_RETURN_IF_ERROR(FindCompilationCandidates());
610
611 if (compilation_candidates_.empty()) {
612 VLOG(2) << "No compilable candidates";
613 return false;
614 }
615
616 TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
617 CreateCycleDetectionGraph(graph_, &cycles_graph_));
618 if (!cycle_detection_graph_ok) {
619 // TODO(sanjoy): This should be logged via the XLA activity listener.
620 VLOG(2) << "Could not form cycle detection graph";
621 return false;
622 }
623
624 if (!debug_options_.ignore_deadness_checks) {
625 XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
626 TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
627 }
628
629 // Each compilation candidate belongs to a cluster. The cluster's
630 // representative names the node in the 'cycles' graph that represents the
631 // cluster.
632 TF_RETURN_IF_ERROR(BuildInitialClusterSet());
633 return true;
634 }
635
636 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)637 StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
638 bool changed = false;
639 for (int32_t node : cycles_graph_.AllNodesInPostOrder()) {
640 Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
641 if (!cluster_from) {
642 continue;
643 }
644
645 // Make a copy of the set of successors because we may modify the graph in
646 // TryToContractEdge.
647 std::vector<int32> successors_copy =
648 cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
649
650 for (int to : successors_copy) {
651 iteration_count_++;
652
653 Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
654 if (!cluster_to) {
655 continue;
656 }
657
658 TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
659 changed |= contracted_edge;
660 }
661 }
662
663 return changed;
664 }
665
GetOnlyNodeIn(const Cluster & cluster)666 Node* MarkForCompilationPassImpl::GetOnlyNodeIn(const Cluster& cluster) {
667 return cluster.cluster_size() == 1
668 ? graph_->FindNodeId(cluster.GetIdOfOnlyNode())
669 : nullptr;
670 }
671
IsSinkLike(const Cluster & cluster)672 bool MarkForCompilationPassImpl::IsSinkLike(const Cluster& cluster) {
673 if (Node* n = GetOnlyNodeIn(cluster)) {
674 return n->type_string() == "NoOp" && n->out_edges().size() == 1 &&
675 (*n->out_edges().begin())->dst()->IsSink();
676 }
677
678 return false;
679 }
680
IsScalarIntegerResourceOperation(const Cluster & cluster)681 bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
682 const Cluster& cluster) {
683 Node* n = GetOnlyNodeIn(cluster);
684 if (!n) {
685 return false;
686 }
687
688 if (n->type_string() != "AssignAddVariableOp" &&
689 n->type_string() != "AssignSubVariableOp") {
690 return false;
691 }
692
693 DataType dtype;
694 if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
695 return false;
696 }
697
698 Node* const_input = nullptr;
699 for (const Edge* e : n->in_edges()) {
700 if (!e->IsControlEdge() && e->src()->IsConstant()) {
701 const_input = e->src();
702 break;
703 }
704 }
705
706 if (!const_input) {
707 return false;
708 }
709
710 const TensorProto* proto = nullptr;
711 if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
712 return false;
713 }
714
715 return TensorShapeUtils::IsScalar(proto->tensor_shape());
716 }
717
RunEdgeContractionLoop()718 Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
719 TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
720 edges_contracted_ = true;
721
722 // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
723 // example, from the Grappler fusion pass).
724
725 // In general there are multiple maximal clusterings, but they are not all
726 // equally performant. Some clustering decision are likely to improve
727 // performance much more than others, and we cannot order contractions on this
728 // cost function, nor can we look at global information while deciding on
729 // individual edges to contract. Instead, we will make decisions on these
730 // important edges then make decisions on all other edges, causing the highest
731 // chance of all most important edges to be contracted.
732 //
733 // An example of where this might occur is with a digraph:
734 // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
735 // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
736 // should be clustered with A because it will prevent a potentially large
737 // tensor from A being computed and copied.
738 //
739 // To choose better maximal clusterings we make multiple iterations over the
740 // graph in post-order, where each such iteration is called a "phase".
741
742 // Phase 0: contract metadata operations with their producer.
743
744 VLOG(4) << "Running phase 0";
745 TF_RETURN_IF_ERROR(
746 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
747 // Shape consuming operations are desirable to cluster with their
748 // operands because they return a small set of scalar values after
749 // consuming a large amount of data. For example, given a graph X -> Y
750 // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
751 // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
752 // output of size will be a small tensor while Y is a potentially large
753 // tensor that must be computed and possible transposed/copied before
754 // the second cluster executes.
755 Node* n = GetOnlyNodeIn(*to);
756 bool is_shape_consumer_op = n && IsShapeConsumerOp(*n);
757 if (!is_shape_consumer_op) {
758 return false;
759 }
760
761 return TryToContractEdge(from, to);
762 }).status());
763
764 // Phase 1: apply a heuristic to ensure that we don't mess up clustering due
765 // to "group_deps". After this phase most edges should have been contracted.
766
767 VLOG(4) << "Running phase 1";
768 TF_RETURN_IF_ERROR(
769 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
770 // We split out this phase to get good clustering in the presence of a
771 // specific pattern seen in some graphs:
772 //
773 // digraph {
774 // ApplyWeightUpdates_0 -> "iteration++"
775 // ApplyWeightUpdates_1 -> "iteration++"
776 // ApplyWeightUpdates_2 -> "iteration++"
777 // ApplyWeightUpdates_0 -> Computation_A
778 // ApplyWeightUpdates_1 -> Computation_B
779 // ApplyWeightUpdates_2 -> Computation_C
780 // Computation_A -> NoOp
781 // Computation_B -> NoOp
782 // Computation_C -> NoOp
783 // "iteration++" -> NoOp
784 // }
785 //
786 // In the graph above we can't cluster iteration++ with any of the
787 // gradient update operations since that will break the TF resource
788 // variable memory model. Given that constraint the ideal clustering
789 // would be to put all the gradient updates and all of the Computation_*
790 // nodes in one cluster, and leave iteration++ and NoOp unclustered.
791 //
792 // A naive post-order traversal would not create this good clustering,
793 // however. Instead it will first create a cluster that puts
794 // Computation_* nodes, the NoOp and iteration++ node in a single
795 // cluster, after which it will fail to put any of the
796 // ApplyWeightUpdates_* nodes into this cluster. To avoid this fate we
797 // instead run a pass that avoids contracting edges _into_ NoOps like
798 // the above, and avoid clustering edges _from_ "iteration++" like the
799 // above. Then we run a second pass that contracts the edges we could
800 // not contract the first time around.
801
802 if (IsSinkLike(*to)) {
803 return false;
804 }
805
806 if (IsScalarIntegerResourceOperation(*from)) {
807 return false;
808 }
809
810 return TryToContractEdge(from, to);
811 }).status());
812
813 // Phase 2: contract any remaining edges. After this phase we should have a
814 // maximal clustering:
815 //
816 // A. We visit a cluster only after maximally clustering all its children.
817 // B. By the time we're done with a node all of its children that could have
818 // been absorbed into the node have been absorbed.
819 // C. We have an invariant that making a cluster larger does not make edges
820 // leaving it more contractable. That is, if we have
821 // digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
822 // to contract Y->Z if Y->Z was not contractible originally.
823 VLOG(4) << "Running phase 2";
824 TF_RETURN_IF_ERROR(ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
825 return TryToContractEdge(from, to);
826 }).status());
827
828 // Check that the conclusion made above (that iterating over the graph once in
829 // post order gives a maximal clustering) holds. Once the linear time
830 // post-order scheme has been battle tested we can move this to happen only in
831 // debug builds.
832 VLOG(2) << "Checking idempotence";
833 TF_ASSIGN_OR_RETURN(bool changed,
834 ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
835 return TryToContractEdge(from, to);
836 }));
837 TF_RET_CHECK(!changed);
838
839 return Status::OK();
840 }
841
842 std::atomic<int64> cluster_sequence_num;
843
GetNextClusterSequenceNumber()844 int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
845
CreateClusters()846 Status MarkForCompilationPassImpl::CreateClusters() {
847 TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
848 clusters_created_ = true;
849
850 // Names for each cluster.
851 std::unordered_map<int, string> cluster_names;
852
853 if (debug_options_.dump_graphs) {
854 DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
855 }
856
857 // Mark clusters for compilation that:
858 // * are placed on a device that requires compilation (an XlaDevice),
859 // * are explicitly marked for compilation (_XlaCompile=true), or
860 // * have more than debug_options_.xla_min_cluster_size elements (applicable
861 // only if compilation is enabled, otherwise there will be no such
862 // candidates).
863 for (Node* n : compilation_candidates_) {
864 Cluster* cluster = GetClusterForNode(n);
865 TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
866 ShouldCompileCluster(*cluster));
867 if (!should_compile_cluster) {
868 continue;
869 }
870
871 // We assume that functional If and While nodes have at least
872 // min_cluster_size non-trivial nodes in them. It would be more principled
873 // to (recursively) verify this fact, but that's probably not worth the
874 // trouble.
875
876 if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size ||
877 cluster->has_functional_control_flow() ||
878 cluster->is_xla_compile_attr_true()) {
879 string& name = cluster_names[cluster->cycles_graph_node_id()];
880
881 if (name.empty()) {
882 name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
883 }
884
885 n->AddAttr(kXlaClusterAttr, name);
886 n->AddAttr(kXlaAlreadyClustered, true);
887 VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
888 }
889 }
890
891 return Status::OK();
892 }
893
DumpDebugInfo()894 Status MarkForCompilationPassImpl::DumpDebugInfo() {
895 TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
896
897 if (debug_options_.dump_graphs) {
898 DumpPostClusteringGraphs();
899 }
900
901 VLogClusteringSummary();
902
903 return Status::OK();
904 }
905
906 StatusOr<bool>
ClusteringWillIntroduceInterDeviceDependency(const Cluster & cluster_from,const Cluster & cluster_to)907 MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
908 const Cluster& cluster_from, const Cluster& cluster_to) {
909 // If any of the consumer's producers are on a different device, do not
910 // cluster these nodes. This prevents other work on this device from being
911 // delayed by work on other devices. We consider predecessors of the entire
912 // cluster rather than just the inputs to the node to prevent the cluster
913 // still being combined in cases where the 'to' cluster has multiple
914 // dependencies on the 'from' cluster and another dependency leads to a
915 // merging of the clusters.
916 //
917 // TODO(b/117085735): We probably want to handle the reciprocal of this case
918 // where a cluster is producing data for multiple devices.
919 for (const auto& in_id :
920 cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) {
921 const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id);
922 if (cluster_in) {
923 TF_ASSIGN_OR_RETURN(bool devices_compatible,
924 AreDevicesCompatible(cluster_to, *cluster_in));
925 if (!devices_compatible) {
926 return true;
927 }
928 TF_ASSIGN_OR_RETURN(devices_compatible,
929 AreDevicesCompatible(cluster_from, *cluster_in));
930 if (!devices_compatible) {
931 return true;
932 }
933 }
934 }
935
936 return false;
937 }
938
GetXlaScope(Node * node)939 absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
940 // Look for either _XlaScope or _XlaInternalScope on both nodes to guide
941 // clustering. If both nodes have a scope and the scopes do not match, do
942 // not cluster along this edge. If even one of the nodes lacks a scope
943 // attribute, then it is treated as a "bridge" and a cluster may be created
944 // along it.
945 //
946 // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is
947 // provided by users through jit_scope APIs, while _XlaInternalScope is
948 // automatically generated by the ClusterScopingPass when auto_jit is on. As
949 // such, we respect _XlaScope only when auto_jit is off, while respecting
950 // _XlaInternalScope only when auto_jit is on.
951 //
952 // We may want to restrict the _XlaScope behavior to require all nodes marked
953 // with _XlaCompile=true to also have a _XlaScope property set (and raise an
954 // error otherwise); but for now we don't do this.
955
956 if (global_jit_level_ != OptimizerOptions::OFF) {
957 // If global_jit_level_ is ON, respect only _XlaInternalScope.
958 const string& scope =
959 GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr);
960 if (!scope.empty()) {
961 return scope;
962 }
963 } else {
964 // If global_jit_level_ is OFF, respect only _XlaScope.
965 const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
966 if (!scope.empty()) {
967 return scope;
968 }
969 }
970
971 return absl::nullopt;
972 }
973
974 // Returns true iff the attribute `attr_name` is attached to either the node or
975 // to it's callee.
GetNodeOrFuncAttr(Node * node,FunctionLibraryDefinition * flib_def,const char * attr_name)976 static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
977 const char* attr_name) {
978 bool out = false;
979 bool attr_value;
980 if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
981 out |= attr_value;
982 }
983
984 if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
985 out |= attr_value;
986 }
987 return out;
988 }
989
BuildInitialClusterSet()990 Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
991 auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
992 return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
993 };
994
995 std::vector<std::pair<int, int>> unsafe_resource_deps_vect;
996 TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
997 *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect));
998 absl::c_copy(
999 unsafe_resource_deps_vect,
1000 std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin()));
1001
1002 cluster_for_node_.resize(graph_->num_node_ids());
1003 for (Node* node : graph_->nodes()) {
1004 if (!IsCompilationCandidate(node)) {
1005 cluster_for_node_[node->id()].Get() = nullptr;
1006 continue;
1007 }
1008
1009 // We want clusters to be big enough that the benefit from XLA's
1010 // optimizations offsets XLA related overhead (for instance we add some
1011 // Switch/Merge nodes into the graph to implement lazy compilation). To
1012 // this end, we don't count Identity and Constant nodes because they do not
1013 // enable interesting optimizations by themselves.
1014 int effective_cluster_size =
1015 (node->IsIdentity() || node->IsConstant()) ? 0 : 1;
1016
1017 bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
1018
1019 absl::optional<DeadnessPredicate> deadness_predicate;
1020 if (deadness_analysis_) {
1021 TF_ASSIGN_OR_RETURN(
1022 deadness_predicate,
1023 deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
1024 }
1025
1026 const string& device_name_str = !node->assigned_device_name().empty()
1027 ? node->assigned_device_name()
1028 : node->requested_device();
1029 TF_ASSIGN_OR_RETURN(DeviceId device,
1030 device_info_cache_.GetIdFor(device_name_str));
1031
1032 bool is_resource_op = HasResourceInputOrOutput(*node);
1033 absl::optional<DeviceId> resource_op_device;
1034 if (is_resource_op) {
1035 resource_op_device = device;
1036 }
1037
1038 absl::optional<int> resource_var_operation_node_id;
1039 if (is_resource_op || MayCallFunction(*node, flib_def_)) {
1040 resource_var_operation_node_id = node->id();
1041 }
1042
1043 bool is_xla_compile_attr_true =
1044 GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
1045 (global_jit_level_ != OptimizerOptions::OFF &&
1046 GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr));
1047
1048 DeviceSet devices;
1049 devices.Insert(device);
1050
1051 Cluster* new_cluster = MakeNewCluster(
1052 /*cycles_graph_node_id=*/node->id(),
1053 /*effective_cluster_size=*/effective_cluster_size,
1054 /*has_functional_control_flow=*/has_functional_control_flow, devices,
1055 resource_op_device, resource_var_operation_node_id, deadness_predicate,
1056 /*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
1057 GetXlaScope(node));
1058
1059 cluster_for_node_[node->id()].Get() = new_cluster;
1060 }
1061
1062 return Status::OK();
1063 }
1064
IsIdentityDrivingConstsInLoop(Node * node)1065 StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
1066 if (!node->IsIdentity()) {
1067 return false;
1068 }
1069
1070 // Check if the Identity is driven by a Switch on its true path.
1071 auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
1072 return e->src()->IsSwitch() && e->src_output() == 1;
1073 });
1074 if (it == node->in_edges().end()) {
1075 return false;
1076 }
1077 const Node* switch_node = (*it)->src();
1078
1079 // Check if the Switch is driven by LoopCond.
1080 const Node* maybe_loop_cond;
1081 TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
1082 if (!maybe_loop_cond->IsLoopCond()) {
1083 return false;
1084 }
1085
1086 // Check if the Identity is driving any const nodes through a control edge.
1087 bool driving_any_consts =
1088 absl::c_any_of(node->out_edges(), [](const Edge* e) {
1089 return e->dst()->IsConstant() && e->IsControlEdge();
1090 });
1091 if (!driving_any_consts) {
1092 return false;
1093 }
1094
1095 return true;
1096 }
1097
GetOrCreateAllowlist()1098 absl::flat_hash_set<string> GetOrCreateAllowlist() {
1099 absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
1100 tensorflow::GetAllowlistTable();
1101 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1102 absl::flat_hash_set<string> allowlist;
1103
1104 for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
1105 if (s == "FUSIBLE") {
1106 for (auto pair : *allowlist_table) {
1107 allowlist.insert(pair.second.begin(), pair.second.end());
1108 }
1109 } else if (allowlist_table->contains(s)) {
1110 auto v = allowlist_table->at(s);
1111 allowlist.insert(v.begin(), v.end());
1112 } else if (!s.empty()) {
1113 // Should be a user provided TF operation.
1114 allowlist.insert(string(s));
1115 }
1116 }
1117
1118 if (VLOG_IS_ON(2) && !allowlist.empty()) {
1119 std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
1120 absl::c_sort(vallowlist);
1121 VLOG(2) << "XLA clustering will only consider the following TF operations: "
1122 << absl::StrJoin(vallowlist, " ");
1123 }
1124 return allowlist;
1125 }
1126
FindCompilationCandidates()1127 Status MarkForCompilationPassImpl::FindCompilationCandidates() {
1128 OptimizerOptions opts;
1129 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1130 new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
1131 TF_GRAPH_DEF_VERSION, flib_def_, opts));
1132 FunctionLibraryRuntime* lib_runtime =
1133 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
1134 std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);
1135 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
1136 *graph_, /*compile_time_const_arg_indices=*/nullptr,
1137 &compile_time_const_nodes, lib_runtime));
1138 // Iterate over nodes in sorted order so that compiler fuel is deterministic.
1139 // We can't simply pass op_nodes().begin() and op_nodes().end() to the
1140 // std::vector constructor because they're not proper iterators, with
1141 // iterator_traits defined and so on.
1142 std::vector<Node*> sorted_nodes;
1143 for (Node* node : graph_->op_nodes()) {
1144 sorted_nodes.push_back(node);
1145 }
1146 std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
1147
1148 if (*debug_options_.fuel >= std::numeric_limits<int64>::max() / 2) {
1149 // The assumption is that if fuel started out as INT64_MAX, it will forever
1150 // stay greater than INT64_MAX / 2.
1151 VLOG(2) << "Starting fuel: infinity";
1152 } else {
1153 VLOG(2) << "Starting fuel: " << *debug_options_.fuel;
1154 }
1155
1156 VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
1157
1158 auto allowlist = GetOrCreateAllowlist();
1159
1160 std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1161 absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1162 // Check that user's provided TF operation really exists.
1163 for (const auto& s : allowlist) {
1164 if (!all_ops.contains(string(s))) {
1165 return errors::InvalidArgument(
1166 "The operation '", s,
1167 "' passed to --tf_xla_ops_to_cluster is not supported by XLA.");
1168 }
1169 }
1170
1171 for (Node* node : sorted_nodes) {
1172 if (*debug_options_.fuel <= 0) {
1173 VLOG(1)
1174 << "Hit fuel limit; not marking any remaining ops as clusterable.";
1175 break;
1176 }
1177
1178 TF_ASSIGN_OR_RETURN(
1179 const DeviceType& device_type,
1180 device_info_cache_.GetDeviceTypeFor(node->assigned_device_name()));
1181 VLOG(4) << "Device type for " << node->name() << ": "
1182 << device_type.type_string();
1183
1184 if (CompilationDisallowedByXlaCompileAttr(node)) {
1185 VLOG(2) << "Not clustering " << node->name()
1186 << ": disallowed by _XlaCompile attribute";
1187 continue;
1188 }
1189
1190 const XlaOpRegistry::DeviceRegistration* registration;
1191 if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
1192 ®istration)) {
1193 VLOG(2) << "Rejecting " << node->name()
1194 << ": could not find JIT device for " << device_type.type();
1195 continue;
1196 }
1197
1198 RecursiveCompilabilityChecker::OperationFilter filter =
1199 CreateOperationFilter(*registration);
1200 filter.require_always_compilable = true;
1201 filter.allow_string_consts = false;
1202 filter.allow_collective_reduce_v2 = false;
1203
1204 RecursiveCompilabilityChecker checker(
1205 filter, DeviceType{registration->compilation_device_name});
1206
1207 if (!checker.IsCompilableNode(*node, lib_runtime)) {
1208 continue;
1209 }
1210
1211 if (node->type_string() == "Const") {
1212 // Skip Const op with type DT_STRING, since XLA autoclustering doesn't
1213 // support it.
1214 const AttrValue* attr = node->attrs().Find("dtype");
1215 if (attr != nullptr && attr->type() == DT_STRING) {
1216 continue;
1217 }
1218 }
1219
1220 if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
1221 VLOG(1) << "Rejecting TF operation " << node->def().op()
1222 << " as it is not listed in --tf_xla_ops_to_cluster.";
1223 continue;
1224 }
1225
1226 if (compile_time_const_nodes[node->id()]) {
1227 const OpDef* op_def;
1228 TF_RETURN_IF_ERROR(
1229 graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def));
1230 if (op_def->is_stateful()) {
1231 // It is easiest to demonstrate the problem we're trying to solve with
1232 // an example. Say we have this graph:
1233 //
1234 // shape = RandomUniformInt();
1235 // reshape = Reshape(input, shape)
1236 //
1237 // Both RandomUniformInt and Reshape are compilable by XLA so, absent
1238 // any other reason, we will try to put both shape and reshape in the
1239 // same cluster. However, since XLA only supports statically shaped
1240 // values, it will expect to be able to constant fold `shape` to get a
1241 // static shape for `reshape`. This is a problem because side-effecting
1242 // ops like RandomUniformInt() cannot be constant folded. We fix this
1243 // by putting `shape` and `reshape` in different clusters, which results
1244 // in us recompiling `reshape`'s cluster for every new value of `shape`,
1245 // making `reshape` statically sized within each compilation. We
1246 // simplify the solution even further by disallowing operations like
1247 // `shape` from being part of *any* non-trivial cluster. They're either
1248 // not compiled by XLA altogether or, if assigned to an XLA_* device
1249 // with "must compile" semantics, compiled into a trivial single-op
1250 // cluster. This approach leaves some room for improvement, and we can
1251 // consider implementing a more aggressive data-flow-analysis based
1252 // solution in the future if needed.
1253 //
1254 // One ugly problem we have to contend with: certain sets of ops *have*
1255 // to be in the same cluster because values flowing between them have
1256 // types that can't be live-in or live-out of a cluster. These ops are:
1257 //
1258 // - TensorArray ops operating on the same TensorArray instance.
1259 // - Stack ops operating on the same Stack instance.
1260 //
1261 // To work around this we avoid isolating these specific ops. Because
1262 // of this concession it is unsound to auto-cluster them because then
1263 // we'd create clusters we could not compile (because we can't constant
1264 // fold, say, a TensorArrayRead or a StackPopV2). But we don't
1265 // auto-cluster these operations today so we're good for now.
1266 const XlaResourceOpInfo* op_info =
1267 GetResourceOpInfoForOp(node->type_string());
1268 bool is_tensor_array_or_stack_op =
1269 op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
1270 if (!is_tensor_array_or_stack_op) {
1271 VLOG(2) << "Isolating " << node->name()
1272 << ": must-be-constant stateful op";
1273 continue;
1274 }
1275 }
1276 }
1277
1278 // This is a heuristic to avoid creating dependency between while loop
1279 // condition and body computations. Dependency between them can be created
1280 // if a special Identity node in the following pattern is clustered in.
1281 // That is, an Identity node in the loop cond computation is used to drive
1282 // const nodes consumed by the loop body. If this Identity node goes into
1283 // the same cluster with nodes from the loop body, extra dependency is
1284 // created between the loop cond and body computations and it hinders the
1285 // progression of the loop cond computation at runtime with significant
1286 // overhead. Specifically, we look for the below pattern and do not cluster
1287 // in this Identity to avoid the described issue. Since Identity has low
1288 // execution cost in native TF, the fact that this heuristic gives up these
1289 // special Identity nodes as candidates should not harm any performance. If
1290 // other considerations emerge in the future, we can revisit the heuristic
1291 // and only disallow these Identities to go into the cluster with nodes from
1292 // the loop body but still consider them candidates.
1293 //
1294 // LoopCond ->
1295 // Merge -> Switch -> Identity -> i++ -> ... -> NextIteration
1296 // ..> Const -> LoopBody
1297 // (control edge)
1298 TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
1299 IsIdentityDrivingConstsInLoop(node));
1300 if (is_identity_driving_consts_in_loop) {
1301 VLOG(2) << "Rejecting " << node->name()
1302 << ": including it can create dependencies between while loop "
1303 "condition and body computations with runtime overhead.";
1304 continue;
1305 }
1306
1307 compilation_candidates_.insert(node);
1308 --(*debug_options_.fuel);
1309 }
1310
1311 VLOG(2) << "compilation_candidates_.size() = "
1312 << compilation_candidates_.size();
1313 return Status::OK();
1314 }
1315
CompilationDisallowedByXlaCompileAttr(Node * node)1316 bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
1317 Node* node) {
1318 if (debug_options_.ignore_xla_compile_attr) {
1319 return false;
1320 }
1321
1322 // If there is a _XlaCompile annotation, use its value.
1323 bool compile = false;
1324 Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
1325 if (status.ok()) {
1326 if (!compile) {
1327 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1328 << kXlaCompileAttr << ") is false.";
1329 }
1330 return !compile;
1331 }
1332
1333 status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile);
1334 if (status.ok()) {
1335 if (!compile) {
1336 VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1337 << kXlaCompileAttr << ") on callee is false.";
1338 }
1339 return !compile;
1340 }
1341
1342 return false;
1343 }
1344
LogNotContractableAndReturnFalse(Cluster * from,Cluster * to,absl::string_view reason)1345 bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
1346 Cluster* from, Cluster* to, absl::string_view reason) {
1347 VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
1348 return false;
1349 }
1350
TryToContractEdge(Cluster * from,Cluster * to)1351 StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
1352 Cluster* to) {
1353 DCHECK(from->deadness_predicate().has_value() ==
1354 to->deadness_predicate().has_value());
1355 if (from->deadness_predicate() != to->deadness_predicate()) {
1356 VLOG(3) << EdgeContractionFailureMsg(
1357 from, to,
1358 absl::StrCat(
1359 "the two nodes have mismatching deadness: ",
1360 deadness_analysis_->DebugString(*from->deadness_predicate()),
1361 " and ",
1362 deadness_analysis_->DebugString(*to->deadness_predicate())));
1363 return false;
1364 }
1365
1366 TF_ASSIGN_OR_RETURN(bool devices_compatible,
1367 AreDevicesCompatible(*from, *to));
1368 if (!devices_compatible) {
1369 return LogNotContractableAndReturnFalse(
1370 from, to, "the two nodes have incompatible devices");
1371 }
1372
1373 if (from->xla_scope().has_value() && to->xla_scope().has_value() &&
1374 *from->xla_scope() != *to->xla_scope()) {
1375 return LogNotContractableAndReturnFalse(
1376 from, to, "the two nodes have mismatching XLA scopes");
1377 }
1378
1379 // Don't exceed the maximum cluster size.
1380 if (from->cluster_size() + to->cluster_size() >
1381 debug_options_.max_cluster_size) {
1382 return LogNotContractableAndReturnFalse(
1383 from, to, "the new cluster will be larger than the max cluster size");
1384 }
1385
1386 TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
1387 ClusteringWillIntroduceInterDeviceDependency(*from, *to));
1388
1389 if (will_introduce_cross_device_dependency) {
1390 return LogNotContractableAndReturnFalse(
1391 from, to, "the new cluster will introduce a cross device dependency");
1392 }
1393
1394 // Check if contracting this edge will break the resource variable concurrency
1395 // semantics. In theory this is quadratic in the number of nodes, but seems
1396 // to not be a problem in practice so far.
1397 if (!debug_options_.ignore_resource_variable_checks) {
1398 for (int resource_var_from : from->resource_var_operation_node_ids()) {
1399 for (int resource_var_to : to->resource_var_operation_node_ids()) {
1400 // If unsafe_resource_deps_ contains {A, B} then
1401 //
1402 // a. A and B are resource operations.
1403 // b. A and B cannot be placed in the same cluster.
1404 // c. There is no path from B to A in the cycles graph (but there may
1405 // be a path from A to B).
1406 //
1407 // So check the legality of the edge contraction by checking if any of
1408 // the n^2 pairs of resource variable operations are forbidden.
1409 if (unsafe_resource_deps_.contains(
1410 {resource_var_from, resource_var_to})) {
1411 return LogNotContractableAndReturnFalse(
1412 from, to,
1413 "the new cluster would break resource variable semantics");
1414 }
1415 }
1416 }
1417 }
1418
1419 return MergeClusters(from, to);
1420 }
1421
Run()1422 Status MarkForCompilationPassImpl::Run() {
1423 // Make sure that kernels have been registered on the JIT device.
1424 XlaOpRegistry::RegisterCompilationKernels();
1425
1426 // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does
1427 // some one-time work.
1428 XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
1429
1430 TF_ASSIGN_OR_RETURN(bool initialized, Initialize());
1431 if (!initialized) {
1432 // Initialization exited early which means this instance of
1433 // MarkForCompilationPassImpl is not set up to run the subsequent phases.
1434 return Status::OK();
1435 }
1436
1437 TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
1438 TF_RETURN_IF_ERROR(CreateClusters());
1439 TF_RETURN_IF_ERROR(DumpDebugInfo());
1440
1441 return Status::OK();
1442 }
1443
DumpPostClusteringGraphs()1444 void MarkForCompilationPassImpl::DumpPostClusteringGraphs() {
1445 DumpGraphToFile("mark_for_compilation", *graph_, flib_def_);
1446
1447 // We also dump out an annotated version of the TF graph where the nodes
1448 // names are prefixed with the cluster names. This can help visualizing the
1449 // clustering decisions on TensorBoard.
1450 Graph new_graph(graph_->op_registry());
1451 CopyGraph(*graph_, &new_graph);
1452
1453 for (Node* n : new_graph.nodes()) {
1454 if (absl::optional<absl::string_view> cluster_name =
1455 GetXlaClusterForNode(*n)) {
1456 n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1457 } else if (n->type_string() == "VarHandleOp") {
1458 n->set_name(absl::StrCat("varhandle/", n->name()));
1459 } else {
1460 // There is room for improvement here. In particular, it may help to
1461 // split these unclustered nodes into classes where every node in a
1462 // specific class has edges to and from the same set of clusters.
1463 n->set_name(absl::StrCat("unclustered/", n->name()));
1464 }
1465 }
1466
1467 DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_);
1468 }
1469
RatioToString(int numerator,int denominator)1470 string RatioToString(int numerator, int denominator) {
1471 return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
1472 (100.0 * numerator) / denominator);
1473 }
1474
VLogClusteringSummary()1475 void MarkForCompilationPassImpl::VLogClusteringSummary() {
1476 if (!VLOG_IS_ON(2)) {
1477 return;
1478 }
1479
1480 XlaAutoClusteringSummary auto_clustering_info =
1481 GetXlaAutoClusteringSummary(*graph_);
1482
1483 VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
1484 VLOG(2) << " Built " << auto_clustering_info.clusters_size()
1485 << " clusters, size "
1486 << RatioToString(auto_clustering_info.clustered_node_count(),
1487 graph_->num_nodes());
1488
1489 for (const XlaAutoClusteringSummary::Cluster& cluster :
1490 auto_clustering_info.clusters()) {
1491 absl::string_view cluster_name = cluster.name();
1492 int size = cluster.size();
1493 VLOG(2) << " " << cluster_name << " "
1494 << RatioToString(size, graph_->num_nodes());
1495 for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1496 cluster.op_histogram()) {
1497 VLOG(3) << " " << op_count.op() << ": " << op_count.count()
1498 << " instances";
1499 }
1500 }
1501
1502 if (!auto_clustering_info.unclustered_op_histogram().empty()) {
1503 VLOG(2) << " Unclustered nodes: "
1504 << RatioToString(auto_clustering_info.unclustered_node_count(),
1505 graph_->num_nodes());
1506 for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1507 auto_clustering_info.unclustered_op_histogram()) {
1508 VLOG(3) << " " << op_count.op() << ": " << op_count.count()
1509 << " instances";
1510 }
1511 }
1512
1513 struct EdgeInfo {
1514 absl::string_view node_name;
1515 absl::optional<absl::string_view> cluster_name;
1516
1517 absl::string_view GetClusterName() const {
1518 return cluster_name ? *cluster_name : "[none]";
1519 }
1520
1521 std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
1522 const {
1523 return {node_name, cluster_name};
1524 }
1525
1526 bool operator<(const EdgeInfo& other) const {
1527 return AsPair() < other.AsPair();
1528 }
1529 };
1530
1531 using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
1532
1533 EdgeInfoMap incoming_edge_infos;
1534 EdgeInfoMap outgoing_edge_infos;
1535
1536 std::set<absl::string_view> cluster_names_to_print;
1537
1538 for (const Edge* e : graph_->edges()) {
1539 const Node* from = e->src();
1540 absl::optional<absl::string_view> from_cluster_name =
1541 GetXlaClusterForNode(*from);
1542
1543 const Node* to = e->dst();
1544 absl::optional<absl::string_view> to_cluster_name =
1545 GetXlaClusterForNode(*to);
1546
1547 if (to_cluster_name == from_cluster_name) {
1548 continue;
1549 }
1550
1551 if (to_cluster_name) {
1552 incoming_edge_infos[*to_cluster_name]
1553 [EdgeInfo{from->name(), from_cluster_name}]++;
1554 cluster_names_to_print.insert(*to_cluster_name);
1555 }
1556
1557 if (from_cluster_name) {
1558 outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
1559 cluster_names_to_print.insert(*from_cluster_name);
1560 }
1561 }
1562
1563 VLOG(4) << "*** Inter-Cluster edges:";
1564 if (cluster_names_to_print.empty()) {
1565 VLOG(4) << " [none]";
1566 }
1567
1568 auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
1569 const EdgeInfoMap& edge_info_map,
1570 absl::string_view desc) {
1571 auto it = edge_info_map.find(cluster_name);
1572 if (it != edge_info_map.end()) {
1573 VLOG(4) << " " << it->second.size() << " " << desc << " edges";
1574 for (const auto& edge_info_count_pair : it->second) {
1575 VLOG(4) << " " << edge_info_count_pair.first.GetClusterName() << " "
1576 << edge_info_count_pair.first.node_name << " # "
1577 << edge_info_count_pair.second;
1578 }
1579 } else {
1580 VLOG(4) << " No " << desc << " edges.";
1581 }
1582 };
1583
1584 for (absl::string_view cluster_name : cluster_names_to_print) {
1585 VLOG(4) << " ** Cluster " << cluster_name;
1586 print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
1587 "incoming");
1588 print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
1589 "outgoing");
1590 }
1591 }
1592
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b)1593 StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
1594 const Cluster& cluster_a, const Cluster& cluster_b) {
1595 DeviceSet devices = cluster_a.devices();
1596 devices.UnionWith(cluster_b.devices());
1597
1598 TF_ASSIGN_OR_RETURN(
1599 absl::optional<jit::DeviceId> maybe_chosen_device,
1600 MaybePickDeviceForXla(device_info_cache_, devices,
1601 /*allow_mixing_unknown_and_cpu=*/false));
1602 if (!maybe_chosen_device.has_value()) {
1603 return false;
1604 }
1605
1606 jit::DeviceId chosen_device = *maybe_chosen_device;
1607
1608 // If we are able to pick a device `chosen_device` for the larger cluster, the
1609 // resource operations in `cluster_a` and `cluster_b` must be placed on the
1610 // same device as `chosen_device`. This is because the _XlaCompile and
1611 // _XlaRun kernels are going to run on and therefore try to access the
1612 // resource variables from `chosen_device`, which will be an error if the
1613 // resource variables are placed on some other device.
1614 auto resource_op_device_ok =
1615 [&](absl::optional<DeviceId> resource_op_device) {
1616 return !resource_op_device.has_value() ||
1617 *resource_op_device == chosen_device;
1618 };
1619
1620 return resource_op_device_ok(cluster_a.resource_op_device()) &&
1621 resource_op_device_ok(cluster_b.resource_op_device());
1622 }
1623
1624 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster)1625 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
1626 const Cluster& cluster) {
1627 TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
1628 PickDeviceForXla(device_info_cache_, cluster.devices(),
1629 /*allow_mixing_unknown_and_cpu=*/false));
1630
1631 const DeviceType& device_type =
1632 device_info_cache_.GetDeviceTypeFor(chosen_device);
1633 const XlaOpRegistry::DeviceRegistration* registration =
1634 device_info_cache_.GetCompilationDevice(chosen_device);
1635 TF_RET_CHECK(registration)
1636 << "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
1637 << "; device type = " << device_type.type() << "; devices ("
1638 << device_info_cache_.DebugString(cluster.devices());
1639
1640 bool should_compile =
1641 cluster.is_xla_compile_attr_true() ||
1642 registration->autoclustering_policy ==
1643 XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1644 (registration->autoclustering_policy ==
1645 XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1646 global_jit_level_ != OptimizerOptions::OFF);
1647
1648 if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
1649 device_type.type_string() == DEVICE_CPU) {
1650 static absl::once_flag once;
1651 absl::call_once(once, [] {
1652 LOG(WARNING)
1653 << "(One-time warning): Not using XLA:CPU for cluster because envvar "
1654 "TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
1655 "XLA:CPU, either set that envvar, or use experimental_jit_scope "
1656 "to enable XLA:CPU. To confirm that XLA is active, pass "
1657 "--vmodule=xla_compilation_cache=1 (as a proper command-line "
1658 "flag, not via TF_XLA_FLAGS) or set the envvar "
1659 "XLA_FLAGS=--xla_hlo_profile.";
1660 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1661 if (flags->tf_xla_cpu_global_jit) {
1662 LOG(WARNING)
1663 << "(Although the tf_xla_cpu_global_jit flag is currently enabled, "
1664 "perhaps it wasn't enabled at process startup?)";
1665 }
1666 });
1667 }
1668
1669 VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
1670 << " cluster with device "
1671 << device_info_cache_.GetNameFor(chosen_device);
1672
1673 return should_compile;
1674 }
1675
ShouldCompileCluster(const Cluster & cluster)1676 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileCluster(
1677 const Cluster& cluster) {
1678 auto it = should_compile_cluster_cache_.find(&cluster);
1679 if (it != should_compile_cluster_cache_.end()) {
1680 return it->second;
1681 }
1682
1683 TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster));
1684 should_compile_cluster_cache_.insert({&cluster, should_compile});
1685 return should_compile;
1686 }
1687
MarkForCompilation(const GraphOptimizationPassOptions & options,const MarkForCompilationPassImpl::DebugOptions & debug_options)1688 Status MarkForCompilation(
1689 const GraphOptimizationPassOptions& options,
1690 const MarkForCompilationPassImpl::DebugOptions& debug_options) {
1691 Graph* graph = options.graph->get();
1692 FunctionLibraryDefinition* flib_def = options.flib_def;
1693
1694 // Deadness analysis expects a graph with source and sink edges properly
1695 // connected but sometimes the incoming graph does not follow this invariant.
1696 // So fix up the source and sink edges before calling into deadness analysis.
1697 FixupSourceAndSinkEdges(graph);
1698
1699 for (Node* n : graph->nodes()) {
1700 // See explanation on `kXlaAlreadyClustered`.
1701 if (n->attrs().Find(kXlaAlreadyClustered)) {
1702 return Status::OK();
1703 }
1704 // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops
1705 // in the graph, which indicates the graph is produced by TPU TF-XLA bridge
1706 // and doesn't require auto clustering.
1707 if (n->type_string() == "TPUExecute" ||
1708 n->type_string() == "TPUExecuteAndUpdateVariables") {
1709 return Status::OK();
1710 }
1711 }
1712
1713 return MarkForCompilationPassImpl{debug_options, graph, flib_def,
1714 options.session_options != nullptr
1715 ? options.session_options->env
1716 : Env::Default(),
1717 GetGlobalJitLevelForGraph(options)}
1718 .Run();
1719 }
1720
GetPointerToFuel(int64_t initial_value)1721 std::atomic<int64>* GetPointerToFuel(int64_t initial_value) {
1722 static std::atomic<int64>* fuel = [&]() {
1723 std::atomic<int64>* fuel = new std::atomic<int64>;
1724 *fuel = initial_value;
1725 return fuel;
1726 }();
1727
1728 return fuel;
1729 }
1730 } // anonymous namespace
1731
Run(const GraphOptimizationPassOptions & options)1732 Status MarkForCompilationPass::Run(
1733 const GraphOptimizationPassOptions& options) {
1734 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1735
1736 MarkForCompilationPassImpl::DebugOptions debug_options;
1737 debug_options.ignore_deadness_checks =
1738 flags->tf_xla_disable_deadness_safety_checks_for_debugging;
1739 debug_options.ignore_resource_variable_checks =
1740 flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1741 debug_options.ignore_xla_compile_attr = false;
1742 debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1743 debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1744 debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1745 debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1746
1747 return MarkForCompilation(options, debug_options);
1748 }
1749
RunForTest(const GraphOptimizationPassOptions & options,bool disable_deadness_analysis)1750 Status MarkForCompilationPass::RunForTest(
1751 const GraphOptimizationPassOptions& options,
1752 bool disable_deadness_analysis) {
1753 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1754
1755 MarkForCompilationPassImpl::DebugOptions debug_options;
1756 debug_options.ignore_deadness_checks = disable_deadness_analysis;
1757 debug_options.ignore_resource_variable_checks =
1758 flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1759 debug_options.ignore_xla_compile_attr = true;
1760 debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1761 debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1762 debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1763 debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1764
1765 return MarkForCompilation(options, debug_options);
1766 }
1767
GetAllowlistTable()1768 absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
1769 // Table format: category name: {list of TF operations in that category}
1770 static absl::flat_hash_map<string, std::vector<string>>* result =
1771 new absl::flat_hash_map<string, std::vector<string>>{
1772 // Unary
1773 {"PW",
1774 {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
1775 "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
1776 "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
1777 "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
1778 "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
1779 "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
1780 "Lgamma", "Digamma",
1781 // Binary
1782 "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
1783 "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
1784 "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
1785 "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
1786 "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
1787 "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
1788 "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
1789 "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
1790 // Others
1791 "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
1792 "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
1793 "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
1794 "SelectV2", "Transpose", "ConjugateTranspose",
1795 "_UnaryOpsComposition", "CollectiveReduceV2",
1796 // The following 5 operations are converted to identity
1797 "PlaceholderWithDefault", "PreventGradient", "StopGradient",
1798 "Snapshot", "_EagerConst"}},
1799 // clang-format off
1800 {"RED",
1801 {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
1802 // clang-format on
1803 {"PWRED",
1804 {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1805 "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1806 {"REDUCEWINDOW",
1807 {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1808 "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1809 {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
1810 {"BN",
1811 {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1812 "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
1813 "FusedBatchNormGradV3"}},
1814 {"SORT", {"TopKV2"}}, // XLA version much faster then TF version.
1815 {"MISC",
1816 // clang-format off
1817 {"BroadcastTo", "ExpandDims", "Fill", "NoOp",
1818 "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
1819 "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
1820 "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
1821 "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
1822 "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
1823 "SplitV", "StridedSlice", "StridedSliceGrad",
1824 "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
1825 "Unpack", "DeviceIndex", "TensorStridedSliceUpdate",
1826 }}};
1827 // clang-format on
1828 return result;
1829 }
1830
1831 namespace testing {
ResetClusterSequenceNumber()1832 void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
1833
GetKnownXLAAllowlistOp()1834 absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
1835 absl::flat_hash_set<string> result{"AdjustContrastv2",
1836 "AdjustHue",
1837 "AdjustSaturation",
1838 "Asinh",
1839 "Assert",
1840 "AssignAddVariableOp",
1841 "AssignSubVariableOp",
1842 "AssignVariableOp",
1843 "AvgPool",
1844 "AvgPool3D",
1845 "AvgPool3DGrad",
1846 "AvgPoolGrad",
1847 "BatchMatMul",
1848 "BatchMatMulV2",
1849 "BatchMatMulV3",
1850 "BatchToSpace",
1851 "BatchToSpaceND",
1852 "BesselI0e",
1853 "BesselI1e",
1854 "Betainc",
1855 "BiasAddV1",
1856 "Bucketize",
1857 "Case",
1858 "CheckNumerics",
1859 "Cholesky",
1860 "ControlTrigger",
1861 "Conv2D",
1862 "Conv2DBackpropFilter",
1863 "Conv2DBackpropInput",
1864 "Conv3D",
1865 "Conv3DBackpropFilterV2",
1866 "Conv3DBackpropInputV2",
1867 "Cross",
1868 "Cumprod",
1869 "Cumsum",
1870 "DataFormatDimMap",
1871 "DataFormatVecPermute",
1872 "DepthToSpace",
1873 "DepthwiseConv2dNative",
1874 "DepthwiseConv2dNativeBackpropFilter",
1875 "DepthwiseConv2dNativeBackpropInput",
1876 "Dequantize",
1877 "Diag",
1878 "DynamicStitch",
1879 "DynamicPartition",
1880 "Einsum",
1881 "EmptyTensorList",
1882 "EnsureShape",
1883 "ExtractImagePatches",
1884 "Igamma",
1885 "IgammaGradA",
1886 "RandomGammaGrad",
1887 "Igammac",
1888 "FFT",
1889 "FFT2D",
1890 "FFT3D",
1891 "FakeParam",
1892 "FakeQuantWithMinMaxArgs",
1893 "FakeQuantWithMinMaxArgsGradient",
1894 "FakeQuantWithMinMaxVars",
1895 "FakeQuantWithMinMaxVarsGradient",
1896 "Gather",
1897 "GatherNd",
1898 "GatherV2",
1899 "HSVToRGB",
1900 "IFFT",
1901 "IFFT2D",
1902 "IFFT3D",
1903 "IRFFT",
1904 "IRFFT2D",
1905 "IRFFT3D",
1906 "If",
1907 "InTopKV2",
1908 "L2Loss",
1909 "LeakyRelu",
1910 "LinSpace",
1911 "ListDiff",
1912 "LogMatrixDeterminant",
1913 "LowerBound",
1914 "MatMul",
1915 "MatrixBandPart",
1916 "MatrixDiag",
1917 "MatrixDiagPart",
1918 "MatrixDiagPartV2",
1919 "MatrixDiagPartV3",
1920 "MatrixDiagV2",
1921 "MatrixDiagV3",
1922 "MatrixInverse",
1923 "MatrixSetDiag",
1924 "MatrixSetDiagV2",
1925 "MatrixSetDiagV3",
1926 "MatrixSolve",
1927 "MatrixTriangularSolve",
1928 "MaxPool",
1929 "MaxPool3D",
1930 "MaxPool3DGrad",
1931 "MaxPool3DGradGrad",
1932 "MaxPoolGrad",
1933 "MaxPoolGradGrad",
1934 "MaxPoolGradGradV2",
1935 "MaxPoolGradV2",
1936 "MaxPoolV2",
1937 "Multinomial",
1938 "NextAfter",
1939 "NonMaxSuppressionV4",
1940 "ParallelDynamicStitch",
1941 "ParameterizedTruncatedNormal",
1942 "PartitionedCall",
1943 "Polygamma",
1944 "PopulationCount",
1945 "Qr",
1946 "QuantizeAndDequantizeV2",
1947 "QuantizeAndDequantizeV3",
1948 "QuantizeAndDequantizeV4",
1949 "RFFT",
1950 "RFFT2D",
1951 "RFFT3D",
1952 "RGBToHSV",
1953 "RandomShuffle",
1954 "RandomStandardNormal",
1955 "RandomUniform",
1956 "RandomUniformInt",
1957 "ReadVariableOp",
1958 "ResizeBilinear",
1959 "ResizeBilinearGrad",
1960 "ResizeNearestNeighbor",
1961 "ResourceApplyAdaMax",
1962 "ResourceApplyAdadelta",
1963 "ResourceApplyAdagrad",
1964 "ResourceApplyAdagradDA",
1965 "ResourceApplyAdagradV2",
1966 "ResourceApplyAdam",
1967 "ResourceApplyAddSign",
1968 "ResourceApplyCenteredRMSProp",
1969 "ResourceApplyFtrl",
1970 "ResourceApplyFtrlV2",
1971 "ResourceApplyGradientDescent",
1972 "ResourceApplyKerasMomentum",
1973 "ResourceApplyMomentum",
1974 "ResourceApplyPowerSign",
1975 "ResourceApplyProximalAdagrad",
1976 "ResourceApplyProximalGradientDescent",
1977 "ResourceApplyRMSProp",
1978 "ResourceGather",
1979 "ResourceScatterAdd",
1980 "ResourceScatterDiv",
1981 "ResourceScatterMax",
1982 "ResourceScatterMin",
1983 "ResourceScatterMul",
1984 "ResourceScatterNdAdd",
1985 "ResourceScatterNdSub",
1986 "ResourceScatterNdUpdate",
1987 "ResourceScatterSub",
1988 "ResourceScatterUpdate",
1989 "RngReadAndSkip",
1990 "RngSkip",
1991 "Roll",
1992 "ScatterNd",
1993 "SelfAdjointEigV2",
1994 "SoftmaxCrossEntropyWithLogits",
1995 "SpaceToBatch",
1996 "SpaceToBatchND",
1997 "SpaceToDepth",
1998 "SparseMatMul",
1999 "SparseToDense",
2000 "StackCloseV2",
2001 "StackPopV2",
2002 "StackPushV2",
2003 "StackV2",
2004 "StatefulPartitionedCall",
2005 "StatefulStandardNormalV2",
2006 "StatefulTruncatedNormal",
2007 "StatefulUniform",
2008 "StatefulUniformFullInt",
2009 "StatefulUniformInt",
2010 "StatelessCase",
2011 "StatelessIf",
2012 "StatelessMultinomial",
2013 "StatelessRandomGetAlg",
2014 "StatelessRandomGetKeyCounter",
2015 "StatelessRandomGetKeyCounterAlg",
2016 "StatelessRandomNormal",
2017 "StatelessRandomNormalV2",
2018 "StatelessRandomUniform",
2019 "StatelessRandomUniformV2",
2020 "StatelessRandomUniformInt",
2021 "StatelessRandomUniformIntV2",
2022 "StatelessRandomUniformFullInt",
2023 "StatelessRandomUniformFullIntV2",
2024 "StatelessTruncatedNormal",
2025 "StatelessTruncatedNormalV2",
2026 "StatelessWhile",
2027 "Svd",
2028 "SymbolicGradient",
2029 "TensorArrayCloseV3",
2030 "TensorArrayConcatV3",
2031 "TensorArrayGatherV3",
2032 "TensorArrayGradV3",
2033 "TensorArrayReadV3",
2034 "TensorArrayScatterV3",
2035 "TensorArraySizeV3",
2036 "TensorArraySplitV3",
2037 "TensorArrayV3",
2038 "TensorArrayWriteV3",
2039 "TensorListConcatV2",
2040 "TensorListElementShape",
2041 "TensorListFromTensor",
2042 "TensorListGather",
2043 "TensorListGetItem",
2044 "TensorListLength",
2045 "TensorListPopBack",
2046 "TensorListPushBack",
2047 "TensorListReserve",
2048 "TensorListSetItem",
2049 "TensorListSplit",
2050 "TensorListStack",
2051 "TensorScatterAdd",
2052 "TensorScatterMax",
2053 "TensorScatterMin",
2054 "TensorScatterSub",
2055 "TensorScatterUpdate",
2056 "ToBool",
2057 "TridiagonalSolve",
2058 "TruncatedNormal",
2059 "Unique",
2060 "UpperBound",
2061 "UnsortedSegmentMax",
2062 "UnsortedSegmentMin",
2063 "UnsortedSegmentProd",
2064 "UnsortedSegmentSum",
2065 "VarIsInitializedOp",
2066 "VariableShape",
2067 "Where",
2068 "While",
2069 "XlaBroadcastHelper",
2070 "XlaConv",
2071 "XlaConvV2",
2072 "XlaDequantize",
2073 "XlaDot",
2074 "XlaDotV2",
2075 "XlaDynamicSlice",
2076 "XlaDynamicUpdateSlice",
2077 "XlaEinsum",
2078 "XlaGather",
2079 "XlaIf",
2080 "XlaKeyValueSort",
2081 "XlaPad",
2082 "XlaRecv",
2083 "XlaReduce",
2084 "XlaReduceWindow",
2085 "XlaRemoveDynamicDimensionSize",
2086 "XlaReplicaId",
2087 "XlaRngBitGenerator",
2088 "XlaScatter",
2089 "XlaSelectAndScatter",
2090 "XlaSelfAdjointEig",
2091 "XlaSend",
2092 "XlaSetBound",
2093 "XlaSetDynamicDimensionSize",
2094 "XlaSharding",
2095 "XlaSort",
2096 "XlaSpmdFullToShardShape",
2097 "XlaSpmdShardToFullShape",
2098 "XlaSvd",
2099 "XlaVariadicReduce",
2100 "XlaVariadicReduceV2",
2101 "XlaVariadicSort",
2102 "XlaWhile",
2103 "Zeta",
2104 "_Arg",
2105 "_ArrayToList",
2106 "_ListToArray",
2107 "_Retval"};
2108 return result;
2109 }
2110
2111 } // namespace testing
2112 } // namespace tensorflow
2113