1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17
18 #include <unordered_map>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/graph/control_flow.h"
32 #include "tensorflow/core/public/session_options.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34
35 namespace tensorflow {
36
37 const char* const kXlaClusterAttr = "_XlaCluster";
38 const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
39 const char* const kXlaCompileTimeConstantInputsAttr =
40 "_XlaCompileTimeConstantInputs";
41
42 namespace {
43 // Returns a string describing how an edge from src to dst would
44 // create a cycle.
DescribeCycle(const GraphCycles * cycles,const Graph & graph,int src,int dst)45 string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
46 int dst) {
47 int32 max_path_size = graph.num_node_ids() + 1;
48 std::vector<int32> path(max_path_size);
49 int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data());
50 if (path_size == 0) {
51 return "";
52 }
53
54 auto node_name = [&graph](int node_id) {
55 if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
56 return string("(null)");
57 }
58 auto* node = graph.FindNodeId(node_id);
59 if (node == nullptr) {
60 return string("(null)");
61 }
62 return node->name();
63 };
64
65 string description;
66 absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
67 node_name(dst), " would create a cycle.\n");
68 path.resize(path_size);
69 for (int32 node_id : path) {
70 string ascii_art;
71 if (node_id == dst) {
72 ascii_art = "+-> ";
73 } else if (node_id != src) {
74 ascii_art = "| ";
75 } else {
76 ascii_art = "+-- ";
77 }
78 absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
79 }
80 return description;
81 }
82
AlwaysForwardsRefInput(const Node & node)83 bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
84
85 } // namespace
86
DeviceToDeviceType(const string & device,DeviceType * device_type)87 Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
88 DeviceNameUtils::ParsedName parsed;
89 if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
90 return errors::Internal("Malformed assigned device '", device, "'");
91 }
92 *device_type = DeviceType(parsed.type);
93 return Status::OK();
94 }
95
HasForwardedRefInput(const Node & node)96 bool HasForwardedRefInput(const Node& node) {
97 if (AlwaysForwardsRefInput(node)) {
98 for (const Edge* incoming_edge : node.in_edges()) {
99 if (incoming_edge->IsControlEdge()) {
100 continue;
101 }
102
103 Node* incoming_node = incoming_edge->src();
104 if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
105 VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
106 << incoming_node->name() << " " << incoming_node->type_string();
107 return true;
108 }
109 }
110 }
111 return false;
112 }
113
CreateCycleDetectionGraph(const Graph * graph,GraphCycles * cycles)114 xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
115 GraphCycles* cycles) {
116 for (int i = 0; i < graph->num_node_ids(); ++i) {
117 // We rely on the node IDs in the cycle detection graph being consecutive
118 // integers starting from 0.
119 CHECK_EQ(i, cycles->NewNode());
120 }
121
122 // Compute the loop structure of the graph.
123 std::vector<ControlFlowInfo> control_flow_info;
124 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
125
126 // The clustering code must avoid adding cycles to the graph to prevent
127 // deadlock. However, the graph may contain loops, which would trigger the
128 // cycle detection code. To handle loops, we alter the structure of the cycle
129 // detection graph, disconnecting each loop from the enclosing graph.
130 // Specifically, we:
131 // * add a new "frame" node for each loop.
132 // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
133 // to/from the corresponding frame node. In essence, we collapse the loop
134 // into a single node for the purpose of cycle detection in the enclosing
135 // graph.
136 // * the body of the loop should now be disconnected from the rest of the
137 // graph; we make it acyclic by breaking loop backedges (edges outgoing from
138 // "NextIteration" nodes.
139
140 // Map from frame name strings to node IDs in the cycle detection graph.
141 std::unordered_map<string, int> frame_nodes;
142
143 // Get the cycle graph node ID for frame 'frame_name', or add one if none
144 // exists.
145 auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
146 int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
147 if (frame_id < 0) {
148 // The emplace succeeded; we have not allocated a frame node yet.
149 frame_id = cycles->NewNode();
150 }
151 return frame_id;
152 };
153
154 for (Edge const* edge : graph->edges()) {
155 if (edge->dst()->IsEnter() || edge->src()->IsExit()) {
156 const char* src_type = "pre-enter";
157 const char* dst_type = "post-exit";
158 int src = edge->src()->id();
159 int dst = edge->dst()->id();
160
161 if (edge->dst()->IsEnter()) {
162 // Lift edges to an "Enter" node to the corresponding frame node.
163 const string& frame_name =
164 control_flow_info[edge->dst()->id()].frame_name;
165 dst = GetOrAddFrameNodeId(frame_name);
166 dst_type = "frame";
167 }
168
169 if (edge->src()->IsExit()) {
170 // Lift edges from an "Exit" node to the corresponding frame node.
171 const string& frame_name =
172 control_flow_info[edge->src()->id()].frame_name;
173 src = GetOrAddFrameNodeId(frame_name);
174 src_type = "frame";
175 }
176
177 if (!cycles->InsertEdge(src, dst)) {
178 // TODO(b/127521408): We can probably handle this situation with a more
179 // sophisticated SCC based algorithm, but for now we bail out.
180 VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
181 << " edge: " << DescribeCycle(cycles, *graph, src, dst);
182 return false;
183 }
184 // Drop the original edge.
185 continue;
186 }
187 if (edge->src()->IsNextIteration()) {
188 // Break loop back-edges.
189 continue;
190 }
191 if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
192 // This should never happen. All cycles in the graph should contain
193 // a control flow operator.
194 return errors::Internal(
195 "Found cycle in graph without control flow operator during XLA "
196 "compilation: ",
197 DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
198 }
199 }
200
201 return true;
202 }
203
GetXlaClusterForNode(const Node & node)204 absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
205 const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
206 if (attr_value == nullptr) {
207 return absl::nullopt;
208 }
209 Status s = AttrValueHasType(*attr_value, "string");
210 if (!s.ok()) {
211 return absl::nullopt;
212 }
213 return attr_value->s();
214 }
215
HasResourceInputOrOutput(const Node & node)216 bool HasResourceInputOrOutput(const Node& node) {
217 return std::find(node.input_types().begin(), node.input_types().end(),
218 DT_RESOURCE) != node.input_types().end() ||
219 std::find(node.output_types().begin(), node.output_types().end(),
220 DT_RESOURCE) != node.output_types().end();
221 }
222
RemoveFromXlaCluster(NodeDef * node_def)223 void RemoveFromXlaCluster(NodeDef* node_def) {
224 node_def->mutable_attr()->erase(kXlaClusterAttr);
225 }
226
RemoveFromXlaCluster(Node * node)227 void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
228
AdjustCycleDetectionGraphForResourceOps(const Graph * graph,const FunctionLibraryDefinition * flib_def,const std::function<Status (const Node &,bool *)> & resource_ops_to_ignore,GraphCycles * cycles)229 Status AdjustCycleDetectionGraphForResourceOps(
230 const Graph* graph, const FunctionLibraryDefinition* flib_def,
231 const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
232 GraphCycles* cycles) {
233 std::vector<std::pair<int, int>> unsafe_deps;
234 TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
235 *graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
236
237 // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
238 // operations that interact with resource variables, must not be put in the
239 // same cluster. We enforce this constraint by creating a phantom node, X,
240 // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P
241 // and Q together since that would create a cycle with X.
242
243 for (std::pair<int, int> unsafe_dep : unsafe_deps) {
244 int phantom_node_id = cycles->NewNode();
245 CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
246 CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
247 }
248 return Status::OK();
249 }
250
PickDeviceForXlaImpl(absl::Span<const string> device_names,bool allow_mixing_unknown_and_cpu,bool * out_can_pick_device,string * out_device_picked)251 Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
252 bool allow_mixing_unknown_and_cpu,
253 bool* out_can_pick_device,
254 string* out_device_picked) {
255 if (out_can_pick_device) {
256 *out_can_pick_device = true;
257 }
258
259 #define FAILED_TO_PICK_DEVICE(failing_status) \
260 do { \
261 if (out_can_pick_device) { \
262 *out_can_pick_device = false; \
263 return Status::OK(); \
264 } else { \
265 return failing_status; \
266 } \
267 } while (false)
268
269 TF_RET_CHECK(!device_names.empty()) << "No devices to choose from";
270 DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
271
272 absl::flat_hash_set<absl::string_view> device_names_set;
273 for (absl::string_view device_name : device_names) {
274 if (!device_name.empty()) {
275 device_names_set.insert(device_name);
276 }
277 }
278
279 absl::optional<absl::string_view> maybe_gpu_device;
280 absl::optional<absl::string_view> maybe_cpu_device;
281 absl::optional<absl::string_view> maybe_unknown_device;
282
283 for (absl::string_view device_name : device_names_set) {
284 DeviceNameUtils::ParsedName parsed_name;
285 TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
286 << device_name;
287 if (parsed_name.type == "GPU") {
288 if (maybe_gpu_device) {
289 FAILED_TO_PICK_DEVICE(errors::Internal(
290 "Multiple GPU devices ", absl::StrJoin(device_names, ", ")));
291 }
292 maybe_gpu_device = device_name;
293 } else if (parsed_name.type == "CPU") {
294 if (maybe_cpu_device) {
295 FAILED_TO_PICK_DEVICE(errors::Internal(
296 "Multiple CPU devices ", absl::StrJoin(device_names, ", ")));
297 }
298 maybe_cpu_device = device_name;
299 } else {
300 if (maybe_unknown_device) {
301 FAILED_TO_PICK_DEVICE(errors::Internal(
302 "Multiple unknown devices ", absl::StrJoin(device_names, ", ")));
303 }
304 maybe_unknown_device = device_name;
305 }
306 }
307
308 if (maybe_unknown_device && maybe_gpu_device) {
309 FAILED_TO_PICK_DEVICE(errors::Internal(
310 "Found both unknown and GPU devices: ", *maybe_unknown_device, ", ",
311 *maybe_gpu_device));
312 }
313
314 if (!allow_mixing_unknown_and_cpu) {
315 if (maybe_unknown_device && maybe_cpu_device) {
316 FAILED_TO_PICK_DEVICE(errors::Internal(
317 "Found both unknown and CPU devices: ", *maybe_unknown_device, ", ",
318 *maybe_cpu_device));
319 }
320 }
321
322 if (out_device_picked) {
323 if (maybe_gpu_device) {
324 *out_device_picked = string(*maybe_gpu_device);
325 } else if (maybe_unknown_device) {
326 *out_device_picked = string(*maybe_unknown_device);
327 } else {
328 *out_device_picked = string(*maybe_cpu_device);
329 }
330 }
331
332 return Status::OK();
333
334 #undef FAILED_TO_PICK_DEVICE
335 }
336
PickDeviceForXla(absl::Span<const string> device_names,bool allow_mixing_unknown_and_cpu,string * out_device_picked)337 Status PickDeviceForXla(absl::Span<const string> device_names,
338 bool allow_mixing_unknown_and_cpu,
339 string* out_device_picked) {
340 return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
341 /*out_can_pick_device=*/nullptr,
342 out_device_picked);
343 }
344
CanPickDeviceForXla(absl::Span<const string> device_names,bool allow_mixing_unknown_and_cpu,bool * out_can_pick_device)345 Status CanPickDeviceForXla(absl::Span<const string> device_names,
346 bool allow_mixing_unknown_and_cpu,
347 bool* out_can_pick_device) {
348 return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
349 out_can_pick_device,
350 /*out_device_picked=*/nullptr);
351 }
352
GetGlobalJitLevel(const GraphOptimizationPassOptions & options)353 OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
354 const GraphOptimizationPassOptions& options) {
355 OptimizerOptions::GlobalJitLevel global_jit_level =
356 options.session_options->config.graph_options()
357 .optimizer_options()
358 .global_jit_level();
359 if (global_jit_level == OptimizerOptions::DEFAULT) {
360 // To set compilation to be on by default, change the following line.
361 global_jit_level = OptimizerOptions::OFF;
362 }
363 MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
364 if (flags->tf_xla_auto_jit != OptimizerOptions::DEFAULT) {
365 // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
366 // the setting in ConfigProto.
367 global_jit_level =
368 static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
369 }
370 return global_jit_level;
371 }
372
373 } // namespace tensorflow
374