• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17 
18 #include <memory>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/common_runtime/composite_device.h"
30 #include "tensorflow/core/common_runtime/device.h"
31 #include "tensorflow/core/common_runtime/device_set.h"
32 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
33 #include "tensorflow/core/common_runtime/inspecting_placer.h"
34 #include "tensorflow/core/common_runtime/partitioning_utils.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/attr_value_util.h"
37 #include "tensorflow/core/framework/dataset.h"
38 #include "tensorflow/core/framework/device_attributes.pb.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/graph_node_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/util/device_name_utils.h"
51 #include "tensorflow/core/util/dump_graph.h"
52 #include "tensorflow/core/util/port.h"
53 
54 namespace tensorflow {
55 
56 namespace {
57 
58 // We hoist the conversion from C-style string literal to StringPiece here,
59 // so that we can avoid the many repeated calls to strlen().
60 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
61 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
62 
63 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)64 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
65   std::vector<string> v;
66   v.reserve(devices.size());
67   for (Device* d : devices) {
68     v.push_back(d->name());
69   }
70   return v;
71 }
72 
73 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)74 std::vector<string> DeviceTypeAndPriorityToString(
75     const PrioritizedDeviceTypeVector& devices) {
76   std::vector<string> v;
77   v.reserve(devices.size());
78   for (const std::pair<DeviceType, int32>& device_and_type : devices) {
79     v.push_back(DeviceTypeString(device_and_type.first));
80   }
81   return v;
82 }
83 
IsRefOrResource(DataType data_type)84 bool IsRefOrResource(DataType data_type) {
85   return IsRefType(data_type) || data_type == DT_RESOURCE;
86 }
87 
88 // While Placer can override requested device on ops processing
89 // resources, i.e. node that take (and potentially return) a resource,
90 // it must not override requested device on ops generating a resource,
91 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
92 // output nodes.
IsRefOrResourceGeneratorNode(const Node & node)93 bool IsRefOrResourceGeneratorNode(const Node& node) {
94   return node.num_inputs() == 0 && node.num_outputs() == 1 &&
95          IsRefOrResource(node.output_type(0));
96 }
97 
IsExemptFromResourceInputColocation(const Node * node)98 bool IsExemptFromResourceInputColocation(const Node* node) {
99   // Note: Partitioned function calls, which place and partition their
100   // function bodies, are exempt from this check: they forward resource and
101   // ref inputs to operations that are appropriately placed, instead of
102   // dereferencing them.
103   const string& op_type = node->op_def().name();
104   auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
105   return exempt_ops.find(op_type) != exempt_ops.end();
106 }
107 
HasPriorities(const PrioritizedDeviceTypeVector & device_types)108 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
109   for (const auto& prioritized_device_type : device_types) {
110     if (prioritized_device_type.second != 0) return true;
111   }
112   return false;
113 }
114 
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)115 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
116                        const PrioritizedDeviceTypeVector& b_types) {
117   if (a_types.size() != b_types.size()) {
118     return false;
119   }
120   for (int i = 0; i < a_types.size(); ++i) {
121     if (a_types[i].first != b_types[i].first) {
122       return false;
123     }
124   }
125   return true;
126 }
127 
IsXlaDevice(absl::string_view device_type)128 bool IsXlaDevice(absl::string_view device_type) {
129   if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
130       device_type == "XLA_TPU_JIT") {
131     // Symbolic XLA device.
132     return true;
133   }
134 
135   return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
136           device_type == "TPU");
137 }
138 
IsCompositeDevice(absl::string_view device_type)139 bool IsCompositeDevice(absl::string_view device_type) {
140   return device_type == kCompositeDeviceType;
141 }
142 
143 }  // namespace
144 
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types,const DeviceNameUtils::ParsedName * local_address_spec)145 Status Member::SetParentAndSupportedDevices(
146     const Node& node, const std::vector<DeviceType>& types,
147     const DeviceNameUtils::ParsedName* local_address_spec) {
148   int id = node.id();
149   if (id < 0) {
150     return errors::Internal("Placer should not be creating a Member for node: ",
151                             node.DebugString());
152   }
153   parent_ = id;
154   return SupportedDeviceTypesForNode(
155       types, node.def(), &supported_device_types_, local_address_spec);
156 }
157 
SetAssignedDeviceName(const string & device_name)158 Status Member::SetAssignedDeviceName(const string& device_name) {
159   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
160     return errors::Internal(
161         "Setting assigned device name when there is a requested device set "
162         "is unsupported");
163   }
164   if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
165     return errors::Internal("Malformed assigned device '", device_name, "'");
166   }
167   // Set requested device to assigned_device to maintain the invariant that
168   // requested is a specialization of assigned.
169   requested_device_name_ = assigned_device_name_;
170   return Status::OK();
171 }
172 
SetResourceDeviceName(const Node & node)173 Status Member::SetResourceDeviceName(const Node& node) {
174   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
175     return errors::Internal(
176         "Setting resource device name when there is a requested device set "
177         "is unsupported");
178   }
179 
180   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
181                                       &resource_device_name_)) {
182     return errors::InvalidArgument("Malformed device specification '",
183                                    node.requested_device(),
184                                    "' in node: ", node.DebugString());
185   }
186 
187   // Set requested device to resource device to maintain the invariant that
188   // requested is a specialization of resource.
189   requested_device_name_ = resource_device_name_;
190   return Status::OK();
191 }
192 
SetRequestedDeviceName(const Node & node)193 Status Member::SetRequestedDeviceName(const Node& node) {
194   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
195     return errors::Internal(
196         "Setting requested device name when there is an assigned device set "
197         "is unsupported");
198   }
199   if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
200     return errors::Internal(
201         "Setting requested device name when there is a resource device set "
202         "is unsupported");
203   }
204   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
205                                       &requested_device_name_)) {
206     return errors::InvalidArgument("Malformed device specification '",
207                                    node.requested_device(),
208                                    "' in node: ", node.DebugString());
209   }
210   return Status::OK();
211 }
212 
FillPossibleDevices(PossibleDevices * possible_device) const213 Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
214   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
215     return errors::Internal(
216         "Cannot fill PossibleDevices from a member that has non-empty assigned "
217         "device. Did we start assigning devices to functions called by deep "
218         "ops? ",
219         DebugString());
220   }
221   possible_device->requested_device_name = requested_device_name_;
222   possible_device->resource_device_name = resource_device_name_;
223   possible_device->device_types = supported_device_types_;
224   return Status::OK();
225 }
226 
IsEdgeFromCompositeDeviceToPhysicalDevice(const Member & src_root) const227 bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice(
228     const Member& src_root) const {
229   auto compatible_edge_from_composite_device_to_physical_device =
230       [](const DeviceNameUtils::ParsedName& src_device,
231          const DeviceNameUtils::ParsedName& dst_device) -> bool {
232     return src_device.has_type && dst_device.has_type &&
233            IsCompositeDevice(src_device.type) &&
234            !IsCompositeDevice(dst_device.type);
235   };
236   if (compatible_edge_from_composite_device_to_physical_device(
237           src_root.assigned_device_name_, assigned_device_name_) ||
238       compatible_edge_from_composite_device_to_physical_device(
239           src_root.resource_device_name_, resource_device_name_) ||
240       compatible_edge_from_composite_device_to_physical_device(
241           src_root.requested_device_name_, requested_device_name_)) {
242     return true;
243   }
244   return false;
245 }
246 
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)247 Status Member::EnsureCompatibilityAcrossResourceEdge(
248     const Node& src, const Member& src_root,
249     const Node& dst, /*dst_root is this*/
250     bool log_device_placement) {
251   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
252                                               assigned_device_name_)) {
253     return errors::InvalidArgument(
254         "Cannot place the graph because a reference or resource edge "
255         "connects colocation groups with incompatible assigned devices: ",
256         DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
257         " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
258         ". The edge src node is ", src.name(), " , and the dst node is ",
259         dst.name());
260   }
261 
262   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
263                                               resource_device_name_)) {
264     return errors::InvalidArgument(
265         "Cannot place the graph because a reference or resource edge "
266         "connects colocation groups with incompatible resource devices: ",
267         DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
268         " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
269         ". The edge src node is ", src.name(), " , and the dst node is ",
270         dst.name());
271   }
272 
273   if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
274                                              requested_device_name_)) {
275     return Status::OK();
276   }
277 
278   // If we are here, assigned and resource devices are compatible but requested
279   // ones are not. We will be overriding the requested device for destination
280   // node, but need to preserve the invariant that it will be a specialization
281   // of the assigned and resource devices.
282   if (log_device_placement) {
283     LOG(INFO) << "Ignoring device specification "
284               << DeviceNameUtils::ParsedNameToString(requested_device_name_)
285               << " for node '" << dst.name()
286               << "' because the input edge from '" << src.name()
287               << "' is a reference connection and already has a device "
288                  "field set to "
289               << DeviceNameUtils::ParsedNameToString(
290                      src_root.requested_device_name_);
291   }
292   requested_device_name_ = src_root.requested_device_name_;
293   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
294                                        assigned_device_name_);
295   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
296                                        resource_device_name_);
297   return Status::OK();
298 }
299 
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)300 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
301                    Member** new_root, Member** old_root, bool dry_run) {
302   Member& x_root_member = (*tree)[x_root];
303   Member& y_root_member = (*tree)[y_root];
304 
305   // Merge the sets by setting the parent pointer of the smaller tree's root
306   // node to point to the root of the larger tree. Together with path
307   // compression in ColocationGraph::FindRoot, this ensures that we do not
308   // experience pathological performance on graphs such as chains.
309   int new_root_id, old_root_id;
310   if (x_root_member.rank_ < y_root_member.rank_) {
311     // The tree rooted at x_root is shallower, so connect it to
312     // y_root. The rank of y_root is unchanged because its new
313     // child has strictly less rank.
314     if (!dry_run) {
315       x_root_member.parent_ = y_root;
316     }
317     new_root_id = y_root;
318     old_root_id = x_root;
319   } else if (x_root_member.rank_ > y_root_member.rank_) {
320     // The tree rooted at y_root is shallower, so connect it to
321     // x_root. The rank of x_root is unchanged because its new
322     // child has strictly less rank.
323     if (!dry_run) {
324       y_root_member.parent_ = x_root;
325     }
326     new_root_id = x_root;
327     old_root_id = y_root;
328   } else {
329     if (!dry_run) {
330       // Both trees have the same rank, so break the tie by choosing
331       // x_root as the new root.
332       y_root_member.parent_ = x_root;
333       // Increment the rank of the tree rooted at x_root, because it
334       // is now strictly deeper than before.
335       ++x_root_member.rank_;
336     }
337     new_root_id = x_root;
338     old_root_id = y_root;
339   }
340 
341   *new_root = &(*tree)[new_root_id];
342   *old_root = &(*tree)[old_root_id];
343 }
344 
345 // tree is non-const because we can change some `parent` pointers in some
346 // members for more efficient future lookups. The vector itself is not
347 // changed.
FindAndUpdateRoot(std::vector<Member> * tree,int node_id)348 int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
349   Member& member = (*tree)[node_id];
350   if (member.parent_ == node_id) {
351     // member.parent is the root of this disjoint tree.  Do nothing.
352   } else {
353     member.parent_ = FindAndUpdateRoot(tree, member.parent_);
354   }
355   // Now it is guaranteed that member.parent is the root of this disjoint
356   // tree.
357   return member.parent_;
358 }
359 
FindRoot(const std::vector<Member> & tree,int node_id)360 int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
361   const Member& member = tree[node_id];
362   if (member.parent_ == node_id) {
363     return member.parent_;
364   }
365   return FindRoot(tree, member.parent_);
366 }
367 
MergeDeviceNames(const Member & other,bool allow_soft_placement)368 Status Member::MergeDeviceNames(const Member& other,
369                                 bool allow_soft_placement) {
370   // Assuming the "requested is a specialization of assigned and resource
371   // devices" invariant holds for this and `other`, it will hold after the
372   // merges below.
373   DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
374   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
375       &assigned_device_name_copy, other.assigned_device_name_));
376 
377   DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
378   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
379       &resource_device_name_copy, other.resource_device_name_));
380 
381   DeviceNameUtils::ParsedName requested_device_name_copy =
382       requested_device_name_;
383   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
384       &requested_device_name_copy, other.requested_device_name_,
385       allow_soft_placement));
386 
387   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
388                                        assigned_device_name_copy);
389   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
390                                        resource_device_name_copy);
391 
392   // We checked for all errors, now change the devices.
393   assigned_device_name_ = assigned_device_name_copy;
394   resource_device_name_ = resource_device_name_copy;
395   requested_device_name_ = requested_device_name_copy;
396   return Status::OK();
397 }
398 
399 // Updates this to contain the intersection of the device types in
400 // this and "other".
MergeSupportedDevices(const Member & other)401 bool Member::MergeSupportedDevices(const Member& other) {
402   return MergeSupportedDevices(other.supported_device_types_);
403 }
404 
MergeSupportedDevices(const PrioritizedDeviceTypeVector & other_devices)405 bool Member::MergeSupportedDevices(
406     const PrioritizedDeviceTypeVector& other_devices) {
407   // Generate intersection with priorities.
408   // Each vector contains the same device types but with different priorities.
409   // The priorities are taken from the corresponding source vector.
410   PrioritizedDeviceTypeVector target_intersection;
411   PrioritizedDeviceTypeVector other_intersection;
412 
413   for (const auto& prioritized_device_type : supported_device_types_) {
414     bool found = false;
415     for (const auto& other_prioritized_device_type : other_devices) {
416       if (prioritized_device_type.first ==
417           other_prioritized_device_type.first) {
418         found = true;
419         other_intersection.push_back(other_prioritized_device_type);
420         break;
421       }
422     }
423     if (found) {
424       target_intersection.push_back(prioritized_device_type);
425     }
426   }
427 
428   DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
429   DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
430 
431   PrioritizedDeviceTypeVector result;
432 
433   bool is_target_prioritized = HasPriorities(target_intersection);
434   bool is_other_prioritized = HasPriorities(other_intersection);
435   if (!is_target_prioritized && !is_other_prioritized) {
436     // If neither are prioritized then we just return the original i.e. target
437     // prioritization.
438     result = target_intersection;
439   } else if (is_target_prioritized && !is_other_prioritized) {
440     // If only one is prioritized, then we respect priorities of that in the
441     // intersection.
442     result = target_intersection;
443   } else if (!is_target_prioritized && is_other_prioritized) {
444     result = other_intersection;
445   } else {
446     // If both have priorities and agree then we go with that. If the
447     // prioritization order is different, then we just fallback to the default
448     // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
449     // merged priorities to 0, so that downstream merges work correctly as well.
450     if (ArePrioritiesSame(target_intersection, other_intersection)) {
451       result = target_intersection;
452     } else {
453       for (const auto& prioritized_device : target_intersection) {
454         result.push_back(std::make_pair(prioritized_device.first, 0));
455       }
456       DeviceSet::SortPrioritizedDeviceTypeVector(&result);
457     }
458   }
459 
460   if (result.empty()) {
461     return false;
462   }
463   supported_device_types_ = result;
464   return true;
465 }
466 
AssignDevice(const Node & node)467 Status Member::AssignDevice(const Node& node) {
468   if (node.assigned_device_name_index() == assigned_device_name_index_) {
469     return Status::OK();
470   }
471 
472   DeviceNameUtils::ParsedName parsed;
473   DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
474   Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
475   if (!s.ok()) {
476     return errors::Internal(
477         "Constraining by assigned device should not cause an error. Original "
478         "root's assigned device name: ",
479         DeviceNameUtils::ParsedNameToString(assigned_device_name_),
480         " node's assigned device name \"", node.assigned_device_name(),
481         ". Error: ", s.error_message());
482   }
483   s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
484   if (!s.ok()) {
485     return errors::Internal(
486         "Constraining by assigned device should not cause an error. Original "
487         "root's resource device name: ",
488         DeviceNameUtils::ParsedNameToString(resource_device_name_),
489         " node's assigned device name \"", node.assigned_device_name(),
490         ". Error: ", s.error_message());
491   }
492   s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
493   if (!s.ok()) {
494     return errors::Internal(
495         "Constraining by assigned device should not cause an error. Original "
496         "root's requested device name: \"",
497         DeviceNameUtils::ParsedNameToString(requested_device_name_),
498         "\", node's assigned device name \"", node.assigned_device_name(),
499         "\". Error: ", s.error_message());
500   }
501 
502   assigned_device_name_index_ = node.assigned_device_name_index();
503   // Clear cached possible_devices, if any.
504   possible_devices_.clear();
505   return Status::OK();
506 }
507 
MaybeExcludeXlaDevices()508 void Member::MaybeExcludeXlaDevices() {
509   for (const auto& parsed_name :
510        {requested_device_name_, assigned_device_name_, resource_device_name_}) {
511     // Don't exculde XLA devices from supported devices if member is explicitly
512     // assigned to a CompositeDevice.
513     if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) ||
514                                  IsCompositeDevice(parsed_name.type))) {
515       return;
516     }
517   }
518 
519   PrioritizedDeviceTypeVector non_xla_types;
520   absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
521                   [&](const std::pair<DeviceType, int32>& entry) {
522                     return !IsXlaDevice(entry.first.type_string());
523                   });
524 
525   // TODO(b/141216278) Remove all XLA device types from the supported device
526   // types if the node has no requested/assigned/resource XLA device.
527   if (!non_xla_types.empty() &&
528       non_xla_types.size() < supported_device_types_.size()) {
529     supported_device_types_ = std::move(non_xla_types);
530   }
531 }
532 
LimitToPossibleDevices(const PossibleDevices & devices,bool allow_soft_placement)533 Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
534                                       bool allow_soft_placement) {
535   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
536       &requested_device_name_, devices.requested_device_name,
537       allow_soft_placement));
538   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
539       &resource_device_name_, devices.resource_device_name));
540   MergeSupportedDevices(devices.device_types);
541   return Status::OK();
542 }
543 
DebugString() const544 string Member::DebugString() const {
545   return absl::StrCat(
546       "Member(assigned_device_name_index_=", assigned_device_name_index_,
547       " requested_device_name_='",
548       DeviceNameUtils::ParsedNameToString(requested_device_name_),
549       "' assigned_device_name_='",
550       DeviceNameUtils::ParsedNameToString(assigned_device_name_),
551       "' resource_device_name_='",
552       DeviceNameUtils::ParsedNameToString(resource_device_name_),
553       "' supported_device_types_=[",
554       absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
555                     ", "),
556       "] possible_devices_=[",
557       absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
558 }
559 
GetSoftDeviceName() const560 DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
561   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
562   if (!assigned_device_name_.has_type) {
563     soft_device_name.type.clear();
564     soft_device_name.has_type = false;
565   }
566   if (!assigned_device_name_.has_id) {
567     soft_device_name.has_id = false;
568   }
569   return soft_device_name;
570 }
571 
GetPreferredSoftDeviceName() const572 DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
573   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
574   if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
575     soft_device_name.type.clear();
576     soft_device_name.has_type = false;
577   }
578   if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
579     soft_device_name.has_id = false;
580   }
581   return soft_device_name;
582 }
583 
584 // Returns ParsedName whose address space (i.e. job, replica, task) identifies
585 // the address space directly accessible by the local process. If the address
586 // space is fully specified and it is exactly the same as the address space
587 // of a device, then all kernels of that device should be registered in the
588 // local process.
LocalAddressSpec(const Device * client_device,const Device * default_local_device)589 static const DeviceNameUtils::ParsedName LocalAddressSpec(
590     const Device* client_device, const Device* default_local_device) {
591   if (client_device != nullptr) {
592     return DeviceNameUtils::AddressSpace(client_device->parsed_name());
593   }
594 
595   if (default_local_device != nullptr) {
596     return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
597   }
598 
599   // TODO(b/139617593) Return the name of the first local device in device_set_
600   // once we can trust the output of Device::IsLocal().
601   return DeviceNameUtils::ParsedName();
602 }
603 
ColocationGraph(const Graph * graph,const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)604 ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
605                                  const FunctionLibraryDefinition* flib_def,
606                                  const DeviceSet* device_set,
607                                  const Device* default_local_device,
608                                  bool allow_soft_placement,
609                                  bool log_device_placement)
610     : graph_(*graph),
611       stack_(stack),
612       inspecting_placer_(stack, flib_def, device_set, default_local_device,
613                          allow_soft_placement, log_device_placement),
614       inspection_required_checker_(graph, flib_def),
615       device_set_(*device_set),
616       device_types_(device_set->PrioritizedDeviceTypeList()),
617       local_address_spec_(
618           LocalAddressSpec(device_set->client_device(), default_local_device)),
619       default_local_device_(default_local_device),
620       allow_soft_placement_(allow_soft_placement),
621       log_device_placement_(log_device_placement) {
622   members_.resize(graph_.num_node_ids());
623 }
624 
625 // Adds each node of the Graph to this ColocationGraph as a singleton.
626 //
627 // NOTE: The implementation assumes that the ids of nodes passed to
628 // this method are dense and zero-based; the memory used will be linear in
629 // the largest node ID.
630 // NOTE: If this method returns an error, *this is left in an undefined
631 // state.
ColocateAllNodes()632 Status ColocationGraph::ColocateAllNodes() {
633   // This maps from a colocation group identifier to the 'root' of that
634   // colocation group.  Note that the keys in this map are StringPiece; the
635   // actual strings are stored under the NodeDef.  The lifetime of this map
636   // is limited to this ColocateAllNodes() method, and no part of the
637   // NodeDef trees are changed during the lifetime of this method, so using
638   // StringPiece as a key is safe.
639   //
640   // Also, as a further optimization, we remove the "loc:@" prefix from
641   // "class" attribute values, when they are used as keys in this table.
642   // This allows us to use StringPiece values that refer to substrings of
643   // 'string' values stored in NodeDef attribute lists, as well as StringPiece
644   // values that refer to 'string' values from NodeDef::name(), without
645   // performing any string allocations.
646   std::unordered_map<StringPiece, const Node*, StringPieceHasher>
647       colocation_group_root;
648 
649   for (const Node* node : graph_.op_nodes()) {
650     // When adding the node, identify whether it is part of a colocation
651     // group.
652 
653     // This code is effectively the equivalent of GetNodeAttr() for a string
654     // array, but it avoids all internal allocations (the allocation of the
655     // backing store of the std::vector<string> as well as the copies of the
656     // strings within it).  Instead, we combine the query of the colocation
657     // attribute with the calls to ColocateNodeToGroup.
658     const AttrValue* attr_value =
659         node->attrs().Find(kColocationAttrNameStringPiece);
660     if (attr_value != nullptr) {
661       if (attr_value->has_list()) {
662         for (const string& class_spec : attr_value->list().s()) {
663           StringPiece spec(class_spec);
664           if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
665             TF_RETURN_IF_ERROR(
666                 ColocateNodeToGroup(&colocation_group_root, node, spec));
667           }
668         }
669       } else if (!attr_value->s().empty()) {
670         LOG(ERROR) << "The value for colocation attribute '_class' must be a "
671                       "list of strings, not a single string: "
672                    << node->DebugString();
673       }
674     }
675 
676     // Each node belongs to a colocation group with the node's name.
677     TF_RETURN_IF_ERROR(
678         ColocateNodeToGroup(&colocation_group_root, node, node->name()));
679   }
680 
681   return Status::OK();
682 }
683 
ColocateResourceOrRefEdge(const Node * src,const Node * dst)684 Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
685                                                   const Node* dst) {
686   // Colocate `src` and `dst` to maintain the invariant that nodes
687   // connected by reference edges are colocated.
688   int src_root_id = FindAndUpdateRoot(src->id());
689   int dst_root_id = FindAndUpdateRoot(dst->id());
690   auto& src_root = members_[src_root_id];
691   auto& dst_root = members_[dst_root_id];
692 
693   if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) {
694     // If the src root is assigned to a composite device and the dst root is
695     // assigned to a physical device, don't colocate the dst root with the src
696     // root.
697     return Status::OK();
698   }
699   TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
700       *src, src_root, *dst, log_device_placement_));
701   Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
702   if (!status.ok()) {
703     return AttachDef(
704         errors::InvalidArgument(
705             "Nodes were connected by a reference or resource connection "
706             "(requiring them to be on the same device), but the two nodes "
707             "were assigned two different devices: ",
708             status.error_message()),
709         *dst);
710   }
711   return Status::OK();
712 }
713 
ColocateResourceAndRefEdges(std::unordered_set<Node * > * inspection_required)714 Status ColocationGraph::ColocateResourceAndRefEdges(
715     std::unordered_set<Node*>* inspection_required) {
716   // If `node` has an input edge with reference type, add an edge from the
717   // source of that edge to `node`.
718   for (const Edge* edge : graph_.edges()) {
719     if (edge->IsControlEdge()) {
720       continue;
721     }
722     Node* src = edge->src();
723     Node* dst = edge->dst();
724     bool needs_inspection;
725     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
726         *src, &needs_inspection));
727     if (needs_inspection) {
728       inspection_required->insert(src);
729       continue;
730     }
731     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
732         *dst, &needs_inspection));
733     if (needs_inspection) {
734       inspection_required->insert(dst);
735       continue;
736     }
737 
738     DataType input_type = dst->input_type(edge->dst_input());
739 
740     // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
741     // This is needed to get around the issue in b/135705778.
742     if (input_type == DT_VARIANT &&
743         data::DatasetOpKernel::IsDatasetOp(&src->op_def()) &&
744         data::DatasetOpKernel::IsDatasetOp(&dst->op_def())) {
745       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
746       continue;
747     }
748 
749     // Even though we can look inside function calling ops, we make an exception
750     // here mostly for performance reasons. Looking inside function calling ops
751     // is extra overhead. It is only necessary when they return resources. When
752     // they don't, we don't look inside them and make this exception here.
753     // Looking inside, could potentially enable us to make better placement
754     // decisions. It might be worth doing at some point.
755     if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
756         !IsExemptFromResourceInputColocation(dst)) {
757       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
758     }
759   }
760 
761   return Status::OK();
762 }
763 
764 namespace {
765 // Returns tensor list element data type, if the node is one of the ops that
766 // operate with TensorLists. Otherwise returns DT_INVALID.
GetElementDataType(const Node & node)767 DataType GetElementDataType(const Node& node) {
768   static absl::flat_hash_set<std::string>* tensor_list_ops =
769       new absl::flat_hash_set<std::string>(
770           {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
771            "TensorListSplit", "TensorListScatter", "TensorListScatterV2",
772            "TensorListScatterIntoExistingList", "TensorListPushBack",
773            "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
774            "TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
775            "TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
776 
777   if (tensor_list_ops->contains(node.type_string())) {
778     DataType element_type;
779     if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
780       return element_type;
781     }
782   }
783 
784   return DT_INVALID;
785 }
786 }  // namespace
787 
AddHostOnlyDataTypesConstraints()788 Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
789   auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
790 
791   auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
792     return entry.first == DEVICE_CPU;
793   };
794 
795   for (Node* node : graph_.nodes()) {
796     // Skip nodes that do not have DT_VARIANT inputs.
797     if (absl::c_none_of(node->input_types(), is_variant)) continue;
798 
799     // Skip nodes that can't be placed on GPU anyway.
800     Member& root = members_[FindAndUpdateRoot(node->id())];
801     if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue;
802 
803     // Stop DFS traversal when found the underlying data type of a variant.
804     absl::optional<bool> is_host_data_type;
805 
806     auto edge_filter = [&](const Edge& edge) -> bool {
807       // We already found the underlying data type.
808       if (is_host_data_type.has_value()) return false;
809 
810       // Otherwise follow only DT_VARIANT data edges.
811       auto edge_dtype = [&]() -> DataType {
812         return edge.src()->output_type(edge.src_output());
813       };
814       return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
815     };
816 
817     auto enter = [&](Node* n) -> void {
818       DataType element_type = GetElementDataType(*n);
819       // To handle nested lists continue traversal after finding a TensorList
820       // operation that uses DT_VARIANT for element type.
821       if (element_type == DT_INVALID || element_type == DT_VARIANT) return;
822       is_host_data_type = DataTypeAlwaysOnHost(element_type);
823     };
824 
825     ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
826                    /*stable_comparator=*/nullptr, edge_filter);
827 
828     if (is_host_data_type.has_value() && *is_host_data_type) {
829       VLOG(2) << "Limit node possible devices to CPU only, because it has a "
830                  "DT_VARIANT input with host-only underlying data type: "
831               << "node=" << node->name();
832 
833       // Restrict possible device types to CPU only.
834       PossibleDevices possible_devices;
835       absl::c_copy_if(root.supported_device_types(),
836                       std::back_inserter(possible_devices.device_types),
837                       is_cpu_device);
838 
839       TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
840           possible_devices, /*allow_soft_placement=*/false));
841     }
842   }
843 
844   return Status::OK();
845 }
846 
AddInspectionConstraints(const std::unordered_set<Node * > & inspection_required)847 Status ColocationGraph::AddInspectionConstraints(
848     const std::unordered_set<Node*>& inspection_required) {
849   for (Node* node : inspection_required) {
850     IOColocationGroups groups;
851     TF_RETURN_IF_ERROR(
852         inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
853     VLOG(2) << "Computed IOColocationGroups for node " << node->name()
854             << ":\n\t" << groups.DebugString();
855     TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
856   }
857   return Status::OK();
858 }
859 
Initialize()860 Status ColocationGraph::Initialize() {
861   TF_RETURN_IF_ERROR(InitializeMembers());
862 
863   std::unordered_set<Node*> inspection_required;
864   TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
865   TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
866   TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
867   TF_RETURN_IF_ERROR(ColocateAllNodes());
868 
869   for (Node* node : graph_.op_nodes()) {
870     int root_id = FindAndUpdateRoot(node->id());
871     members_[root_id].MaybeExcludeXlaDevices();
872   }
873 
874   return Status::OK();
875 }
876 
877 // pair containing a node and whether this node has a resource input
878 // from the node requiring placer inspection.
879 using NodeAndBool = std::pair<const Node*, bool>;
880 
881 namespace {
882 
883 // Returns a vector of node names from `nodes`.
NodeAndBoolToString(const std::vector<NodeAndBool> & nodes)884 std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
885   std::vector<string> v;
886   v.reserve(nodes.size());
887   for (const NodeAndBool& node_and_bool : nodes) {
888     v.push_back(node_and_bool.first->name());
889   }
890   return v;
891 }
892 
893 // Given a node requiring placer inspection and its IOColocationGroups,
894 // computes `group_nodes`.
895 // group_nodes[i] contains the nodes that are members of colocation
896 // group i. These nodes are inputs or outputs of `node`.
897 // group_nodes[i][j] is a pair containing a node and whether this node
898 // has a resource input from `node`.
899 // Note:
900 // The same node can be added multiple times to the same group.
901 // The same node can be added to multiple groups.
GetGroupNodes(const IOColocationGroups & groups,const Node & node,std::vector<std::vector<NodeAndBool>> * group_nodes)902 Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
903                      std::vector<std::vector<NodeAndBool>>* group_nodes) {
904   group_nodes->reserve(groups.group_devices.size());
905   for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
906     const Node* src;
907     TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
908     int group_id = groups.input_groups[arg_idx];
909     (*group_nodes)[group_id].emplace_back(src, false);
910   }
911 
912   for (const Edge* edge : node.out_edges()) {
913     if (edge->IsControlEdge()) {
914       continue;
915     }
916 
917     int group_id = groups.output_groups[edge->src_output()];
918     (*group_nodes)[group_id].emplace_back(
919         edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
920   }
921 
922   if (VLOG_IS_ON(2)) {
923     VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
924     for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
925       VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
926               << "]";
927     }
928   }
929   return Status::OK();
930 }
931 
932 // Returns whether the device_type in `device_attributes` is supported.
IsSupportedDeviceType(const DeviceAttributes & device_attributes,const DeviceType & supported_type)933 bool IsSupportedDeviceType(const DeviceAttributes& device_attributes,
934                            const DeviceType& supported_type) {
935   if (DeviceType(device_attributes.device_type()) == supported_type) {
936     return true;
937   }
938   return IsCompositeDevice(device_attributes.device_type());
939 }
940 
941 }  // namespace
942 
ApplyIOColocationGroups(const IOColocationGroups & groups,const Node & node)943 Status ColocationGraph::ApplyIOColocationGroups(
944     const IOColocationGroups& groups, const Node& node) {
945   if (groups.input_groups.size() != node.num_inputs()) {
946     return errors::Internal(
947         "Cannot apply input/output device constraints to node ",
948         node.DebugString(), " because input_groups.size() (",
949         groups.input_groups.size(),
950         ") is different from number of inputs into the op node (",
951         node.num_inputs(), ")");
952   }
953   if (groups.output_groups.size() != node.num_outputs()) {
954     return errors::Internal(
955         "Cannot apply input/output device constraints to node ",
956         node.DebugString(), " because output_groups.size() (",
957         groups.output_groups.size(),
958         ") is different from number of outputs into the op node (",
959         node.num_outputs(), ")");
960   }
961 
962   // group_nodes[i] contains the nodes that are members of colocation
963   // group i. These nodes are inputs or outputs of `node`.
964   // group_nodes[i][j] is a pair containing the node and whether this node
965   // has a resource input from `node`.
966   // The same node can be added multiple times to the same group.
967   // The same node can be added to multiple groups.
968   // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
969   std::vector<std::vector<NodeAndBool>> group_nodes(
970       groups.group_devices.size());
971   TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
972 
973   // Colocate nodes in each group
974   for (const std::vector<NodeAndBool>& nodes : group_nodes) {
975     for (int i = 1; i < nodes.size(); ++i) {
976       VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
977               << nodes[i].first->name() << "\"";
978       if (nodes[i].second) {
979         TF_RETURN_IF_ERROR(
980             ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
981       } else {
982         TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
983       }
984     }
985   }
986 
987   // Limit devices in each group
988   for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
989     // Nothing to do for empty groups. Groups can be empty if some output
990     // of an op is not used.
991     if (group_nodes[group_id].empty()) {
992       continue;
993     }
994     const Node* group_node = group_nodes[group_id][0].first;
995     const PossibleDevices& possible_devices = groups.group_devices[group_id];
996     TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
997   }
998 
999   return Status::OK();
1000 }
1001 
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)1002 Status ColocationGraph::ColocateNodeToGroup(
1003     std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
1004         colocation_group_root,
1005     const Node* node, StringPiece colocation_group) {
1006   const Node*& root_node = (*colocation_group_root)[colocation_group];
1007   if (root_node == nullptr) {
1008     // This is the first node of the colocation group, so
1009     // designate this node as the 'root' of that colocation group.
1010     root_node = node;
1011   } else {
1012     // Try to colocate the node with the root.  If there is an
1013     // error, return it.
1014     Status s = ColocateNodes(*node, *root_node);
1015     if (!s.ok()) {
1016       if (!allow_soft_placement_) {
1017         return AttachDef(s, *node);
1018       }
1019       if (log_device_placement_) {
1020         LOG(INFO) << "Ignoring request to colocate node '" << node->name()
1021                   << "' with nodes in colocation group '" << colocation_group
1022                   << "' because soft placement is on and an attempt at doing "
1023                      "so resulted in the following error: "
1024                   << AttachDef(s, *node).ToString();
1025       }
1026     }
1027   }
1028   return Status::OK();
1029 }
1030 
1031 // Merge the (possibly disjoint) sets containing nodes "x" and
1032 // "y". Returns OK if the all nodes in the union of these sets can
1033 // be placed on the same device type.
1034 //
1035 // NOTE: If this method returns an error, *this is left in an undefined
1036 // state.
ColocateNodes(const Node & x,const Node & y)1037 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
1038   int x_root = FindAndUpdateRoot(x.id());
1039   int y_root = FindAndUpdateRoot(y.id());
1040   return ColocateNodes(x, x_root, y, y_root);
1041 }
1042 
1043 // This overload of ColocateNodes() allows a caller to provide the root node
1044 // ids for the two nodes. For large graphs, this noticeably reduces the
1045 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)1046 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
1047                                       int y_root) {
1048   if (x_root == y_root) {
1049     return Status::OK();
1050   }
1051 
1052   Member* new_root_member;
1053   Member* old_root_member;
1054   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1055                 /*dry_run=*/true);
1056 
1057   // Merge the partial device specifications, and ensure that they are
1058   // compatible. NULL options_ is treated as allowing soft placement.
1059   // If there is an error, nothing is modified.
1060   // TODO(mrry): Consider enriching the error message by pointing
1061   // out which nodes have the explicit partial device
1062   // specifications that caused this conflict.
1063   Status s = new_root_member->MergeDeviceNames(*old_root_member,
1064                                                allow_soft_placement_);
1065   if (!s.ok()) {
1066     return errors::InvalidArgument(
1067         "Cannot colocate nodes ",
1068         errors::FormatColocationNodeForError(x.name()), " and ",
1069         errors::FormatColocationNodeForError(y.name()), ": ",
1070         s.error_message());
1071   }
1072 
1073   // Ensure that the common root has at least one supported device
1074   // type, by computing the intersection of
1075   // new_root_member.supported_device_types and
1076   // old_root_member.supported_device_types.
1077   if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
1078     return errors::InvalidArgument(
1079         "Cannot colocate nodes ",
1080         errors::FormatColocationNodeForError(x.name()), " and ",
1081         errors::FormatColocationNodeForError(y.name()),
1082         " because no device type supports both of those nodes and the "
1083         "other nodes colocated with them.",
1084         DebugInfo(x_root), DebugInfo(y_root));
1085   }
1086 
1087   // All error checks are done, merge the colocation graphs.
1088   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1089                 /*dry_run=*/false);
1090   return Status::OK();
1091 }
1092 
LimitToAssignedDevice(const Node & node)1093 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
1094   if (node.assigned_device_name_index() < 0) {
1095     return errors::Internal(
1096         "Expected an assigned node as argument to LimitToAssignedDevice but "
1097         "got: ",
1098         node.DebugString());
1099   }
1100   int root = FindAndUpdateRoot(node.id());
1101   Member& root_member = members_[root];
1102   return root_member.AssignDevice(node);
1103 }
1104 
GetSoftDeviceCandidates(const Node & node,const Member & root_member,int root_id,std::vector<Device * > * possible_devices)1105 void ColocationGraph::GetSoftDeviceCandidates(
1106     const Node& node, const Member& root_member, int root_id,
1107     std::vector<Device*>* possible_devices) {
1108   // Try to find supported devices that don't violate resource devices.
1109   // The soft_device_name is the same as the requested device name
1110   // without specifying the device type or ID (if assigned and requested
1111   // devices does not specify them).
1112   DeviceNameUtils::ParsedName soft_device_name =
1113       root_member.GetPreferredSoftDeviceName();
1114   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1115   if (!possible_devices->empty()) {
1116     *possible_devices = FilterSupportedDevices(
1117         *possible_devices, root_member.supported_device_types(),
1118         default_local_device_);
1119   }
1120 
1121   if (!possible_devices->empty()) {
1122     return;
1123   }
1124 
1125   // TODO(iga): Disallow changing resource devices when this ColocationGraph
1126   // is for :
1127   // - a function called by an op requiring deep inspection, or
1128   // - a graph containing ops requiring inspection.
1129   // It is fairly tricky to make changing resource devices in presence of
1130   // ops requiring inspection work correctly. One thing it would require is to
1131   // communicate these "resource movement" decisions across Placer instances.
1132 
1133   // Failed to find supported devices that don't violate resource devices.
1134   // Try finding some devices that violated resource devices.
1135   // If we succeed, we will log a warning below.
1136   soft_device_name = root_member.GetSoftDeviceName();
1137   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1138   if (!possible_devices->empty()) {
1139     *possible_devices = FilterSupportedDevices(
1140         *possible_devices, root_member.supported_device_types(),
1141         default_local_device_);
1142   }
1143 
1144   if (!possible_devices->empty()) {
1145     LOG(WARNING)
1146         << "Failed to place the graph without changing the devices of some "
1147            "resources. Some of the operations (that had to be colocated with "
1148            "resource generating operations) are not supported on the "
1149            "resources' devices. Current candidate devices are [\n  "
1150         << absl::StrJoin(DevicesToString(*possible_devices), "\n  ")
1151         << "].\nSee below for details of this colocation group:"
1152         << DebugInfo(root_id);
1153   }
1154 }
1155 
LimitToPossibleDevices(const Node & node,const PossibleDevices & devices)1156 Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1157                                                const PossibleDevices& devices) {
1158   int root = FindAndUpdateRoot(node.id());
1159   Member& root_member = members_[root];
1160   return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1161 }
1162 
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)1163 Status ColocationGraph::GetDevicesForNode(
1164     Node* node, const std::vector<Device*>** possible_devices) {
1165   *possible_devices = nullptr;
1166   const int node_root = FindAndUpdateRoot(node->id());
1167   if (!members_[node_root].possible_devices().empty()) {
1168     *possible_devices = &members_[node_root].possible_devices();
1169     return Status::OK();
1170   }
1171 
1172   Member& root_member = members_[node_root];
1173 
1174   // We have not yet computed the possible devices for the
1175   // colocated node set containing 'node', so we do so now using the
1176   // constraints on the root node.
1177 
1178   // "devices" will contain the set of feasible placements for the
1179   // colocated node set containing 'node'.
1180   // NOTE: Basing possible device computation on requested device name
1181   // is guaranteed to respect the assigned and resource device names because
1182   // requested device is always a specialization of both.
1183   std::vector<Device*> devices;
1184   if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1185     // The root node has a (possibly partial) device
1186     // specification, so enumerate the physical devices that
1187     // conform to it.
1188     device_set_.FindMatchingDevices(root_member.requested_device_name(),
1189                                     &devices);
1190 
1191     if (!devices.empty()) {
1192       // Filter devices into those that are compatible with the root
1193       // node (and its children).
1194       devices = FilterSupportedDevices(
1195           devices, root_member.supported_device_types(), default_local_device_);
1196     }
1197 
1198     // Perform soft placement if allow_soft_placement_ is set.
1199     if (devices.empty() && allow_soft_placement_) {
1200       GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1201     }
1202 
1203     if (devices.empty()) {
1204       // Return an error when a physical device that matches an explicit
1205       // device specification is not found. This ensures that we don't
1206       // assign a node to GPU when the user wanted to force it on CPU.
1207       string debug_info = DebugInfo(node_root);
1208 
1209       DeviceNameUtils::ParsedName specified_device_name;
1210       if (DeviceNameUtils::ParseFullName(node->requested_device(),
1211                                          &specified_device_name) &&
1212           specified_device_name == root_member.requested_device_name()) {
1213         // The specified device and merged set device match, and
1214         // will appear in the GraphDef (for debugging), so just
1215         // print the specified device.
1216         std::vector<Device*> devices_matching_nodedef;
1217         device_set_.FindMatchingDevices(specified_device_name,
1218                                         &devices_matching_nodedef);
1219         if (devices_matching_nodedef.empty()) {
1220           // Sometimes it is almost impossible to understand the problem
1221           // without a list of available devices.
1222           std::vector<string> device_names;
1223           for (const Device* device : device_set_.devices()) {
1224             device_names.push_back(device->name());
1225           }
1226           std::sort(device_names.begin(), device_names.end());
1227 
1228           string gpu_msg = "";
1229           if (!IsGoogleCudaEnabled() &&
1230               absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1231             gpu_msg =
1232                 " The requested device appears to be a GPU, but CUDA is not "
1233                 "enabled.";
1234           }
1235 
1236           return errors::InvalidArgument(
1237               errors::FormatNodeNameForError(node->name()),
1238               " was explicitly assigned to ", node->requested_device(),
1239               " but available devices are [ ",
1240               absl::StrJoin(device_names, ", "), " ]. Make sure ",
1241               "the device specification refers to a valid device.", gpu_msg);
1242         } else if (specified_device_name.has_type) {
1243           return errors::InvalidArgument(
1244               "Could not satisfy explicit device specification '",
1245               node->requested_device(), "' because no supported kernel for ",
1246               specified_device_name.type, " devices is available.", debug_info,
1247               "\nOp: ", node->type_string(),
1248               "\nNode attrs: ", node->attrs().DebugString(),
1249               "\nRegistered kernels:\n",
1250               KernelsRegisteredForOp(node->type_string()));
1251         } else {
1252           return errors::InvalidArgument(
1253               "Could not satisfy explicit device specification '",
1254               node->requested_device(), debug_info);
1255         }
1256       } else {
1257         // The specified device may be a valid device but the
1258         // merged set device is different, so print both.
1259         // TODO(b/129057603): There are many possibilities at this point.
1260         // Provide good error messages.
1261         return errors::InvalidArgument(
1262             "Could not satisfy explicit device specification '",
1263             node->requested_device(), "' because the node ",
1264             errors::FormatColocationNodeForError(node->name()),
1265             " was colocated with a group of nodes that ",
1266             "required incompatible device '",
1267             DeviceNameUtils::ParsedNameToString(
1268                 root_member.requested_device_name()),
1269             "'. All available devices [",
1270             absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1271             debug_info);
1272       }
1273     }
1274   } else {
1275     // The device is completely unspecified, so enumerate the devices that
1276     // support all of the nodes in the set.
1277     if (device_set_.devices().empty()) {
1278       return errors::Internal("No devices are registered");
1279     }
1280     devices = FilterSupportedDevices(device_set_.devices(),
1281                                      root_member.supported_device_types(),
1282                                      default_local_device_);
1283 
1284     if (devices.empty()) {
1285       return errors::InvalidArgument(
1286           "Node had no OpKernel registered to support this operation: ",
1287           "Operation was ", node->type_string(), " and inputs were [",
1288           DataTypeVectorString(node->input_types()), "].\n",
1289           DebugInfo(node_root));
1290     }
1291   }
1292 
1293   // Cache the result of the possible devices for this node group.
1294   root_member.set_possible_devices(std::move(devices));
1295   *possible_devices = &root_member.possible_devices();
1296   return Status::OK();
1297 }
1298 
InitializeMembers()1299 Status ColocationGraph::InitializeMembers() {
1300   for (Node* node : graph_.op_nodes()) {
1301     Status status = InitializeMember(*node, &members_[node->id()]);
1302     if (!status.ok()) {
1303       return AttachDef(status, *node);
1304     }
1305   }
1306   return Status::OK();
1307 }
1308 
DebugString() const1309 string ColocationGraph::DebugString() const {
1310   std::unordered_set<int> roots;
1311   std::vector<string> root_strings;
1312   for (const Node* node : graph_.nodes()) {
1313     if (!node->IsOp()) {
1314       continue;
1315     }
1316     int node_root = FindRoot(node->id());
1317     if (roots.count(node_root) == 0) {
1318       root_strings.push_back(DebugInfo(node_root));
1319       roots.insert(node_root);
1320     }
1321   }
1322   return absl::StrJoin(root_strings, "\n");
1323 }
1324 
1325 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root) const1326 string ColocationGraph::DebugInfo(const int node_root) const {
1327   string text(
1328       "\nColocation Debug Info:\n"
1329       "Colocation group had the following types and supported devices: ");
1330 
1331   // If this node is part of a colocation group, then we want to
1332   // collect the mapping of ops to supported devices, so that
1333   // the user can see why an unsatisfiable placement occurred.
1334 
1335   std::unordered_map<string, string> type_to_devices;
1336   std::vector<const Node*> colocation_nodes;
1337   int num_nodes_found = 0;
1338 
1339   for (const Node* node : graph_.nodes()) {
1340     if (!node->IsOp()) {
1341       continue;
1342     }
1343     int id = node->id();
1344     if (FindRoot(id) != node_root) {
1345       continue;
1346     }
1347     ++num_nodes_found;
1348     colocation_nodes.push_back(node);
1349 
1350     PrioritizedDeviceTypeVector supported_types;
1351     SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1352                                 &local_address_spec_)
1353         .IgnoreError();
1354     string devices_registered;
1355     for (const auto& device_type : supported_types) {
1356       strings::StrAppend(&devices_registered,
1357                          DeviceTypeString(device_type.first), " ");
1358     }
1359 
1360     const string& op_type = node->type_string();
1361     type_to_devices[op_type] = std::move(devices_registered);
1362   }
1363   strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1364 
1365   for (const auto& td : type_to_devices) {
1366     strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1367   }
1368   strings::StrAppend(&text,
1369                      "\n\nColocation members, user-requested devices, and "
1370                      "framework assigned devices, if any:");
1371   for (const Node* node : colocation_nodes) {
1372     strings::StrAppend(&text, "\n  ", node->name(), " (", node->type_string(),
1373                        ") ", node->requested_device());
1374     if (node->has_assigned_device_name()) {
1375       strings::StrAppend(
1376           &text, " framework assigned device=", node->assigned_device_name());
1377     }
1378   }
1379   strings::StrAppend(&text, "\n");
1380 
1381   if (num_nodes_found <= 0) {
1382     text.clear();
1383   }
1384   return text;
1385 }
1386 
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,Member * member)1387 Status ColocationGraph::InitializeMemberWithAssignedDevice(
1388     const string& assigned_device_name, const string& node_type,
1389     Member* member) {
1390   // This node has already been assigned to a device, so we
1391   // respect this placement, after sanity-checking it.
1392   // NOTE: Since any assignment must have been performed by
1393   // the TensorFlow runtime, we consider errors in this branch to
1394   // be INTERNAL.
1395   TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1396 
1397   // Since assigned device must be a full specification, do extra checks.
1398   const Device* assigned_device =
1399       device_set_.FindDeviceByName(assigned_device_name);
1400   if (assigned_device == nullptr) {
1401     // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1402     // calls when they are supported.
1403     return errors::Internal(
1404         "Assigned device '", assigned_device_name,
1405         "' does not match any device. This error can happen when one attempts "
1406         "to run a tf.function with resource inputs residing on remote devices. "
1407         "This use case is currently not supported. Here are the devices "
1408         "available on this machine: [",
1409         absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1410         "If you are seeing this error when running using a tf.Session, set "
1411         "share_cluster_devices_in_session to true in the tf.ConfigProto.");
1412   }
1413 
1414   for (const auto& d : member->supported_device_types()) {
1415     if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) {
1416       return Status::OK();
1417     }
1418   }
1419 
1420   return errors::Internal("Assigned device '", assigned_device_name,
1421                           "' does not have registered OpKernel support "
1422                           "for ",
1423                           node_type);
1424 }
1425 
InitializeMember(const Node & node,Member * member)1426 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1427   TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1428       node, device_types_, &local_address_spec_));
1429 
1430   if (node.has_assigned_device_name()) {
1431     TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1432         node.assigned_device_name(), node.type_string(), member));
1433   } else {
1434     // This node has not yet been assigned to a device, so we
1435     // calculate any constraints due to the set of registered
1436     // kernels and any (partial) user-provided device specification
1437     // in the NodeDef.
1438 
1439     // If no kernels are registered for this op type, fail with an error.
1440     if (member->supported_device_types().empty()) {
1441       std::set<string> registered_device_types;
1442       for (Device* d : device_set_.devices()) {
1443         registered_device_types.insert(d->device_type());
1444       }
1445       return errors::InvalidArgument(
1446           "No OpKernel was registered to support Op '", node.type_string(),
1447           "' used by ", errors::FormatNodeNameForError(node.name()),
1448           " with these attrs: [", node.attrs().DebugString(),
1449           "]\n"
1450           "Registered devices: [",
1451           absl::StrJoin(registered_device_types, ", "), "]\n",
1452           "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1453     }
1454 
1455     // If the NodeDef contains a device, then we interpret it as a
1456     // (partial) device specification.
1457     if (!node.requested_device().empty()) {
1458       if (IsRefOrResourceGeneratorNode(node)) {
1459         // Treat requested device on resource generating nodes as assigned
1460         // device so that we don't override it.
1461         TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1462       } else {
1463         // The user has specified a device in the NodeDef, try to find a
1464         // valid device matching their specification in the set of
1465         // devices.
1466         // NOTE: The full name may specify a device that is not in
1467         // n.supported_device_types(), but we check that in AssignDevice().
1468         TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1469       }
1470     }
1471   }
1472   return Status::OK();
1473 }
1474 
1475 // Returns a list of devices having type in supported_device_types.  The
1476 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_local_device)1477 /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1478     const std::vector<Device*>& devices,
1479     const PrioritizedDeviceTypeVector& supported_device_types,
1480     const Device* default_local_device) {
1481   Device* filtered_default_device = nullptr;
1482   PrioritizedDeviceVector prioritized_filtered_devices;
1483   for (const auto& supported_device_type : supported_device_types) {
1484     for (Device* device : devices) {
1485       if (IsSupportedDeviceType(device->attributes(),
1486                                 supported_device_type.first)) {
1487         if (default_local_device &&
1488             (device == default_local_device ||
1489              // TODO(nareshmodi, fishx): At times the device pointer in the
1490              // device set is different to the one passed in as the default
1491              // device. Figure out why this might be.
1492              device->name() == default_local_device->name())) {
1493           filtered_default_device = device;
1494         } else {
1495           prioritized_filtered_devices.emplace_back(
1496               device, supported_device_type.second);
1497         }
1498       }
1499     }
1500   }
1501   DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1502 
1503   std::vector<Device*> filtered_devices;
1504   if (filtered_default_device != nullptr) {
1505     filtered_devices.emplace_back(filtered_default_device);
1506   }
1507   for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1508     filtered_devices.push_back(prioritized_filtered_device.first);
1509   }
1510   return filtered_devices;
1511 }
1512 
1513 }  // namespace tensorflow
1514