• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17 
18 #include <atomic>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "absl/base/call_once.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/jit/compilability_check_util.h"
29 #include "tensorflow/compiler/jit/deadness_analysis.h"
30 #include "tensorflow/compiler/jit/defs.h"
31 #include "tensorflow/compiler/jit/device_util.h"
32 #include "tensorflow/compiler/jit/flags.h"
33 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
34 #include "tensorflow/compiler/jit/xla_cluster_util.h"
35 #include "tensorflow/compiler/tf2xla/const_analysis.h"
36 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
37 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
38 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/union_find.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/common_runtime/function.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/framework/bounds_check.h"
45 #include "tensorflow/core/framework/graph_def_util.h"
46 #include "tensorflow/core/framework/memory_types.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/tensor.pb.h"
50 #include "tensorflow/core/framework/types.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/control_flow.h"
53 #include "tensorflow/core/lib/gtl/cleanup.h"
54 #include "tensorflow/core/lib/strings/stringprintf.h"
55 #include "tensorflow/core/public/version.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 
58 namespace tensorflow {
59 
60 namespace {
61 using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
62 using jit::DeviceId;
63 using jit::DeviceSet;
64 
65 // The clusters we create here are eventually lowered into an
66 // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the
67 // PartitionedCall op to execute the cluster in the regular graph executor if
68 // need be.  PartitionedCall, however, reruns the entire TF graph optimization
69 // pipeline over the cluster which includes this mark for compilation pass.  To
70 // avoid endlessly recursing we tag nodes that we've already visited with this
71 // attribute so that we can bail out if we see them a second time.
72 //
73 // TODO(sanjoy): This method is not robust since it is possible that the
74 // optimizations run by PartitionedCall can mutate the cluster arbitrarily,
75 // dropping the kXlaAlreadyClustered attributes from all nodes in the process.
76 // The correct fix is to use the ConfigProto to pass in some sort of flag into
77 // the PartitionedCall kernel that tells it to not rerun auto-clustering on the
78 // cluster.
79 const char* kXlaAlreadyClustered = "_XlaAlreadyClustered";
80 
81 class MarkForCompilationPassImpl {
82  public:
83   struct DebugOptions {
84     // If true, do not respect the results of deadness analysis.
85     bool ignore_deadness_checks;
86 
87     // If true, do not do safety checks to preserve TensorFlow's resource
88     // variable concurrency semantics.
89     bool ignore_resource_variable_checks;
90 
91     // If true, do not respect the _XlaCompile=false attribute.
92     bool ignore_xla_compile_attr;
93 
94     int max_cluster_size;
95     int min_cluster_size;
96 
97     // Compiler fuel for the auto-clustering algorithm.
98     //
99     // We decrement this value by one on every time we choose a compilation
100     // candidate and we stop clustering when it hits zero.  This means the
101     // initial value for this variable (via --tf_xla_clustering_fuel=N)
102     // effectively acts as a "cap" for how much we cluster and we can bisect
103     // over this initial value to discover clustering decisions that cause a
104     // miscompile or a performance regression.
105     std::atomic<int64>* fuel;
106 
107     bool dump_graphs;
108   };
109 
MarkForCompilationPassImpl(DebugOptions debug_options,Graph * graph,FunctionLibraryDefinition * flib_def,Env * env,OptimizerOptions::GlobalJitLevel global_jit_level)110   MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph,
111                              FunctionLibraryDefinition* flib_def, Env* env,
112                              OptimizerOptions::GlobalJitLevel global_jit_level)
113       : debug_options_(debug_options),
114         graph_(graph),
115         flib_def_(flib_def),
116         env_(env),
117         global_jit_level_(global_jit_level) {}
118 
119   Status Run();
120 
121  private:
122   // Represents a "cluster" or a connected subgraph of a TensorFlow graph.
123   class Cluster {
124    public:
125     // Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,DeviceSet devices,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)126     Cluster(int tf_graph_node_id, int effective_cluster_size,
127             bool has_functional_control_flow, DeviceSet devices,
128             absl::optional<DeviceId> resource_op_device,
129             absl::optional<int> resource_var_operation_node_id,
130             absl::optional<DeadnessPredicate> deadness_predicate,
131             bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
132         : cycles_graph_node_id_(tf_graph_node_id),
133           effective_cluster_size_(effective_cluster_size),
134           has_functional_control_flow_(has_functional_control_flow),
135           devices_(std::move(devices)),
136           resource_op_device_(resource_op_device),
137           deadness_predicate_(deadness_predicate),
138           is_xla_compile_attr_true_(is_xla_compile_attr_true),
139           xla_scope_(std::move(xla_scope)) {
140       if (resource_var_operation_node_id.has_value()) {
141         resource_var_operation_node_ids_.push_back(
142             *resource_var_operation_node_id);
143       }
144     }
145 
146     // Merges `other` into this cluster, and clears `other`.  This method is
147     // closely tied with the implementation of `MarkForCompilationPassImpl`.
148     void Merge(Cluster* other);
149 
150     // If this is a trivial cluster containing only one node then return the ID
151     // of that node.  May not be called otherwise.
GetIdOfOnlyNode() const152     int GetIdOfOnlyNode() const {
153       DCHECK_EQ(cluster_size(), 1);
154       return cycles_graph_node_id();
155     }
156 
157     // The number of TF nodes in this cluster.
cluster_size() const158     int cluster_size() const { return cluster_size_; }
159 
160     // The ID of the cluster as represented in `cycles_graph_`.
cycles_graph_node_id() const161     int cycles_graph_node_id() const { return cycles_graph_node_id_; }
162 
163     // Sets the ID of the cluster as represented in `cycles_graph_`.
set_cycles_graph_node_id(int cycles_graph_node_id)164     void set_cycles_graph_node_id(int cycles_graph_node_id) {
165       cycles_graph_node_id_ = cycles_graph_node_id;
166     }
167 
168     // The size of the cluster excluding constant and identity nodes.
effective_cluster_size() const169     int effective_cluster_size() const { return effective_cluster_size_; }
170 
171     // True if the cluster has functional control flow like `If` and `While`.
has_functional_control_flow() const172     bool has_functional_control_flow() const {
173       return has_functional_control_flow_;
174     }
175 
176     // The set of devices nodes in the cluster are placed on.
devices() const177     const DeviceSet& devices() const { return devices_; }
178 
179     // If the cluster has a resource operation then the device the resource
180     // operation is placed on.  A cluster may have resource ops placed only on a
181     // single device.
resource_op_device() const182     const absl::optional<DeviceId>& resource_op_device() const {
183       return resource_op_device_;
184     }
185 
186     // If not nullopt the a predicate that is true iff the cluster is alive.
187     // Otherwise the user has (unsafely) disabled deadness analysis.  If this is
188     // unset on a single Cluster instance then it is unset on all Cluster
189     // instances.
deadness_predicate() const190     const absl::optional<DeadnessPredicate>& deadness_predicate() const {
191       return deadness_predicate_;
192     }
193 
194     // If true then the cluster has a XlaCompile=true attribute on one of its
195     // nodes.
is_xla_compile_attr_true() const196     bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; }
197 
198     // If not nullopt then the all nodes in the cluster either do not have the
199     // XlaScope attribute set or have it set to the value returned.
xla_scope() const200     const absl::optional<string>& xla_scope() const { return xla_scope_; }
201 
202     // Returns the TF graph node IDs for the resource variable operations in
203     // this cluster.
resource_var_operation_node_ids() const204     absl::Span<const int> resource_var_operation_node_ids() const {
205       return resource_var_operation_node_ids_;
206     }
207 
DebugString(const Graph & graph) const208     string DebugString(const Graph& graph) const {
209       Node* node = graph.FindNodeId(cycles_graph_node_id());
210       if (!node) {
211         // This should never happen but we try to be resilient because this is a
212         // debugging aid.
213         return absl::StrCat("NULL NODE IN #", cycles_graph_node_id());
214       }
215 
216       if (cluster_size() == 1) {
217         return absl::StrCat("<", node->name(), " #", cycles_graph_node_id(),
218                             ">");
219       }
220 
221       return absl::StrCat("<", node->name(), " + ", cluster_size() - 1,
222                           " others #", cycles_graph_node_id(), ">");
223     }
224 
225    private:
226     int cluster_size_ = 1;
227     int cycles_graph_node_id_;
228     int effective_cluster_size_;
229     bool has_functional_control_flow_;
230     DeviceSet devices_;
231     absl::optional<DeviceId> resource_op_device_;
232     absl::optional<DeadnessPredicate> deadness_predicate_;
233     bool is_xla_compile_attr_true_;
234     absl::optional<string> xla_scope_;
235     std::vector<int> resource_var_operation_node_ids_;
236 
237     TF_DISALLOW_COPY_AND_ASSIGN(Cluster);
238   };
239 
240   // If `cluster` has only a single node then returns that, otherwise returns
241   // nullptr.
242   Node* GetOnlyNodeIn(const Cluster& cluster);
243 
244   // Returns true if `cluster` is a trivial cluster containing a "sink like"
245   // node -- a NoOp node that only the Sink node control depends on.
246   bool IsSinkLike(const Cluster& cluster);
247 
248   // Returns true if `cluster` looks like an "i++" operation on an integer
249   // scalar resource variable.
250   bool IsScalarIntegerResourceOperation(const Cluster& cluster);
251 
252   // ---------------------------------------------------------------------------
253   // The pass proceeds in four steps, out of which `RunEdgeContractionLoop` and
254   // `CreateClusters` do most of the heavy lifting.
255 
256   // Initializes some internal data structures.
257   //
258   // If this returns false then Initialize exited early (either because there is
259   // nothing to do or we saw a graph that we can't handle) and not all the
260   // fields in this MarkForCompilationPassImpl instance are set up.
261   StatusOr<bool> Initialize();
262 
263   // Runs through the entire cluster graph in post-order and calls `fn(from,
264   // to)` on each edge.  `fn(from, to)` is expected to return true if it was
265   // able to contract `from`->`to`.
266   //
267   // Returns true if `fn` returned true for any edge.
268   template <typename FnTy>
269   StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn);
270 
271   // Contracts as many edges as possible to create XLA clusters.  After this
272   // finishes the clustering decisions made are implicitly stored in
273   // `clusters_`.
274   Status RunEdgeContractionLoop();
275 
276   // Manifests the clustering decisions into the TF graph by tagging nodes with
277   // an `_XlaCluster` attribute.  Also some basic filter logic, like
278   // tf_xla_min_cluster_size, are applied here.
279   Status CreateClusters();
280 
281   Status DumpDebugInfo();
282 
IsCompilationCandidate(Node * n) const283   bool IsCompilationCandidate(Node* n) const {
284     return compilation_candidates_.find(n) != compilation_candidates_.end();
285   }
286 
287   // Tries to contract the edge from cluster `from` to cluster `to`.  Returns
288   // true if successful.
289   StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to);
290 
291   // Nodes that XLA can compile are put in `compilation_candidates_`.
292   Status FindCompilationCandidates();
293 
294   bool CompilationDisallowedByXlaCompileAttr(Node* node);
295 
296   // Populates `clusters_`.
297   Status BuildInitialClusterSet();
298 
299   StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster);
300 
301   StatusOr<bool> ShouldCompileCluster(const Cluster& cluster);
302 
303   StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency(
304       const Cluster& from, const Cluster& to);
305 
306   // Returns true if the devices in `cluster_a` and `cluster_b` are compatible
307   // and therefore not a hindrance for combining the two clusters into a larger
308   // cluster.
309   StatusOr<bool> AreDevicesCompatible(const Cluster& cluster_a,
310                                       const Cluster& cluster_b);
311 
312   void DumpPostClusteringGraphs();
313   void VLogClusteringSummary();
314 
MakeNewCluster(int cycles_graph_node_id,int effective_cluster_size,bool has_functional_control_flow,const DeviceSet & device_set,absl::optional<DeviceId> resource_op_device,absl::optional<int> resource_var_operation_node_id,absl::optional<DeadnessPredicate> deadness_predicate,bool is_xla_compile_attr_true,absl::optional<string> xla_scope)315   Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
316                           bool has_functional_control_flow,
317                           const DeviceSet& device_set,
318                           absl::optional<DeviceId> resource_op_device,
319                           absl::optional<int> resource_var_operation_node_id,
320                           absl::optional<DeadnessPredicate> deadness_predicate,
321                           bool is_xla_compile_attr_true,
322                           absl::optional<string> xla_scope) {
323     cluster_storage_.push_back(absl::make_unique<Cluster>(
324         cycles_graph_node_id, effective_cluster_size,
325         has_functional_control_flow, device_set, resource_op_device,
326         resource_var_operation_node_id, deadness_predicate,
327         is_xla_compile_attr_true, xla_scope));
328     return cluster_storage_.back().get();
329   }
330 
331   absl::optional<string> GetXlaScope(Node* n);
332 
333   // Returns the cluster for node `n`.  If two nodes, N1 and N2, are placed in
334   // the same cluster by the clustering algorithm then this function will return
335   // the same Cluster instance for N1 and N2.
336   //
337   // Returns nullptr if `n` is not a compilation candidate.
GetClusterForNode(Node * n)338   Cluster* GetClusterForNode(Node* n) {
339     return cluster_for_node_[n->id()].Get();
340   }
341 
342   // Returns the cluster for a node in `cycles_graph_`.  This uses the same
343   // underlying map because of how we set things up, but we can do an additional
344   // CHECK in this accessor.
345   //
346   // Returns nullptr if `node_id` is not a compilation candidate.
GetClusterForCyclesGraphNode(int node_id)347   Cluster* GetClusterForCyclesGraphNode(int node_id) {
348     // We have to check `graph_->FindNodeId(node) == nullptr` because we add all
349     // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the
350     // TF graph may be missing some node ids.
351     if (node_id >= graph_->num_node_ids() ||
352         graph_->FindNodeId(node_id) == nullptr) {
353       return nullptr;
354     }
355     Cluster* cluster = cluster_for_node_[node_id].Get();
356     if (cluster) {
357       DCHECK_EQ(cluster->cycles_graph_node_id(), node_id);
358     }
359     return cluster;
360   }
361 
362   bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to,
363                                         absl::string_view reason);
364 
365   // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct
366   // edge from `from` to `to`.
367   //
368   // Tries to find a path that contains at least one unclusterable node.
369   std::vector<int> FindAlternatePathForDebugging(int from, int to);
370 
371   // Returns a string representing `cycles_graph_node_id`.  If the node is
372   // unclusterable (either it is a phatom "frame" node or is not a compilation
373   // candidate) then set `*found_unclustered` to true.
374   string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered);
375 
376   // We could not contract the edge from `from` to `to`.  Return a string
377   // describing an alternate path from `from` to `to` (besides the direct edge
378   // from `from` to `to`) which would have created a cycle had we contracted the
379   // edge.
380   //
381   // Tries (if possible) to find a path that contains at least one unclusterable
382   // node as it is surprising to the user if we print "A->B could not be
383   // contracted because of the path [P,Q,R]" where P, Q and R are all clusters
384   // since in that case a natural question is why we could not form a {A, P, Q,
385   // R, B} cluster.
386   string DescribePotentialCycle(int from, int to);
387 
388   // Merge the clusters `cluster_from` and `cluster_to`. After this step the
389   // larger combined cluster is represented by `cluster_from`, but can have
390   // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
391   // which way will require less operations.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)392   bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
393     int from = cluster_from->cycles_graph_node_id();
394     int to = cluster_to->cycles_graph_node_id();
395 
396     auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
397     if (!optional_merged_node.has_value()) {
398       VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
399               << " -> " << cluster_to->DebugString(*graph_)
400               << " because contracting the edge would create a cycle via "
401               << DescribePotentialCycle(from, to) << ".";
402       return false;
403     }
404 
405     // Merge the clusters.
406     cluster_from->Merge(cluster_to);
407     // Update `cycle_graph_`'s ID.
408     cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
409 
410     // Merge the UnionFind<Cluster*>.
411     cluster_for_node_[from].Merge(&cluster_for_node_[to]);
412 
413     return true;
414   }
415 
EdgeContractionFailureMsg(Cluster * from,Cluster * to,absl::string_view reason)416   string EdgeContractionFailureMsg(Cluster* from, Cluster* to,
417                                    absl::string_view reason) {
418     return absl::StrCat("Could not contract ", from->DebugString(*graph_),
419                         " -> ", to->DebugString(*graph_), " because ", reason,
420                         ".");
421   }
422 
423   DebugOptions debug_options_;
424   Graph* graph_;
425   FunctionLibraryDefinition* flib_def_;
426   Env* env_;
427   OptimizerOptions::GlobalJitLevel global_jit_level_;
428   absl::flat_hash_map<const Cluster*, bool> should_compile_cluster_cache_;
429   jit::DeviceInfoCache device_info_cache_;
430 
431   bool initialized_ = false;
432   bool edges_contracted_ = false;
433   bool clusters_created_ = false;
434 
435   std::vector<std::unique_ptr<Cluster>> cluster_storage_;
436   std::vector<UnionFind<Cluster*>> cluster_for_node_;
437   GraphCycles cycles_graph_;
438   OrderedNodeSet compilation_candidates_;
439   std::unique_ptr<DeadnessAnalysis> deadness_analysis_;
440   int64 iteration_count_ = 0;
441   absl::flat_hash_set<std::pair<int, int>> unsafe_resource_deps_;
442 };
443 
FindAlternatePathForDebugging(int from,int to)444 std::vector<int> MarkForCompilationPassImpl::FindAlternatePathForDebugging(
445     int from, int to) {
446   std::vector<int> rpo = cycles_graph_.AllNodesInPostOrder();
447   absl::c_reverse(rpo);
448 
449   // best_pred_for_node[n] contains a predecessor of `n` that has an
450   // unclusterable node in some path from `from` to itself.
451   // best_pred_for_node[n] is unpopulated for nodes that are not reachable from
452   // `from`.  We build this table up inductively by traversing the cycles graph
453   // in RPO.
454   absl::flat_hash_map<int, int> best_pred_for_node;
455   best_pred_for_node[from] = -1;
456 
457   int rpo_index = 0, current_rpo_node;
458   do {
459     current_rpo_node = rpo[rpo_index++];
460     absl::optional<int> some_pred, preferred_pred;
461     for (int pred : cycles_graph_.Predecessors(current_rpo_node)) {
462       if (!best_pred_for_node.contains(pred)) {
463         continue;
464       }
465 
466       // Ignore the from->to edge since we're trying to find an alternate path.
467       if (current_rpo_node == to && pred == from) {
468         continue;
469       }
470 
471       some_pred = pred;
472       if (GetClusterForCyclesGraphNode(pred) == nullptr) {
473         preferred_pred = pred;
474       }
475     }
476 
477     if (some_pred || preferred_pred) {
478       best_pred_for_node[current_rpo_node] =
479           preferred_pred.has_value() ? *preferred_pred : *some_pred;
480     }
481   } while (current_rpo_node != to);
482 
483   auto get_best_pred = [&](int n) {
484     auto it = best_pred_for_node.find(n);
485     CHECK(it != best_pred_for_node.end());
486     return it->second;
487   };
488 
489   std::vector<int> path;
490   int current_path_node = get_best_pred(to);
491   while (current_path_node != from) {
492     path.push_back(current_path_node);
493     current_path_node = get_best_pred(current_path_node);
494   }
495 
496   absl::c_reverse(path);
497   return path;
498 }
499 
DebugStringForCyclesGraphNode(int cycles_graph_node_id,bool * found_unclustered)500 string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode(
501     int cycles_graph_node_id, bool* found_unclustered) {
502   Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id);
503   if (cluster) {
504     return cluster->DebugString(*graph_);
505   }
506 
507   *found_unclustered = true;
508   if (cycles_graph_node_id >= graph_->num_node_ids()) {
509     return absl::StrCat("<oob #", cycles_graph_node_id, ">");
510   }
511 
512   Node* node = graph_->FindNodeId(cycles_graph_node_id);
513   if (!node) {
514     return absl::StrCat("<bad #", cycles_graph_node_id, ">");
515   }
516 
517   return node->name();
518 }
519 
DescribePotentialCycle(int from,int to)520 string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) {
521   std::vector<string> path_str;
522   bool found_unclustered = false;
523   absl::c_transform(FindAlternatePathForDebugging(from, to),
524                     std::back_inserter(path_str), [&](int node_id) {
525                       return DebugStringForCyclesGraphNode(node_id,
526                                                            &found_unclustered);
527                     });
528   return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[",
529                       absl::StrJoin(path_str, ","), "]");
530 }
531 
Merge(Cluster * other)532 void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
533   // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does.
534 
535   // Clearing out data structures in `other` is just a memory saving
536   // optimization and not needed for correctness.
537 
538   cluster_size_ += other->cluster_size_;
539   effective_cluster_size_ += other->effective_cluster_size_;
540   has_functional_control_flow_ |= other->has_functional_control_flow_;
541 
542   devices_.UnionWith(other->devices_);
543 
544   DCHECK(!(resource_op_device_.has_value() &&
545            other->resource_op_device_.has_value()) ||
546          *resource_op_device_ == *other->resource_op_device_)
547       << "AreDevicesCompatible should have returned false otherwise!";
548 
549   if (!resource_op_device_.has_value()) {
550     resource_op_device_ = other->resource_op_device_;
551   }
552 
553   is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
554 
555   if (!xla_scope_.has_value()) {
556     xla_scope_ = std::move(other->xla_scope_);
557   }
558 
559   resource_var_operation_node_ids_.reserve(
560       resource_var_operation_node_ids_.size() +
561       other->resource_var_operation_node_ids_.size());
562   absl::c_copy(other->resource_var_operation_node_ids_,
563                std::back_inserter(resource_var_operation_node_ids_));
564   other->resource_var_operation_node_ids_.clear();
565 }
566 
IgnoreResourceOpForSafetyAnalysis(jit::DeviceInfoCache * device_info_cache,const Node & n,bool * ignore)567 Status IgnoreResourceOpForSafetyAnalysis(
568     jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) {
569   // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
570   // ignore it during resource operation safety analysis.  We need this hack
571   // because of two reasons:
572   //
573   //  1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
574   //  2. We don't support live-out values of type DT_RESOURCE and live-in values
575   //     of type DT_RESOURCE that are not resource variables.
576   //
577   // Together these imply we cannot let resource variable safety analysis
578   // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
579   // clusters: both of them will have to be clustered because of (1) and we
580   // won't be able to keep the edge between the two as neither the input to the
581   // second XLA cluster nor the output from the first XLA cluster are supported
582   // because of (2).
583   //
584   // TODO(b/113100872): This can be fixed if the TensorFlow representation for
585   // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
586   // (2) would no longer hold.
587 
588   if (n.assigned_device_name().empty()) {
589     *ignore = false;
590     return Status::OK();
591   }
592 
593   TF_ASSIGN_OR_RETURN(
594       const XlaOpRegistry::DeviceRegistration* registration,
595       device_info_cache->GetCompilationDevice(n.assigned_device_name()));
596 
597   if (!registration) {
598     *ignore = true;
599   } else {
600     *ignore = registration->cluster_resource_variable_ops_unsafely;
601   }
602   return Status::OK();
603 }
604 
Initialize()605 StatusOr<bool> MarkForCompilationPassImpl::Initialize() {
606   TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_);
607   initialized_ = true;
608 
609   TF_RETURN_IF_ERROR(FindCompilationCandidates());
610 
611   if (compilation_candidates_.empty()) {
612     VLOG(2) << "No compilable candidates";
613     return false;
614   }
615 
616   TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
617                       CreateCycleDetectionGraph(graph_, &cycles_graph_));
618   if (!cycle_detection_graph_ok) {
619     // TODO(sanjoy): This should be logged via the XLA activity listener.
620     VLOG(2) << "Could not form cycle detection graph";
621     return false;
622   }
623 
624   if (!debug_options_.ignore_deadness_checks) {
625     XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
626     TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_));
627   }
628 
629   // Each compilation candidate belongs to a cluster. The cluster's
630   // representative names the node in the 'cycles' graph that represents the
631   // cluster.
632   TF_RETURN_IF_ERROR(BuildInitialClusterSet());
633   return true;
634 }
635 
636 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)637 StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) {
638   bool changed = false;
639   for (int32_t node : cycles_graph_.AllNodesInPostOrder()) {
640     Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
641     if (!cluster_from) {
642       continue;
643     }
644 
645     // Make a copy of the set of successors because we may modify the graph in
646     // TryToContractEdge.
647     std::vector<int32> successors_copy =
648         cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
649 
650     for (int to : successors_copy) {
651       iteration_count_++;
652 
653       Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
654       if (!cluster_to) {
655         continue;
656       }
657 
658       TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to));
659       changed |= contracted_edge;
660     }
661   }
662 
663   return changed;
664 }
665 
GetOnlyNodeIn(const Cluster & cluster)666 Node* MarkForCompilationPassImpl::GetOnlyNodeIn(const Cluster& cluster) {
667   return cluster.cluster_size() == 1
668              ? graph_->FindNodeId(cluster.GetIdOfOnlyNode())
669              : nullptr;
670 }
671 
IsSinkLike(const Cluster & cluster)672 bool MarkForCompilationPassImpl::IsSinkLike(const Cluster& cluster) {
673   if (Node* n = GetOnlyNodeIn(cluster)) {
674     return n->type_string() == "NoOp" && n->out_edges().size() == 1 &&
675            (*n->out_edges().begin())->dst()->IsSink();
676   }
677 
678   return false;
679 }
680 
IsScalarIntegerResourceOperation(const Cluster & cluster)681 bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
682     const Cluster& cluster) {
683   Node* n = GetOnlyNodeIn(cluster);
684   if (!n) {
685     return false;
686   }
687 
688   if (n->type_string() != "AssignAddVariableOp" &&
689       n->type_string() != "AssignSubVariableOp") {
690     return false;
691   }
692 
693   DataType dtype;
694   if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
695     return false;
696   }
697 
698   Node* const_input = nullptr;
699   for (const Edge* e : n->in_edges()) {
700     if (!e->IsControlEdge() && e->src()->IsConstant()) {
701       const_input = e->src();
702       break;
703     }
704   }
705 
706   if (!const_input) {
707     return false;
708   }
709 
710   const TensorProto* proto = nullptr;
711   if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
712     return false;
713   }
714 
715   return TensorShapeUtils::IsScalar(proto->tensor_shape());
716 }
717 
RunEdgeContractionLoop()718 Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
719   TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_);
720   edges_contracted_ = true;
721 
722   // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
723   // example, from the Grappler fusion pass).
724 
725   // In general there are multiple maximal clusterings, but they are not all
726   // equally performant.  Some clustering decision are likely to improve
727   // performance much more than others, and we cannot order contractions on this
728   // cost function, nor can we look at global information while deciding on
729   // individual edges to contract.  Instead, we will make decisions on these
730   // important edges then make decisions on all other edges, causing the highest
731   // chance of all most important edges to be contracted.
732   //
733   // An example of where this might occur is with a digraph:
734   // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is
735   // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B
736   // should be clustered with A because it will prevent a potentially large
737   // tensor from A being computed and copied.
738   //
739   // To choose better maximal clusterings we make multiple iterations over the
740   // graph in post-order, where each such iteration is called a "phase".
741 
742   // Phase 0: contract metadata operations with their producer.
743 
744   VLOG(4) << "Running phase 0";
745   TF_RETURN_IF_ERROR(
746       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
747         // Shape consuming operations are desirable to cluster with their
748         // operands because they return a small set of scalar values after
749         // consuming a large amount of data.  For example, given a graph X -> Y
750         // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or
751         // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the
752         // output of size will be a small tensor while Y is a potentially large
753         // tensor that must be computed and possible transposed/copied before
754         // the second cluster executes.
755         Node* n = GetOnlyNodeIn(*to);
756         bool is_shape_consumer_op = n && IsShapeConsumerOp(*n);
757         if (!is_shape_consumer_op) {
758           return false;
759         }
760 
761         return TryToContractEdge(from, to);
762       }).status());
763 
764   // Phase 1: apply a heuristic to ensure that we don't mess up clustering due
765   // to "group_deps".  After this phase most edges should have been contracted.
766 
767   VLOG(4) << "Running phase 1";
768   TF_RETURN_IF_ERROR(
769       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> {
770         // We split out this phase to get good clustering in the presence of a
771         // specific pattern seen in some graphs:
772         //
773         // digraph {
774         //   ApplyWeightUpdates_0 -> "iteration++"
775         //   ApplyWeightUpdates_1 -> "iteration++"
776         //   ApplyWeightUpdates_2 -> "iteration++"
777         //   ApplyWeightUpdates_0 -> Computation_A
778         //   ApplyWeightUpdates_1 -> Computation_B
779         //   ApplyWeightUpdates_2 -> Computation_C
780         //   Computation_A -> NoOp
781         //   Computation_B -> NoOp
782         //   Computation_C -> NoOp
783         //   "iteration++" -> NoOp
784         // }
785         //
786         // In the graph above we can't cluster iteration++ with any of the
787         // gradient update operations since that will break the TF resource
788         // variable memory model.  Given that constraint the ideal clustering
789         // would be to put all the gradient updates and all of the Computation_*
790         // nodes in one cluster, and leave iteration++ and NoOp unclustered.
791         //
792         // A naive post-order traversal would not create this good clustering,
793         // however.  Instead it will first create a cluster that puts
794         // Computation_* nodes, the NoOp and iteration++ node in a single
795         // cluster, after which it will fail to put any of the
796         // ApplyWeightUpdates_* nodes into this cluster. To avoid this fate we
797         // instead run a pass that avoids contracting edges _into_ NoOps like
798         // the above, and avoid clustering edges _from_ "iteration++" like the
799         // above.  Then we run a second pass that contracts the edges we could
800         // not contract the first time around.
801 
802         if (IsSinkLike(*to)) {
803           return false;
804         }
805 
806         if (IsScalarIntegerResourceOperation(*from)) {
807           return false;
808         }
809 
810         return TryToContractEdge(from, to);
811       }).status());
812 
813   // Phase 2: contract any remaining edges.  After this phase we should have a
814   // maximal clustering:
815   //
816   // A. We visit a cluster only after maximally clustering all its children.
817   // B. By the time we're done with a node all of its children that could have
818   //    been absorbed into the node have been absorbed.
819   // C. We have an invariant that making a cluster larger does not make edges
820   //    leaving it more contractable. That is, if we have
821   //    digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible
822   //    to contract Y->Z if Y->Z was not contractible originally.
823   VLOG(4) << "Running phase 2";
824   TF_RETURN_IF_ERROR(ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
825                        return TryToContractEdge(from, to);
826                      }).status());
827 
828   // Check that the conclusion made above (that iterating over the graph once in
829   // post order gives a maximal clustering) holds.  Once the linear time
830   // post-order scheme has been battle tested we can move this to happen only in
831   // debug builds.
832   VLOG(2) << "Checking idempotence";
833   TF_ASSIGN_OR_RETURN(bool changed,
834                       ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) {
835                         return TryToContractEdge(from, to);
836                       }));
837   TF_RET_CHECK(!changed);
838 
839   return Status::OK();
840 }
841 
842 std::atomic<int64> cluster_sequence_num;
843 
GetNextClusterSequenceNumber()844 int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; }
845 
CreateClusters()846 Status MarkForCompilationPassImpl::CreateClusters() {
847   TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_);
848   clusters_created_ = true;
849 
850   // Names for each cluster.
851   std::unordered_map<int, string> cluster_names;
852 
853   if (debug_options_.dump_graphs) {
854     DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_);
855   }
856 
857   // Mark clusters for compilation that:
858   // * are placed on a device that requires compilation (an XlaDevice),
859   // * are explicitly marked for compilation (_XlaCompile=true), or
860   // * have more than debug_options_.xla_min_cluster_size elements (applicable
861   //   only if compilation is enabled, otherwise there will be no such
862   //   candidates).
863   for (Node* n : compilation_candidates_) {
864     Cluster* cluster = GetClusterForNode(n);
865     TF_ASSIGN_OR_RETURN(bool should_compile_cluster,
866                         ShouldCompileCluster(*cluster));
867     if (!should_compile_cluster) {
868       continue;
869     }
870 
871     // We assume that functional If and While nodes have at least
872     // min_cluster_size non-trivial nodes in them.  It would be more principled
873     // to (recursively) verify this fact, but that's probably not worth the
874     // trouble.
875 
876     if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size ||
877         cluster->has_functional_control_flow() ||
878         cluster->is_xla_compile_attr_true()) {
879       string& name = cluster_names[cluster->cycles_graph_node_id()];
880 
881       if (name.empty()) {
882         name = absl::StrCat("cluster_", GetNextClusterSequenceNumber());
883       }
884 
885       n->AddAttr(kXlaClusterAttr, name);
886       n->AddAttr(kXlaAlreadyClustered, true);
887       VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
888     }
889   }
890 
891   return Status::OK();
892 }
893 
DumpDebugInfo()894 Status MarkForCompilationPassImpl::DumpDebugInfo() {
895   TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_);
896 
897   if (debug_options_.dump_graphs) {
898     DumpPostClusteringGraphs();
899   }
900 
901   VLogClusteringSummary();
902 
903   return Status::OK();
904 }
905 
906 StatusOr<bool>
ClusteringWillIntroduceInterDeviceDependency(const Cluster & cluster_from,const Cluster & cluster_to)907 MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
908     const Cluster& cluster_from, const Cluster& cluster_to) {
909   // If any of the consumer's producers are on a different device, do not
910   // cluster these nodes. This prevents other work on this device from being
911   // delayed by work on other devices. We consider predecessors of the entire
912   // cluster rather than just the inputs to the node to prevent the cluster
913   // still being combined in cases where the 'to' cluster has multiple
914   // dependencies on the 'from' cluster and another dependency leads to a
915   // merging of the clusters.
916   //
917   // TODO(b/117085735): We probably want to handle the reciprocal of this case
918   // where a cluster is producing data for multiple devices.
919   for (const auto& in_id :
920        cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) {
921     const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id);
922     if (cluster_in) {
923       TF_ASSIGN_OR_RETURN(bool devices_compatible,
924                           AreDevicesCompatible(cluster_to, *cluster_in));
925       if (!devices_compatible) {
926         return true;
927       }
928       TF_ASSIGN_OR_RETURN(devices_compatible,
929                           AreDevicesCompatible(cluster_from, *cluster_in));
930       if (!devices_compatible) {
931         return true;
932       }
933     }
934   }
935 
936   return false;
937 }
938 
GetXlaScope(Node * node)939 absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
940   // Look for either _XlaScope or _XlaInternalScope on both nodes to guide
941   // clustering.  If both nodes have a scope and the scopes do not match, do
942   // not cluster along this edge.  If even one of the nodes lacks a scope
943   // attribute, then it is treated as a "bridge" and a cluster may be created
944   // along it.
945   //
946   // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is
947   // provided by users through jit_scope APIs, while _XlaInternalScope is
948   // automatically generated by the ClusterScopingPass when auto_jit is on.  As
949   // such, we respect _XlaScope only when auto_jit is off, while respecting
950   // _XlaInternalScope only when auto_jit is on.
951   //
952   // We may want to restrict the _XlaScope behavior to require all nodes marked
953   // with _XlaCompile=true to also have a _XlaScope property set (and raise an
954   // error otherwise); but for now we don't do this.
955 
956   if (global_jit_level_ != OptimizerOptions::OFF) {
957     // If global_jit_level_ is ON, respect only _XlaInternalScope.
958     const string& scope =
959         GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr);
960     if (!scope.empty()) {
961       return scope;
962     }
963   } else {
964     // If global_jit_level_ is OFF, respect only _XlaScope.
965     const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
966     if (!scope.empty()) {
967       return scope;
968     }
969   }
970 
971   return absl::nullopt;
972 }
973 
974 // Returns true iff the attribute `attr_name` is attached to either the node or
975 // to it's callee.
GetNodeOrFuncAttr(Node * node,FunctionLibraryDefinition * flib_def,const char * attr_name)976 static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
977                               const char* attr_name) {
978   bool out = false;
979   bool attr_value;
980   if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
981     out |= attr_value;
982   }
983 
984   if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
985     out |= attr_value;
986   }
987   return out;
988 }
989 
BuildInitialClusterSet()990 Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
991   auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
992     return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
993   };
994 
995   std::vector<std::pair<int, int>> unsafe_resource_deps_vect;
996   TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
997       *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect));
998   absl::c_copy(
999       unsafe_resource_deps_vect,
1000       std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin()));
1001 
1002   cluster_for_node_.resize(graph_->num_node_ids());
1003   for (Node* node : graph_->nodes()) {
1004     if (!IsCompilationCandidate(node)) {
1005       cluster_for_node_[node->id()].Get() = nullptr;
1006       continue;
1007     }
1008 
1009     // We want clusters to be big enough that the benefit from XLA's
1010     // optimizations offsets XLA related overhead (for instance we add some
1011     // Switch/Merge nodes into the graph to implement lazy compilation).  To
1012     // this end, we don't count Identity and Constant nodes because they do not
1013     // enable interesting optimizations by themselves.
1014     int effective_cluster_size =
1015         (node->IsIdentity() || node->IsConstant()) ? 0 : 1;
1016 
1017     bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
1018 
1019     absl::optional<DeadnessPredicate> deadness_predicate;
1020     if (deadness_analysis_) {
1021       TF_ASSIGN_OR_RETURN(
1022           deadness_predicate,
1023           deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
1024     }
1025 
1026     const string& device_name_str = !node->assigned_device_name().empty()
1027                                         ? node->assigned_device_name()
1028                                         : node->requested_device();
1029     TF_ASSIGN_OR_RETURN(DeviceId device,
1030                         device_info_cache_.GetIdFor(device_name_str));
1031 
1032     bool is_resource_op = HasResourceInputOrOutput(*node);
1033     absl::optional<DeviceId> resource_op_device;
1034     if (is_resource_op) {
1035       resource_op_device = device;
1036     }
1037 
1038     absl::optional<int> resource_var_operation_node_id;
1039     if (is_resource_op || MayCallFunction(*node, flib_def_)) {
1040       resource_var_operation_node_id = node->id();
1041     }
1042 
1043     bool is_xla_compile_attr_true =
1044         GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
1045         (global_jit_level_ != OptimizerOptions::OFF &&
1046          GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr));
1047 
1048     DeviceSet devices;
1049     devices.Insert(device);
1050 
1051     Cluster* new_cluster = MakeNewCluster(
1052         /*cycles_graph_node_id=*/node->id(),
1053         /*effective_cluster_size=*/effective_cluster_size,
1054         /*has_functional_control_flow=*/has_functional_control_flow, devices,
1055         resource_op_device, resource_var_operation_node_id, deadness_predicate,
1056         /*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
1057         GetXlaScope(node));
1058 
1059     cluster_for_node_[node->id()].Get() = new_cluster;
1060   }
1061 
1062   return Status::OK();
1063 }
1064 
IsIdentityDrivingConstsInLoop(Node * node)1065 StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
1066   if (!node->IsIdentity()) {
1067     return false;
1068   }
1069 
1070   // Check if the Identity is driven by a Switch on its true path.
1071   auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
1072     return e->src()->IsSwitch() && e->src_output() == 1;
1073   });
1074   if (it == node->in_edges().end()) {
1075     return false;
1076   }
1077   const Node* switch_node = (*it)->src();
1078 
1079   // Check if the Switch is driven by LoopCond.
1080   const Node* maybe_loop_cond;
1081   TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
1082   if (!maybe_loop_cond->IsLoopCond()) {
1083     return false;
1084   }
1085 
1086   // Check if the Identity is driving any const nodes through a control edge.
1087   bool driving_any_consts =
1088       absl::c_any_of(node->out_edges(), [](const Edge* e) {
1089         return e->dst()->IsConstant() && e->IsControlEdge();
1090       });
1091   if (!driving_any_consts) {
1092     return false;
1093   }
1094 
1095   return true;
1096 }
1097 
GetOrCreateAllowlist()1098 absl::flat_hash_set<string> GetOrCreateAllowlist() {
1099   absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
1100       tensorflow::GetAllowlistTable();
1101   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1102   absl::flat_hash_set<string> allowlist;
1103 
1104   for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
1105     if (s == "FUSIBLE") {
1106       for (auto pair : *allowlist_table) {
1107         allowlist.insert(pair.second.begin(), pair.second.end());
1108       }
1109     } else if (allowlist_table->contains(s)) {
1110       auto v = allowlist_table->at(s);
1111       allowlist.insert(v.begin(), v.end());
1112     } else if (!s.empty()) {
1113       // Should be a user provided TF operation.
1114       allowlist.insert(string(s));
1115     }
1116   }
1117 
1118   if (VLOG_IS_ON(2) && !allowlist.empty()) {
1119     std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
1120     absl::c_sort(vallowlist);
1121     VLOG(2) << "XLA clustering will only consider the following TF operations: "
1122             << absl::StrJoin(vallowlist, " ");
1123   }
1124   return allowlist;
1125 }
1126 
FindCompilationCandidates()1127 Status MarkForCompilationPassImpl::FindCompilationCandidates() {
1128   OptimizerOptions opts;
1129   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1130       new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
1131                                         TF_GRAPH_DEF_VERSION, flib_def_, opts));
1132   FunctionLibraryRuntime* lib_runtime =
1133       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
1134   std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);
1135   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
1136       *graph_, /*compile_time_const_arg_indices=*/nullptr,
1137       &compile_time_const_nodes, lib_runtime));
1138   // Iterate over nodes in sorted order so that compiler fuel is deterministic.
1139   // We can't simply pass op_nodes().begin() and op_nodes().end() to the
1140   // std::vector constructor because they're not proper iterators, with
1141   // iterator_traits defined and so on.
1142   std::vector<Node*> sorted_nodes;
1143   for (Node* node : graph_->op_nodes()) {
1144     sorted_nodes.push_back(node);
1145   }
1146   std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
1147 
1148   if (*debug_options_.fuel >= std::numeric_limits<int64>::max() / 2) {
1149     // The assumption is that if fuel started out as INT64_MAX, it will forever
1150     // stay greater than INT64_MAX / 2.
1151     VLOG(2) << "Starting fuel: infinity";
1152   } else {
1153     VLOG(2) << "Starting fuel: " << *debug_options_.fuel;
1154   }
1155 
1156   VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
1157 
1158   auto allowlist = GetOrCreateAllowlist();
1159 
1160   std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1161   absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1162   // Check that user's provided TF operation really exists.
1163   for (const auto& s : allowlist) {
1164     if (!all_ops.contains(string(s))) {
1165       return errors::InvalidArgument(
1166           "The operation '", s,
1167           "' passed to --tf_xla_ops_to_cluster is not supported by XLA.");
1168     }
1169   }
1170 
1171   for (Node* node : sorted_nodes) {
1172     if (*debug_options_.fuel <= 0) {
1173       VLOG(1)
1174           << "Hit fuel limit; not marking any remaining ops as clusterable.";
1175       break;
1176     }
1177 
1178     TF_ASSIGN_OR_RETURN(
1179         const DeviceType& device_type,
1180         device_info_cache_.GetDeviceTypeFor(node->assigned_device_name()));
1181     VLOG(4) << "Device type for " << node->name() << ": "
1182             << device_type.type_string();
1183 
1184     if (CompilationDisallowedByXlaCompileAttr(node)) {
1185       VLOG(2) << "Not clustering " << node->name()
1186               << ": disallowed by _XlaCompile attribute";
1187       continue;
1188     }
1189 
1190     const XlaOpRegistry::DeviceRegistration* registration;
1191     if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
1192                                              &registration)) {
1193       VLOG(2) << "Rejecting " << node->name()
1194               << ": could not find JIT device for " << device_type.type();
1195       continue;
1196     }
1197 
1198     RecursiveCompilabilityChecker::OperationFilter filter =
1199         CreateOperationFilter(*registration);
1200     filter.require_always_compilable = true;
1201     filter.allow_string_consts = false;
1202     filter.allow_collective_reduce_v2 = false;
1203 
1204     RecursiveCompilabilityChecker checker(
1205         filter, DeviceType{registration->compilation_device_name});
1206 
1207     if (!checker.IsCompilableNode(*node, lib_runtime)) {
1208       continue;
1209     }
1210 
1211     if (node->type_string() == "Const") {
1212       // Skip Const op with type DT_STRING, since XLA autoclustering doesn't
1213       // support it.
1214       const AttrValue* attr = node->attrs().Find("dtype");
1215       if (attr != nullptr && attr->type() == DT_STRING) {
1216         continue;
1217       }
1218     }
1219 
1220     if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
1221       VLOG(1) << "Rejecting TF operation " << node->def().op()
1222               << " as it is not listed in --tf_xla_ops_to_cluster.";
1223       continue;
1224     }
1225 
1226     if (compile_time_const_nodes[node->id()]) {
1227       const OpDef* op_def;
1228       TF_RETURN_IF_ERROR(
1229           graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def));
1230       if (op_def->is_stateful()) {
1231         // It is easiest to demonstrate the problem we're trying to solve with
1232         // an example.  Say we have this graph:
1233         //
1234         //   shape = RandomUniformInt();
1235         //   reshape = Reshape(input, shape)
1236         //
1237         // Both RandomUniformInt and Reshape are compilable by XLA so, absent
1238         // any other reason, we will try to put both shape and reshape in the
1239         // same cluster.  However, since XLA only supports statically shaped
1240         // values, it will expect to be able to constant fold `shape` to get a
1241         // static shape for `reshape`.  This is a problem because side-effecting
1242         // ops like RandomUniformInt() cannot be constant folded.  We fix this
1243         // by putting `shape` and `reshape` in different clusters, which results
1244         // in us recompiling `reshape`'s cluster for every new value of `shape`,
1245         // making `reshape` statically sized within each compilation.  We
1246         // simplify the solution even further by disallowing operations like
1247         // `shape` from being part of *any* non-trivial cluster.  They're either
1248         // not compiled by XLA altogether or, if assigned to an XLA_* device
1249         // with "must compile" semantics, compiled into a trivial single-op
1250         // cluster.  This approach leaves some room for improvement, and we can
1251         // consider implementing a more aggressive data-flow-analysis based
1252         // solution in the future if needed.
1253         //
1254         // One ugly problem we have to contend with: certain sets of ops *have*
1255         // to be in the same cluster because values flowing between them have
1256         // types that can't be live-in or live-out of a cluster.  These ops are:
1257         //
1258         //  - TensorArray ops operating on the same TensorArray instance.
1259         //  - Stack ops operating on the same Stack instance.
1260         //
1261         // To work around this we avoid isolating these specific ops.  Because
1262         // of this concession it is unsound to auto-cluster them because then
1263         // we'd create clusters we could not compile (because we can't constant
1264         // fold, say, a TensorArrayRead or a StackPopV2).  But we don't
1265         // auto-cluster these operations today so we're good for now.
1266         const XlaResourceOpInfo* op_info =
1267             GetResourceOpInfoForOp(node->type_string());
1268         bool is_tensor_array_or_stack_op =
1269             op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
1270         if (!is_tensor_array_or_stack_op) {
1271           VLOG(2) << "Isolating " << node->name()
1272                   << ": must-be-constant stateful op";
1273           continue;
1274         }
1275       }
1276     }
1277 
1278     // This is a heuristic to avoid creating dependency between while loop
1279     // condition and body computations.  Dependency between them can be created
1280     // if a special Identity node in the following pattern is clustered in.
1281     // That is, an Identity node in the loop cond computation is used to drive
1282     // const nodes consumed by the loop body.  If this Identity node goes into
1283     // the same cluster with nodes from the loop body, extra dependency is
1284     // created between the loop cond and body computations and it hinders the
1285     // progression of the loop cond computation at runtime with significant
1286     // overhead.  Specifically, we look for the below pattern and do not cluster
1287     // in this Identity to avoid the described issue.  Since Identity has low
1288     // execution cost in native TF, the fact that this heuristic gives up these
1289     // special Identity nodes as candidates should not harm any performance.  If
1290     // other considerations emerge in the future, we can revisit the heuristic
1291     // and only disallow these Identities to go into the cluster with nodes from
1292     // the loop body but still consider them candidates.
1293     //
1294     // LoopCond ->
1295     // Merge    -> Switch -> Identity -> i++ -> ... -> NextIteration
1296     //                               ..> Const -> LoopBody
1297     //                            (control edge)
1298     TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
1299                         IsIdentityDrivingConstsInLoop(node));
1300     if (is_identity_driving_consts_in_loop) {
1301       VLOG(2) << "Rejecting " << node->name()
1302               << ": including it can create dependencies between while loop "
1303                  "condition and body computations with runtime overhead.";
1304       continue;
1305     }
1306 
1307     compilation_candidates_.insert(node);
1308     --(*debug_options_.fuel);
1309   }
1310 
1311   VLOG(2) << "compilation_candidates_.size() = "
1312           << compilation_candidates_.size();
1313   return Status::OK();
1314 }
1315 
CompilationDisallowedByXlaCompileAttr(Node * node)1316 bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr(
1317     Node* node) {
1318   if (debug_options_.ignore_xla_compile_attr) {
1319     return false;
1320   }
1321 
1322   // If there is a _XlaCompile annotation, use its value.
1323   bool compile = false;
1324   Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
1325   if (status.ok()) {
1326     if (!compile) {
1327       VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1328               << kXlaCompileAttr << ") is false.";
1329     }
1330     return !compile;
1331   }
1332 
1333   status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile);
1334   if (status.ok()) {
1335     if (!compile) {
1336       VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
1337               << kXlaCompileAttr << ") on callee is false.";
1338     }
1339     return !compile;
1340   }
1341 
1342   return false;
1343 }
1344 
LogNotContractableAndReturnFalse(Cluster * from,Cluster * to,absl::string_view reason)1345 bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse(
1346     Cluster* from, Cluster* to, absl::string_view reason) {
1347   VLOG(3) << EdgeContractionFailureMsg(from, to, reason);
1348   return false;
1349 }
1350 
TryToContractEdge(Cluster * from,Cluster * to)1351 StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from,
1352                                                              Cluster* to) {
1353   DCHECK(from->deadness_predicate().has_value() ==
1354          to->deadness_predicate().has_value());
1355   if (from->deadness_predicate() != to->deadness_predicate()) {
1356     VLOG(3) << EdgeContractionFailureMsg(
1357         from, to,
1358         absl::StrCat(
1359             "the two nodes have mismatching deadness: ",
1360             deadness_analysis_->DebugString(*from->deadness_predicate()),
1361             " and ",
1362             deadness_analysis_->DebugString(*to->deadness_predicate())));
1363     return false;
1364   }
1365 
1366   TF_ASSIGN_OR_RETURN(bool devices_compatible,
1367                       AreDevicesCompatible(*from, *to));
1368   if (!devices_compatible) {
1369     return LogNotContractableAndReturnFalse(
1370         from, to, "the two nodes have incompatible devices");
1371   }
1372 
1373   if (from->xla_scope().has_value() && to->xla_scope().has_value() &&
1374       *from->xla_scope() != *to->xla_scope()) {
1375     return LogNotContractableAndReturnFalse(
1376         from, to, "the two nodes have mismatching XLA scopes");
1377   }
1378 
1379   // Don't exceed the maximum cluster size.
1380   if (from->cluster_size() + to->cluster_size() >
1381       debug_options_.max_cluster_size) {
1382     return LogNotContractableAndReturnFalse(
1383         from, to, "the new cluster will be larger than the max cluster size");
1384   }
1385 
1386   TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency,
1387                       ClusteringWillIntroduceInterDeviceDependency(*from, *to));
1388 
1389   if (will_introduce_cross_device_dependency) {
1390     return LogNotContractableAndReturnFalse(
1391         from, to, "the new cluster will introduce a cross device dependency");
1392   }
1393 
1394   // Check if contracting this edge will break the resource variable concurrency
1395   // semantics.  In theory this is quadratic in the number of nodes, but seems
1396   // to not be a problem in practice so far.
1397   if (!debug_options_.ignore_resource_variable_checks) {
1398     for (int resource_var_from : from->resource_var_operation_node_ids()) {
1399       for (int resource_var_to : to->resource_var_operation_node_ids()) {
1400         // If unsafe_resource_deps_ contains {A, B} then
1401         //
1402         //  a. A and B are resource operations.
1403         //  b. A and B cannot be placed in the same cluster.
1404         //  c. There is no path from B to A in the cycles graph (but there may
1405         //     be a path from A to B).
1406         //
1407         // So check the legality of the edge contraction by checking if any of
1408         // the n^2 pairs of resource variable operations are forbidden.
1409         if (unsafe_resource_deps_.contains(
1410                 {resource_var_from, resource_var_to})) {
1411           return LogNotContractableAndReturnFalse(
1412               from, to,
1413               "the new cluster would break resource variable semantics");
1414         }
1415       }
1416     }
1417   }
1418 
1419   return MergeClusters(from, to);
1420 }
1421 
Run()1422 Status MarkForCompilationPassImpl::Run() {
1423   // Make sure that kernels have been registered on the JIT device.
1424   XlaOpRegistry::RegisterCompilationKernels();
1425 
1426   // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does
1427   // some one-time work.
1428   XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1);
1429 
1430   TF_ASSIGN_OR_RETURN(bool initialized, Initialize());
1431   if (!initialized) {
1432     // Initialization exited early which means this instance of
1433     // MarkForCompilationPassImpl is not set up to run the subsequent phases.
1434     return Status::OK();
1435   }
1436 
1437   TF_RETURN_IF_ERROR(RunEdgeContractionLoop());
1438   TF_RETURN_IF_ERROR(CreateClusters());
1439   TF_RETURN_IF_ERROR(DumpDebugInfo());
1440 
1441   return Status::OK();
1442 }
1443 
DumpPostClusteringGraphs()1444 void MarkForCompilationPassImpl::DumpPostClusteringGraphs() {
1445   DumpGraphToFile("mark_for_compilation", *graph_, flib_def_);
1446 
1447   // We also dump out an annotated version of the TF graph where the nodes
1448   // names are prefixed with the cluster names.  This can help visualizing the
1449   // clustering decisions on TensorBoard.
1450   Graph new_graph(graph_->op_registry());
1451   CopyGraph(*graph_, &new_graph);
1452 
1453   for (Node* n : new_graph.nodes()) {
1454     if (absl::optional<absl::string_view> cluster_name =
1455             GetXlaClusterForNode(*n)) {
1456       n->set_name(absl::StrCat(*cluster_name, "/", n->name()));
1457     } else if (n->type_string() == "VarHandleOp") {
1458       n->set_name(absl::StrCat("varhandle/", n->name()));
1459     } else {
1460       // There is room for improvement here.  In particular, it may help to
1461       // split these unclustered nodes into classes where every node in a
1462       // specific class has edges to and from the same set of clusters.
1463       n->set_name(absl::StrCat("unclustered/", n->name()));
1464     }
1465   }
1466 
1467   DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_);
1468 }
1469 
RatioToString(int numerator,int denominator)1470 string RatioToString(int numerator, int denominator) {
1471   return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
1472                          (100.0 * numerator) / denominator);
1473 }
1474 
VLogClusteringSummary()1475 void MarkForCompilationPassImpl::VLogClusteringSummary() {
1476   if (!VLOG_IS_ON(2)) {
1477     return;
1478   }
1479 
1480   XlaAutoClusteringSummary auto_clustering_info =
1481       GetXlaAutoClusteringSummary(*graph_);
1482 
1483   VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes();
1484   VLOG(2) << " Built " << auto_clustering_info.clusters_size()
1485           << " clusters, size "
1486           << RatioToString(auto_clustering_info.clustered_node_count(),
1487                            graph_->num_nodes());
1488 
1489   for (const XlaAutoClusteringSummary::Cluster& cluster :
1490        auto_clustering_info.clusters()) {
1491     absl::string_view cluster_name = cluster.name();
1492     int size = cluster.size();
1493     VLOG(2) << "  " << cluster_name << " "
1494             << RatioToString(size, graph_->num_nodes());
1495     for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1496          cluster.op_histogram()) {
1497       VLOG(3) << "   " << op_count.op() << ": " << op_count.count()
1498               << " instances";
1499     }
1500   }
1501 
1502   if (!auto_clustering_info.unclustered_op_histogram().empty()) {
1503     VLOG(2) << " Unclustered nodes: "
1504             << RatioToString(auto_clustering_info.unclustered_node_count(),
1505                              graph_->num_nodes());
1506     for (const XlaAutoClusteringSummary::OpAndCount& op_count :
1507          auto_clustering_info.unclustered_op_histogram()) {
1508       VLOG(3) << "  " << op_count.op() << ": " << op_count.count()
1509               << " instances";
1510     }
1511   }
1512 
1513   struct EdgeInfo {
1514     absl::string_view node_name;
1515     absl::optional<absl::string_view> cluster_name;
1516 
1517     absl::string_view GetClusterName() const {
1518       return cluster_name ? *cluster_name : "[none]";
1519     }
1520 
1521     std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
1522         const {
1523       return {node_name, cluster_name};
1524     }
1525 
1526     bool operator<(const EdgeInfo& other) const {
1527       return AsPair() < other.AsPair();
1528     }
1529   };
1530 
1531   using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
1532 
1533   EdgeInfoMap incoming_edge_infos;
1534   EdgeInfoMap outgoing_edge_infos;
1535 
1536   std::set<absl::string_view> cluster_names_to_print;
1537 
1538   for (const Edge* e : graph_->edges()) {
1539     const Node* from = e->src();
1540     absl::optional<absl::string_view> from_cluster_name =
1541         GetXlaClusterForNode(*from);
1542 
1543     const Node* to = e->dst();
1544     absl::optional<absl::string_view> to_cluster_name =
1545         GetXlaClusterForNode(*to);
1546 
1547     if (to_cluster_name == from_cluster_name) {
1548       continue;
1549     }
1550 
1551     if (to_cluster_name) {
1552       incoming_edge_infos[*to_cluster_name]
1553                          [EdgeInfo{from->name(), from_cluster_name}]++;
1554       cluster_names_to_print.insert(*to_cluster_name);
1555     }
1556 
1557     if (from_cluster_name) {
1558       outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
1559       cluster_names_to_print.insert(*from_cluster_name);
1560     }
1561   }
1562 
1563   VLOG(4) << "*** Inter-Cluster edges:";
1564   if (cluster_names_to_print.empty()) {
1565     VLOG(4) << "   [none]";
1566   }
1567 
1568   auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
1569                                              const EdgeInfoMap& edge_info_map,
1570                                              absl::string_view desc) {
1571     auto it = edge_info_map.find(cluster_name);
1572     if (it != edge_info_map.end()) {
1573       VLOG(4) << "  " << it->second.size() << " " << desc << " edges";
1574       for (const auto& edge_info_count_pair : it->second) {
1575         VLOG(4) << "   " << edge_info_count_pair.first.GetClusterName() << " "
1576                 << edge_info_count_pair.first.node_name << " # "
1577                 << edge_info_count_pair.second;
1578       }
1579     } else {
1580       VLOG(4) << "  No " << desc << " edges.";
1581     }
1582   };
1583 
1584   for (absl::string_view cluster_name : cluster_names_to_print) {
1585     VLOG(4) << " ** Cluster " << cluster_name;
1586     print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
1587                                     "incoming");
1588     print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
1589                                     "outgoing");
1590   }
1591 }
1592 
AreDevicesCompatible(const Cluster & cluster_a,const Cluster & cluster_b)1593 StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
1594     const Cluster& cluster_a, const Cluster& cluster_b) {
1595   DeviceSet devices = cluster_a.devices();
1596   devices.UnionWith(cluster_b.devices());
1597 
1598   TF_ASSIGN_OR_RETURN(
1599       absl::optional<jit::DeviceId> maybe_chosen_device,
1600       MaybePickDeviceForXla(device_info_cache_, devices,
1601                             /*allow_mixing_unknown_and_cpu=*/false));
1602   if (!maybe_chosen_device.has_value()) {
1603     return false;
1604   }
1605 
1606   jit::DeviceId chosen_device = *maybe_chosen_device;
1607 
1608   // If we are able to pick a device `chosen_device` for the larger cluster, the
1609   // resource operations in `cluster_a` and `cluster_b` must be placed on the
1610   // same device as `chosen_device`.  This is because the _XlaCompile and
1611   // _XlaRun kernels are going to run on and therefore try to access the
1612   // resource variables from `chosen_device`, which will be an error if the
1613   // resource variables are placed on some other device.
1614   auto resource_op_device_ok =
1615       [&](absl::optional<DeviceId> resource_op_device) {
1616         return !resource_op_device.has_value() ||
1617                *resource_op_device == chosen_device;
1618       };
1619 
1620   return resource_op_device_ok(cluster_a.resource_op_device()) &&
1621          resource_op_device_ok(cluster_b.resource_op_device());
1622 }
1623 
1624 // Returns `true` iff we should compile `cluster`.
ShouldCompileClusterImpl(const Cluster & cluster)1625 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
1626     const Cluster& cluster) {
1627   TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
1628                       PickDeviceForXla(device_info_cache_, cluster.devices(),
1629                                        /*allow_mixing_unknown_and_cpu=*/false));
1630 
1631   const DeviceType& device_type =
1632       device_info_cache_.GetDeviceTypeFor(chosen_device);
1633   const XlaOpRegistry::DeviceRegistration* registration =
1634       device_info_cache_.GetCompilationDevice(chosen_device);
1635   TF_RET_CHECK(registration)
1636       << "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
1637       << "; device type = " << device_type.type() << "; devices ("
1638       << device_info_cache_.DebugString(cluster.devices());
1639 
1640   bool should_compile =
1641       cluster.is_xla_compile_attr_true() ||
1642       registration->autoclustering_policy ==
1643           XlaOpRegistry::AutoclusteringPolicy::kAlways ||
1644       (registration->autoclustering_policy ==
1645            XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
1646        global_jit_level_ != OptimizerOptions::OFF);
1647 
1648   if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
1649       device_type.type_string() == DEVICE_CPU) {
1650     static absl::once_flag once;
1651     absl::call_once(once, [] {
1652       LOG(WARNING)
1653           << "(One-time warning): Not using XLA:CPU for cluster because envvar "
1654              "TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set.  If you want "
1655              "XLA:CPU, either set that envvar, or use experimental_jit_scope "
1656              "to enable XLA:CPU.  To confirm that XLA is active, pass "
1657              "--vmodule=xla_compilation_cache=1 (as a proper command-line "
1658              "flag, not via TF_XLA_FLAGS) or set the envvar "
1659              "XLA_FLAGS=--xla_hlo_profile.";
1660       MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1661       if (flags->tf_xla_cpu_global_jit) {
1662         LOG(WARNING)
1663             << "(Although the tf_xla_cpu_global_jit flag is currently enabled, "
1664                "perhaps it wasn't enabled at process startup?)";
1665       }
1666     });
1667   }
1668 
1669   VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
1670           << " cluster with device "
1671           << device_info_cache_.GetNameFor(chosen_device);
1672 
1673   return should_compile;
1674 }
1675 
ShouldCompileCluster(const Cluster & cluster)1676 StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileCluster(
1677     const Cluster& cluster) {
1678   auto it = should_compile_cluster_cache_.find(&cluster);
1679   if (it != should_compile_cluster_cache_.end()) {
1680     return it->second;
1681   }
1682 
1683   TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster));
1684   should_compile_cluster_cache_.insert({&cluster, should_compile});
1685   return should_compile;
1686 }
1687 
MarkForCompilation(const GraphOptimizationPassOptions & options,const MarkForCompilationPassImpl::DebugOptions & debug_options)1688 Status MarkForCompilation(
1689     const GraphOptimizationPassOptions& options,
1690     const MarkForCompilationPassImpl::DebugOptions& debug_options) {
1691   Graph* graph = options.graph->get();
1692   FunctionLibraryDefinition* flib_def = options.flib_def;
1693 
1694   // Deadness analysis expects a graph with source and sink edges properly
1695   // connected but sometimes the incoming graph does not follow this invariant.
1696   // So fix up the source and sink edges before calling into deadness analysis.
1697   FixupSourceAndSinkEdges(graph);
1698 
1699   for (Node* n : graph->nodes()) {
1700     // See explanation on `kXlaAlreadyClustered`.
1701     if (n->attrs().Find(kXlaAlreadyClustered)) {
1702       return Status::OK();
1703     }
1704     // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops
1705     // in the graph, which indicates the graph is produced by TPU TF-XLA bridge
1706     // and doesn't require auto clustering.
1707     if (n->type_string() == "TPUExecute" ||
1708         n->type_string() == "TPUExecuteAndUpdateVariables") {
1709       return Status::OK();
1710     }
1711   }
1712 
1713   return MarkForCompilationPassImpl{debug_options, graph, flib_def,
1714                                     options.session_options != nullptr
1715                                         ? options.session_options->env
1716                                         : Env::Default(),
1717                                     GetGlobalJitLevelForGraph(options)}
1718       .Run();
1719 }
1720 
GetPointerToFuel(int64_t initial_value)1721 std::atomic<int64>* GetPointerToFuel(int64_t initial_value) {
1722   static std::atomic<int64>* fuel = [&]() {
1723     std::atomic<int64>* fuel = new std::atomic<int64>;
1724     *fuel = initial_value;
1725     return fuel;
1726   }();
1727 
1728   return fuel;
1729 }
1730 }  // anonymous namespace
1731 
Run(const GraphOptimizationPassOptions & options)1732 Status MarkForCompilationPass::Run(
1733     const GraphOptimizationPassOptions& options) {
1734   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1735 
1736   MarkForCompilationPassImpl::DebugOptions debug_options;
1737   debug_options.ignore_deadness_checks =
1738       flags->tf_xla_disable_deadness_safety_checks_for_debugging;
1739   debug_options.ignore_resource_variable_checks =
1740       flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1741   debug_options.ignore_xla_compile_attr = false;
1742   debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1743   debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1744   debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1745   debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1746 
1747   return MarkForCompilation(options, debug_options);
1748 }
1749 
RunForTest(const GraphOptimizationPassOptions & options,bool disable_deadness_analysis)1750 Status MarkForCompilationPass::RunForTest(
1751     const GraphOptimizationPassOptions& options,
1752     bool disable_deadness_analysis) {
1753   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
1754 
1755   MarkForCompilationPassImpl::DebugOptions debug_options;
1756   debug_options.ignore_deadness_checks = disable_deadness_analysis;
1757   debug_options.ignore_resource_variable_checks =
1758       flags->tf_xla_disable_resource_variable_safety_checks_for_debugging;
1759   debug_options.ignore_xla_compile_attr = true;
1760   debug_options.max_cluster_size = flags->tf_xla_max_cluster_size;
1761   debug_options.min_cluster_size = flags->tf_xla_min_cluster_size;
1762   debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel);
1763   debug_options.dump_graphs = flags->tf_xla_clustering_debug;
1764 
1765   return MarkForCompilation(options, debug_options);
1766 }
1767 
GetAllowlistTable()1768 absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
1769   // Table format: category name: {list of TF operations in that category}
1770   static absl::flat_hash_map<string, std::vector<string>>* result =
1771       new absl::flat_hash_map<string, std::vector<string>>{
1772           // Unary
1773           {"PW",
1774            {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
1775             "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
1776             "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
1777             "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
1778             "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
1779             "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
1780             "Lgamma", "Digamma",
1781             // Binary
1782             "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
1783             "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod",
1784             "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift",
1785             "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
1786             "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
1787             "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
1788             "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
1789             "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
1790             // Others
1791             "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
1792             "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
1793             "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
1794             "SelectV2", "Transpose", "ConjugateTranspose",
1795             "_UnaryOpsComposition", "CollectiveReduceV2",
1796             // The following 5 operations are converted to identity
1797             "PlaceholderWithDefault", "PreventGradient", "StopGradient",
1798             "Snapshot", "_EagerConst"}},
1799           // clang-format off
1800     {"RED",
1801      {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
1802           // clang-format on
1803           {"PWRED",
1804            {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1805             "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1806           {"REDUCEWINDOW",
1807            {"ArgMax", "ArgMin", "DiagPart", "Softmax",
1808             "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
1809           {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
1810           {"BN",
1811            {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1812             "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
1813             "FusedBatchNormGradV3"}},
1814           {"SORT", {"TopKV2"}},  // XLA version much faster then TF version.
1815           {"MISC",
1816            // clang-format off
1817      {"BroadcastTo", "ExpandDims", "Fill", "NoOp",
1818       "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
1819       "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
1820       "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
1821       "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad",
1822       "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split",
1823       "SplitV", "StridedSlice", "StridedSliceGrad",
1824       "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation",
1825       "Unpack", "DeviceIndex", "TensorStridedSliceUpdate",
1826      }}};
1827   // clang-format on
1828   return result;
1829 }
1830 
1831 namespace testing {
ResetClusterSequenceNumber()1832 void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
1833 
GetKnownXLAAllowlistOp()1834 absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
1835   absl::flat_hash_set<string> result{"AdjustContrastv2",
1836                                      "AdjustHue",
1837                                      "AdjustSaturation",
1838                                      "Asinh",
1839                                      "Assert",
1840                                      "AssignAddVariableOp",
1841                                      "AssignSubVariableOp",
1842                                      "AssignVariableOp",
1843                                      "AvgPool",
1844                                      "AvgPool3D",
1845                                      "AvgPool3DGrad",
1846                                      "AvgPoolGrad",
1847                                      "BatchMatMul",
1848                                      "BatchMatMulV2",
1849                                      "BatchMatMulV3",
1850                                      "BatchToSpace",
1851                                      "BatchToSpaceND",
1852                                      "BesselI0e",
1853                                      "BesselI1e",
1854                                      "Betainc",
1855                                      "BiasAddV1",
1856                                      "Bucketize",
1857                                      "Case",
1858                                      "CheckNumerics",
1859                                      "Cholesky",
1860                                      "ControlTrigger",
1861                                      "Conv2D",
1862                                      "Conv2DBackpropFilter",
1863                                      "Conv2DBackpropInput",
1864                                      "Conv3D",
1865                                      "Conv3DBackpropFilterV2",
1866                                      "Conv3DBackpropInputV2",
1867                                      "Cross",
1868                                      "Cumprod",
1869                                      "Cumsum",
1870                                      "DataFormatDimMap",
1871                                      "DataFormatVecPermute",
1872                                      "DepthToSpace",
1873                                      "DepthwiseConv2dNative",
1874                                      "DepthwiseConv2dNativeBackpropFilter",
1875                                      "DepthwiseConv2dNativeBackpropInput",
1876                                      "Dequantize",
1877                                      "Diag",
1878                                      "DynamicStitch",
1879                                      "DynamicPartition",
1880                                      "Einsum",
1881                                      "EmptyTensorList",
1882                                      "EnsureShape",
1883                                      "ExtractImagePatches",
1884                                      "Igamma",
1885                                      "IgammaGradA",
1886                                      "RandomGammaGrad",
1887                                      "Igammac",
1888                                      "FFT",
1889                                      "FFT2D",
1890                                      "FFT3D",
1891                                      "FakeParam",
1892                                      "FakeQuantWithMinMaxArgs",
1893                                      "FakeQuantWithMinMaxArgsGradient",
1894                                      "FakeQuantWithMinMaxVars",
1895                                      "FakeQuantWithMinMaxVarsGradient",
1896                                      "Gather",
1897                                      "GatherNd",
1898                                      "GatherV2",
1899                                      "HSVToRGB",
1900                                      "IFFT",
1901                                      "IFFT2D",
1902                                      "IFFT3D",
1903                                      "IRFFT",
1904                                      "IRFFT2D",
1905                                      "IRFFT3D",
1906                                      "If",
1907                                      "InTopKV2",
1908                                      "L2Loss",
1909                                      "LeakyRelu",
1910                                      "LinSpace",
1911                                      "ListDiff",
1912                                      "LogMatrixDeterminant",
1913                                      "LowerBound",
1914                                      "MatMul",
1915                                      "MatrixBandPart",
1916                                      "MatrixDiag",
1917                                      "MatrixDiagPart",
1918                                      "MatrixDiagPartV2",
1919                                      "MatrixDiagPartV3",
1920                                      "MatrixDiagV2",
1921                                      "MatrixDiagV3",
1922                                      "MatrixInverse",
1923                                      "MatrixSetDiag",
1924                                      "MatrixSetDiagV2",
1925                                      "MatrixSetDiagV3",
1926                                      "MatrixSolve",
1927                                      "MatrixTriangularSolve",
1928                                      "MaxPool",
1929                                      "MaxPool3D",
1930                                      "MaxPool3DGrad",
1931                                      "MaxPool3DGradGrad",
1932                                      "MaxPoolGrad",
1933                                      "MaxPoolGradGrad",
1934                                      "MaxPoolGradGradV2",
1935                                      "MaxPoolGradV2",
1936                                      "MaxPoolV2",
1937                                      "Multinomial",
1938                                      "NextAfter",
1939                                      "NonMaxSuppressionV4",
1940                                      "ParallelDynamicStitch",
1941                                      "ParameterizedTruncatedNormal",
1942                                      "PartitionedCall",
1943                                      "Polygamma",
1944                                      "PopulationCount",
1945                                      "Qr",
1946                                      "QuantizeAndDequantizeV2",
1947                                      "QuantizeAndDequantizeV3",
1948                                      "QuantizeAndDequantizeV4",
1949                                      "RFFT",
1950                                      "RFFT2D",
1951                                      "RFFT3D",
1952                                      "RGBToHSV",
1953                                      "RandomShuffle",
1954                                      "RandomStandardNormal",
1955                                      "RandomUniform",
1956                                      "RandomUniformInt",
1957                                      "ReadVariableOp",
1958                                      "ResizeBilinear",
1959                                      "ResizeBilinearGrad",
1960                                      "ResizeNearestNeighbor",
1961                                      "ResourceApplyAdaMax",
1962                                      "ResourceApplyAdadelta",
1963                                      "ResourceApplyAdagrad",
1964                                      "ResourceApplyAdagradDA",
1965                                      "ResourceApplyAdagradV2",
1966                                      "ResourceApplyAdam",
1967                                      "ResourceApplyAddSign",
1968                                      "ResourceApplyCenteredRMSProp",
1969                                      "ResourceApplyFtrl",
1970                                      "ResourceApplyFtrlV2",
1971                                      "ResourceApplyGradientDescent",
1972                                      "ResourceApplyKerasMomentum",
1973                                      "ResourceApplyMomentum",
1974                                      "ResourceApplyPowerSign",
1975                                      "ResourceApplyProximalAdagrad",
1976                                      "ResourceApplyProximalGradientDescent",
1977                                      "ResourceApplyRMSProp",
1978                                      "ResourceGather",
1979                                      "ResourceScatterAdd",
1980                                      "ResourceScatterDiv",
1981                                      "ResourceScatterMax",
1982                                      "ResourceScatterMin",
1983                                      "ResourceScatterMul",
1984                                      "ResourceScatterNdAdd",
1985                                      "ResourceScatterNdSub",
1986                                      "ResourceScatterNdUpdate",
1987                                      "ResourceScatterSub",
1988                                      "ResourceScatterUpdate",
1989                                      "RngReadAndSkip",
1990                                      "RngSkip",
1991                                      "Roll",
1992                                      "ScatterNd",
1993                                      "SelfAdjointEigV2",
1994                                      "SoftmaxCrossEntropyWithLogits",
1995                                      "SpaceToBatch",
1996                                      "SpaceToBatchND",
1997                                      "SpaceToDepth",
1998                                      "SparseMatMul",
1999                                      "SparseToDense",
2000                                      "StackCloseV2",
2001                                      "StackPopV2",
2002                                      "StackPushV2",
2003                                      "StackV2",
2004                                      "StatefulPartitionedCall",
2005                                      "StatefulStandardNormalV2",
2006                                      "StatefulTruncatedNormal",
2007                                      "StatefulUniform",
2008                                      "StatefulUniformFullInt",
2009                                      "StatefulUniformInt",
2010                                      "StatelessCase",
2011                                      "StatelessIf",
2012                                      "StatelessMultinomial",
2013                                      "StatelessRandomGetAlg",
2014                                      "StatelessRandomGetKeyCounter",
2015                                      "StatelessRandomGetKeyCounterAlg",
2016                                      "StatelessRandomNormal",
2017                                      "StatelessRandomNormalV2",
2018                                      "StatelessRandomUniform",
2019                                      "StatelessRandomUniformV2",
2020                                      "StatelessRandomUniformInt",
2021                                      "StatelessRandomUniformIntV2",
2022                                      "StatelessRandomUniformFullInt",
2023                                      "StatelessRandomUniformFullIntV2",
2024                                      "StatelessTruncatedNormal",
2025                                      "StatelessTruncatedNormalV2",
2026                                      "StatelessWhile",
2027                                      "Svd",
2028                                      "SymbolicGradient",
2029                                      "TensorArrayCloseV3",
2030                                      "TensorArrayConcatV3",
2031                                      "TensorArrayGatherV3",
2032                                      "TensorArrayGradV3",
2033                                      "TensorArrayReadV3",
2034                                      "TensorArrayScatterV3",
2035                                      "TensorArraySizeV3",
2036                                      "TensorArraySplitV3",
2037                                      "TensorArrayV3",
2038                                      "TensorArrayWriteV3",
2039                                      "TensorListConcatV2",
2040                                      "TensorListElementShape",
2041                                      "TensorListFromTensor",
2042                                      "TensorListGather",
2043                                      "TensorListGetItem",
2044                                      "TensorListLength",
2045                                      "TensorListPopBack",
2046                                      "TensorListPushBack",
2047                                      "TensorListReserve",
2048                                      "TensorListSetItem",
2049                                      "TensorListSplit",
2050                                      "TensorListStack",
2051                                      "TensorScatterAdd",
2052                                      "TensorScatterMax",
2053                                      "TensorScatterMin",
2054                                      "TensorScatterSub",
2055                                      "TensorScatterUpdate",
2056                                      "ToBool",
2057                                      "TridiagonalSolve",
2058                                      "TruncatedNormal",
2059                                      "Unique",
2060                                      "UpperBound",
2061                                      "UnsortedSegmentMax",
2062                                      "UnsortedSegmentMin",
2063                                      "UnsortedSegmentProd",
2064                                      "UnsortedSegmentSum",
2065                                      "VarIsInitializedOp",
2066                                      "VariableShape",
2067                                      "Where",
2068                                      "While",
2069                                      "XlaBroadcastHelper",
2070                                      "XlaConv",
2071                                      "XlaConvV2",
2072                                      "XlaDequantize",
2073                                      "XlaDot",
2074                                      "XlaDotV2",
2075                                      "XlaDynamicSlice",
2076                                      "XlaDynamicUpdateSlice",
2077                                      "XlaEinsum",
2078                                      "XlaGather",
2079                                      "XlaIf",
2080                                      "XlaKeyValueSort",
2081                                      "XlaPad",
2082                                      "XlaRecv",
2083                                      "XlaReduce",
2084                                      "XlaReduceWindow",
2085                                      "XlaRemoveDynamicDimensionSize",
2086                                      "XlaReplicaId",
2087                                      "XlaRngBitGenerator",
2088                                      "XlaScatter",
2089                                      "XlaSelectAndScatter",
2090                                      "XlaSelfAdjointEig",
2091                                      "XlaSend",
2092                                      "XlaSetBound",
2093                                      "XlaSetDynamicDimensionSize",
2094                                      "XlaSharding",
2095                                      "XlaSort",
2096                                      "XlaSpmdFullToShardShape",
2097                                      "XlaSpmdShardToFullShape",
2098                                      "XlaSvd",
2099                                      "XlaVariadicReduce",
2100                                      "XlaVariadicReduceV2",
2101                                      "XlaVariadicSort",
2102                                      "XlaWhile",
2103                                      "Zeta",
2104                                      "_Arg",
2105                                      "_ArrayToList",
2106                                      "_ListToArray",
2107                                      "_Retval"};
2108   return result;
2109 }
2110 
2111 }  // namespace testing
2112 }  // namespace tensorflow
2113