• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/graph/control_flow.h"
32 #include "tensorflow/core/public/session_options.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34 
35 namespace tensorflow {
36 
37 const char* const kXlaClusterAttr = "_XlaCluster";
38 const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
39 const char* const kXlaCompileTimeConstantInputsAttr =
40     "_XlaCompileTimeConstantInputs";
41 
42 namespace {
43 // Returns a string describing how an edge from src to dst would
44 // create a cycle.
DescribeCycle(const GraphCycles * cycles,const Graph & graph,int src,int dst)45 string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
46                      int dst) {
47   int32 max_path_size = graph.num_node_ids() + 1;
48   std::vector<int32> path(max_path_size);
49   int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data());
50   if (path_size == 0) {
51     return "";
52   }
53 
54   auto node_name = [&graph](int node_id) {
55     if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
56       return string("(null)");
57     }
58     auto* node = graph.FindNodeId(node_id);
59     if (node == nullptr) {
60       return string("(null)");
61     }
62     return node->name();
63   };
64 
65   string description;
66   absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
67                   node_name(dst), " would create a cycle.\n");
68   path.resize(path_size);
69   for (int32 node_id : path) {
70     string ascii_art;
71     if (node_id == dst) {
72       ascii_art = "+-> ";
73     } else if (node_id != src) {
74       ascii_art = "|   ";
75     } else {
76       ascii_art = "+-- ";
77     }
78     absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
79   }
80   return description;
81 }
82 
AlwaysForwardsRefInput(const Node & node)83 bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
84 
85 }  // namespace
86 
DeviceToDeviceType(const string & device,DeviceType * device_type)87 Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
88   DeviceNameUtils::ParsedName parsed;
89   if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
90     return errors::Internal("Malformed assigned device '", device, "'");
91   }
92   *device_type = DeviceType(parsed.type);
93   return Status::OK();
94 }
95 
HasForwardedRefInput(const Node & node)96 bool HasForwardedRefInput(const Node& node) {
97   if (AlwaysForwardsRefInput(node)) {
98     for (const Edge* incoming_edge : node.in_edges()) {
99       if (incoming_edge->IsControlEdge()) {
100         continue;
101       }
102 
103       Node* incoming_node = incoming_edge->src();
104       if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
105         VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
106                 << incoming_node->name() << " " << incoming_node->type_string();
107         return true;
108       }
109     }
110   }
111   return false;
112 }
113 
CreateCycleDetectionGraph(const Graph * graph,GraphCycles * cycles)114 xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
115                                               GraphCycles* cycles) {
116   for (int i = 0; i < graph->num_node_ids(); ++i) {
117     // We rely on the node IDs in the cycle detection graph being consecutive
118     // integers starting from 0.
119     CHECK_EQ(i, cycles->NewNode());
120   }
121 
122   // Compute the loop structure of the graph.
123   std::vector<ControlFlowInfo> control_flow_info;
124   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
125 
126   // The clustering code must avoid adding cycles to the graph to prevent
127   // deadlock. However, the graph may contain loops, which would trigger the
128   // cycle detection code. To handle loops, we alter the structure of the cycle
129   // detection graph, disconnecting each loop from the enclosing graph.
130   // Specifically, we:
131   // * add a new "frame" node for each loop.
132   // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
133   //   to/from the corresponding frame node. In essence, we collapse the loop
134   //   into a single node for the purpose of cycle detection in the enclosing
135   //   graph.
136   // * the body of the loop should now be disconnected from the rest of the
137   //   graph; we make it acyclic by breaking loop backedges (edges outgoing from
138   //   "NextIteration" nodes.
139 
140   // Map from frame name strings to node IDs in the cycle detection graph.
141   std::unordered_map<string, int> frame_nodes;
142 
143   // Get the cycle graph node ID for frame 'frame_name', or add one if none
144   // exists.
145   auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
146     int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
147     if (frame_id < 0) {
148       // The emplace succeeded; we have not allocated a frame node yet.
149       frame_id = cycles->NewNode();
150     }
151     return frame_id;
152   };
153 
154   for (Edge const* edge : graph->edges()) {
155     if (edge->dst()->IsEnter() || edge->src()->IsExit()) {
156       const char* src_type = "pre-enter";
157       const char* dst_type = "post-exit";
158       int src = edge->src()->id();
159       int dst = edge->dst()->id();
160 
161       if (edge->dst()->IsEnter()) {
162         // Lift edges to an "Enter" node to the corresponding frame node.
163         const string& frame_name =
164             control_flow_info[edge->dst()->id()].frame_name;
165         dst = GetOrAddFrameNodeId(frame_name);
166         dst_type = "frame";
167       }
168 
169       if (edge->src()->IsExit()) {
170         // Lift edges from an "Exit" node to the corresponding frame node.
171         const string& frame_name =
172             control_flow_info[edge->src()->id()].frame_name;
173         src = GetOrAddFrameNodeId(frame_name);
174         src_type = "frame";
175       }
176 
177       if (!cycles->InsertEdge(src, dst)) {
178         // TODO(b/127521408): We can probably handle this situation with a more
179         // sophisticated SCC based algorithm, but for now we bail out.
180         VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
181                 << " edge: " << DescribeCycle(cycles, *graph, src, dst);
182         return false;
183       }
184       // Drop the original edge.
185       continue;
186     }
187     if (edge->src()->IsNextIteration()) {
188       // Break loop back-edges.
189       continue;
190     }
191     if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
192       // This should never happen. All cycles in the graph should contain
193       // a control flow operator.
194       return errors::Internal(
195           "Found cycle in graph without control flow operator during XLA "
196           "compilation: ",
197           DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
198     }
199   }
200 
201   return true;
202 }
203 
GetXlaClusterForNode(const Node & node)204 absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
205   const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
206   if (attr_value == nullptr) {
207     return absl::nullopt;
208   }
209   Status s = AttrValueHasType(*attr_value, "string");
210   if (!s.ok()) {
211     return absl::nullopt;
212   }
213   return attr_value->s();
214 }
215 
HasResourceInputOrOutput(const Node & node)216 bool HasResourceInputOrOutput(const Node& node) {
217   return std::find(node.input_types().begin(), node.input_types().end(),
218                    DT_RESOURCE) != node.input_types().end() ||
219          std::find(node.output_types().begin(), node.output_types().end(),
220                    DT_RESOURCE) != node.output_types().end();
221 }
222 
RemoveFromXlaCluster(NodeDef * node_def)223 void RemoveFromXlaCluster(NodeDef* node_def) {
224   node_def->mutable_attr()->erase(kXlaClusterAttr);
225 }
226 
RemoveFromXlaCluster(Node * node)227 void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
228 
AdjustCycleDetectionGraphForResourceOps(const Graph * graph,const FunctionLibraryDefinition * flib_def,const std::function<Status (const Node &,bool *)> & resource_ops_to_ignore,GraphCycles * cycles)229 Status AdjustCycleDetectionGraphForResourceOps(
230     const Graph* graph, const FunctionLibraryDefinition* flib_def,
231     const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
232     GraphCycles* cycles) {
233   std::vector<std::pair<int, int>> unsafe_deps;
234   TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
235       *graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
236 
237   // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
238   // operations that interact with resource variables, must not be put in the
239   // same cluster.  We enforce this constraint by creating a phantom node, X,
240   // and adding edges P->X and X->Q.  MarkForCompilation then cannot cluster P
241   // and Q together since that would create a cycle with X.
242 
243   for (std::pair<int, int> unsafe_dep : unsafe_deps) {
244     int phantom_node_id = cycles->NewNode();
245     CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
246     CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
247   }
248   return Status::OK();
249 }
250 
PickDeviceForXlaImpl(absl::Span<const string> device_names,bool allow_mixing_unknown_and_cpu,bool * out_can_pick_device,string * out_device_picked)251 Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
252                             bool allow_mixing_unknown_and_cpu,
253                             bool* out_can_pick_device,
254                             string* out_device_picked) {
255   if (out_can_pick_device) {
256     *out_can_pick_device = true;
257   }
258 
259 #define FAILED_TO_PICK_DEVICE(failing_status) \
260   do {                                        \
261     if (out_can_pick_device) {                \
262       *out_can_pick_device = false;           \
263       return Status::OK();                    \
264     } else {                                  \
265       return failing_status;                  \
266     }                                         \
267   } while (false)
268 
269   TF_RET_CHECK(!device_names.empty()) << "No devices to choose from";
270   DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
271 
272   absl::flat_hash_set<absl::string_view> device_names_set;
273   for (absl::string_view device_name : device_names) {
274     if (!device_name.empty()) {
275       device_names_set.insert(device_name);
276     }
277   }
278 
279   absl::optional<absl::string_view> maybe_gpu_device;
280   absl::optional<absl::string_view> maybe_cpu_device;
281   absl::optional<absl::string_view> maybe_unknown_device;
282 
283   for (absl::string_view device_name : device_names_set) {
284     DeviceNameUtils::ParsedName parsed_name;
285     TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
286         << device_name;
287     if (parsed_name.type == "GPU") {
288       if (maybe_gpu_device) {
289         FAILED_TO_PICK_DEVICE(errors::Internal(
290             "Multiple GPU devices ", absl::StrJoin(device_names, ", ")));
291       }
292       maybe_gpu_device = device_name;
293     } else if (parsed_name.type == "CPU") {
294       if (maybe_cpu_device) {
295         FAILED_TO_PICK_DEVICE(errors::Internal(
296             "Multiple CPU devices ", absl::StrJoin(device_names, ", ")));
297       }
298       maybe_cpu_device = device_name;
299     } else {
300       if (maybe_unknown_device) {
301         FAILED_TO_PICK_DEVICE(errors::Internal(
302             "Multiple unknown devices ", absl::StrJoin(device_names, ", ")));
303       }
304       maybe_unknown_device = device_name;
305     }
306   }
307 
308   if (maybe_unknown_device && maybe_gpu_device) {
309     FAILED_TO_PICK_DEVICE(errors::Internal(
310         "Found both unknown and GPU devices: ", *maybe_unknown_device, ", ",
311         *maybe_gpu_device));
312   }
313 
314   if (!allow_mixing_unknown_and_cpu) {
315     if (maybe_unknown_device && maybe_cpu_device) {
316       FAILED_TO_PICK_DEVICE(errors::Internal(
317           "Found both unknown and CPU devices: ", *maybe_unknown_device, ", ",
318           *maybe_cpu_device));
319     }
320   }
321 
322   if (out_device_picked) {
323     if (maybe_gpu_device) {
324       *out_device_picked = string(*maybe_gpu_device);
325     } else if (maybe_unknown_device) {
326       *out_device_picked = string(*maybe_unknown_device);
327     } else {
328       *out_device_picked = string(*maybe_cpu_device);
329     }
330   }
331 
332   return Status::OK();
333 
334 #undef FAILED_TO_PICK_DEVICE
335 }
336 
PickDeviceForXla(absl::Span<const string> device_names,bool allow_mixing_unknown_and_cpu,string * out_device_picked)337 Status PickDeviceForXla(absl::Span<const string> device_names,
338                         bool allow_mixing_unknown_and_cpu,
339                         string* out_device_picked) {
340   return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
341                               /*out_can_pick_device=*/nullptr,
342                               out_device_picked);
343 }
344 
CanPickDeviceForXla(absl::Span<const string> device_names,bool allow_mixing_unknown_and_cpu,bool * out_can_pick_device)345 Status CanPickDeviceForXla(absl::Span<const string> device_names,
346                            bool allow_mixing_unknown_and_cpu,
347                            bool* out_can_pick_device) {
348   return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
349                               out_can_pick_device,
350                               /*out_device_picked=*/nullptr);
351 }
352 
GetGlobalJitLevel(const GraphOptimizationPassOptions & options)353 OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
354     const GraphOptimizationPassOptions& options) {
355   OptimizerOptions::GlobalJitLevel global_jit_level =
356       options.session_options->config.graph_options()
357           .optimizer_options()
358           .global_jit_level();
359   if (global_jit_level == OptimizerOptions::DEFAULT) {
360     // To set compilation to be on by default, change the following line.
361     global_jit_level = OptimizerOptions::OFF;
362   }
363   MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
364   if (flags->tf_xla_auto_jit != OptimizerOptions::DEFAULT) {
365     // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides
366     // the setting in ConfigProto.
367     global_jit_level =
368         static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
369   }
370   return global_jit_level;
371 }
372 
373 }  // namespace tensorflow
374