• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17 
18 #include <memory>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/common_runtime/composite_device.h"
30 #include "tensorflow/core/common_runtime/device.h"
31 #include "tensorflow/core/common_runtime/device_set.h"
32 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
33 #include "tensorflow/core/common_runtime/inspecting_placer.h"
34 #include "tensorflow/core/common_runtime/partitioning_utils.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/attr_value_util.h"
37 #include "tensorflow/core/framework/dataset.h"
38 #include "tensorflow/core/framework/device_attributes.pb.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/graph_node_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/util/device_name_utils.h"
51 #include "tensorflow/core/util/dump_graph.h"
52 #include "tensorflow/core/util/port.h"
53 
54 namespace tensorflow {
55 
56 namespace {
57 
58 // We hoist the conversion from C-style string literal to StringPiece here,
59 // so that we can avoid the many repeated calls to strlen().
60 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
61 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
62 
63 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)64 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
65   std::vector<string> v;
66   v.reserve(devices.size());
67   for (Device* d : devices) {
68     v.push_back(d->name());
69   }
70   return v;
71 }
72 
73 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)74 std::vector<string> DeviceTypeAndPriorityToString(
75     const PrioritizedDeviceTypeVector& devices) {
76   std::vector<string> v;
77   v.reserve(devices.size());
78   for (const std::pair<DeviceType, int32>& device_and_type : devices) {
79     v.push_back(DeviceTypeString(device_and_type.first));
80   }
81   return v;
82 }
83 
IsRefOrResource(DataType data_type)84 bool IsRefOrResource(DataType data_type) {
85   return IsRefType(data_type) || data_type == DT_RESOURCE;
86 }
87 
88 // While Placer can override requested device on ops processing
89 // resources, i.e. node that take (and potentially return) a resource,
90 // it must not override requested device on ops generating a resource,
91 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
92 // output nodes.
IsRefOrResourceGeneratorNode(const Node & node)93 bool IsRefOrResourceGeneratorNode(const Node& node) {
94   return node.num_inputs() == 0 && node.num_outputs() == 1 &&
95          IsRefOrResource(node.output_type(0));
96 }
97 
IsExemptFromResourceInputColocation(const Node * node)98 bool IsExemptFromResourceInputColocation(const Node* node) {
99   // Note: Partitioned function calls, which place and partition their
100   // function bodies, are exempt from this check: they forward resource and
101   // ref inputs to operations that are appropriately placed, instead of
102   // dereferencing them.
103   const string& op_type = node->op_def().name();
104   auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
105   return exempt_ops.find(op_type) != exempt_ops.end();
106 }
107 
HasPriorities(const PrioritizedDeviceTypeVector & device_types)108 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
109   for (const auto& prioritized_device_type : device_types) {
110     if (prioritized_device_type.second != 0) return true;
111   }
112   return false;
113 }
114 
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)115 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
116                        const PrioritizedDeviceTypeVector& b_types) {
117   if (a_types.size() != b_types.size()) {
118     return false;
119   }
120   for (int i = 0; i < a_types.size(); ++i) {
121     if (a_types[i].first != b_types[i].first) {
122       return false;
123     }
124   }
125   return true;
126 }
127 
IsXlaDevice(absl::string_view device_type)128 bool IsXlaDevice(absl::string_view device_type) {
129   if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
130       device_type == "XLA_TPU_JIT") {
131     // Symbolic XLA device.
132     return true;
133   }
134 
135   return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
136           device_type == "TPU");
137 }
138 
IsCompositeDevice(absl::string_view device_type)139 bool IsCompositeDevice(absl::string_view device_type) {
140   return device_type == kCompositeDeviceType;
141 }
142 
143 }  // namespace
144 
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types,const DeviceNameUtils::ParsedName * local_address_spec)145 Status Member::SetParentAndSupportedDevices(
146     const Node& node, const std::vector<DeviceType>& types,
147     const DeviceNameUtils::ParsedName* local_address_spec) {
148   int id = node.id();
149   if (id < 0) {
150     return errors::Internal("Placer should not be creating a Member for node: ",
151                             node.DebugString());
152   }
153   parent_ = id;
154   return SupportedDeviceTypesForNode(
155       types, node.def(), &supported_device_types_, local_address_spec);
156 }
157 
SetAssignedDeviceName(const string & device_name)158 Status Member::SetAssignedDeviceName(const string& device_name) {
159   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
160     return errors::Internal(
161         "Setting assigned device name when there is a requested device set "
162         "is unsupported");
163   }
164   if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
165     return errors::Internal("Malformed assigned device '", device_name, "'");
166   }
167   // Set requested device to assigned_device to maintain the invariant that
168   // requested is a specialization of assigned.
169   requested_device_name_ = assigned_device_name_;
170   return Status::OK();
171 }
172 
SetResourceDeviceName(const Node & node)173 Status Member::SetResourceDeviceName(const Node& node) {
174   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
175     return errors::Internal(
176         "Setting resource device name when there is a requested device set "
177         "is unsupported");
178   }
179 
180   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
181                                       &resource_device_name_)) {
182     return errors::InvalidArgument("Malformed device specification '",
183                                    node.requested_device(),
184                                    "' in node: ", node.DebugString());
185   }
186 
187   // Set requested device to resource device to maintain the invariant that
188   // requested is a specialization of resource.
189   requested_device_name_ = resource_device_name_;
190   return Status::OK();
191 }
192 
SetRequestedDeviceName(const Node & node)193 Status Member::SetRequestedDeviceName(const Node& node) {
194   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
195     return errors::Internal(
196         "Setting requested device name when there is an assigned device set "
197         "is unsupported");
198   }
199   if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
200     return errors::Internal(
201         "Setting requested device name when there is a resource device set "
202         "is unsupported");
203   }
204   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
205                                       &requested_device_name_)) {
206     return errors::InvalidArgument("Malformed device specification '",
207                                    node.requested_device(),
208                                    "' in node: ", node.DebugString());
209   }
210   return Status::OK();
211 }
212 
FillPossibleDevices(PossibleDevices * possible_device) const213 Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
214   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
215     return errors::Internal(
216         "Cannot fill PossibleDevices from a member that has non-empty assigned "
217         "device. Did we start assigning devices to functions called by deep "
218         "ops? ",
219         DebugString());
220   }
221   possible_device->requested_device_name = requested_device_name_;
222   possible_device->resource_device_name = resource_device_name_;
223   possible_device->device_types = supported_device_types_;
224   return Status::OK();
225 }
226 
IsEdgeFromCompositeDeviceToPhysicalDevice(const Member & src_root) const227 bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice(
228     const Member& src_root) const {
229   auto compatible_edge_from_composite_device_to_physical_device =
230       [](const DeviceNameUtils::ParsedName& src_device,
231          const DeviceNameUtils::ParsedName& dst_device) -> bool {
232     return src_device.has_type && dst_device.has_type &&
233            IsCompositeDevice(src_device.type) &&
234            !IsCompositeDevice(dst_device.type);
235   };
236   if (compatible_edge_from_composite_device_to_physical_device(
237           src_root.assigned_device_name_, assigned_device_name_) ||
238       compatible_edge_from_composite_device_to_physical_device(
239           src_root.resource_device_name_, resource_device_name_) ||
240       compatible_edge_from_composite_device_to_physical_device(
241           src_root.requested_device_name_, requested_device_name_)) {
242     return true;
243   }
244   return false;
245 }
246 
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)247 Status Member::EnsureCompatibilityAcrossResourceEdge(
248     const Node& src, const Member& src_root,
249     const Node& dst, /*dst_root is this*/
250     bool log_device_placement) {
251   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
252                                               assigned_device_name_)) {
253     return errors::InvalidArgument(
254         "Cannot place the graph because a reference or resource edge "
255         "connects colocation groups with incompatible assigned devices: ",
256         DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
257         " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
258         ". The edge src node is ", src.name(), " , and the dst node is ",
259         dst.name());
260   }
261 
262   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
263                                               resource_device_name_)) {
264     return errors::InvalidArgument(
265         "Cannot place the graph because a reference or resource edge "
266         "connects colocation groups with incompatible resource devices: ",
267         DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
268         " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
269         ". The edge src node is ", src.name(), " , and the dst node is ",
270         dst.name());
271   }
272 
273   if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
274                                              requested_device_name_)) {
275     return Status::OK();
276   }
277 
278   // If we are here, assigned and resource devices are compatible but requested
279   // ones are not. We will be overriding the requested device for destination
280   // node, but need to preserve the invariant that it will be a specialization
281   // of the assigned and resource devices.
282   if (log_device_placement) {
283     LOG(INFO) << "Ignoring device specification "
284               << DeviceNameUtils::ParsedNameToString(requested_device_name_)
285               << " for node '" << dst.name()
286               << "' because the input edge from '" << src.name()
287               << "' is a reference connection and already has a device "
288                  "field set to "
289               << DeviceNameUtils::ParsedNameToString(
290                      src_root.requested_device_name_);
291   }
292   requested_device_name_ = src_root.requested_device_name_;
293   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
294                                        assigned_device_name_);
295   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
296                                        resource_device_name_);
297   return Status::OK();
298 }
299 
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)300 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
301                    Member** new_root, Member** old_root, bool dry_run) {
302   Member& x_root_member = (*tree)[x_root];
303   Member& y_root_member = (*tree)[y_root];
304 
305   // Merge the sets by setting the parent pointer of the smaller tree's root
306   // node to point to the root of the larger tree. Together with path
307   // compression in ColocationGraph::FindRoot, this ensures that we do not
308   // experience pathological performance on graphs such as chains.
309   int new_root_id, old_root_id;
310   if (x_root_member.rank_ < y_root_member.rank_) {
311     // The tree rooted at x_root is shallower, so connect it to
312     // y_root. The rank of y_root is unchanged because its new
313     // child has strictly less rank.
314     if (!dry_run) {
315       x_root_member.parent_ = y_root;
316     }
317     new_root_id = y_root;
318     old_root_id = x_root;
319   } else if (x_root_member.rank_ > y_root_member.rank_) {
320     // The tree rooted at y_root is shallower, so connect it to
321     // x_root. The rank of x_root is unchanged because its new
322     // child has strictly less rank.
323     if (!dry_run) {
324       y_root_member.parent_ = x_root;
325     }
326     new_root_id = x_root;
327     old_root_id = y_root;
328   } else {
329     if (!dry_run) {
330       // Both trees have the same rank, so break the tie by choosing
331       // x_root as the new root.
332       y_root_member.parent_ = x_root;
333       // Increment the rank of the tree rooted at x_root, because it
334       // is now strictly deeper than before.
335       ++x_root_member.rank_;
336     }
337     new_root_id = x_root;
338     old_root_id = y_root;
339   }
340 
341   *new_root = &(*tree)[new_root_id];
342   *old_root = &(*tree)[old_root_id];
343 }
344 
345 // tree is non-const because we can change some `parent` pointers in some
346 // members for more efficient future lookups. The vector itself is not
347 // changed.
FindAndUpdateRoot(std::vector<Member> * tree,int node_id)348 int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
349   Member& member = (*tree)[node_id];
350   if (member.parent_ == node_id) {
351     // member.parent is the root of this disjoint tree.  Do nothing.
352   } else {
353     member.parent_ = FindAndUpdateRoot(tree, member.parent_);
354   }
355   // Now it is guaranteed that member.parent is the root of this disjoint
356   // tree.
357   return member.parent_;
358 }
359 
FindRoot(const std::vector<Member> & tree,int node_id)360 int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
361   const Member& member = tree[node_id];
362   if (member.parent_ == node_id) {
363     return member.parent_;
364   }
365   return FindRoot(tree, member.parent_);
366 }
367 
MergeDeviceNames(const Member & other,bool allow_soft_placement)368 Status Member::MergeDeviceNames(const Member& other,
369                                 bool allow_soft_placement) {
370   // Assuming the "requested is a specialization of assigned and resource
371   // devices" invariant holds for this and `other`, it will hold after the
372   // merges below.
373   DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
374   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
375       &assigned_device_name_copy, other.assigned_device_name_));
376 
377   DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
378   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
379       &resource_device_name_copy, other.resource_device_name_));
380 
381   DeviceNameUtils::ParsedName requested_device_name_copy =
382       requested_device_name_;
383   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
384       &requested_device_name_copy, other.requested_device_name_,
385       allow_soft_placement));
386 
387   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
388                                        assigned_device_name_copy);
389   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
390                                        resource_device_name_copy);
391 
392   // We checked for all errors, now change the devices.
393   assigned_device_name_ = assigned_device_name_copy;
394   resource_device_name_ = resource_device_name_copy;
395   requested_device_name_ = requested_device_name_copy;
396   return Status::OK();
397 }
398 
399 // Updates this to contain the intersection of the device types in
400 // this and "other".
MergeSupportedDevices(const Member & other)401 bool Member::MergeSupportedDevices(const Member& other) {
402   return MergeSupportedDevices(other.supported_device_types_);
403 }
404 
MergeSupportedDevices(const PrioritizedDeviceTypeVector & other_devices)405 bool Member::MergeSupportedDevices(
406     const PrioritizedDeviceTypeVector& other_devices) {
407   // Generate intersection with priorities.
408   // Each vector contains the same device types but with different priorities.
409   // The priorities are taken from the corresponding source vector.
410   PrioritizedDeviceTypeVector target_intersection;
411   PrioritizedDeviceTypeVector other_intersection;
412 
413   for (const auto& prioritized_device_type : supported_device_types_) {
414     bool found = false;
415     for (const auto& other_prioritized_device_type : other_devices) {
416       if (prioritized_device_type.first ==
417           other_prioritized_device_type.first) {
418         found = true;
419         other_intersection.push_back(other_prioritized_device_type);
420         break;
421       }
422     }
423     if (found) {
424       target_intersection.push_back(prioritized_device_type);
425     }
426   }
427 
428   DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
429   DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
430 
431   PrioritizedDeviceTypeVector result;
432 
433   bool is_target_prioritized = HasPriorities(target_intersection);
434   bool is_other_prioritized = HasPriorities(other_intersection);
435   if (!is_target_prioritized && !is_other_prioritized) {
436     // If neither are prioritized then we just return the original i.e. target
437     // prioritization.
438     result = target_intersection;
439   } else if (is_target_prioritized && !is_other_prioritized) {
440     // If only one is prioritized, then we respect priorities of that in the
441     // intersection.
442     result = target_intersection;
443   } else if (!is_target_prioritized && is_other_prioritized) {
444     result = other_intersection;
445   } else {
446     // If both have priorities and agree then we go with that. If the
447     // prioritization order is different, then we just fallback to the default
448     // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
449     // merged priorities to 0, so that downstream merges work correctly as well.
450     if (ArePrioritiesSame(target_intersection, other_intersection)) {
451       result = target_intersection;
452     } else {
453       for (const auto& prioritized_device : target_intersection) {
454         result.push_back(std::make_pair(prioritized_device.first, 0));
455       }
456       DeviceSet::SortPrioritizedDeviceTypeVector(&result);
457     }
458   }
459 
460   if (result.empty()) {
461     return false;
462   }
463   supported_device_types_ = result;
464   return true;
465 }
466 
AssignDevice(const Node & node)467 Status Member::AssignDevice(const Node& node) {
468   if (node.assigned_device_name_index() == assigned_device_name_index_) {
469     return Status::OK();
470   }
471 
472   DeviceNameUtils::ParsedName parsed;
473   DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
474   Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
475   if (!s.ok()) {
476     return errors::Internal(
477         "Constraining by assigned device should not cause an error. Original "
478         "root's assigned device name: ",
479         DeviceNameUtils::ParsedNameToString(assigned_device_name_),
480         " node's assigned device name \"", node.assigned_device_name(),
481         ". Error: ", s.error_message());
482   }
483   s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
484   if (!s.ok()) {
485     return errors::Internal(
486         "Constraining by assigned device should not cause an error. Original "
487         "root's resource device name: ",
488         DeviceNameUtils::ParsedNameToString(resource_device_name_),
489         " node's assigned device name \"", node.assigned_device_name(),
490         ". Error: ", s.error_message());
491   }
492   s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
493   if (!s.ok()) {
494     return errors::Internal(
495         "Constraining by assigned device should not cause an error. Original "
496         "root's requested device name: \"",
497         DeviceNameUtils::ParsedNameToString(requested_device_name_),
498         "\", node's assigned device name \"", node.assigned_device_name(),
499         "\". Error: ", s.error_message());
500   }
501 
502   assigned_device_name_index_ = node.assigned_device_name_index();
503   // Clear cached possible_devices, if any.
504   possible_devices_.clear();
505   return Status::OK();
506 }
507 
MaybeExcludeXlaDevices()508 void Member::MaybeExcludeXlaDevices() {
509   for (const auto& parsed_name :
510        {requested_device_name_, assigned_device_name_, resource_device_name_}) {
511     // Don't exculde XLA devices from supported devices if member is explicitly
512     // assigned to a CompositeDevice.
513     if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) ||
514                                  IsCompositeDevice(parsed_name.type))) {
515       return;
516     }
517   }
518 
519   PrioritizedDeviceTypeVector non_xla_types;
520   absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
521                   [&](const std::pair<DeviceType, int32>& entry) {
522                     return !IsXlaDevice(entry.first.type_string());
523                   });
524 
525   // TODO(b/141216278) Remove all XLA device types from the supported device
526   // types if the node has no requested/assigned/resource XLA device.
527   if (!non_xla_types.empty() &&
528       non_xla_types.size() < supported_device_types_.size()) {
529     supported_device_types_ = std::move(non_xla_types);
530   }
531 }
532 
LimitToPossibleDevices(const PossibleDevices & devices,bool allow_soft_placement)533 Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
534                                       bool allow_soft_placement) {
535   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
536       &requested_device_name_, devices.requested_device_name,
537       allow_soft_placement));
538   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
539       &resource_device_name_, devices.resource_device_name));
540   MergeSupportedDevices(devices.device_types);
541   return Status::OK();
542 }
543 
DebugString() const544 string Member::DebugString() const {
545   return absl::StrCat(
546       "Member(assigned_device_name_index_=", assigned_device_name_index_,
547       " requested_device_name_='",
548       DeviceNameUtils::ParsedNameToString(requested_device_name_),
549       "' assigned_device_name_='",
550       DeviceNameUtils::ParsedNameToString(assigned_device_name_),
551       "' resource_device_name_='",
552       DeviceNameUtils::ParsedNameToString(resource_device_name_),
553       "' supported_device_types_=[",
554       absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
555                     ", "),
556       "] possible_devices_=[",
557       absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
558 }
559 
GetSoftDeviceName() const560 DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
561   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
562   if (!assigned_device_name_.has_type) {
563     soft_device_name.type.clear();
564     soft_device_name.has_type = false;
565   }
566   if (!assigned_device_name_.has_id) {
567     soft_device_name.has_id = false;
568   }
569   return soft_device_name;
570 }
571 
GetPreferredSoftDeviceName() const572 DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
573   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
574   if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
575     soft_device_name.type.clear();
576     soft_device_name.has_type = false;
577   }
578   if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
579     soft_device_name.has_id = false;
580   }
581   return soft_device_name;
582 }
583 
584 // Returns ParsedName whose address space (i.e. job, replica, task) identifies
585 // the address space directly accessible by the local process. If the address
586 // space is fully specified and it is exactly the same as the address space
587 // of a device, then all kernels of that device should be registered in the
588 // local process.
LocalAddressSpec(const Device * client_device,const Device * default_local_device)589 static const DeviceNameUtils::ParsedName LocalAddressSpec(
590     const Device* client_device, const Device* default_local_device) {
591   if (client_device != nullptr) {
592     return DeviceNameUtils::AddressSpace(client_device->parsed_name());
593   }
594 
595   if (default_local_device != nullptr) {
596     return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
597   }
598 
599   // TODO(b/139617593) Return the name of the first local device in device_set_
600   // once we can trust the output of Device::IsLocal().
601   return DeviceNameUtils::ParsedName();
602 }
603 
ColocationGraph(const Graph * graph,const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)604 ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
605                                  const FunctionLibraryDefinition* flib_def,
606                                  const DeviceSet* device_set,
607                                  const Device* default_local_device,
608                                  bool allow_soft_placement,
609                                  bool log_device_placement)
610     : graph_(*graph),
611       stack_(stack),
612       flib_def_(*flib_def),
613       inspecting_placer_(stack, flib_def, device_set, default_local_device,
614                          allow_soft_placement, log_device_placement),
615       inspection_required_checker_(graph, flib_def),
616       device_set_(*device_set),
617       device_types_(device_set->PrioritizedDeviceTypeList()),
618       local_address_spec_(
619           LocalAddressSpec(device_set->client_device(), default_local_device)),
620       default_local_device_(default_local_device),
621       allow_soft_placement_(allow_soft_placement),
622       log_device_placement_(log_device_placement) {
623   members_.resize(graph_.num_node_ids());
624 }
625 
626 // Adds each node of the Graph to this ColocationGraph as a singleton.
627 //
628 // NOTE: The implementation assumes that the ids of nodes passed to
629 // this method are dense and zero-based; the memory used will be linear in
630 // the largest node ID.
631 // NOTE: If this method returns an error, *this is left in an undefined
632 // state.
ColocateAllNodes()633 Status ColocationGraph::ColocateAllNodes() {
634   // This maps from a colocation group identifier to the 'root' of that
635   // colocation group.  Note that the keys in this map are StringPiece; the
636   // actual strings are stored under the NodeDef.  The lifetime of this map
637   // is limited to this ColocateAllNodes() method, and no part of the
638   // NodeDef trees are changed during the lifetime of this method, so using
639   // StringPiece as a key is safe.
640   //
641   // Also, as a further optimization, we remove the "loc:@" prefix from
642   // "class" attribute values, when they are used as keys in this table.
643   // This allows us to use StringPiece values that refer to substrings of
644   // 'string' values stored in NodeDef attribute lists, as well as StringPiece
645   // values that refer to 'string' values from NodeDef::name(), without
646   // performing any string allocations.
647   std::unordered_map<StringPiece, const Node*, StringPieceHasher>
648       colocation_group_root;
649 
650   for (const Node* node : graph_.op_nodes()) {
651     // When adding the node, identify whether it is part of a colocation
652     // group.
653 
654     // This code is effectively the equivalent of GetNodeAttr() for a string
655     // array, but it avoids all internal allocations (the allocation of the
656     // backing store of the std::vector<string> as well as the copies of the
657     // strings within it).  Instead, we combine the query of the colocation
658     // attribute with the calls to ColocateNodeToGroup.
659     const AttrValue* attr_value =
660         node->attrs().Find(kColocationAttrNameStringPiece);
661     if (attr_value != nullptr) {
662       if (attr_value->has_list()) {
663         for (const string& class_spec : attr_value->list().s()) {
664           StringPiece spec(class_spec);
665           if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
666             TF_RETURN_IF_ERROR(
667                 ColocateNodeToGroup(&colocation_group_root, node, spec));
668           }
669         }
670       } else if (!attr_value->s().empty()) {
671         LOG(ERROR) << "The value for colocation attribute '_class' must be a "
672                       "list of strings, not a single string: "
673                    << node->DebugString();
674       }
675     }
676 
677     // Each node belongs to a colocation group with the node's name.
678     TF_RETURN_IF_ERROR(
679         ColocateNodeToGroup(&colocation_group_root, node, node->name()));
680   }
681 
682   return Status::OK();
683 }
684 
ColocateResourceOrRefEdge(const Node * src,const Node * dst)685 Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
686                                                   const Node* dst) {
687   // Colocate `src` and `dst` to maintain the invariant that nodes
688   // connected by reference edges are colocated.
689   int src_root_id = FindAndUpdateRoot(src->id());
690   int dst_root_id = FindAndUpdateRoot(dst->id());
691   auto& src_root = members_[src_root_id];
692   auto& dst_root = members_[dst_root_id];
693 
694   if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) {
695     // If the src root is assigned to a composite device and the dst root is
696     // assigned to a physical device, don't colocate the dst root with the src
697     // root.
698     return Status::OK();
699   }
700   TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
701       *src, src_root, *dst, log_device_placement_));
702   Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
703   if (!status.ok()) {
704     return AttachDef(
705         errors::InvalidArgument(
706             "Nodes were connected by a reference or resource connection "
707             "(requiring them to be on the same device), but the two nodes "
708             "were assigned two different devices: ",
709             status.error_message()),
710         *dst);
711   }
712   return Status::OK();
713 }
714 
ColocateResourceAndRefEdges(std::unordered_set<Node * > * inspection_required)715 Status ColocationGraph::ColocateResourceAndRefEdges(
716     std::unordered_set<Node*>* inspection_required) {
717   // If `node` has an input edge with reference type, add an edge from the
718   // source of that edge to `node`.
719   for (const Edge* edge : graph_.edges()) {
720     if (edge->IsControlEdge()) {
721       continue;
722     }
723     Node* src = edge->src();
724     Node* dst = edge->dst();
725     bool needs_inspection;
726     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
727         *src, &needs_inspection));
728     if (needs_inspection) {
729       inspection_required->insert(src);
730       continue;
731     }
732     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
733         *dst, &needs_inspection));
734     if (needs_inspection) {
735       inspection_required->insert(dst);
736       continue;
737     }
738 
739     DataType input_type = dst->input_type(edge->dst_input());
740 
741     // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
742     // This is needed to get around the issue in b/135705778.
743     if (input_type == DT_VARIANT &&
744         data::DatasetOpKernel::IsDatasetOp(&src->op_def()) &&
745         data::DatasetOpKernel::IsDatasetOp(&dst->op_def())) {
746       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
747       continue;
748     }
749 
750     // Even though we can look inside function calling ops, we make an exception
751     // here mostly for performance reasons. Looking inside function calling ops
752     // is extra overhead. It is only necessary when they return resources. When
753     // they don't, we don't look inside them and make this exception here.
754     // Looking inside, could potentially enable us to make better placement
755     // decisions. It might be worth doing at some point.
756     if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
757         !IsExemptFromResourceInputColocation(dst)) {
758       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
759     }
760   }
761 
762   return Status::OK();
763 }
764 
765 namespace {
766 // Returns tensor list element data type, if the node is one of the ops that
767 // operate with TensorLists. Otherwise returns DT_INVALID.
GetElementDataType(const Node & node)768 DataType GetElementDataType(const Node& node) {
769   static absl::flat_hash_set<std::string>* tensor_list_ops =
770       new absl::flat_hash_set<std::string>(
771           {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
772            "TensorListSplit", "TensorListScatter", "TensorListScatterV2",
773            "TensorListScatterIntoExistingList", "TensorListPushBack",
774            "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
775            "TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
776            "TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
777 
778   if (tensor_list_ops->contains(node.type_string())) {
779     DataType element_type;
780     if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
781       return element_type;
782     }
783   }
784 
785   return DT_INVALID;
786 }
787 }  // namespace
788 
AddHostOnlyDataTypesConstraints()789 Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
790   auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
791 
792   auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
793     return entry.first == DEVICE_CPU;
794   };
795 
796   for (Node* node : graph_.nodes()) {
797     // Skip nodes that do not have DT_VARIANT inputs.
798     if (absl::c_none_of(node->input_types(), is_variant)) continue;
799 
800     // Skip nodes that can't be placed on GPU anyway.
801     Member& root = members_[FindAndUpdateRoot(node->id())];
802     if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue;
803 
804     // Stop DFS traversal when found the underlying data type of a variant.
805     absl::optional<bool> is_host_data_type;
806 
807     auto edge_filter = [&](const Edge& edge) -> bool {
808       // We already found the underlying data type.
809       if (is_host_data_type.has_value()) return false;
810 
811       // Otherwise follow only DT_VARIANT data edges.
812       auto edge_dtype = [&]() -> DataType {
813         return edge.src()->output_type(edge.src_output());
814       };
815       return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
816     };
817 
818     auto enter = [&](Node* n) -> void {
819       DataType element_type = GetElementDataType(*n);
820       // To handle nested lists continue traversal after finding a TensorList
821       // operation that uses DT_VARIANT for element type.
822       if (element_type == DT_INVALID || element_type == DT_VARIANT) return;
823       is_host_data_type = DataTypeAlwaysOnHost(element_type);
824     };
825 
826     ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
827                    /*stable_comparator=*/nullptr, edge_filter);
828 
829     if (is_host_data_type.has_value() && *is_host_data_type) {
830       VLOG(2) << "Limit node possible devices to CPU only, because it has a "
831                  "DT_VARIANT input with host-only underlying data type: "
832               << "node=" << node->name();
833 
834       // Restrict possible device types to CPU only.
835       PossibleDevices possible_devices;
836       absl::c_copy_if(root.supported_device_types(),
837                       std::back_inserter(possible_devices.device_types),
838                       is_cpu_device);
839 
840       TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
841           possible_devices, /*allow_soft_placement=*/false));
842     }
843   }
844 
845   return Status::OK();
846 }
847 
AddInspectionConstraints(const std::unordered_set<Node * > & inspection_required)848 Status ColocationGraph::AddInspectionConstraints(
849     const std::unordered_set<Node*>& inspection_required) {
850   for (Node* node : inspection_required) {
851     IOColocationGroups groups;
852     TF_RETURN_IF_ERROR(
853         inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
854     VLOG(2) << "Computed IOColocationGroups for node " << node->name()
855             << ":\n\t" << groups.DebugString();
856     TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
857   }
858   return Status::OK();
859 }
860 
Initialize()861 Status ColocationGraph::Initialize() {
862   TF_RETURN_IF_ERROR(InitializeMembers());
863 
864   std::unordered_set<Node*> inspection_required;
865   TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
866   TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
867   TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
868   TF_RETURN_IF_ERROR(ColocateAllNodes());
869 
870   for (Node* node : graph_.op_nodes()) {
871     int root_id = FindAndUpdateRoot(node->id());
872     members_[root_id].MaybeExcludeXlaDevices();
873   }
874 
875   return Status::OK();
876 }
877 
878 // pair containing a node and whether this node has a resource input
879 // from the node requiring placer inspection.
880 using NodeAndBool = std::pair<const Node*, bool>;
881 
882 namespace {
883 
884 // Returns a vector of node names from `nodes`.
NodeAndBoolToString(const std::vector<NodeAndBool> & nodes)885 std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
886   std::vector<string> v;
887   v.reserve(nodes.size());
888   for (const NodeAndBool& node_and_bool : nodes) {
889     v.push_back(node_and_bool.first->name());
890   }
891   return v;
892 }
893 
894 // Given a node requiring placer inspection and its IOColocationGroups,
895 // computes `group_nodes`.
896 // group_nodes[i] contains the nodes that are members of colocation
897 // group i. These nodes are inputs or outputs of `node`.
898 // group_nodes[i][j] is a pair containing a node and whether this node
899 // has a resource input from `node`.
900 // Note:
901 // The same node can be added multiple times to the same group.
902 // The same node can be added to multiple groups.
GetGroupNodes(const IOColocationGroups & groups,const Node & node,std::vector<std::vector<NodeAndBool>> * group_nodes)903 Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
904                      std::vector<std::vector<NodeAndBool>>* group_nodes) {
905   group_nodes->reserve(groups.group_devices.size());
906   for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
907     const Node* src;
908     TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
909     int group_id = groups.input_groups[arg_idx];
910     (*group_nodes)[group_id].emplace_back(src, false);
911   }
912 
913   for (const Edge* edge : node.out_edges()) {
914     if (edge->IsControlEdge()) {
915       continue;
916     }
917 
918     int group_id = groups.output_groups[edge->src_output()];
919     (*group_nodes)[group_id].emplace_back(
920         edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
921   }
922 
923   if (VLOG_IS_ON(2)) {
924     VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
925     for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
926       VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
927               << "]";
928     }
929   }
930   return Status::OK();
931 }
932 
933 // Returns whether the device_type in `device_attributes` is supported.
IsSupportedDeviceType(const DeviceAttributes & device_attributes,const DeviceType & supported_type)934 bool IsSupportedDeviceType(const DeviceAttributes& device_attributes,
935                            const DeviceType& supported_type) {
936   if (DeviceType(device_attributes.device_type()) == supported_type) {
937     return true;
938   }
939   return IsCompositeDevice(device_attributes.device_type());
940 }
941 
942 }  // namespace
943 
ApplyIOColocationGroups(const IOColocationGroups & groups,const Node & node)944 Status ColocationGraph::ApplyIOColocationGroups(
945     const IOColocationGroups& groups, const Node& node) {
946   if (groups.input_groups.size() != node.num_inputs()) {
947     return errors::Internal(
948         "Cannot apply input/output device constraints to node ",
949         node.DebugString(), " because input_groups.size() (",
950         groups.input_groups.size(),
951         ") is different from number of inputs into the op node (",
952         node.num_inputs(), ")");
953   }
954   if (groups.output_groups.size() != node.num_outputs()) {
955     return errors::Internal(
956         "Cannot apply input/output device constraints to node ",
957         node.DebugString(), " because output_groups.size() (",
958         groups.output_groups.size(),
959         ") is different from number of outputs into the op node (",
960         node.num_outputs(), ")");
961   }
962 
963   // group_nodes[i] contains the nodes that are members of colocation
964   // group i. These nodes are inputs or outputs of `node`.
965   // group_nodes[i][j] is a pair containing the node and whether this node
966   // has a resource input from `node`.
967   // The same node can be added multiple times to the same group.
968   // The same node can be added to multiple groups.
969   // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
970   std::vector<std::vector<NodeAndBool>> group_nodes(
971       groups.group_devices.size());
972   TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
973 
974   // Colocate nodes in each group
975   for (const std::vector<NodeAndBool>& nodes : group_nodes) {
976     for (int i = 1; i < nodes.size(); ++i) {
977       VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
978               << nodes[i].first->name() << "\"";
979       if (nodes[i].second) {
980         TF_RETURN_IF_ERROR(
981             ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
982       } else {
983         TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
984       }
985     }
986   }
987 
988   // Limit devices in each group
989   for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
990     // Nothing to do for empty groups. Groups can be empty if some output
991     // of an op is not used.
992     if (group_nodes[group_id].empty()) {
993       continue;
994     }
995     const Node* group_node = group_nodes[group_id][0].first;
996     const PossibleDevices& possible_devices = groups.group_devices[group_id];
997     TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
998   }
999 
1000   return Status::OK();
1001 }
1002 
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)1003 Status ColocationGraph::ColocateNodeToGroup(
1004     std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
1005         colocation_group_root,
1006     const Node* node, StringPiece colocation_group) {
1007   const Node*& root_node = (*colocation_group_root)[colocation_group];
1008   if (root_node == nullptr) {
1009     // This is the first node of the colocation group, so
1010     // designate this node as the 'root' of that colocation group.
1011     root_node = node;
1012   } else {
1013     // Try to colocate the node with the root.  If there is an
1014     // error, return it.
1015     Status s = ColocateNodes(*node, *root_node);
1016     if (!s.ok()) {
1017       if (!allow_soft_placement_) {
1018         return AttachDef(s, *node);
1019       }
1020       if (log_device_placement_) {
1021         LOG(INFO) << "Ignoring request to colocate node '" << node->name()
1022                   << "' with nodes in colocation group '" << colocation_group
1023                   << "' because soft placement is on and an attempt at doing "
1024                      "so resulted in the following error: "
1025                   << AttachDef(s, *node).ToString();
1026       }
1027     }
1028   }
1029   return Status::OK();
1030 }
1031 
1032 // Merge the (possibly disjoint) sets containing nodes "x" and
1033 // "y". Returns OK if the all nodes in the union of these sets can
1034 // be placed on the same device type.
1035 //
1036 // NOTE: If this method returns an error, *this is left in an undefined
1037 // state.
ColocateNodes(const Node & x,const Node & y)1038 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
1039   int x_root = FindAndUpdateRoot(x.id());
1040   int y_root = FindAndUpdateRoot(y.id());
1041   return ColocateNodes(x, x_root, y, y_root);
1042 }
1043 
1044 // This overload of ColocateNodes() allows a caller to provide the root node
1045 // ids for the two nodes. For large graphs, this noticeably reduces the
1046 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)1047 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
1048                                       int y_root) {
1049   if (x_root == y_root) {
1050     return Status::OK();
1051   }
1052 
1053   Member* new_root_member;
1054   Member* old_root_member;
1055   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1056                 /*dry_run=*/true);
1057 
1058   // Merge the partial device specifications, and ensure that they are
1059   // compatible. NULL options_ is treated as allowing soft placement.
1060   // If there is an error, nothing is modified.
1061   // TODO(mrry): Consider enriching the error message by pointing
1062   // out which nodes have the explicit partial device
1063   // specifications that caused this conflict.
1064   Status s = new_root_member->MergeDeviceNames(*old_root_member,
1065                                                allow_soft_placement_);
1066   if (!s.ok()) {
1067     return errors::InvalidArgument(
1068         "Cannot colocate nodes ",
1069         errors::FormatColocationNodeForError(x.name()), " and ",
1070         errors::FormatColocationNodeForError(y.name()), ": ",
1071         s.error_message());
1072   }
1073 
1074   // Ensure that the common root has at least one supported device
1075   // type, by computing the intersection of
1076   // new_root_member.supported_device_types and
1077   // old_root_member.supported_device_types.
1078   if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
1079     return errors::InvalidArgument(
1080         "Cannot colocate nodes ",
1081         errors::FormatColocationNodeForError(x.name()), " and ",
1082         errors::FormatColocationNodeForError(y.name()),
1083         " because no device type supports both of those nodes and the "
1084         "other nodes colocated with them.",
1085         DebugInfo(x_root), DebugInfo(y_root));
1086   }
1087 
1088   // All error checks are done, merge the colocation graphs.
1089   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1090                 /*dry_run=*/false);
1091   return Status::OK();
1092 }
1093 
LimitToAssignedDevice(const Node & node)1094 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
1095   if (node.assigned_device_name_index() < 0) {
1096     return errors::Internal(
1097         "Expected an assigned node as argument to LimitToAssignedDevice but "
1098         "got: ",
1099         node.DebugString());
1100   }
1101   int root = FindAndUpdateRoot(node.id());
1102   Member& root_member = members_[root];
1103   return root_member.AssignDevice(node);
1104 }
1105 
GetSoftDeviceCandidates(const Node & node,const Member & root_member,int root_id,std::vector<Device * > * possible_devices)1106 void ColocationGraph::GetSoftDeviceCandidates(
1107     const Node& node, const Member& root_member, int root_id,
1108     std::vector<Device*>* possible_devices) {
1109   // Try to find supported devices that don't violate resource devices.
1110   // The soft_device_name is the same as the requested device name
1111   // without specifying the device type or ID (if assigned and requested
1112   // devices does not specify them).
1113   DeviceNameUtils::ParsedName soft_device_name =
1114       root_member.GetPreferredSoftDeviceName();
1115   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1116   if (!possible_devices->empty()) {
1117     *possible_devices = FilterSupportedDevices(
1118         *possible_devices, root_member.supported_device_types(),
1119         default_local_device_);
1120   }
1121 
1122   if (!possible_devices->empty()) {
1123     return;
1124   }
1125 
1126   // TODO(iga): Disallow changing resource devices when this ColocationGraph
1127   // is for :
1128   // - a function called by an op requiring deep inspection, or
1129   // - a graph containing ops requiring inspection.
1130   // It is fairly tricky to make changing resource devices in presence of
1131   // ops requiring inspection work correctly. One thing it would require is to
1132   // communicate these "resource movement" decisions across Placer instances.
1133 
1134   // Failed to find supported devices that don't violate resource devices.
1135   // Try finding some devices that violated resource devices.
1136   // If we succeed, we will log a warning below.
1137   soft_device_name = root_member.GetSoftDeviceName();
1138   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1139   if (!possible_devices->empty()) {
1140     *possible_devices = FilterSupportedDevices(
1141         *possible_devices, root_member.supported_device_types(),
1142         default_local_device_);
1143   }
1144 
1145   if (!possible_devices->empty()) {
1146     LOG(WARNING)
1147         << "Failed to place the graph without changing the devices of some "
1148            "resources. Some of the operations (that had to be colocated with "
1149            "resource generating operations) are not supported on the "
1150            "resources' devices. Current candidate devices are [\n  "
1151         << absl::StrJoin(DevicesToString(*possible_devices), "\n  ")
1152         << "].\nSee below for details of this colocation group:"
1153         << DebugInfo(root_id);
1154   }
1155 }
1156 
LimitToPossibleDevices(const Node & node,const PossibleDevices & devices)1157 Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1158                                                const PossibleDevices& devices) {
1159   int root = FindAndUpdateRoot(node.id());
1160   Member& root_member = members_[root];
1161   return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1162 }
1163 
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)1164 Status ColocationGraph::GetDevicesForNode(
1165     Node* node, const std::vector<Device*>** possible_devices) {
1166   *possible_devices = nullptr;
1167   const int node_root = FindAndUpdateRoot(node->id());
1168   if (!members_[node_root].possible_devices().empty()) {
1169     *possible_devices = &members_[node_root].possible_devices();
1170     return Status::OK();
1171   }
1172 
1173   Member& root_member = members_[node_root];
1174 
1175   // We have not yet computed the possible devices for the
1176   // colocated node set containing 'node', so we do so now using the
1177   // constraints on the root node.
1178 
1179   // "devices" will contain the set of feasible placements for the
1180   // colocated node set containing 'node'.
1181   // NOTE: Basing possible device computation on requested device name
1182   // is guaranteed to respect the assigned and resource device names because
1183   // requested device is always a specialization of both.
1184   std::vector<Device*> devices;
1185   if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1186     // The root node has a (possibly partial) device
1187     // specification, so enumerate the physical devices that
1188     // conform to it.
1189     device_set_.FindMatchingDevices(root_member.requested_device_name(),
1190                                     &devices);
1191 
1192     if (!devices.empty()) {
1193       // Filter devices into those that are compatible with the root
1194       // node (and its children).
1195       devices = FilterSupportedDevices(
1196           devices, root_member.supported_device_types(), default_local_device_);
1197     }
1198 
1199     // Perform soft placement if allow_soft_placement_ is set.
1200     if (devices.empty() && allow_soft_placement_) {
1201       GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1202     }
1203 
1204     if (devices.empty()) {
1205       // Return an error when a physical device that matches an explicit
1206       // device specification is not found. This ensures that we don't
1207       // assign a node to GPU when the user wanted to force it on CPU.
1208       string debug_info = DebugInfo(node_root);
1209 
1210       DeviceNameUtils::ParsedName specified_device_name;
1211       if (DeviceNameUtils::ParseFullName(node->requested_device(),
1212                                          &specified_device_name) &&
1213           specified_device_name == root_member.requested_device_name()) {
1214         // The specified device and merged set device match, and
1215         // will appear in the GraphDef (for debugging), so just
1216         // print the specified device.
1217         std::vector<Device*> devices_matching_nodedef;
1218         device_set_.FindMatchingDevices(specified_device_name,
1219                                         &devices_matching_nodedef);
1220         if (devices_matching_nodedef.empty()) {
1221           // Sometimes it is almost impossible to understand the problem
1222           // without a list of available devices.
1223           std::vector<string> device_names;
1224           for (const Device* device : device_set_.devices()) {
1225             device_names.push_back(device->name());
1226           }
1227           std::sort(device_names.begin(), device_names.end());
1228 
1229           string gpu_msg = "";
1230           if (!IsGoogleCudaEnabled() &&
1231               absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1232             gpu_msg =
1233                 " The requested device appears to be a GPU, but CUDA is not "
1234                 "enabled.";
1235           }
1236 
1237           return errors::InvalidArgument(
1238               errors::FormatNodeNameForError(node->name()),
1239               " was explicitly assigned to ", node->requested_device(),
1240               " but available devices are [ ",
1241               absl::StrJoin(device_names, ", "), " ]. Make sure ",
1242               "the device specification refers to a valid device.", gpu_msg);
1243         } else if (specified_device_name.has_type) {
1244           return errors::InvalidArgument(
1245               "Could not satisfy explicit device specification '",
1246               node->requested_device(), "' because no supported kernel for ",
1247               specified_device_name.type, " devices is available.", debug_info,
1248               "\nOp: ", node->type_string(),
1249               "\nNode attrs: ", node->attrs().DebugString(),
1250               "\nRegistered kernels:\n",
1251               KernelsRegisteredForOp(node->type_string()));
1252         } else {
1253           return errors::InvalidArgument(
1254               "Could not satisfy explicit device specification '",
1255               node->requested_device(), debug_info);
1256         }
1257       } else {
1258         // The specified device may be a valid device but the
1259         // merged set device is different, so print both.
1260         // TODO(b/129057603): There are many possibilities at this point.
1261         // Provide good error messages.
1262         return errors::InvalidArgument(
1263             "Could not satisfy explicit device specification '",
1264             node->requested_device(), "' because the node ",
1265             errors::FormatColocationNodeForError(node->name()),
1266             " was colocated with a group of nodes that ",
1267             "required incompatible device '",
1268             DeviceNameUtils::ParsedNameToString(
1269                 root_member.requested_device_name()),
1270             "'. All available devices [",
1271             absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1272             debug_info);
1273       }
1274     }
1275   } else {
1276     // The device is completely unspecified, so enumerate the devices that
1277     // support all of the nodes in the set.
1278     if (device_set_.devices().empty()) {
1279       return errors::Internal("No devices are registered");
1280     }
1281     devices = FilterSupportedDevices(device_set_.devices(),
1282                                      root_member.supported_device_types(),
1283                                      default_local_device_);
1284 
1285     if (devices.empty()) {
1286       return errors::InvalidArgument(
1287           "Node had no OpKernel registered to support this operation: ",
1288           "Operation was ", node->type_string(), " and inputs were [",
1289           DataTypeVectorString(node->input_types()), "].\n",
1290           DebugInfo(node_root));
1291     }
1292   }
1293 
1294   // Cache the result of the possible devices for this node group.
1295   root_member.set_possible_devices(std::move(devices));
1296   *possible_devices = &root_member.possible_devices();
1297   return Status::OK();
1298 }
1299 
InitializeMembers()1300 Status ColocationGraph::InitializeMembers() {
1301   for (Node* node : graph_.op_nodes()) {
1302     Status status = InitializeMember(*node, &members_[node->id()]);
1303     if (!status.ok()) {
1304       return AttachDef(status, *node);
1305     }
1306   }
1307   return Status::OK();
1308 }
1309 
DebugString() const1310 string ColocationGraph::DebugString() const {
1311   std::unordered_set<int> roots;
1312   std::vector<string> root_strings;
1313   for (const Node* node : graph_.nodes()) {
1314     if (!node->IsOp()) {
1315       continue;
1316     }
1317     int node_root = FindRoot(node->id());
1318     if (roots.count(node_root) == 0) {
1319       root_strings.push_back(DebugInfo(node_root));
1320       roots.insert(node_root);
1321     }
1322   }
1323   return absl::StrJoin(root_strings, "\n");
1324 }
1325 
1326 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root) const1327 string ColocationGraph::DebugInfo(const int node_root) const {
1328   string text(
1329       "\nColocation Debug Info:\n"
1330       "Colocation group had the following types and supported devices: ");
1331 
1332   // If this node is part of a colocation group, then we want to
1333   // collect the mapping of ops to supported devices, so that
1334   // the user can see why an unsatisfiable placement occurred.
1335 
1336   std::unordered_map<string, string> type_to_devices;
1337   std::vector<const Node*> colocation_nodes;
1338   int num_nodes_found = 0;
1339 
1340   for (const Node* node : graph_.nodes()) {
1341     if (!node->IsOp()) {
1342       continue;
1343     }
1344     int id = node->id();
1345     if (FindRoot(id) != node_root) {
1346       continue;
1347     }
1348     ++num_nodes_found;
1349     colocation_nodes.push_back(node);
1350 
1351     PrioritizedDeviceTypeVector supported_types;
1352     SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1353                                 &local_address_spec_)
1354         .IgnoreError();
1355     string devices_registered;
1356     for (const auto& device_type : supported_types) {
1357       strings::StrAppend(&devices_registered,
1358                          DeviceTypeString(device_type.first), " ");
1359     }
1360 
1361     const string& op_type = node->type_string();
1362     type_to_devices[op_type] = std::move(devices_registered);
1363   }
1364   strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1365 
1366   for (const auto& td : type_to_devices) {
1367     strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1368   }
1369   strings::StrAppend(&text,
1370                      "\n\nColocation members, user-requested devices, and "
1371                      "framework assigned devices, if any:");
1372   for (const Node* node : colocation_nodes) {
1373     strings::StrAppend(&text, "\n  ", node->name(), " (", node->type_string(),
1374                        ") ", node->requested_device());
1375     if (node->has_assigned_device_name()) {
1376       strings::StrAppend(
1377           &text, " framework assigned device=", node->assigned_device_name());
1378     }
1379   }
1380   strings::StrAppend(&text, "\n");
1381 
1382   if (num_nodes_found <= 0) {
1383     text.clear();
1384   }
1385   return text;
1386 }
1387 
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,Member * member)1388 Status ColocationGraph::InitializeMemberWithAssignedDevice(
1389     const string& assigned_device_name, const string& node_type,
1390     Member* member) {
1391   // This node has already been assigned to a device, so we
1392   // respect this placement, after sanity-checking it.
1393   // NOTE: Since any assignment must have been performed by
1394   // the TensorFlow runtime, we consider errors in this branch to
1395   // be INTERNAL.
1396   TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1397 
1398   // Since assigned device must be a full specification, do extra checks.
1399   const Device* assigned_device =
1400       device_set_.FindDeviceByName(assigned_device_name);
1401   if (assigned_device == nullptr) {
1402     // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1403     // calls when they are supported.
1404     return errors::Internal(
1405         "Assigned device '", assigned_device_name,
1406         "' does not match any device. This error can happen when one attempts "
1407         "to run a tf.function with resource inputs residing on remote devices. "
1408         "This use case is currently not supported. Here are the devices "
1409         "available on this machine: [",
1410         absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1411         "If you are seeing this error when running using a tf.Session, set "
1412         "share_cluster_devices_in_session to true in the tf.ConfigProto.");
1413   }
1414 
1415   for (const auto& d : member->supported_device_types()) {
1416     if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) {
1417       return Status::OK();
1418     }
1419   }
1420 
1421   return errors::Internal("Assigned device '", assigned_device_name,
1422                           "' does not have registered OpKernel support "
1423                           "for ",
1424                           node_type);
1425 }
1426 
InitializeMember(const Node & node,Member * member)1427 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1428   TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1429       node, device_types_, &local_address_spec_));
1430 
1431   if (node.has_assigned_device_name()) {
1432     TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1433         node.assigned_device_name(), node.type_string(), member));
1434   } else {
1435     // This node has not yet been assigned to a device, so we
1436     // calculate any constraints due to the set of registered
1437     // kernels and any (partial) user-provided device specification
1438     // in the NodeDef.
1439 
1440     // If no kernels are registered for this op type, fail with an error.
1441     if (member->supported_device_types().empty()) {
1442       std::set<string> registered_device_types;
1443       for (Device* d : device_set_.devices()) {
1444         registered_device_types.insert(d->device_type());
1445       }
1446       return errors::InvalidArgument(
1447           "No OpKernel was registered to support Op '", node.type_string(),
1448           "' used by ", errors::FormatNodeNameForError(node.name()),
1449           " with these attrs: [", node.attrs().DebugString(),
1450           "]\n"
1451           "Registered devices: [",
1452           absl::StrJoin(registered_device_types, ", "), "]\n",
1453           "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1454     }
1455 
1456     // If the NodeDef contains a device, then we interpret it as a
1457     // (partial) device specification.
1458     if (!node.requested_device().empty()) {
1459       if (IsRefOrResourceGeneratorNode(node)) {
1460         // Treat requested device on resource generating nodes as assigned
1461         // device so that we don't override it.
1462         TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1463       } else {
1464         // The user has specified a device in the NodeDef, try to find a
1465         // valid device matching their specification in the set of
1466         // devices.
1467         // NOTE: The full name may specify a device that is not in
1468         // n.supported_device_types(), but we check that in AssignDevice().
1469         TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1470       }
1471     }
1472   }
1473   return Status::OK();
1474 }
1475 
1476 // Returns a list of devices having type in supported_device_types.  The
1477 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_local_device)1478 /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1479     const std::vector<Device*>& devices,
1480     const PrioritizedDeviceTypeVector& supported_device_types,
1481     const Device* default_local_device) {
1482   Device* filtered_default_device = nullptr;
1483   PrioritizedDeviceVector prioritized_filtered_devices;
1484   for (const auto& supported_device_type : supported_device_types) {
1485     for (Device* device : devices) {
1486       if (IsSupportedDeviceType(device->attributes(),
1487                                 supported_device_type.first)) {
1488         if (default_local_device &&
1489             (device == default_local_device ||
1490              // TODO(nareshmodi, fishx): At times the device pointer in the
1491              // device set is different to the one passed in as the default
1492              // device. Figure out why this might be.
1493              device->name() == default_local_device->name())) {
1494           filtered_default_device = device;
1495         } else {
1496           prioritized_filtered_devices.emplace_back(
1497               device, supported_device_type.second);
1498         }
1499       }
1500     }
1501   }
1502   DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1503 
1504   std::vector<Device*> filtered_devices;
1505   if (filtered_default_device != nullptr) {
1506     filtered_devices.emplace_back(filtered_default_device);
1507   }
1508   for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1509     filtered_devices.push_back(prioritized_filtered_device.first);
1510   }
1511   return filtered_devices;
1512 }
1513 
1514 }  // namespace tensorflow
1515