1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17
18 #include <unordered_map>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/graph/control_flow.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/public/session_options.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 #include "tensorflow/core/util/xla_config_registry.h"
36
37 namespace tensorflow {
38
39 const char* const kXlaClusterAttr = "_XlaCluster";
40 const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
41 const char* const kXlaCompileTimeConstantInputsAttr =
42 "_XlaCompileTimeConstantInputs";
43
44 namespace {
45 // Returns a string describing how an edge from src to dst would
46 // create a cycle.
DescribeCycle(const GraphCycles * cycles,const Graph & graph,int src,int dst)47 string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
48 int dst) {
49 int32 max_path_size = graph.num_node_ids() + 1;
50 std::vector<int32> path(max_path_size);
51 int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data());
52 if (path_size == 0) {
53 return "";
54 }
55
56 auto node_name = [&graph](int node_id) {
57 if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
58 return string("(null)");
59 }
60 auto* node = graph.FindNodeId(node_id);
61 if (node == nullptr) {
62 return string("(null)");
63 }
64 return node->name();
65 };
66
67 string description;
68 absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
69 node_name(dst), " would create a cycle.\n");
70 path.resize(path_size);
71 for (int32 node_id : path) {
72 string ascii_art;
73 if (node_id == dst) {
74 ascii_art = "+-> ";
75 } else if (node_id != src) {
76 ascii_art = "| ";
77 } else {
78 ascii_art = "+-- ";
79 }
80 absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
81 }
82 return description;
83 }
84
AlwaysForwardsRefInput(const Node & node)85 bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
86
87 } // namespace
88
HasForwardedRefInput(const Node & node)89 bool HasForwardedRefInput(const Node& node) {
90 if (AlwaysForwardsRefInput(node)) {
91 for (const Edge* incoming_edge : node.in_edges()) {
92 if (incoming_edge->IsControlEdge()) {
93 continue;
94 }
95
96 Node* incoming_node = incoming_edge->src();
97 if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
98 VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
99 << incoming_node->name() << " " << incoming_node->type_string();
100 return true;
101 }
102 }
103 }
104 return false;
105 }
106
CreateCycleDetectionGraph(const Graph * graph,GraphCycles * cycles)107 xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
108 GraphCycles* cycles) {
109 for (int i = 0; i < graph->num_node_ids(); ++i) {
110 // We rely on the node IDs in the cycle detection graph being consecutive
111 // integers starting from 0.
112 CHECK_EQ(i, cycles->NewNode());
113 }
114
115 // Compute the loop structure of the graph.
116 std::vector<ControlFlowInfo> control_flow_info;
117 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
118
119 // The clustering code must avoid adding cycles to the graph to prevent
120 // deadlock. However, the graph may contain loops, which would trigger the
121 // cycle detection code. To handle loops, we alter the structure of the cycle
122 // detection graph, disconnecting each loop from the enclosing graph.
123 // Specifically, we:
124 // * add a new "frame" node for each loop.
125 // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
126 // to/from the corresponding frame node. In essence, we collapse the loop
127 // into a single node for the purpose of cycle detection in the enclosing
128 // graph.
129 // * the body of the loop should now be disconnected from the rest of the
130 // graph; we make it acyclic by breaking loop backedges (edges outgoing from
131 // "NextIteration" nodes.
132
133 // Map from frame name strings to node IDs in the cycle detection graph.
134 std::unordered_map<string, int> frame_nodes;
135
136 // Get the cycle graph node ID for frame 'frame_name', or add one if none
137 // exists.
138 auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
139 int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
140 if (frame_id < 0) {
141 // The emplace succeeded; we have not allocated a frame node yet.
142 frame_id = cycles->NewNode();
143 }
144 return frame_id;
145 };
146
147 for (Edge const* edge : graph->edges()) {
148 if (edge->dst()->IsEnter() || edge->src()->IsExit()) {
149 const char* src_type = "pre-enter";
150 const char* dst_type = "post-exit";
151 int src = edge->src()->id();
152 int dst = edge->dst()->id();
153
154 if (edge->dst()->IsEnter()) {
155 // Lift edges to an "Enter" node to the corresponding frame node.
156 const string& frame_name =
157 control_flow_info[edge->dst()->id()].frame_name;
158 dst = GetOrAddFrameNodeId(frame_name);
159 dst_type = "frame";
160 }
161
162 if (edge->src()->IsExit()) {
163 // Lift edges from an "Exit" node to the corresponding frame node.
164 const string& frame_name =
165 control_flow_info[edge->src()->id()].frame_name;
166 src = GetOrAddFrameNodeId(frame_name);
167 src_type = "frame";
168 }
169
170 if (!cycles->InsertEdge(src, dst)) {
171 // TODO(b/127521408): We can probably handle this situation with a more
172 // sophisticated SCC based algorithm, but for now we bail out.
173 VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
174 << " edge: " << DescribeCycle(cycles, *graph, src, dst);
175 return false;
176 }
177 // Drop the original edge.
178 continue;
179 }
180 if (edge->src()->IsNextIteration()) {
181 // Break loop back-edges.
182 continue;
183 }
184 if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
185 // This should never happen. All cycles in the graph should contain
186 // a control flow operator.
187 return errors::Internal(
188 "Found cycle in graph without control flow operator during XLA "
189 "compilation: ",
190 DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
191 }
192 }
193
194 return true;
195 }
196
GetXlaClusterForNode(const Node & node)197 absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
198 const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
199 if (attr_value == nullptr) {
200 return absl::nullopt;
201 }
202 Status s = AttrValueHasType(*attr_value, "string");
203 if (!s.ok()) {
204 return absl::nullopt;
205 }
206 return attr_value->s();
207 }
208
HasResourceInputOrOutput(const Node & node)209 bool HasResourceInputOrOutput(const Node& node) {
210 return std::find(node.input_types().begin(), node.input_types().end(),
211 DT_RESOURCE) != node.input_types().end() ||
212 std::find(node.output_types().begin(), node.output_types().end(),
213 DT_RESOURCE) != node.output_types().end();
214 }
215
RemoveFromXlaCluster(NodeDef * node_def)216 void RemoveFromXlaCluster(NodeDef* node_def) {
217 node_def->mutable_attr()->erase(kXlaClusterAttr);
218 }
219
RemoveFromXlaCluster(Node * node)220 void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
221
222 namespace {
223 typedef xla_config_registry::XlaGlobalJitLevel XlaGlobalJitLevel;
224
GetXlaGlobalJitLevel(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)225 XlaGlobalJitLevel GetXlaGlobalJitLevel(
226 const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
227 XlaGlobalJitLevel result;
228
229 if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) {
230 // To set compilation to be on by default, change the following line.
231 result.single_gpu = result.general = OptimizerOptions::OFF;
232 } else {
233 result.single_gpu = result.general = jit_level_in_session_opts;
234 }
235
236 // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
237 // the setting in ConfigProto.
238 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
239 if (flags->xla_auto_jit_flag.optimization_level_single_gpu !=
240 OptimizerOptions::DEFAULT) {
241 result.single_gpu = static_cast<OptimizerOptions::GlobalJitLevel>(
242 flags->xla_auto_jit_flag.optimization_level_single_gpu);
243 }
244 if (flags->xla_auto_jit_flag.optimization_level_general !=
245 OptimizerOptions::DEFAULT) {
246 result.general = static_cast<OptimizerOptions::GlobalJitLevel>(
247 flags->xla_auto_jit_flag.optimization_level_general);
248 }
249
250 return result;
251 }
252
GetGpuNumber(const string & device_name)253 int GetGpuNumber(const string& device_name) {
254 DeviceNameUtils::ParsedName parsed_name;
255 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
256 return -1;
257 }
258
259 return parsed_name.type == DEVICE_GPU ? parsed_name.id : -1;
260 }
261 } // namespace
262
IsSingleGpuGraph(const Graph & g)263 bool IsSingleGpuGraph(const Graph& g) {
264 int gpus_seen = 0;
265 absl::flat_hash_set<string> devices_seen;
266
267 for (Node* n : g.op_nodes()) {
268 if (devices_seen.contains(n->assigned_device_name())) {
269 continue;
270 }
271
272 int gpu_number = GetGpuNumber(n->assigned_device_name());
273 if (gpu_number != -1) {
274 if (++gpus_seen > 1) {
275 return false;
276 }
277 }
278
279 devices_seen.insert(n->assigned_device_name());
280 }
281
282 return gpus_seen == 1;
283 }
284
GetGlobalJitLevelForGraph(const GraphOptimizationPassOptions & options)285 OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
286 const GraphOptimizationPassOptions& options) {
287 OptimizerOptions::GlobalJitLevel jit_level_in_session_opts =
288 options.session_options->config.graph_options()
289 .optimizer_options()
290 .global_jit_level();
291 XlaGlobalJitLevel xla_global_jit_level =
292 GetXlaGlobalJitLevel(jit_level_in_session_opts);
293 if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) {
294 VLOG(4) << "GetGlobalJitLevelForGraph returning "
295 << xla_global_jit_level.single_gpu;
296 return xla_global_jit_level.single_gpu;
297 }
298 OptimizerOptions::GlobalJitLevel result =
299 IsSingleGpuGraph(**options.graph) ? xla_global_jit_level.single_gpu
300 : xla_global_jit_level.general;
301 VLOG(4) << "GetGlobalJitLevelForGraph returning " << result;
302 return result;
303 }
304
MayCallFunction(const Node & n,const FunctionLibraryDefinition * flib_def)305 bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
306 if (flib_def->Contains(n.type_string())) {
307 return true;
308 }
309
310 // This is a conservative check: there may be nodes with a `func`
311 // attribute that do not make function calls.
312 return absl::c_any_of(n.def().attr(),
313 [](const std::pair<string, AttrValue>& name_attr_pair) {
314 return name_attr_pair.second.has_func();
315 });
316 }
IsShapeConsumerOp(const Node & node)317 bool IsShapeConsumerOp(const Node& node) {
318 return node.type_string() == "Shape" || node.type_string() == "Rank" ||
319 node.type_string() == "Size";
320 }
321
322 namespace {
323 struct ClusterInfo {
324 int size;
325
326 // Maps op names to the number of times they appear in the cluster.
327 absl::flat_hash_map<absl::string_view, int> op_histogram;
328 };
329
HistogramMapToRepeatedOpAndCount(protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount> * result,const absl::flat_hash_map<absl::string_view,int> & histogram)330 void HistogramMapToRepeatedOpAndCount(
331 protobuf::RepeatedPtrField<XlaAutoClusteringSummary::OpAndCount>* result,
332 const absl::flat_hash_map<absl::string_view, int>& histogram) {
333 for (const auto& pair : histogram) {
334 XlaAutoClusteringSummary::OpAndCount* new_entry = result->Add();
335 new_entry->set_op(std::string(pair.first));
336 new_entry->set_count(pair.second);
337 }
338
339 absl::c_sort(*result, [](const XlaAutoClusteringSummary::OpAndCount& a,
340 const XlaAutoClusteringSummary::OpAndCount& b) {
341 return a.op() < b.op();
342 });
343 }
344
ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster * result,absl::string_view name,const ClusterInfo & info)345 void ClusterInfoToProtobuf(XlaAutoClusteringSummary::Cluster* result,
346 absl::string_view name, const ClusterInfo& info) {
347 result->set_name(std::string(name));
348 result->set_size(info.size);
349 HistogramMapToRepeatedOpAndCount(result->mutable_op_histogram(),
350 info.op_histogram);
351 }
352 } // namespace
353
GetXlaAutoClusteringSummary(const Graph & graph)354 XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
355 absl::flat_hash_map<absl::string_view, ClusterInfo> cluster_name_to_info;
356 XlaAutoClusteringSummary result;
357
358 absl::flat_hash_map<absl::string_view, int> unclustered_op_histogram;
359
360 for (Node* n : graph.nodes()) {
361 absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
362 if (cluster_name) {
363 result.set_clustered_node_count(result.clustered_node_count() + 1);
364 ClusterInfo* info = &cluster_name_to_info[*cluster_name];
365 info->size++;
366 info->op_histogram[n->type_string()]++;
367 } else {
368 result.set_unclustered_node_count(result.unclustered_node_count() + 1);
369 unclustered_op_histogram[n->type_string()]++;
370 }
371 }
372
373 for (const auto& pair : cluster_name_to_info) {
374 XlaAutoClusteringSummary::Cluster* new_cluster = result.add_clusters();
375 ClusterInfoToProtobuf(new_cluster, pair.first, pair.second);
376 }
377
378 absl::c_sort(*result.mutable_clusters(),
379 [&](const XlaAutoClusteringSummary::Cluster& a,
380 const XlaAutoClusteringSummary::Cluster& b) {
381 return a.name() < b.name();
382 });
383
384 HistogramMapToRepeatedOpAndCount(result.mutable_unclustered_op_histogram(),
385 unclustered_op_histogram);
386
387 return result;
388 }
389
390 namespace {
391 using CallTargetListTy = absl::InlinedVector<NameAttrList, 2>;
392
GetCallTargetListFromNode(const Node & n,FunctionLibraryRuntime * lib_runtime)393 CallTargetListTy GetCallTargetListFromNode(
394 const Node& n, FunctionLibraryRuntime* lib_runtime) {
395 const FunctionLibraryDefinition& flib_def =
396 *lib_runtime->GetFunctionLibraryDefinition();
397 if (flib_def.Find(n.type_string())) {
398 NameAttrList callee;
399 callee.set_name(n.type_string());
400 *callee.mutable_attr() = n.def().attr();
401 return {callee};
402 }
403
404 CallTargetListTy result;
405 for (const auto& name_attr_pair : n.attrs()) {
406 const AttrValue& attr_value = name_attr_pair.second;
407 if (attr_value.value_case() == AttrValue::kFunc) {
408 result.push_back(attr_value.func());
409 } else if (attr_value.value_case() == AttrValue::kList) {
410 result.insert(result.end(), attr_value.list().func().begin(),
411 attr_value.list().func().end());
412 }
413 }
414
415 return result;
416 }
417
418 enum class Direction { kForward, kBackward };
419
420 Status GetNodesRelatedToRefVariablesInDirection(
421 const Graph& graph, FunctionLibraryRuntime* lib_runtime,
422 Direction direction, int depth, absl::flat_hash_set<Node*>* result);
423
DoesAnyCalleeHaveRefNodes(const CallTargetListTy & call_target_list,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth)424 xla::StatusOr<bool> DoesAnyCalleeHaveRefNodes(
425 const CallTargetListTy& call_target_list,
426 FunctionLibraryRuntime* lib_runtime, Direction direction, int depth) {
427 const int kMaxDepth = 10;
428
429 if (depth == kMaxDepth && !call_target_list.empty()) {
430 // Conservative answer to avoid recursing too much.
431 return true;
432 }
433
434 absl::flat_hash_set<Node*> callee_ref_nodes;
435 for (const NameAttrList& call_target : call_target_list) {
436 const OpRegistrationData* op_reg;
437 if (OpRegistry::Global()->LookUp(call_target.name(), &op_reg).ok()) {
438 const OpDef& op = op_reg->op_def;
439 if (absl::c_any_of(op.output_arg(), [](const OpDef::ArgDef arg) {
440 return arg.is_ref();
441 })) {
442 return true;
443 }
444 continue;
445 }
446
447 callee_ref_nodes.clear();
448 FunctionLibraryRuntime::Handle handle;
449 if (!lib_runtime
450 ->Instantiate(call_target.name(), AttrSlice(&call_target.attr()),
451 &handle)
452 .ok()) {
453 VLOG(2) << "Could not find " << call_target.name()
454 << " in the function library.";
455 // Since we don't know the semantic of `n` we don't know if this is an
456 // error. We return true to signal a conservative answer.
457 return true;
458 }
459
460 auto release_handle_on_return = gtl::MakeCleanup(
461 [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
462
463 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
464 TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
465 *fbody->graph, lib_runtime, direction, depth + 1, &callee_ref_nodes));
466
467 // We could possibly use something cheaper than
468 // GetNodesRelatedToRefVariablesInDirection since we only care about the
469 // size of `callee_ref_nodes` but for now we don't ceare.
470 if (!callee_ref_nodes.empty()) {
471 return true;
472 }
473 }
474
475 return false;
476 }
477
478 // Helper for GetNodesRelatedToRefVariables that traverses the graph in one
479 // direction.
GetNodesRelatedToRefVariablesInDirection(const Graph & graph,FunctionLibraryRuntime * lib_runtime,Direction direction,int depth,absl::flat_hash_set<Node * > * result)480 Status GetNodesRelatedToRefVariablesInDirection(
481 const Graph& graph, FunctionLibraryRuntime* lib_runtime,
482 Direction direction, int depth, absl::flat_hash_set<Node*>* result) {
483 std::vector<Node*> nodes_in_order;
484 if (direction == Direction::kForward) {
485 GetReversePostOrder(graph, &nodes_in_order,
486 /*stable_comparator=*/NodeComparatorName());
487 } else {
488 GetPostOrder(graph, &nodes_in_order,
489 /*stable_comparator=*/NodeComparatorName());
490 }
491
492 size_t old_result_size;
493 int iterations = 0;
494
495 const int kMaxIterations = 10 * 1000;
496
497 std::vector<bool> callee_has_ref_nodes_cache;
498 callee_has_ref_nodes_cache.resize(graph.num_node_ids());
499
500 auto does_callee_have_ref_nodes = [&](Node* n) -> xla::StatusOr<bool> {
501 if (iterations == 1) {
502 TF_ASSIGN_OR_RETURN(
503 bool callee_has_ref_nodes,
504 DoesAnyCalleeHaveRefNodes(GetCallTargetListFromNode(*n, lib_runtime),
505 lib_runtime, direction, depth));
506 callee_has_ref_nodes_cache[n->id()] = callee_has_ref_nodes;
507 return callee_has_ref_nodes;
508 } else {
509 return {callee_has_ref_nodes_cache[n->id()]};
510 }
511 };
512
513 do {
514 TF_RET_CHECK(iterations++ < kMaxIterations) << "infinite loop?";
515
516 old_result_size = result->size();
517 for (Node* n : nodes_in_order) {
518 if (n->IsSource() || n->IsSink()) {
519 continue;
520 }
521
522 bool inserted_n = false;
523 const EdgeSet& edges =
524 direction == Direction::kForward ? n->in_edges() : n->out_edges();
525 for (const Edge* e : edges) {
526 if (result->contains(direction == Direction::kForward ? e->src()
527 : e->dst())) {
528 result->insert(n);
529 inserted_n = true;
530 break;
531 }
532 }
533
534 if (inserted_n) {
535 continue;
536 }
537
538 if (direction == Direction::kForward &&
539 absl::c_any_of(n->output_types(), IsRefType)) {
540 result->insert(n);
541 continue;
542 }
543
544 TF_ASSIGN_OR_RETURN(bool callee_has_ref_nodes,
545 does_callee_have_ref_nodes(n));
546 if (callee_has_ref_nodes) {
547 result->insert(n);
548 continue;
549 }
550 }
551
552 // Loop until convergence.
553 } while (result->size() != old_result_size);
554
555 VLOG(2) << "# iterations = " << iterations;
556
557 return Status::OK();
558 }
559 } // namespace
560
GetNodesRelatedToRefVariables(const Graph & graph,FunctionLibraryRuntime * lib_runtime)561 xla::StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
562 const Graph& graph, FunctionLibraryRuntime* lib_runtime) {
563 absl::flat_hash_set<Node*> result;
564 TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
565 graph, lib_runtime, Direction::kForward, 0, &result));
566 TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
567 graph, lib_runtime, Direction::kBackward, 0, &result));
568
569 VLOG(1) << "GetNodesRelatedToRefVariables() found " << result.size()
570 << " nodes";
571 return result;
572 }
573
574 // Register a callback for querying XlaGlobalJitLevel.
575 REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
576
577 } // namespace tensorflow
578