• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17 
18 #include <memory>
19 #include <set>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_set.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 #include "tensorflow/core/util/dump_graph.h"
38 #include "tensorflow/core/util/port.h"
39 
40 namespace tensorflow {
41 
42 namespace {
43 
44 // We hoist the conversion from C-style string literal to StringPiece here,
45 // so that we can avoid the many repeated calls to strlen().
46 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
47 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
48 
49 // Returns a list of devices having type in supported_device_types.  The
50 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_device)51 std::vector<Device*> FilterSupportedDevices(
52     const std::vector<Device*>& devices,
53     const PrioritizedDeviceTypeVector& supported_device_types,
54     const Device* default_device) {
55   Device* filtered_default_device = nullptr;
56   std::vector<std::pair<Device*, int32>> prioritized_filtered_devices;
57   for (const auto& supported_device_type : supported_device_types) {
58     for (Device* device : devices) {
59       if (DeviceType(device->attributes().device_type()) ==
60           supported_device_type.first) {
61         if (device == default_device) {
62           filtered_default_device = device;
63         } else {
64           prioritized_filtered_devices.emplace_back(
65               device, supported_device_type.second);
66         }
67       }
68     }
69   }
70 
71   auto device_sort = [](const std::pair<Device*, int32>& a,
72                         const std::pair<Device*, int32>& b) {
73     if (a.second != b.second) {
74       return a.second > b.second;
75     }
76 
77     auto a_priority =
78         DeviceSet::DeviceTypeOrder(DeviceType(a.first->device_type()));
79     auto b_priority =
80         DeviceSet::DeviceTypeOrder(DeviceType(b.first->device_type()));
81     // First sort by prioritized device type (higher is preferred) and
82     // then by device name (lexicographically).
83     if (a_priority != b_priority) {
84       return a_priority > b_priority;
85     }
86     return StringPiece(a.first->name()) < StringPiece(b.first->name());
87   };
88   std::sort(prioritized_filtered_devices.begin(),
89             prioritized_filtered_devices.end(), device_sort);
90 
91   std::vector<Device*> filtered_devices;
92   if (filtered_default_device != nullptr) {
93     filtered_devices.emplace_back(filtered_default_device);
94   }
95   for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
96     filtered_devices.push_back(prioritized_filtered_device.first);
97   }
98   return filtered_devices;
99 }
100 
101 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)102 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
103   std::vector<string> v;
104   v.reserve(devices.size());
105   for (Device* d : devices) {
106     v.push_back(d->name());
107   }
108   return v;
109 }
110 
111 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)112 std::vector<string> DeviceTypeAndPriorityToString(
113     const PrioritizedDeviceTypeVector& devices) {
114   std::vector<string> v;
115   v.reserve(devices.size());
116   for (const std::pair<DeviceType, int32>& device_and_type : devices) {
117     v.push_back(DeviceTypeString(device_and_type.first));
118   }
119   return v;
120 }
121 
122 // While Placer can override requested device on ops processing
123 // resources, i.e. node that take (and potentially return) a resource,
124 // it must not override requested device on ops generating a resource,
125 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
126 // output nodes.
IsResourceGeneratorNode(const Node & node)127 bool IsResourceGeneratorNode(const Node& node) {
128   return node.num_inputs() == 0 && node.num_outputs() == 1 &&
129          (IsRefType(node.output_type(0)) || node.output_type(0) == DT_RESOURCE);
130 }
131 
IsExemptFromResourceInputColocation(const Node * node)132 bool IsExemptFromResourceInputColocation(const Node* node) {
133   // Note: Partitioned function calls, which place and partition their
134   // function bodies, are exempt from this check: they forward resource and
135   // ref inputs to operations that are appropriately placed, instead of
136   // dereferencing them.
137   const string& op_type = node->op_def().name();
138   return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
139 }
140 
HasPriorities(const PrioritizedDeviceTypeVector & device_types)141 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
142   for (const auto& prioritized_device_type : device_types) {
143     if (prioritized_device_type.second != 0) return true;
144   }
145   return false;
146 }
147 
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)148 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
149                        const PrioritizedDeviceTypeVector& b_types) {
150   if (a_types.size() != b_types.size()) {
151     return false;
152   }
153   for (int i = 0; i < a_types.size(); ++i) {
154     if (a_types[i].first != b_types[i].first) {
155       return false;
156     }
157   }
158   return true;
159 }
160 
161 }  // namespace
162 
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types)163 Status Member::SetParentAndSupportedDevices(
164     const Node& node, const std::vector<DeviceType>& types) {
165   int id = node.id();
166   if (id < 0) {
167     return errors::Internal("Placer should not be creating a Member for node: ",
168                             node.DebugString());
169   }
170   parent_ = id;
171   return SupportedDeviceTypesForNode(types, node.def(),
172                                      &supported_device_types_);
173 }
174 
SetAssignedDeviceName(const string & device_name)175 Status Member::SetAssignedDeviceName(const string& device_name) {
176   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
177     return errors::Internal(
178         "Setting assigned device name when there is a requested device set "
179         "is unsupported");
180   }
181   if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
182     return errors::Internal("Malformed assigned device '", device_name, "'");
183   }
184   // Set requested device to assigned_device to maintain the invariant that
185   // requested is a specialization of assigned.
186   requested_device_name_ = assigned_device_name_;
187   return Status::OK();
188 }
189 
SetRequestedDeviceName(const Node & node)190 Status Member::SetRequestedDeviceName(const Node& node) {
191   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
192                                       &requested_device_name_)) {
193     return errors::InvalidArgument("Malformed device specification '",
194                                    node.requested_device(),
195                                    "' in node: ", node.DebugString());
196   }
197   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
198     return errors::Internal(
199         "Setting requested device name when there is an assigned device set "
200         "is unsupported");
201   }
202   return Status::OK();
203 }
204 
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)205 Status Member::EnsureCompatibilityAcrossResourceEdge(
206     const Node& src, const Member& src_root,
207     const Node& dst, /*dst_root is this*/
208     bool log_device_placement) {
209   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
210                                               assigned_device_name_)) {
211     return errors::InvalidArgument(
212         "Cannot place the graph because a reference or resource edge "
213         "connects colocation groups with incompatible assigned devices: ",
214         DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
215         " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
216         ". The edge src node is ", src.name(), " , and the dst node is ",
217         dst.name());
218   }
219 
220   if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
221                                              requested_device_name_)) {
222     return Status::OK();
223   }
224 
225   // If we are here, assigned devices are compatible but requested ones are
226   // not. We will be overriding the requested device for destination node, but
227   // need to preserve the invariant that it will be a specialization of
228   // the assigned device.
229   if (log_device_placement) {
230     LOG(INFO) << "Ignoring device specification "
231               << DeviceNameUtils::ParsedNameToString(requested_device_name_)
232               << " for node '" << dst.name()
233               << "' because the input edge from '" << src.name()
234               << "' is a reference connection and already has a device "
235                  "field set to "
236               << DeviceNameUtils::ParsedNameToString(
237                      src_root.requested_device_name_);
238   }
239   requested_device_name_ = src_root.requested_device_name_;
240   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
241                                        assigned_device_name_);
242   return Status::OK();
243 }
244 
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)245 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
246                    Member** new_root, Member** old_root, bool dry_run) {
247   Member& x_root_member = (*tree)[x_root];
248   Member& y_root_member = (*tree)[y_root];
249 
250   // Merge the sets by setting the parent pointer of the smaller tree's root
251   // node to point to the root of the larger tree. Together with path
252   // compression in ColocationGraph::FindRoot, this ensures that we do not
253   // experience pathological performance on graphs such as chains.
254   int new_root_id, old_root_id;
255   if (x_root_member.rank_ < y_root_member.rank_) {
256     // The tree rooted at x_root is shallower, so connect it to
257     // y_root. The rank of y_root is unchanged because its new
258     // child has strictly less rank.
259     if (!dry_run) {
260       x_root_member.parent_ = y_root;
261     }
262     new_root_id = y_root;
263     old_root_id = x_root;
264   } else if (x_root_member.rank_ > y_root_member.rank_) {
265     // The tree rooted at y_root is shallower, so connect it to
266     // x_root. The rank of x_root is unchanged because its new
267     // child has strictly less rank.
268     if (!dry_run) {
269       y_root_member.parent_ = x_root;
270     }
271     new_root_id = x_root;
272     old_root_id = y_root;
273   } else {
274     if (!dry_run) {
275       // Both trees have the same rank, so break the tie by choosing
276       // x_root as the new root.
277       y_root_member.parent_ = x_root;
278       // Increment the rank of the tree rooted at x_root, because it
279       // is now strictly deeper than before.
280       ++x_root_member.rank_;
281     }
282     new_root_id = x_root;
283     old_root_id = y_root;
284   }
285 
286   *new_root = &(*tree)[new_root_id];
287   *old_root = &(*tree)[old_root_id];
288 }
289 
290 // tree is non-const because we can change some `parent` pointers in some
291 // members for more efficient future lookups. The vector itself is not
292 // changed.
FindRoot(std::vector<Member> * tree,int node_id)293 int Member::FindRoot(std::vector<Member>* tree, int node_id) {
294   Member& member = (*tree)[node_id];
295   if (member.parent_ == node_id) {
296     // member.parent is the root of this disjoint tree.  Do nothing.
297   } else {
298     member.parent_ = FindRoot(tree, member.parent_);
299   }
300   // Now it is guaranteed that member.parent is the root of this disjoint
301   // tree.
302   return member.parent_;
303 }
304 
MergeDeviceNames(const Member & other,bool allow_soft_placement)305 Status Member::MergeDeviceNames(const Member& other,
306                                 bool allow_soft_placement) {
307   // Assuming the "requested is a specialization of assigned" invariant holds
308   // for this and `other`, it will hold after the two merges below.
309   DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
310   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
311       &assigned_device_name_copy, other.assigned_device_name_));
312 
313   DeviceNameUtils::ParsedName requested_device_name_copy =
314       requested_device_name_;
315   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
316       &requested_device_name_copy, other.requested_device_name_,
317       allow_soft_placement));
318 
319   // We checked for all errors, now change the devices.
320   assigned_device_name_ = assigned_device_name_copy;
321   requested_device_name_ = requested_device_name_copy;
322   return Status::OK();
323 }
324 
325 // Updates this to contain the intersection of the device types in
326 // this and "other".
MergeSupportedDevices(const Member & other)327 bool Member::MergeSupportedDevices(const Member& other) {
328   // Generate intersection with priorities.
329   // Each vector contains the same device types but with different priorities.
330   // The priorities are taken from the corresponding source vector.
331   PrioritizedDeviceTypeVector target_intersection;
332   PrioritizedDeviceTypeVector other_intersection;
333   for (const auto& prioritized_device_type : supported_device_types_) {
334     bool found = false;
335     for (const auto& other_prioritized_device_type :
336          other.supported_device_types_) {
337       if (prioritized_device_type.first ==
338           other_prioritized_device_type.first) {
339         found = true;
340         other_intersection.push_back(other_prioritized_device_type);
341         break;
342       }
343     }
344     if (found) {
345       target_intersection.push_back(prioritized_device_type);
346     }
347   }
348 
349   // Sort the devices by priority order.
350   auto device_sort = [](const std::pair<DeviceType, int32>& a,
351                         const std::pair<DeviceType, int32>& b) {
352     // First look at set priorities.
353     if (a.second != b.second) {
354       return a.second > b.second;
355     }
356     // Then fallback to default priorities.
357     auto a_priority = DeviceSet::DeviceTypeOrder(a.first);
358     auto b_priority = DeviceSet::DeviceTypeOrder(b.first);
359     if (a_priority != b_priority) {
360       return a_priority > b_priority;
361     }
362     // Finally just look at the Device type strings.
363     return a.first.type_string() < b.first.type_string();
364   };
365 
366   std::sort(target_intersection.begin(), target_intersection.end(),
367             device_sort);
368   std::sort(other_intersection.begin(), other_intersection.end(), device_sort);
369 
370   PrioritizedDeviceTypeVector result;
371 
372   bool is_target_prioritized = HasPriorities(target_intersection);
373   bool is_other_prioritized = HasPriorities(other_intersection);
374   if (!is_target_prioritized && !is_other_prioritized) {
375     // If neither are prioritized then we just return the original i.e. target
376     // prioritization.
377     result = target_intersection;
378   } else if (is_target_prioritized && !is_other_prioritized) {
379     // If only one is prioritized, then we respect priorities of that in the
380     // intersection.
381     result = target_intersection;
382   } else if (!is_target_prioritized && is_other_prioritized) {
383     result = other_intersection;
384   } else {
385     // If both have priorities and agree then we go with that. If the
386     // prioritization order is different, then we just fallback to the default
387     // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
388     // merged priorities to 0, so that downstream merges work correctly as well.
389     if (ArePrioritiesSame(target_intersection, other_intersection)) {
390       result = target_intersection;
391     } else {
392       for (const auto& prioritized_device : target_intersection) {
393         result.push_back(std::make_pair(prioritized_device.first, 0));
394       }
395       std::sort(result.begin(), result.end(), device_sort);
396     }
397   }
398 
399   if (result.empty()) {
400     return false;
401   }
402   supported_device_types_ = result;
403   return true;
404 }
405 
AssignDevice(const Node & node,bool allow_soft_placement)406 Status Member::AssignDevice(const Node& node, bool allow_soft_placement) {
407   if (node.assigned_device_name_index() == assigned_device_name_index_) {
408     return Status::OK();
409   }
410 
411   DeviceNameUtils::ParsedName parsed;
412   DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
413   Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed,
414                                             allow_soft_placement);
415   if (!s.ok()) {
416     return errors::Internal(
417         "Constraining by assigned device should not cause an error. Original "
418         "root's assigned device name: ",
419         DeviceNameUtils::ParsedNameToString(assigned_device_name_),
420         " node's assigned device name \"", node.assigned_device_name(),
421         ". Error: ", s.error_message());
422   }
423   s = DeviceNameUtils::MergeDevNames(&requested_device_name_, parsed,
424                                      allow_soft_placement);
425   if (!s.ok()) {
426     return errors::Internal(
427         "Constraining by assigned device should not cause an error. Original "
428         "root's requested device name: \"",
429         DeviceNameUtils::ParsedNameToString(requested_device_name_),
430         "\", node's assigned device name \"", node.assigned_device_name(),
431         "\". Error: ", s.error_message());
432   }
433 
434   assigned_device_name_index_ = node.assigned_device_name_index();
435   // Clear cached possible_devices, if any.
436   possible_devices_.clear();
437   return Status::OK();
438 }
DebugString()439 string Member::DebugString() {
440   return absl::StrCat(
441       "Member(assigned_device_name_index_=", assigned_device_name_index_,
442       " requested_device_name_=",
443       DeviceNameUtils::ParsedNameToString(requested_device_name_),
444       " assigned_device_name_=",
445       DeviceNameUtils::ParsedNameToString(assigned_device_name_),
446       " supported_device_types_=[",
447       absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
448                     ", "),
449       "] possible_devices_=[",
450       absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
451 }
ColocationGraph(const Graph * graph,const DeviceSet * device_set,const Device * default_device,bool allow_soft_placement,bool log_device_placement)452 ColocationGraph::ColocationGraph(const Graph* graph,
453                                  const DeviceSet* device_set,
454                                  const Device* default_device,
455                                  bool allow_soft_placement,
456                                  bool log_device_placement)
457     : graph_(graph),
458       device_set_(device_set),
459       device_types_(device_set->PrioritizedDeviceTypeList()),
460       default_device_(default_device),
461       allow_soft_placement_(allow_soft_placement),
462       log_device_placement_(log_device_placement) {
463   members_.resize(graph->num_node_ids());
464 }
465 
466 // Adds each node of the Graph to this ColocationGraph as a singleton.
467 //
468 // NOTE: The implementation assumes that the ids of nodes passed to
469 // this method are dense and zero-based; the memory used will be linear in
470 // the largest node ID.
471 // NOTE: If this method returns an error, *this is left in an undefined
472 // state.
ColocateAllNodes()473 Status ColocationGraph::ColocateAllNodes() {
474   // This maps from a colocation group identifier to the 'root' of that
475   // colocation group.  Note that the keys in this map are StringPiece; the
476   // actual strings are stored under the NodeDef.  The lifetime of this map
477   // is limited to this ColocateAllNodes() method, and no part of the
478   // NodeDef trees are changed during the lifetime of this method, so using
479   // StringPiece as a key is safe.
480   //
481   // Also, as a further optimization, we remove the "loc:@" prefix from
482   // "class" attribute values, when they are used as keys in this table.
483   // This allows us to use StringPiece values that refer to substrings of
484   // 'string' values stored in NodeDef attribute lists, as well as StringPiece
485   // values that refer to 'string' values from NodeDef::name(), without
486   // performing any string allocations.
487   std::unordered_map<StringPiece, const Node*, StringPieceHasher>
488       colocation_group_root;
489 
490   for (const Node* node : graph_->op_nodes()) {
491     // When adding the node, identify whether it is part of a colocation
492     // group.
493 
494     // This code is effectively the equivalent of GetNodeAttr() for a string
495     // array, but it avoids all internal allocations (the allocation of the
496     // backing store of the std::vector<string> as well as the copies of the
497     // strings within it).  Instead, we combine the query of the colocation
498     // attribute with the calls to ColocateNodeToGroup.
499     bool found_spec = false;
500     const AttrValue* attr_value =
501         node->attrs().Find(kColocationAttrNameStringPiece);
502     if (attr_value != nullptr && attr_value->has_list()) {
503       for (const string& class_spec : attr_value->list().s()) {
504         StringPiece spec(class_spec);
505         if (str_util::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
506           found_spec = true;
507           TF_RETURN_IF_ERROR(
508               ColocateNodeToGroup(&colocation_group_root, node, spec));
509         }
510       }
511     }
512 
513     // TODO(iga): Even when the node has a spec, we need to colocate the
514     // node to its "name group" because other nodes can still use
515     // "loc:@<this_node_name>" in their colocation specs.
516     if (!found_spec) {
517       // If the node does not specify a colocation group, then use the
518       // name of this node as the colocation group.
519       TF_RETURN_IF_ERROR(
520           ColocateNodeToGroup(&colocation_group_root, node, node->name()));
521     }
522   }
523 
524   return Status::OK();
525 }
526 
ColocateResourceOrRefEdge(Node * src,Node * dst)527 Status ColocationGraph::ColocateResourceOrRefEdge(Node* src, Node* dst) {
528   // Colocate `src` and `dst` to maintain the invariant that nodes
529   // connected by reference edges are colocated.
530   int src_root_id = FindRoot(src->id());
531   int dst_root_id = FindRoot(dst->id());
532   auto& src_root = members_[src_root_id];
533   auto& dst_root = members_[dst_root_id];
534 
535   TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
536       *src, src_root, *dst, log_device_placement_));
537   Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
538   if (!status.ok()) {
539     return AttachDef(
540         errors::InvalidArgument("Nodes were connected by a "
541                                 "reference connection (requiring them to "
542                                 "be on the same device), but the two nodes "
543                                 "were assigned two different devices: ",
544                                 status.error_message()),
545         *dst);
546   }
547   return Status::OK();
548 }
549 
ColocateResourceAndRefEdges()550 Status ColocationGraph::ColocateResourceAndRefEdges() {
551   // Enumerate the constraint edges, and use them to update the disjoint
552   // node set.
553   // If `node` has an input edge with reference type, add an edge from the
554   // source of that edge to `node`.
555   for (const Edge* edge : graph_->edges()) {
556     if (edge->IsControlEdge()) {
557       continue;
558     }
559     Node* src = edge->src();
560     Node* dst = edge->dst();
561     DataType input_type = dst->input_type(edge->dst_input());
562     if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
563         !IsExemptFromResourceInputColocation(dst)) {
564       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
565     }
566   }
567   return Status::OK();
568 }
569 
Initialize()570 Status ColocationGraph::Initialize() {
571   TF_RETURN_IF_ERROR(InitializeMembers());
572   TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges());
573   TF_RETURN_IF_ERROR(ColocateAllNodes());
574   return Status::OK();
575 }
576 
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)577 Status ColocationGraph::ColocateNodeToGroup(
578     std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
579         colocation_group_root,
580     const Node* node, StringPiece colocation_group) {
581   const Node*& root_node = (*colocation_group_root)[colocation_group];
582   if (root_node == nullptr) {
583     // This is the first node of the colocation group, so
584     // designate this node as the 'root' of that colocation group.
585     root_node = node;
586   } else {
587     // Try to colocate the node with the root.  If there is an
588     // error, return it.
589     Status s = ColocateNodes(*node, *root_node);
590     if (!s.ok()) {
591       if (!allow_soft_placement_) {
592         return AttachDef(s, *node);
593       }
594       if (log_device_placement_) {
595         LOG(INFO) << "Ignoring request to colocate node '" << node->name()
596                   << "' with nodes in colocation group '" << colocation_group
597                   << "' because soft placement is on and an attempt at doing "
598                      "so resulted in the following error: "
599                   << AttachDef(s, *node).ToString();
600       }
601     }
602   }
603   return Status::OK();
604 }
605 
606 // Merge the (possibly disjoint) sets containing nodes "x" and
607 // "y". Returns OK if the all nodes in the union of these sets can
608 // be placed on the same device type.
609 //
610 // NOTE: If this method returns an error, *this is left in an undefined
611 // state.
ColocateNodes(const Node & x,const Node & y)612 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
613   int x_root = FindRoot(x.id());
614   int y_root = FindRoot(y.id());
615   return ColocateNodes(x, x_root, y, y_root);
616 }
617 
618 // This overload of ColocateNodes() allows a caller to provide the root node
619 // ids for the two nodes. For large graphs, this noticeably reduces the
620 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)621 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
622                                       int y_root) {
623   if (x_root == y_root) {
624     return Status::OK();
625   }
626 
627   Member* new_root_member;
628   Member* old_root_member;
629   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
630                 /*dry_run=*/true);
631 
632   // Merge the partial device specifications, and ensure that they are
633   // compatible. NULL options_ is treated as allowing soft placement.
634   // If there is an error, nothing is modified.
635   // TODO(mrry): Consider enriching the error message by pointing
636   // out which nodes have the explicit partial device
637   // specifications that caused this conflict.
638   Status s = new_root_member->MergeDeviceNames(*old_root_member,
639                                                allow_soft_placement_);
640   if (!s.ok()) {
641     return errors::InvalidArgument(
642         "Cannot colocate nodes ",
643         errors::FormatColocationNodeForError(x.name()), " and ",
644         errors::FormatColocationNodeForError(y.name()), ": ",
645         s.error_message());
646   }
647 
648   // Ensure that the common root has at least one supported device
649   // type, by computing the intersection of
650   // new_root_member.supported_device_types and
651   // old_root_member.supported_device_types.
652   if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
653     return errors::InvalidArgument(
654         "Cannot colocate nodes ",
655         errors::FormatColocationNodeForError(x.name()), " and ",
656         errors::FormatColocationNodeForError(y.name()),
657         " because no device type supports both of those nodes and the "
658         "other nodes colocated with them.",
659         DebugInfo(x_root), DebugInfo(y_root));
660   }
661 
662   // All error checks are done, merge the colocation graphs.
663   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
664                 /*dry_run=*/false);
665   return Status::OK();
666 }
667 
668 // Limits the possible devices of `node`'s colocation group to the device
669 // to which `node` is assigned. This makes sure that all nodes in this
670 // colocation group will be assigned to the same device. Without this
671 // explicit restriction, heuristics can choose a different possible device
672 // for other nodes in the group.
LimitToAssignedDevice(const Node & node)673 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
674   if (node.assigned_device_name_index() < 0) {
675     return errors::Internal(
676         "Expected an assigned node as argument to LimitToAssignedDevice but "
677         "got: ",
678         node.DebugString());
679   }
680   int root = FindRoot(node.id());
681   Member& root_member = members_[root];
682   return root_member.AssignDevice(node, allow_soft_placement_);
683 }
684 
685 // For the given node, subject to the constraints previously given
686 // to this ColocationGraph, set its assigned_device_name. Returns OK
687 // if a satisfying device can be found, otherwise an error.
688 //
689 // Note: This method returns a pointer to a field within members_.
690 // The caller must not use the returned pointer after there is any possibility
691 // that the members_[i].possible_devices field has been modified.
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)692 Status ColocationGraph::GetDevicesForNode(
693     Node* node, const std::vector<Device*>** possible_devices) {
694   *possible_devices = nullptr;
695   const int node_root = FindRoot(node->id());
696   if (!members_[node_root].possible_devices().empty()) {
697     *possible_devices = &members_[node_root].possible_devices();
698     return Status::OK();
699   }
700 
701   // We have not yet computed the possible devices for the
702   // colocated node set containing 'node', so we do so now using the
703   // constraints on the root node.
704 
705   // "devices" will contain the set of feasible placements for the
706   // colocated node set containing 'node'.
707   std::vector<Device*> devices;
708   if (DeviceNameUtils::HasSomeDetails(
709           members_[node_root].requested_device_name())) {
710     // The root node has a (possibly partial) device
711     // specification, so enumerate the physical devices that
712     // conform to it.
713     device_set_->FindMatchingDevices(
714         members_[node_root].requested_device_name(), &devices);
715 
716     if (!devices.empty()) {
717       // Filter devices into those that are compatible with the root
718       // node (and its children).
719       devices = FilterSupportedDevices(
720           devices, members_[node_root].supported_device_types(),
721           default_device_);
722     }
723 
724     // Perform soft placement if allow_soft_placement_ is set.
725     if (devices.empty() && allow_soft_placement_) {
726       // The soft_device_name is the same as the node's device name
727       // without specifying the device type or ID.
728       DeviceNameUtils::ParsedName soft_device_name =
729           members_[node_root].requested_device_name();
730       soft_device_name.type.clear();
731       soft_device_name.has_type = false;
732       soft_device_name.has_id = false;
733       device_set_->FindMatchingDevices(soft_device_name, &devices);
734       if (!devices.empty()) {
735         devices = FilterSupportedDevices(
736             devices, members_[node_root].supported_device_types(),
737             default_device_);
738       }
739     }
740 
741     if (devices.empty()) {
742       // Return an error when a physical device that matches an explicit
743       // device specification is not found. This ensures that we don't
744       // assign a node to GPU when the user wanted to force it on CPU.
745       string debug_info = DebugInfo(node_root);
746 
747       DeviceNameUtils::ParsedName specified_device_name;
748       if (DeviceNameUtils::ParseFullName(node->requested_device(),
749                                          &specified_device_name) &&
750           specified_device_name ==
751               members_[node_root].requested_device_name()) {
752         // The specified device and merged set device match, and
753         // will appear in the GraphDef (for debugging), so just
754         // print the specified device.
755         std::vector<Device*> devices_matching_nodedef;
756         device_set_->FindMatchingDevices(specified_device_name,
757                                          &devices_matching_nodedef);
758         if (devices_matching_nodedef.empty()) {
759           // Sometimes it is almost impossible to understand the problem
760           // without a list of available devices.
761           std::vector<string> device_names;
762           for (const Device* device : device_set_->devices()) {
763             device_names.push_back(device->name());
764           }
765           std::sort(device_names.begin(), device_names.end());
766 
767           string gpu_msg = "";
768           if (!IsGoogleCudaEnabled() &&
769               str_util::Lowercase(specified_device_name.type) == "gpu") {
770             gpu_msg =
771                 " The requested device appears to be a GPU, but CUDA is not "
772                 "enabled.";
773           }
774 
775           return errors::InvalidArgument(
776               errors::FormatNodeNameForError(node->name()),
777               "was explicitly assigned to ", node->requested_device(),
778               " but available devices are [ ",
779               str_util::Join(device_names, ", "), " ]. Make sure ",
780               "the device specification refers to a valid device.", gpu_msg);
781         } else if (specified_device_name.has_type) {
782           return errors::InvalidArgument(
783               "Could not satisfy explicit device specification '",
784               node->requested_device(), "' because no supported kernel for ",
785               specified_device_name.type, " devices is available.", debug_info,
786               "\nRegistered kernels:\n",
787               KernelsRegisteredForOp(node->type_string()));
788         } else {
789           return errors::InvalidArgument(
790               "Could not satisfy explicit device specification '",
791               node->requested_device(), debug_info);
792         }
793       } else {
794         // The specified device may be a valid device but the
795         // merged set device is different, so print both.
796         return errors::InvalidArgument(
797             "Could not satisfy explicit device specification '",
798             node->requested_device(), "' because the node ",
799             errors::FormatColocationNodeForError(node->name()),
800             " was colocated with a group of nodes that ",
801             "required incompatible device '",
802             DeviceNameUtils::ParsedNameToString(
803                 members_[node_root].requested_device_name()),
804             "'", debug_info);
805       }
806     }
807   } else {
808     // The device is completely unspecified, so enumerate the devices that
809     // support all of the nodes in the set.
810     if (device_set_->devices().empty()) {
811       return errors::Internal("No devices are registered");
812     }
813     devices = FilterSupportedDevices(
814         device_set_->devices(), members_[node_root].supported_device_types(),
815         default_device_);
816 
817     if (devices.empty()) {
818       return errors::InvalidArgument(
819           "Node had no OpKernel registered to support this operation: ",
820           "Operation was ", node->type_string(), " and inputs were ",
821           DataTypeVectorString(node->input_types()), DebugInfo(node_root));
822     }
823   }
824 
825   // Cache the result of the possible devices for this node group.
826   members_[node_root].set_possible_devices(std::move(devices));
827   *possible_devices = &members_[node_root].possible_devices();
828   return Status::OK();
829 }
830 
InitializeMembers()831 Status ColocationGraph::InitializeMembers() {
832   for (Node* node : graph_->op_nodes()) {
833     Status status = InitializeMember(*node, &members_[node->id()]);
834     if (!status.ok()) {
835       return AttachDef(status, *node);
836     }
837   }
838   return Status::OK();
839 }
840 
DebugString()841 string ColocationGraph::DebugString() {
842   std::unordered_set<int> roots;
843   std::vector<string> root_strings;
844   for (const Node* node : graph_->nodes()) {
845     if (!node->IsOp()) {
846       continue;
847     }
848     int node_root = FindRoot(node->id());
849     if (roots.count(node_root) == 0) {
850       root_strings.push_back(DebugInfo(node_root));
851       roots.insert(node_root);
852     }
853   }
854   return absl::StrJoin(root_strings, "\n");
855 }
856 
857 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root)858 string ColocationGraph::DebugInfo(const int node_root) {
859   string text(
860       "\nColocation Debug Info:\n"
861       "Colocation group had the following types and devices: ");
862 
863   // If this node is part of a colocation group, then we want to
864   // collect the mapping of ops to supported devices, so that
865   // the user can see why an unsatisfiable placement occurred.
866 
867   std::unordered_map<string, string> type_to_devices;
868   std::vector<const Node*> colocation_nodes;
869   int num_nodes_found = 0;
870 
871   for (const Node* node : graph_->nodes()) {
872     if (!node->IsOp()) {
873       continue;
874     }
875     int id = node->id();
876     if (FindRoot(id) != node_root) {
877       continue;
878     }
879     ++num_nodes_found;
880     colocation_nodes.push_back(node);
881     const string& op_type = node->type_string();
882     string devices_registered;
883     for (const auto& device_type : members_[id].supported_device_types()) {
884       strings::StrAppend(&devices_registered,
885                          DeviceTypeString(device_type.first), " ");
886     }
887 
888     type_to_devices[op_type] = std::move(devices_registered);
889   }
890 
891   for (const auto& td : type_to_devices) {
892     strings::StrAppend(&text, "\n", td.first, ": ", td.second);
893   }
894   strings::StrAppend(&text,
895                      "\n\nColocation members and user-requested devices:");
896   for (const Node* node : colocation_nodes) {
897     strings::StrAppend(&text, "\n  ", node->name(), " (", node->type_string(),
898                        ") ", node->requested_device());
899   }
900   strings::StrAppend(&text, "\n");
901 
902   if (num_nodes_found <= 0) {
903     text.clear();
904   }
905   return text;
906 }
907 
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,bool must_be_full_name,Member * member)908 Status ColocationGraph::InitializeMemberWithAssignedDevice(
909     const string& assigned_device_name, const string& node_type,
910     bool must_be_full_name, Member* member) {
911   // This node has already been assigned to a device, so we
912   // respect this placement, after sanity-checking it.
913   // NOTE: Since any assignment must have been performed by
914   // the TensorFlow runtime, we consider errors in this branch to
915   // be INTERNAL.
916   TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
917   if (!must_be_full_name) {
918     return Status::OK();
919   }
920   // Since assigned device must be a full specification, do extra checks.
921   const Device* assigned_device =
922       device_set_->FindDeviceByName(assigned_device_name);
923   if (assigned_device == nullptr) {
924     return errors::Internal("Assigned device '", assigned_device_name,
925                             "' does not match any device");
926   }
927 
928   for (const auto& d : member->supported_device_types()) {
929     if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
930       return Status::OK();
931     }
932   }
933 
934   return errors::Internal("Assigned device '", assigned_device_name,
935                           "' does not have registered OpKernel support "
936                           "for ",
937                           node_type);
938 }
939 
InitializeMember(const Node & node,Member * member)940 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
941   TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(node, device_types_));
942 
943   if (node.has_assigned_device_name()) {
944     TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
945         node.assigned_device_name(), node.type_string(), true, member));
946   } else {
947     // This node has not yet been assigned to a device, so we
948     // calculate any constraints due to the set of registered
949     // kernels and any (partial) user-provided device specification
950     // in the NodeDef.
951 
952     // If no kernels are registered for this op type, fail with an error.
953     if (member->supported_device_types().empty()) {
954       std::set<string> registered_device_types;
955       for (Device* d : device_set_->devices()) {
956         registered_device_types.insert(d->device_type());
957       }
958       std::vector<string> attr_key_vals;
959       for (const auto& it : node.attrs()) {
960         const string& name = it.first;
961         const AttrValue& attr_value = it.second;
962         attr_key_vals.push_back(
963             strings::StrCat(name, "=", SummarizeAttrValue(attr_value)));
964       }
965       return errors::InvalidArgument(
966           "No OpKernel was registered to support Op '", node.type_string(),
967           "' used by ", errors::FormatNodeNameForError(node.name()),
968           "with these attrs: [", str_util::Join(attr_key_vals, ", "),
969           "]\n"
970           "Registered devices: [",
971           str_util::Join(registered_device_types, ", "), "]\n",
972           "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
973     }
974 
975     // If the NodeDef contains a device, then we interpret it as a
976     // (partial) device specification.
977     if (!node.requested_device().empty()) {
978       if (IsResourceGeneratorNode(node)) {
979         // Treat requested device on resource generating nodes as assigned
980         // device so that we don't override it.
981         TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
982             node.requested_device(), node.type_string(), false, member));
983       } else {
984         // The user has specified a device in the NodeDef, try to find a
985         // valid device matching their specification in the set of
986         // devices.
987         // NOTE: The full name may specify a device that is not in
988         // n.supported_device_types(), but we check that in AssignDevice().
989         TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
990       }
991     }
992   }
993   return Status::OK();
994 }
995 
996 }  // namespace tensorflow
997