• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17 
18 #include <memory>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_set.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
31 #include "tensorflow/core/common_runtime/inspecting_placer.h"
32 #include "tensorflow/core/common_runtime/partitioning_utils.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/attr_value_util.h"
35 #include "tensorflow/core/framework/dataset.h"
36 #include "tensorflow/core/framework/device_attributes.pb.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/graph/graph_node_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/strings/str_util.h"
46 #include "tensorflow/core/lib/strings/strcat.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48 #include "tensorflow/core/util/dump_graph.h"
49 #include "tensorflow/core/util/port.h"
50 
51 namespace tensorflow {
52 
53 namespace {
54 
55 // We hoist the conversion from C-style string literal to StringPiece here,
56 // so that we can avoid the many repeated calls to strlen().
57 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
58 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
59 
60 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)61 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
62   std::vector<string> v;
63   v.reserve(devices.size());
64   for (Device* d : devices) {
65     v.push_back(d->name());
66   }
67   return v;
68 }
69 
70 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)71 std::vector<string> DeviceTypeAndPriorityToString(
72     const PrioritizedDeviceTypeVector& devices) {
73   std::vector<string> v;
74   v.reserve(devices.size());
75   for (const std::pair<DeviceType, int32>& device_and_type : devices) {
76     v.push_back(DeviceTypeString(device_and_type.first));
77   }
78   return v;
79 }
80 
IsRefOrResource(DataType data_type)81 bool IsRefOrResource(DataType data_type) {
82   return IsRefType(data_type) || data_type == DT_RESOURCE;
83 }
84 
85 // While Placer can override requested device on ops processing
86 // resources, i.e. node that take (and potentially return) a resource,
87 // it must not override requested device on ops generating a resource,
88 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
89 // output nodes.
IsRefOrResourceGeneratorNode(const Node & node)90 bool IsRefOrResourceGeneratorNode(const Node& node) {
91   return node.num_inputs() == 0 && node.num_outputs() == 1 &&
92          IsRefOrResource(node.output_type(0));
93 }
94 
IsExemptFromResourceInputColocation(const Node * node)95 bool IsExemptFromResourceInputColocation(const Node* node) {
96   // Note: Partitioned function calls, which place and partition their
97   // function bodies, are exempt from this check: they forward resource and
98   // ref inputs to operations that are appropriately placed, instead of
99   // dereferencing them.
100   const string& op_type = node->op_def().name();
101   auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
102   return exempt_ops.find(op_type) != exempt_ops.end();
103 }
104 
HasPriorities(const PrioritizedDeviceTypeVector & device_types)105 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
106   for (const auto& prioritized_device_type : device_types) {
107     if (prioritized_device_type.second != 0) return true;
108   }
109   return false;
110 }
111 
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)112 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
113                        const PrioritizedDeviceTypeVector& b_types) {
114   if (a_types.size() != b_types.size()) {
115     return false;
116   }
117   for (int i = 0; i < a_types.size(); ++i) {
118     if (a_types[i].first != b_types[i].first) {
119       return false;
120     }
121   }
122   return true;
123 }
124 
IsXlaDevice(absl::string_view device_type)125 bool IsXlaDevice(absl::string_view device_type) {
126   if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
127       device_type == "XLA_TPU_JIT") {
128     // Symbolic XLA device.
129     return true;
130   }
131 
132   return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
133           device_type == "TPU");
134 }
135 
136 }  // namespace
137 
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types,const DeviceNameUtils::ParsedName * local_address_spec)138 Status Member::SetParentAndSupportedDevices(
139     const Node& node, const std::vector<DeviceType>& types,
140     const DeviceNameUtils::ParsedName* local_address_spec) {
141   int id = node.id();
142   if (id < 0) {
143     return errors::Internal("Placer should not be creating a Member for node: ",
144                             node.DebugString());
145   }
146   parent_ = id;
147   return SupportedDeviceTypesForNode(
148       types, node.def(), &supported_device_types_, local_address_spec);
149 }
150 
SetAssignedDeviceName(const string & device_name)151 Status Member::SetAssignedDeviceName(const string& device_name) {
152   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
153     return errors::Internal(
154         "Setting assigned device name when there is a requested device set "
155         "is unsupported");
156   }
157   if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
158     return errors::Internal("Malformed assigned device '", device_name, "'");
159   }
160   // Set requested device to assigned_device to maintain the invariant that
161   // requested is a specialization of assigned.
162   requested_device_name_ = assigned_device_name_;
163   return Status::OK();
164 }
165 
SetResourceDeviceName(const Node & node)166 Status Member::SetResourceDeviceName(const Node& node) {
167   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
168     return errors::Internal(
169         "Setting resource device name when there is a requested device set "
170         "is unsupported");
171   }
172 
173   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
174                                       &resource_device_name_)) {
175     return errors::InvalidArgument("Malformed device specification '",
176                                    node.requested_device(),
177                                    "' in node: ", node.DebugString());
178   }
179 
180   // Set requested device to resource device to maintain the invariant that
181   // requested is a specialization of resource.
182   requested_device_name_ = resource_device_name_;
183   return Status::OK();
184 }
185 
SetRequestedDeviceName(const Node & node)186 Status Member::SetRequestedDeviceName(const Node& node) {
187   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
188     return errors::Internal(
189         "Setting requested device name when there is an assigned device set "
190         "is unsupported");
191   }
192   if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
193     return errors::Internal(
194         "Setting requested device name when there is a resource device set "
195         "is unsupported");
196   }
197   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
198                                       &requested_device_name_)) {
199     return errors::InvalidArgument("Malformed device specification '",
200                                    node.requested_device(),
201                                    "' in node: ", node.DebugString());
202   }
203   return Status::OK();
204 }
205 
FillPossibleDevices(PossibleDevices * possible_device) const206 Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
207   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
208     return errors::Internal(
209         "Cannot fill PossibleDevices from a member that has non-empty assigned "
210         "device. Did we start assigning devices to functions called by deep "
211         "ops? ",
212         DebugString());
213   }
214   possible_device->requested_device_name = requested_device_name_;
215   possible_device->resource_device_name = resource_device_name_;
216   possible_device->device_types = supported_device_types_;
217   return Status::OK();
218 }
219 
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)220 Status Member::EnsureCompatibilityAcrossResourceEdge(
221     const Node& src, const Member& src_root,
222     const Node& dst, /*dst_root is this*/
223     bool log_device_placement) {
224   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
225                                               assigned_device_name_)) {
226     return errors::InvalidArgument(
227         "Cannot place the graph because a reference or resource edge "
228         "connects colocation groups with incompatible assigned devices: ",
229         DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
230         " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
231         ". The edge src node is ", src.name(), " , and the dst node is ",
232         dst.name());
233   }
234 
235   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
236                                               resource_device_name_)) {
237     return errors::InvalidArgument(
238         "Cannot place the graph because a reference or resource edge "
239         "connects colocation groups with incompatible resource devices: ",
240         DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
241         " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
242         ". The edge src node is ", src.name(), " , and the dst node is ",
243         dst.name());
244   }
245 
246   if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
247                                              requested_device_name_)) {
248     return Status::OK();
249   }
250 
251   // If we are here, assigned and resource devices are compatible but requested
252   // ones are not. We will be overriding the requested device for destination
253   // node, but need to preserve the invariant that it will be a specialization
254   // of the assigned and resource devices.
255   if (log_device_placement) {
256     LOG(INFO) << "Ignoring device specification "
257               << DeviceNameUtils::ParsedNameToString(requested_device_name_)
258               << " for node '" << dst.name()
259               << "' because the input edge from '" << src.name()
260               << "' is a reference connection and already has a device "
261                  "field set to "
262               << DeviceNameUtils::ParsedNameToString(
263                      src_root.requested_device_name_);
264   }
265   requested_device_name_ = src_root.requested_device_name_;
266   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
267                                        assigned_device_name_);
268   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
269                                        resource_device_name_);
270   return Status::OK();
271 }
272 
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)273 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
274                    Member** new_root, Member** old_root, bool dry_run) {
275   Member& x_root_member = (*tree)[x_root];
276   Member& y_root_member = (*tree)[y_root];
277 
278   // Merge the sets by setting the parent pointer of the smaller tree's root
279   // node to point to the root of the larger tree. Together with path
280   // compression in ColocationGraph::FindRoot, this ensures that we do not
281   // experience pathological performance on graphs such as chains.
282   int new_root_id, old_root_id;
283   if (x_root_member.rank_ < y_root_member.rank_) {
284     // The tree rooted at x_root is shallower, so connect it to
285     // y_root. The rank of y_root is unchanged because its new
286     // child has strictly less rank.
287     if (!dry_run) {
288       x_root_member.parent_ = y_root;
289     }
290     new_root_id = y_root;
291     old_root_id = x_root;
292   } else if (x_root_member.rank_ > y_root_member.rank_) {
293     // The tree rooted at y_root is shallower, so connect it to
294     // x_root. The rank of x_root is unchanged because its new
295     // child has strictly less rank.
296     if (!dry_run) {
297       y_root_member.parent_ = x_root;
298     }
299     new_root_id = x_root;
300     old_root_id = y_root;
301   } else {
302     if (!dry_run) {
303       // Both trees have the same rank, so break the tie by choosing
304       // x_root as the new root.
305       y_root_member.parent_ = x_root;
306       // Increment the rank of the tree rooted at x_root, because it
307       // is now strictly deeper than before.
308       ++x_root_member.rank_;
309     }
310     new_root_id = x_root;
311     old_root_id = y_root;
312   }
313 
314   *new_root = &(*tree)[new_root_id];
315   *old_root = &(*tree)[old_root_id];
316 }
317 
318 // tree is non-const because we can change some `parent` pointers in some
319 // members for more efficient future lookups. The vector itself is not
320 // changed.
FindAndUpdateRoot(std::vector<Member> * tree,int node_id)321 int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
322   Member& member = (*tree)[node_id];
323   if (member.parent_ == node_id) {
324     // member.parent is the root of this disjoint tree.  Do nothing.
325   } else {
326     member.parent_ = FindAndUpdateRoot(tree, member.parent_);
327   }
328   // Now it is guaranteed that member.parent is the root of this disjoint
329   // tree.
330   return member.parent_;
331 }
332 
FindRoot(const std::vector<Member> & tree,int node_id)333 int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
334   const Member& member = tree[node_id];
335   if (member.parent_ == node_id) {
336     return member.parent_;
337   }
338   return FindRoot(tree, member.parent_);
339 }
340 
MergeDeviceNames(const Member & other,bool allow_soft_placement)341 Status Member::MergeDeviceNames(const Member& other,
342                                 bool allow_soft_placement) {
343   // Assuming the "requested is a specialization of assigned and resource
344   // devices" invariant holds for this and `other`, it will hold after the
345   // merges below.
346   DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
347   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
348       &assigned_device_name_copy, other.assigned_device_name_));
349 
350   DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
351   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
352       &resource_device_name_copy, other.resource_device_name_));
353 
354   DeviceNameUtils::ParsedName requested_device_name_copy =
355       requested_device_name_;
356   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
357       &requested_device_name_copy, other.requested_device_name_,
358       allow_soft_placement));
359 
360   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
361                                        assigned_device_name_copy);
362   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
363                                        resource_device_name_copy);
364 
365   // We checked for all errors, now change the devices.
366   assigned_device_name_ = assigned_device_name_copy;
367   resource_device_name_ = resource_device_name_copy;
368   requested_device_name_ = requested_device_name_copy;
369   return Status::OK();
370 }
371 
372 // Updates this to contain the intersection of the device types in
373 // this and "other".
MergeSupportedDevices(const Member & other)374 bool Member::MergeSupportedDevices(const Member& other) {
375   return MergeSupportedDevices(other.supported_device_types_);
376 }
377 
MergeSupportedDevices(const PrioritizedDeviceTypeVector & other_devices)378 bool Member::MergeSupportedDevices(
379     const PrioritizedDeviceTypeVector& other_devices) {
380   // Generate intersection with priorities.
381   // Each vector contains the same device types but with different priorities.
382   // The priorities are taken from the corresponding source vector.
383   PrioritizedDeviceTypeVector target_intersection;
384   PrioritizedDeviceTypeVector other_intersection;
385 
386   for (const auto& prioritized_device_type : supported_device_types_) {
387     bool found = false;
388     for (const auto& other_prioritized_device_type : other_devices) {
389       if (prioritized_device_type.first ==
390           other_prioritized_device_type.first) {
391         found = true;
392         other_intersection.push_back(other_prioritized_device_type);
393         break;
394       }
395     }
396     if (found) {
397       target_intersection.push_back(prioritized_device_type);
398     }
399   }
400 
401   DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
402   DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
403 
404   PrioritizedDeviceTypeVector result;
405 
406   bool is_target_prioritized = HasPriorities(target_intersection);
407   bool is_other_prioritized = HasPriorities(other_intersection);
408   if (!is_target_prioritized && !is_other_prioritized) {
409     // If neither are prioritized then we just return the original i.e. target
410     // prioritization.
411     result = target_intersection;
412   } else if (is_target_prioritized && !is_other_prioritized) {
413     // If only one is prioritized, then we respect priorities of that in the
414     // intersection.
415     result = target_intersection;
416   } else if (!is_target_prioritized && is_other_prioritized) {
417     result = other_intersection;
418   } else {
419     // If both have priorities and agree then we go with that. If the
420     // prioritization order is different, then we just fallback to the default
421     // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
422     // merged priorities to 0, so that downstream merges work correctly as well.
423     if (ArePrioritiesSame(target_intersection, other_intersection)) {
424       result = target_intersection;
425     } else {
426       for (const auto& prioritized_device : target_intersection) {
427         result.push_back(std::make_pair(prioritized_device.first, 0));
428       }
429       DeviceSet::SortPrioritizedDeviceTypeVector(&result);
430     }
431   }
432 
433   if (result.empty()) {
434     return false;
435   }
436   supported_device_types_ = result;
437   return true;
438 }
439 
AssignDevice(const Node & node)440 Status Member::AssignDevice(const Node& node) {
441   if (node.assigned_device_name_index() == assigned_device_name_index_) {
442     return Status::OK();
443   }
444 
445   DeviceNameUtils::ParsedName parsed;
446   DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
447   Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
448   if (!s.ok()) {
449     return errors::Internal(
450         "Constraining by assigned device should not cause an error. Original "
451         "root's assigned device name: ",
452         DeviceNameUtils::ParsedNameToString(assigned_device_name_),
453         " node's assigned device name \"", node.assigned_device_name(),
454         ". Error: ", s.error_message());
455   }
456   s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
457   if (!s.ok()) {
458     return errors::Internal(
459         "Constraining by assigned device should not cause an error. Original "
460         "root's resource device name: ",
461         DeviceNameUtils::ParsedNameToString(resource_device_name_),
462         " node's assigned device name \"", node.assigned_device_name(),
463         ". Error: ", s.error_message());
464   }
465   s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
466   if (!s.ok()) {
467     return errors::Internal(
468         "Constraining by assigned device should not cause an error. Original "
469         "root's requested device name: \"",
470         DeviceNameUtils::ParsedNameToString(requested_device_name_),
471         "\", node's assigned device name \"", node.assigned_device_name(),
472         "\". Error: ", s.error_message());
473   }
474 
475   assigned_device_name_index_ = node.assigned_device_name_index();
476   // Clear cached possible_devices, if any.
477   possible_devices_.clear();
478   return Status::OK();
479 }
480 
MaybeExcludeXlaDevices()481 void Member::MaybeExcludeXlaDevices() {
482   for (const auto& parsed_name :
483        {requested_device_name_, assigned_device_name_, resource_device_name_}) {
484     if (parsed_name.has_type && IsXlaDevice(parsed_name.type)) {
485       return;
486     }
487   }
488 
489   PrioritizedDeviceTypeVector non_xla_types;
490   absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
491                   [&](const std::pair<DeviceType, int32>& entry) {
492                     return !IsXlaDevice(entry.first.type_string());
493                   });
494 
495   // TODO(b/141216278) Remove all XLA device types from the supported device
496   // types if the node has no requested/assigned/resource XLA device.
497   if (!non_xla_types.empty() &&
498       non_xla_types.size() < supported_device_types_.size()) {
499     supported_device_types_ = std::move(non_xla_types);
500   }
501 }
502 
LimitToPossibleDevices(const PossibleDevices & devices,bool allow_soft_placement)503 Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
504                                       bool allow_soft_placement) {
505   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
506       &requested_device_name_, devices.requested_device_name,
507       allow_soft_placement));
508   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
509       &resource_device_name_, devices.resource_device_name));
510   MergeSupportedDevices(devices.device_types);
511   return Status::OK();
512 }
513 
DebugString() const514 string Member::DebugString() const {
515   return absl::StrCat(
516       "Member(assigned_device_name_index_=", assigned_device_name_index_,
517       " requested_device_name_='",
518       DeviceNameUtils::ParsedNameToString(requested_device_name_),
519       "' assigned_device_name_='",
520       DeviceNameUtils::ParsedNameToString(assigned_device_name_),
521       "' resource_device_name_='",
522       DeviceNameUtils::ParsedNameToString(resource_device_name_),
523       "' supported_device_types_=[",
524       absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
525                     ", "),
526       "] possible_devices_=[",
527       absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
528 }
529 
GetSoftDeviceName() const530 DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
531   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
532   if (!assigned_device_name_.has_type) {
533     soft_device_name.type.clear();
534     soft_device_name.has_type = false;
535   }
536   if (!assigned_device_name_.has_id) {
537     soft_device_name.has_id = false;
538   }
539   return soft_device_name;
540 }
541 
GetPreferredSoftDeviceName() const542 DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
543   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
544   if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
545     soft_device_name.type.clear();
546     soft_device_name.has_type = false;
547   }
548   if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
549     soft_device_name.has_id = false;
550   }
551   return soft_device_name;
552 }
553 
554 // Returns ParsedName whose address space (i.e. job, replica, task) identifies
555 // the address space directly accessible by the local process. If the address
556 // space is fully specified and it is exactly the same as the address space
557 // of a device, then all kernels of that device should be registered in the
558 // local process.
LocalAddressSpec(const Device * client_device,const Device * default_local_device)559 static const DeviceNameUtils::ParsedName LocalAddressSpec(
560     const Device* client_device, const Device* default_local_device) {
561   if (client_device != nullptr) {
562     return DeviceNameUtils::AddressSpace(client_device->parsed_name());
563   }
564 
565   if (default_local_device != nullptr) {
566     return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
567   }
568 
569   // TODO(b/139617593) Return the name of the first local device in device_set_
570   // once we can trust the output of Device::IsLocal().
571   return DeviceNameUtils::ParsedName();
572 }
573 
ColocationGraph(const Graph * graph,const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)574 ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
575                                  const FunctionLibraryDefinition* flib_def,
576                                  const DeviceSet* device_set,
577                                  const Device* default_local_device,
578                                  bool allow_soft_placement,
579                                  bool log_device_placement)
580     : graph_(*graph),
581       stack_(stack),
582       flib_def_(*flib_def),
583       inspecting_placer_(stack, flib_def, device_set, default_local_device,
584                          allow_soft_placement, log_device_placement),
585       inspection_required_checker_(graph, flib_def),
586       device_set_(*device_set),
587       device_types_(device_set->PrioritizedDeviceTypeList()),
588       local_address_spec_(
589           LocalAddressSpec(device_set->client_device(), default_local_device)),
590       default_local_device_(default_local_device),
591       allow_soft_placement_(allow_soft_placement),
592       log_device_placement_(log_device_placement) {
593   members_.resize(graph_.num_node_ids());
594 }
595 
596 // Adds each node of the Graph to this ColocationGraph as a singleton.
597 //
598 // NOTE: The implementation assumes that the ids of nodes passed to
599 // this method are dense and zero-based; the memory used will be linear in
600 // the largest node ID.
601 // NOTE: If this method returns an error, *this is left in an undefined
602 // state.
ColocateAllNodes()603 Status ColocationGraph::ColocateAllNodes() {
604   // This maps from a colocation group identifier to the 'root' of that
605   // colocation group.  Note that the keys in this map are StringPiece; the
606   // actual strings are stored under the NodeDef.  The lifetime of this map
607   // is limited to this ColocateAllNodes() method, and no part of the
608   // NodeDef trees are changed during the lifetime of this method, so using
609   // StringPiece as a key is safe.
610   //
611   // Also, as a further optimization, we remove the "loc:@" prefix from
612   // "class" attribute values, when they are used as keys in this table.
613   // This allows us to use StringPiece values that refer to substrings of
614   // 'string' values stored in NodeDef attribute lists, as well as StringPiece
615   // values that refer to 'string' values from NodeDef::name(), without
616   // performing any string allocations.
617   std::unordered_map<StringPiece, const Node*, StringPieceHasher>
618       colocation_group_root;
619 
620   for (const Node* node : graph_.op_nodes()) {
621     // When adding the node, identify whether it is part of a colocation
622     // group.
623 
624     // This code is effectively the equivalent of GetNodeAttr() for a string
625     // array, but it avoids all internal allocations (the allocation of the
626     // backing store of the std::vector<string> as well as the copies of the
627     // strings within it).  Instead, we combine the query of the colocation
628     // attribute with the calls to ColocateNodeToGroup.
629     const AttrValue* attr_value =
630         node->attrs().Find(kColocationAttrNameStringPiece);
631     if (attr_value != nullptr) {
632       if (attr_value->has_list()) {
633         for (const string& class_spec : attr_value->list().s()) {
634           StringPiece spec(class_spec);
635           if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
636             TF_RETURN_IF_ERROR(
637                 ColocateNodeToGroup(&colocation_group_root, node, spec));
638           }
639         }
640       } else if (!attr_value->s().empty()) {
641         LOG(ERROR) << "The value for colocation attribute '_class' must be a "
642                       "list of strings, not a single string: "
643                    << node->DebugString();
644       }
645     }
646 
647     // Each node belongs to a colocation group with the node's name.
648     TF_RETURN_IF_ERROR(
649         ColocateNodeToGroup(&colocation_group_root, node, node->name()));
650   }
651 
652   return Status::OK();
653 }
654 
ColocateResourceOrRefEdge(const Node * src,const Node * dst)655 Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
656                                                   const Node* dst) {
657   // Colocate `src` and `dst` to maintain the invariant that nodes
658   // connected by reference edges are colocated.
659   int src_root_id = FindAndUpdateRoot(src->id());
660   int dst_root_id = FindAndUpdateRoot(dst->id());
661   auto& src_root = members_[src_root_id];
662   auto& dst_root = members_[dst_root_id];
663 
664   TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
665       *src, src_root, *dst, log_device_placement_));
666   Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
667   if (!status.ok()) {
668     return AttachDef(
669         errors::InvalidArgument("Nodes were connected by a "
670                                 "reference connection (requiring them to "
671                                 "be on the same device), but the two nodes "
672                                 "were assigned two different devices: ",
673                                 status.error_message()),
674         *dst);
675   }
676   return Status::OK();
677 }
678 
ColocateResourceAndRefEdges(std::unordered_set<Node * > * inspection_required)679 Status ColocationGraph::ColocateResourceAndRefEdges(
680     std::unordered_set<Node*>* inspection_required) {
681   // If `node` has an input edge with reference type, add an edge from the
682   // source of that edge to `node`.
683   for (const Edge* edge : graph_.edges()) {
684     if (edge->IsControlEdge()) {
685       continue;
686     }
687     Node* src = edge->src();
688     Node* dst = edge->dst();
689     bool needs_inspection;
690     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
691         *src, &needs_inspection));
692     if (needs_inspection) {
693       inspection_required->insert(src);
694       continue;
695     }
696     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
697         *dst, &needs_inspection));
698     if (needs_inspection) {
699       inspection_required->insert(dst);
700       continue;
701     }
702 
703     DataType input_type = dst->input_type(edge->dst_input());
704 
705     // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
706     // This is needed to get around the issue in b/135705778.
707     if (input_type == DT_VARIANT &&
708         data::DatasetOpKernel::IsDatasetOp(&src->op_def()) &&
709         data::DatasetOpKernel::IsDatasetOp(&dst->op_def())) {
710       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
711       continue;
712     }
713 
714     // Even though we can look inside function calling ops, we make an exception
715     // here mostly for performance reasons. Looking inside function calling ops
716     // is extra overhead. It is only necessary when they return resources. When
717     // they don't, we don't look inside them and make this exception here.
718     // Looking inside, could potentially enable us to make better placement
719     // decisions. It might be worth doing at some point.
720     if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
721         !IsExemptFromResourceInputColocation(dst)) {
722       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
723     }
724   }
725 
726   return Status::OK();
727 }
728 
AddInspectionConstraints(const std::unordered_set<Node * > & inspection_required)729 Status ColocationGraph::AddInspectionConstraints(
730     const std::unordered_set<Node*>& inspection_required) {
731   for (Node* node : inspection_required) {
732     IOColocationGroups groups;
733     TF_RETURN_IF_ERROR(
734         inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
735     VLOG(2) << "Computed IOColocationGroups for node " << node->name()
736             << ":\n\t" << groups.DebugString();
737     TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
738   }
739   return Status::OK();
740 }
741 
Initialize()742 Status ColocationGraph::Initialize() {
743   TF_RETURN_IF_ERROR(InitializeMembers());
744 
745   std::unordered_set<Node*> inspection_required;
746   TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
747   TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
748   TF_RETURN_IF_ERROR(ColocateAllNodes());
749 
750   for (Node* node : graph_.op_nodes()) {
751     int root_id = FindAndUpdateRoot(node->id());
752     members_[root_id].MaybeExcludeXlaDevices();
753   }
754 
755   return Status::OK();
756 }
757 
758 // pair containing a node and whether this node has a resource input
759 // from the node requiring placer inspection.
760 using NodeAndBool = std::pair<const Node*, bool>;
761 
762 namespace {
763 
764 // Returns a vector of node names from `nodes`.
NodeAndBoolToString(const std::vector<NodeAndBool> & nodes)765 std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
766   std::vector<string> v;
767   v.reserve(nodes.size());
768   for (const NodeAndBool& node_and_bool : nodes) {
769     v.push_back(node_and_bool.first->name());
770   }
771   return v;
772 }
773 
774 // Given a node requiring placer inspection and its IOColocationGroups,
775 // computes `group_nodes`.
776 // group_nodes[i] contains the nodes that are members of colocation
777 // group i. These nodes are inputs or outputs of `node`.
778 // group_nodes[i][j] is a pair containing a node and whether this node
779 // has a resource input from `node`.
780 // Note:
781 // The same node can be added multiple times to the same group.
782 // The same node can be added to multiple groups.
GetGroupNodes(const IOColocationGroups & groups,const Node & node,std::vector<std::vector<NodeAndBool>> * group_nodes)783 Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
784                      std::vector<std::vector<NodeAndBool>>* group_nodes) {
785   group_nodes->reserve(groups.group_devices.size());
786   for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
787     const Node* src;
788     TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
789     int group_id = groups.input_groups[arg_idx];
790     (*group_nodes)[group_id].emplace_back(src, false);
791   }
792 
793   for (const Edge* edge : node.out_edges()) {
794     if (edge->IsControlEdge()) {
795       continue;
796     }
797 
798     int group_id = groups.output_groups[edge->src_output()];
799     (*group_nodes)[group_id].emplace_back(
800         edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
801   }
802 
803   if (VLOG_IS_ON(2)) {
804     VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
805     for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
806       VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
807               << "]";
808     }
809   }
810   return Status::OK();
811 }
812 
813 }  // namespace
814 
ApplyIOColocationGroups(const IOColocationGroups & groups,const Node & node)815 Status ColocationGraph::ApplyIOColocationGroups(
816     const IOColocationGroups& groups, const Node& node) {
817   if (groups.input_groups.size() != node.num_inputs()) {
818     return errors::Internal(
819         "Cannot apply input/output device constraints to node ",
820         node.DebugString(), " because input_groups.size() (",
821         groups.input_groups.size(),
822         ") is different from number of inputs into the op node (",
823         node.num_inputs(), ")");
824   }
825   if (groups.output_groups.size() != node.num_outputs()) {
826     return errors::Internal(
827         "Cannot apply input/output device constraints to node ",
828         node.DebugString(), " because output_groups.size() (",
829         groups.output_groups.size(),
830         ") is different from number of outputs into the op node (",
831         node.num_outputs(), ")");
832   }
833 
834   // group_nodes[i] contains the nodes that are members of colocation
835   // group i. These nodes are inputs or outputs of `node`.
836   // group_nodes[i][j] is a pair containing the node and whether this node
837   // has a resource input from `node`.
838   // The same node can be added multiple times to the same group.
839   // The same node can be added to multiple groups.
840   // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
841   std::vector<std::vector<NodeAndBool>> group_nodes(
842       groups.group_devices.size());
843   TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
844 
845   // Colocate nodes in each group
846   for (const std::vector<NodeAndBool>& nodes : group_nodes) {
847     for (int i = 1; i < nodes.size(); ++i) {
848       VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
849               << nodes[i].first->name() << "\"";
850       if (nodes[i].second) {
851         TF_RETURN_IF_ERROR(
852             ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
853       } else {
854         TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
855       }
856     }
857   }
858 
859   // Limit devices in each group
860   for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
861     // Nothing to do for empty groups. Groups can be empty if some output
862     // of an op is not used.
863     if (group_nodes[group_id].empty()) {
864       continue;
865     }
866     const Node* group_node = group_nodes[group_id][0].first;
867     const PossibleDevices& possible_devices = groups.group_devices[group_id];
868     TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
869   }
870 
871   return Status::OK();
872 }
873 
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)874 Status ColocationGraph::ColocateNodeToGroup(
875     std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
876         colocation_group_root,
877     const Node* node, StringPiece colocation_group) {
878   const Node*& root_node = (*colocation_group_root)[colocation_group];
879   if (root_node == nullptr) {
880     // This is the first node of the colocation group, so
881     // designate this node as the 'root' of that colocation group.
882     root_node = node;
883   } else {
884     // Try to colocate the node with the root.  If there is an
885     // error, return it.
886     Status s = ColocateNodes(*node, *root_node);
887     if (!s.ok()) {
888       if (!allow_soft_placement_) {
889         return AttachDef(s, *node);
890       }
891       if (log_device_placement_) {
892         LOG(INFO) << "Ignoring request to colocate node '" << node->name()
893                   << "' with nodes in colocation group '" << colocation_group
894                   << "' because soft placement is on and an attempt at doing "
895                      "so resulted in the following error: "
896                   << AttachDef(s, *node).ToString();
897       }
898     }
899   }
900   return Status::OK();
901 }
902 
903 // Merge the (possibly disjoint) sets containing nodes "x" and
904 // "y". Returns OK if the all nodes in the union of these sets can
905 // be placed on the same device type.
906 //
907 // NOTE: If this method returns an error, *this is left in an undefined
908 // state.
ColocateNodes(const Node & x,const Node & y)909 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
910   int x_root = FindAndUpdateRoot(x.id());
911   int y_root = FindAndUpdateRoot(y.id());
912   return ColocateNodes(x, x_root, y, y_root);
913 }
914 
915 // This overload of ColocateNodes() allows a caller to provide the root node
916 // ids for the two nodes. For large graphs, this noticeably reduces the
917 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)918 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
919                                       int y_root) {
920   if (x_root == y_root) {
921     return Status::OK();
922   }
923 
924   Member* new_root_member;
925   Member* old_root_member;
926   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
927                 /*dry_run=*/true);
928 
929   // Merge the partial device specifications, and ensure that they are
930   // compatible. NULL options_ is treated as allowing soft placement.
931   // If there is an error, nothing is modified.
932   // TODO(mrry): Consider enriching the error message by pointing
933   // out which nodes have the explicit partial device
934   // specifications that caused this conflict.
935   Status s = new_root_member->MergeDeviceNames(*old_root_member,
936                                                allow_soft_placement_);
937   if (!s.ok()) {
938     return errors::InvalidArgument(
939         "Cannot colocate nodes ",
940         errors::FormatColocationNodeForError(x.name()), " and ",
941         errors::FormatColocationNodeForError(y.name()), ": ",
942         s.error_message());
943   }
944 
945   // Ensure that the common root has at least one supported device
946   // type, by computing the intersection of
947   // new_root_member.supported_device_types and
948   // old_root_member.supported_device_types.
949   if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
950     return errors::InvalidArgument(
951         "Cannot colocate nodes ",
952         errors::FormatColocationNodeForError(x.name()), " and ",
953         errors::FormatColocationNodeForError(y.name()),
954         " because no device type supports both of those nodes and the "
955         "other nodes colocated with them.",
956         DebugInfo(x_root), DebugInfo(y_root));
957   }
958 
959   // All error checks are done, merge the colocation graphs.
960   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
961                 /*dry_run=*/false);
962   return Status::OK();
963 }
964 
LimitToAssignedDevice(const Node & node)965 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
966   if (node.assigned_device_name_index() < 0) {
967     return errors::Internal(
968         "Expected an assigned node as argument to LimitToAssignedDevice but "
969         "got: ",
970         node.DebugString());
971   }
972   int root = FindAndUpdateRoot(node.id());
973   Member& root_member = members_[root];
974   return root_member.AssignDevice(node);
975 }
976 
GetSoftDeviceCandidates(const Node & node,const Member & root_member,int root_id,std::vector<Device * > * possible_devices)977 void ColocationGraph::GetSoftDeviceCandidates(
978     const Node& node, const Member& root_member, int root_id,
979     std::vector<Device*>* possible_devices) {
980   // Try to find supported devices that don't violate resource devices.
981   // The soft_device_name is the same as the requested device name
982   // without specifying the device type or ID (if assigned and requested
983   // devices does not specify them).
984   DeviceNameUtils::ParsedName soft_device_name =
985       root_member.GetPreferredSoftDeviceName();
986   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
987   if (!possible_devices->empty()) {
988     *possible_devices = FilterSupportedDevices(
989         *possible_devices, root_member.supported_device_types(),
990         default_local_device_);
991   }
992 
993   if (!possible_devices->empty()) {
994     return;
995   }
996 
997   // TODO(iga): Disallow changing resource devices when this ColocationGraph
998   // is for :
999   // - a function called by an op requiring deep inspection, or
1000   // - a graph containing ops requiring inspection.
1001   // It is fairly tricky to make changing resource devices in presence of
1002   // ops requiring inspection work correctly. One thing it would require is to
1003   // communicate these "resource movement" decisions across Placer instances.
1004 
1005   // Failed to find supported devices that don't violate resource devices.
1006   // Try finding some devices that violated resource devices.
1007   // If we succceed, we will log a warning below.
1008   soft_device_name = root_member.GetSoftDeviceName();
1009   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1010   if (!possible_devices->empty()) {
1011     *possible_devices = FilterSupportedDevices(
1012         *possible_devices, root_member.supported_device_types(),
1013         default_local_device_);
1014   }
1015 
1016   if (!possible_devices->empty()) {
1017     LOG(WARNING)
1018         << "Failed to place the graph without changing the devices of some "
1019            "resources. Some of the operations (that had to be colocated with "
1020            "resource generating operations) are not supported on the "
1021            "resources' devices. Current candidate devices are [\n  "
1022         << absl::StrJoin(DevicesToString(*possible_devices), "\n  ")
1023         << "].\nSee below for details of this colocation group:"
1024         << DebugInfo(root_id);
1025   }
1026 }
1027 
LimitToPossibleDevices(const Node & node,const PossibleDevices & devices)1028 Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1029                                                const PossibleDevices& devices) {
1030   int root = FindAndUpdateRoot(node.id());
1031   Member& root_member = members_[root];
1032   return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1033 }
1034 
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)1035 Status ColocationGraph::GetDevicesForNode(
1036     Node* node, const std::vector<Device*>** possible_devices) {
1037   *possible_devices = nullptr;
1038   const int node_root = FindAndUpdateRoot(node->id());
1039   if (!members_[node_root].possible_devices().empty()) {
1040     *possible_devices = &members_[node_root].possible_devices();
1041     return Status::OK();
1042   }
1043 
1044   Member& root_member = members_[node_root];
1045 
1046   // We have not yet computed the possible devices for the
1047   // colocated node set containing 'node', so we do so now using the
1048   // constraints on the root node.
1049 
1050   // "devices" will contain the set of feasible placements for the
1051   // colocated node set containing 'node'.
1052   // NOTE: Basing possible device computation on requested device name
1053   // is guaranteed to respect the assigned and resource device names because
1054   // requested device is always a specialization of both.
1055   std::vector<Device*> devices;
1056   if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1057     // The root node has a (possibly partial) device
1058     // specification, so enumerate the physical devices that
1059     // conform to it.
1060     device_set_.FindMatchingDevices(root_member.requested_device_name(),
1061                                     &devices);
1062 
1063     if (!devices.empty()) {
1064       // Filter devices into those that are compatible with the root
1065       // node (and its children).
1066       devices = FilterSupportedDevices(
1067           devices, root_member.supported_device_types(), default_local_device_);
1068     }
1069 
1070     // Perform soft placement if allow_soft_placement_ is set.
1071     if (devices.empty() && allow_soft_placement_) {
1072       GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1073     }
1074 
1075     if (devices.empty()) {
1076       // Return an error when a physical device that matches an explicit
1077       // device specification is not found. This ensures that we don't
1078       // assign a node to GPU when the user wanted to force it on CPU.
1079       string debug_info = DebugInfo(node_root);
1080 
1081       DeviceNameUtils::ParsedName specified_device_name;
1082       if (DeviceNameUtils::ParseFullName(node->requested_device(),
1083                                          &specified_device_name) &&
1084           specified_device_name == root_member.requested_device_name()) {
1085         // The specified device and merged set device match, and
1086         // will appear in the GraphDef (for debugging), so just
1087         // print the specified device.
1088         std::vector<Device*> devices_matching_nodedef;
1089         device_set_.FindMatchingDevices(specified_device_name,
1090                                         &devices_matching_nodedef);
1091         if (devices_matching_nodedef.empty()) {
1092           // Sometimes it is almost impossible to understand the problem
1093           // without a list of available devices.
1094           std::vector<string> device_names;
1095           for (const Device* device : device_set_.devices()) {
1096             device_names.push_back(device->name());
1097           }
1098           std::sort(device_names.begin(), device_names.end());
1099 
1100           string gpu_msg = "";
1101           if (!IsGoogleCudaEnabled() &&
1102               absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1103             gpu_msg =
1104                 " The requested device appears to be a GPU, but CUDA is not "
1105                 "enabled.";
1106           }
1107 
1108           return errors::InvalidArgument(
1109               errors::FormatNodeNameForError(node->name()),
1110               " was explicitly assigned to ", node->requested_device(),
1111               " but available devices are [ ",
1112               absl::StrJoin(device_names, ", "), " ]. Make sure ",
1113               "the device specification refers to a valid device.", gpu_msg);
1114         } else if (specified_device_name.has_type) {
1115           return errors::InvalidArgument(
1116               "Could not satisfy explicit device specification '",
1117               node->requested_device(), "' because no supported kernel for ",
1118               specified_device_name.type, " devices is available.", debug_info,
1119               "\nOp: ", node->type_string(),
1120               "\nNode attrs: ", node->attrs().DebugString(),
1121               "\nRegistered kernels:\n",
1122               KernelsRegisteredForOp(node->type_string()));
1123         } else {
1124           return errors::InvalidArgument(
1125               "Could not satisfy explicit device specification '",
1126               node->requested_device(), debug_info);
1127         }
1128       } else {
1129         // The specified device may be a valid device but the
1130         // merged set device is different, so print both.
1131         // TODO(b/129057603): There are many possibilities at this point.
1132         // Provide good error messages.
1133         return errors::InvalidArgument(
1134             "Could not satisfy explicit device specification '",
1135             node->requested_device(), "' because the node ",
1136             errors::FormatColocationNodeForError(node->name()),
1137             " was colocated with a group of nodes that ",
1138             "required incompatible device '",
1139             DeviceNameUtils::ParsedNameToString(
1140                 root_member.requested_device_name()),
1141             "'. All available devices [",
1142             absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1143             debug_info);
1144       }
1145     }
1146   } else {
1147     // The device is completely unspecified, so enumerate the devices that
1148     // support all of the nodes in the set.
1149     if (device_set_.devices().empty()) {
1150       return errors::Internal("No devices are registered");
1151     }
1152     devices = FilterSupportedDevices(device_set_.devices(),
1153                                      root_member.supported_device_types(),
1154                                      default_local_device_);
1155 
1156     if (devices.empty()) {
1157       return errors::InvalidArgument(
1158           "Node had no OpKernel registered to support this operation: ",
1159           "Operation was ", node->type_string(), " and inputs were [",
1160           DataTypeVectorString(node->input_types()), "].\n",
1161           DebugInfo(node_root));
1162     }
1163   }
1164 
1165   // Cache the result of the possible devices for this node group.
1166   root_member.set_possible_devices(std::move(devices));
1167   *possible_devices = &root_member.possible_devices();
1168   return Status::OK();
1169 }
1170 
InitializeMembers()1171 Status ColocationGraph::InitializeMembers() {
1172   for (Node* node : graph_.op_nodes()) {
1173     Status status = InitializeMember(*node, &members_[node->id()]);
1174     if (!status.ok()) {
1175       return AttachDef(status, *node);
1176     }
1177   }
1178   return Status::OK();
1179 }
1180 
DebugString() const1181 string ColocationGraph::DebugString() const {
1182   std::unordered_set<int> roots;
1183   std::vector<string> root_strings;
1184   for (const Node* node : graph_.nodes()) {
1185     if (!node->IsOp()) {
1186       continue;
1187     }
1188     int node_root = FindRoot(node->id());
1189     if (roots.count(node_root) == 0) {
1190       root_strings.push_back(DebugInfo(node_root));
1191       roots.insert(node_root);
1192     }
1193   }
1194   return absl::StrJoin(root_strings, "\n");
1195 }
1196 
1197 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root) const1198 string ColocationGraph::DebugInfo(const int node_root) const {
1199   string text(
1200       "\nColocation Debug Info:\n"
1201       "Colocation group had the following types and supported devices: ");
1202 
1203   // If this node is part of a colocation group, then we want to
1204   // collect the mapping of ops to supported devices, so that
1205   // the user can see why an unsatisfiable placement occurred.
1206 
1207   std::unordered_map<string, string> type_to_devices;
1208   std::vector<const Node*> colocation_nodes;
1209   int num_nodes_found = 0;
1210 
1211   for (const Node* node : graph_.nodes()) {
1212     if (!node->IsOp()) {
1213       continue;
1214     }
1215     int id = node->id();
1216     if (FindRoot(id) != node_root) {
1217       continue;
1218     }
1219     ++num_nodes_found;
1220     colocation_nodes.push_back(node);
1221 
1222     PrioritizedDeviceTypeVector supported_types;
1223     SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1224                                 &local_address_spec_)
1225         .IgnoreError();
1226     string devices_registered;
1227     for (const auto& device_type : supported_types) {
1228       strings::StrAppend(&devices_registered,
1229                          DeviceTypeString(device_type.first), " ");
1230     }
1231 
1232     const string& op_type = node->type_string();
1233     type_to_devices[op_type] = std::move(devices_registered);
1234   }
1235   strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1236 
1237   for (const auto& td : type_to_devices) {
1238     strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1239   }
1240   strings::StrAppend(&text,
1241                      "\n\nColocation members, user-requested devices, and "
1242                      "framework assigned devices, if any:");
1243   for (const Node* node : colocation_nodes) {
1244     strings::StrAppend(&text, "\n  ", node->name(), " (", node->type_string(),
1245                        ") ", node->requested_device());
1246     if (node->has_assigned_device_name()) {
1247       strings::StrAppend(
1248           &text, " framework assigned device=", node->assigned_device_name());
1249     }
1250   }
1251   strings::StrAppend(&text, "\n");
1252 
1253   if (num_nodes_found <= 0) {
1254     text.clear();
1255   }
1256   return text;
1257 }
1258 
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,Member * member)1259 Status ColocationGraph::InitializeMemberWithAssignedDevice(
1260     const string& assigned_device_name, const string& node_type,
1261     Member* member) {
1262   // This node has already been assigned to a device, so we
1263   // respect this placement, after sanity-checking it.
1264   // NOTE: Since any assignment must have been performed by
1265   // the TensorFlow runtime, we consider errors in this branch to
1266   // be INTERNAL.
1267   TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1268 
1269   // Since assigned device must be a full specification, do extra checks.
1270   const Device* assigned_device =
1271       device_set_.FindDeviceByName(assigned_device_name);
1272   if (assigned_device == nullptr) {
1273     // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1274     // calls when they are supported.
1275     return errors::Internal(
1276         "Assigned device '", assigned_device_name,
1277         "' does not match any device. This error can happen when one attempts "
1278         "to run a tf.function with resource inputs residing on remote devices. "
1279         "This use case is currently not supported. Here are the devices "
1280         "available on this machine: [",
1281         absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1282         "If you are seeing this error when running using a tf.Session, set "
1283         "experimental.share_cluster_devices_in_session to true in the "
1284         "tf.ConfigProto.");
1285   }
1286 
1287   for (const auto& d : member->supported_device_types()) {
1288     if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
1289       return Status::OK();
1290     }
1291   }
1292 
1293   return errors::Internal("Assigned device '", assigned_device_name,
1294                           "' does not have registered OpKernel support "
1295                           "for ",
1296                           node_type);
1297 }
1298 
InitializeMember(const Node & node,Member * member)1299 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1300   TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1301       node, device_types_, &local_address_spec_));
1302 
1303   if (node.has_assigned_device_name()) {
1304     TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1305         node.assigned_device_name(), node.type_string(), member));
1306   } else {
1307     // This node has not yet been assigned to a device, so we
1308     // calculate any constraints due to the set of registered
1309     // kernels and any (partial) user-provided device specification
1310     // in the NodeDef.
1311 
1312     // If no kernels are registered for this op type, fail with an error.
1313     if (member->supported_device_types().empty()) {
1314       std::set<string> registered_device_types;
1315       for (Device* d : device_set_.devices()) {
1316         registered_device_types.insert(d->device_type());
1317       }
1318       return errors::InvalidArgument(
1319           "No OpKernel was registered to support Op '", node.type_string(),
1320           "' used by ", errors::FormatNodeNameForError(node.name()),
1321           " with these attrs: [", node.attrs().DebugString(),
1322           "]\n"
1323           "Registered devices: [",
1324           absl::StrJoin(registered_device_types, ", "), "]\n",
1325           "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1326     }
1327 
1328     // If the NodeDef contains a device, then we interpret it as a
1329     // (partial) device specification.
1330     if (!node.requested_device().empty()) {
1331       if (IsRefOrResourceGeneratorNode(node)) {
1332         // Treat requested device on resource generating nodes as assigned
1333         // device so that we don't override it.
1334         TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1335       } else {
1336         // The user has specified a device in the NodeDef, try to find a
1337         // valid device matching their specification in the set of
1338         // devices.
1339         // NOTE: The full name may specify a device that is not in
1340         // n.supported_device_types(), but we check that in AssignDevice().
1341         TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1342       }
1343     }
1344   }
1345   return Status::OK();
1346 }
1347 
1348 // Returns a list of devices having type in supported_device_types.  The
1349 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_local_device)1350 /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1351     const std::vector<Device*>& devices,
1352     const PrioritizedDeviceTypeVector& supported_device_types,
1353     const Device* default_local_device) {
1354   Device* filtered_default_device = nullptr;
1355   PrioritizedDeviceVector prioritized_filtered_devices;
1356   for (const auto& supported_device_type : supported_device_types) {
1357     for (Device* device : devices) {
1358       if (DeviceType(device->attributes().device_type()) ==
1359           supported_device_type.first) {
1360         if (default_local_device &&
1361             (device == default_local_device ||
1362              // TODO(nareshmodi, fishx): At times the device pointer in the
1363              // device set is different to the one passed in as the default
1364              // device. Figure out why this might be.
1365              device->name() == default_local_device->name())) {
1366           filtered_default_device = device;
1367         } else {
1368           prioritized_filtered_devices.emplace_back(
1369               device, supported_device_type.second);
1370         }
1371       }
1372     }
1373   }
1374   DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1375 
1376   std::vector<Device*> filtered_devices;
1377   if (filtered_default_device != nullptr) {
1378     filtered_devices.emplace_back(filtered_default_device);
1379   }
1380   for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1381     filtered_devices.push_back(prioritized_filtered_device.first);
1382   }
1383   return filtered_devices;
1384 }
1385 
1386 }  // namespace tensorflow
1387