1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17
18 #include <memory>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_set.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
31 #include "tensorflow/core/common_runtime/inspecting_placer.h"
32 #include "tensorflow/core/common_runtime/partitioning_utils.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/attr_value_util.h"
35 #include "tensorflow/core/framework/dataset.h"
36 #include "tensorflow/core/framework/device_attributes.pb.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/graph/graph_node_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/strings/str_util.h"
46 #include "tensorflow/core/lib/strings/strcat.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48 #include "tensorflow/core/util/dump_graph.h"
49 #include "tensorflow/core/util/port.h"
50
51 namespace tensorflow {
52
53 namespace {
54
55 // We hoist the conversion from C-style string literal to StringPiece here,
56 // so that we can avoid the many repeated calls to strlen().
57 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
58 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
59
60 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)61 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
62 std::vector<string> v;
63 v.reserve(devices.size());
64 for (Device* d : devices) {
65 v.push_back(d->name());
66 }
67 return v;
68 }
69
70 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)71 std::vector<string> DeviceTypeAndPriorityToString(
72 const PrioritizedDeviceTypeVector& devices) {
73 std::vector<string> v;
74 v.reserve(devices.size());
75 for (const std::pair<DeviceType, int32>& device_and_type : devices) {
76 v.push_back(DeviceTypeString(device_and_type.first));
77 }
78 return v;
79 }
80
IsRefOrResource(DataType data_type)81 bool IsRefOrResource(DataType data_type) {
82 return IsRefType(data_type) || data_type == DT_RESOURCE;
83 }
84
85 // While Placer can override requested device on ops processing
86 // resources, i.e. node that take (and potentially return) a resource,
87 // it must not override requested device on ops generating a resource,
88 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
89 // output nodes.
IsRefOrResourceGeneratorNode(const Node & node)90 bool IsRefOrResourceGeneratorNode(const Node& node) {
91 return node.num_inputs() == 0 && node.num_outputs() == 1 &&
92 IsRefOrResource(node.output_type(0));
93 }
94
IsExemptFromResourceInputColocation(const Node * node)95 bool IsExemptFromResourceInputColocation(const Node* node) {
96 // Note: Partitioned function calls, which place and partition their
97 // function bodies, are exempt from this check: they forward resource and
98 // ref inputs to operations that are appropriately placed, instead of
99 // dereferencing them.
100 const string& op_type = node->op_def().name();
101 auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
102 return exempt_ops.find(op_type) != exempt_ops.end();
103 }
104
HasPriorities(const PrioritizedDeviceTypeVector & device_types)105 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
106 for (const auto& prioritized_device_type : device_types) {
107 if (prioritized_device_type.second != 0) return true;
108 }
109 return false;
110 }
111
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)112 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
113 const PrioritizedDeviceTypeVector& b_types) {
114 if (a_types.size() != b_types.size()) {
115 return false;
116 }
117 for (int i = 0; i < a_types.size(); ++i) {
118 if (a_types[i].first != b_types[i].first) {
119 return false;
120 }
121 }
122 return true;
123 }
124
IsXlaDevice(absl::string_view device_type)125 bool IsXlaDevice(absl::string_view device_type) {
126 if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
127 device_type == "XLA_TPU_JIT") {
128 // Symbolic XLA device.
129 return true;
130 }
131
132 return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
133 device_type == "TPU");
134 }
135
136 } // namespace
137
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types,const DeviceNameUtils::ParsedName * local_address_spec)138 Status Member::SetParentAndSupportedDevices(
139 const Node& node, const std::vector<DeviceType>& types,
140 const DeviceNameUtils::ParsedName* local_address_spec) {
141 int id = node.id();
142 if (id < 0) {
143 return errors::Internal("Placer should not be creating a Member for node: ",
144 node.DebugString());
145 }
146 parent_ = id;
147 return SupportedDeviceTypesForNode(
148 types, node.def(), &supported_device_types_, local_address_spec);
149 }
150
SetAssignedDeviceName(const string & device_name)151 Status Member::SetAssignedDeviceName(const string& device_name) {
152 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
153 return errors::Internal(
154 "Setting assigned device name when there is a requested device set "
155 "is unsupported");
156 }
157 if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
158 return errors::Internal("Malformed assigned device '", device_name, "'");
159 }
160 // Set requested device to assigned_device to maintain the invariant that
161 // requested is a specialization of assigned.
162 requested_device_name_ = assigned_device_name_;
163 return Status::OK();
164 }
165
SetResourceDeviceName(const Node & node)166 Status Member::SetResourceDeviceName(const Node& node) {
167 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
168 return errors::Internal(
169 "Setting resource device name when there is a requested device set "
170 "is unsupported");
171 }
172
173 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
174 &resource_device_name_)) {
175 return errors::InvalidArgument("Malformed device specification '",
176 node.requested_device(),
177 "' in node: ", node.DebugString());
178 }
179
180 // Set requested device to resource device to maintain the invariant that
181 // requested is a specialization of resource.
182 requested_device_name_ = resource_device_name_;
183 return Status::OK();
184 }
185
SetRequestedDeviceName(const Node & node)186 Status Member::SetRequestedDeviceName(const Node& node) {
187 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
188 return errors::Internal(
189 "Setting requested device name when there is an assigned device set "
190 "is unsupported");
191 }
192 if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
193 return errors::Internal(
194 "Setting requested device name when there is a resource device set "
195 "is unsupported");
196 }
197 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
198 &requested_device_name_)) {
199 return errors::InvalidArgument("Malformed device specification '",
200 node.requested_device(),
201 "' in node: ", node.DebugString());
202 }
203 return Status::OK();
204 }
205
FillPossibleDevices(PossibleDevices * possible_device) const206 Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
207 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
208 return errors::Internal(
209 "Cannot fill PossibleDevices from a member that has non-empty assigned "
210 "device. Did we start assigning devices to functions called by deep "
211 "ops? ",
212 DebugString());
213 }
214 possible_device->requested_device_name = requested_device_name_;
215 possible_device->resource_device_name = resource_device_name_;
216 possible_device->device_types = supported_device_types_;
217 return Status::OK();
218 }
219
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)220 Status Member::EnsureCompatibilityAcrossResourceEdge(
221 const Node& src, const Member& src_root,
222 const Node& dst, /*dst_root is this*/
223 bool log_device_placement) {
224 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
225 assigned_device_name_)) {
226 return errors::InvalidArgument(
227 "Cannot place the graph because a reference or resource edge "
228 "connects colocation groups with incompatible assigned devices: ",
229 DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
230 " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
231 ". The edge src node is ", src.name(), " , and the dst node is ",
232 dst.name());
233 }
234
235 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
236 resource_device_name_)) {
237 return errors::InvalidArgument(
238 "Cannot place the graph because a reference or resource edge "
239 "connects colocation groups with incompatible resource devices: ",
240 DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
241 " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
242 ". The edge src node is ", src.name(), " , and the dst node is ",
243 dst.name());
244 }
245
246 if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
247 requested_device_name_)) {
248 return Status::OK();
249 }
250
251 // If we are here, assigned and resource devices are compatible but requested
252 // ones are not. We will be overriding the requested device for destination
253 // node, but need to preserve the invariant that it will be a specialization
254 // of the assigned and resource devices.
255 if (log_device_placement) {
256 LOG(INFO) << "Ignoring device specification "
257 << DeviceNameUtils::ParsedNameToString(requested_device_name_)
258 << " for node '" << dst.name()
259 << "' because the input edge from '" << src.name()
260 << "' is a reference connection and already has a device "
261 "field set to "
262 << DeviceNameUtils::ParsedNameToString(
263 src_root.requested_device_name_);
264 }
265 requested_device_name_ = src_root.requested_device_name_;
266 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
267 assigned_device_name_);
268 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
269 resource_device_name_);
270 return Status::OK();
271 }
272
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)273 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
274 Member** new_root, Member** old_root, bool dry_run) {
275 Member& x_root_member = (*tree)[x_root];
276 Member& y_root_member = (*tree)[y_root];
277
278 // Merge the sets by setting the parent pointer of the smaller tree's root
279 // node to point to the root of the larger tree. Together with path
280 // compression in ColocationGraph::FindRoot, this ensures that we do not
281 // experience pathological performance on graphs such as chains.
282 int new_root_id, old_root_id;
283 if (x_root_member.rank_ < y_root_member.rank_) {
284 // The tree rooted at x_root is shallower, so connect it to
285 // y_root. The rank of y_root is unchanged because its new
286 // child has strictly less rank.
287 if (!dry_run) {
288 x_root_member.parent_ = y_root;
289 }
290 new_root_id = y_root;
291 old_root_id = x_root;
292 } else if (x_root_member.rank_ > y_root_member.rank_) {
293 // The tree rooted at y_root is shallower, so connect it to
294 // x_root. The rank of x_root is unchanged because its new
295 // child has strictly less rank.
296 if (!dry_run) {
297 y_root_member.parent_ = x_root;
298 }
299 new_root_id = x_root;
300 old_root_id = y_root;
301 } else {
302 if (!dry_run) {
303 // Both trees have the same rank, so break the tie by choosing
304 // x_root as the new root.
305 y_root_member.parent_ = x_root;
306 // Increment the rank of the tree rooted at x_root, because it
307 // is now strictly deeper than before.
308 ++x_root_member.rank_;
309 }
310 new_root_id = x_root;
311 old_root_id = y_root;
312 }
313
314 *new_root = &(*tree)[new_root_id];
315 *old_root = &(*tree)[old_root_id];
316 }
317
318 // tree is non-const because we can change some `parent` pointers in some
319 // members for more efficient future lookups. The vector itself is not
320 // changed.
FindAndUpdateRoot(std::vector<Member> * tree,int node_id)321 int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
322 Member& member = (*tree)[node_id];
323 if (member.parent_ == node_id) {
324 // member.parent is the root of this disjoint tree. Do nothing.
325 } else {
326 member.parent_ = FindAndUpdateRoot(tree, member.parent_);
327 }
328 // Now it is guaranteed that member.parent is the root of this disjoint
329 // tree.
330 return member.parent_;
331 }
332
FindRoot(const std::vector<Member> & tree,int node_id)333 int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
334 const Member& member = tree[node_id];
335 if (member.parent_ == node_id) {
336 return member.parent_;
337 }
338 return FindRoot(tree, member.parent_);
339 }
340
MergeDeviceNames(const Member & other,bool allow_soft_placement)341 Status Member::MergeDeviceNames(const Member& other,
342 bool allow_soft_placement) {
343 // Assuming the "requested is a specialization of assigned and resource
344 // devices" invariant holds for this and `other`, it will hold after the
345 // merges below.
346 DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
347 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
348 &assigned_device_name_copy, other.assigned_device_name_));
349
350 DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
351 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
352 &resource_device_name_copy, other.resource_device_name_));
353
354 DeviceNameUtils::ParsedName requested_device_name_copy =
355 requested_device_name_;
356 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
357 &requested_device_name_copy, other.requested_device_name_,
358 allow_soft_placement));
359
360 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
361 assigned_device_name_copy);
362 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
363 resource_device_name_copy);
364
365 // We checked for all errors, now change the devices.
366 assigned_device_name_ = assigned_device_name_copy;
367 resource_device_name_ = resource_device_name_copy;
368 requested_device_name_ = requested_device_name_copy;
369 return Status::OK();
370 }
371
372 // Updates this to contain the intersection of the device types in
373 // this and "other".
MergeSupportedDevices(const Member & other)374 bool Member::MergeSupportedDevices(const Member& other) {
375 return MergeSupportedDevices(other.supported_device_types_);
376 }
377
MergeSupportedDevices(const PrioritizedDeviceTypeVector & other_devices)378 bool Member::MergeSupportedDevices(
379 const PrioritizedDeviceTypeVector& other_devices) {
380 // Generate intersection with priorities.
381 // Each vector contains the same device types but with different priorities.
382 // The priorities are taken from the corresponding source vector.
383 PrioritizedDeviceTypeVector target_intersection;
384 PrioritizedDeviceTypeVector other_intersection;
385
386 for (const auto& prioritized_device_type : supported_device_types_) {
387 bool found = false;
388 for (const auto& other_prioritized_device_type : other_devices) {
389 if (prioritized_device_type.first ==
390 other_prioritized_device_type.first) {
391 found = true;
392 other_intersection.push_back(other_prioritized_device_type);
393 break;
394 }
395 }
396 if (found) {
397 target_intersection.push_back(prioritized_device_type);
398 }
399 }
400
401 DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
402 DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
403
404 PrioritizedDeviceTypeVector result;
405
406 bool is_target_prioritized = HasPriorities(target_intersection);
407 bool is_other_prioritized = HasPriorities(other_intersection);
408 if (!is_target_prioritized && !is_other_prioritized) {
409 // If neither are prioritized then we just return the original i.e. target
410 // prioritization.
411 result = target_intersection;
412 } else if (is_target_prioritized && !is_other_prioritized) {
413 // If only one is prioritized, then we respect priorities of that in the
414 // intersection.
415 result = target_intersection;
416 } else if (!is_target_prioritized && is_other_prioritized) {
417 result = other_intersection;
418 } else {
419 // If both have priorities and agree then we go with that. If the
420 // prioritization order is different, then we just fallback to the default
421 // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
422 // merged priorities to 0, so that downstream merges work correctly as well.
423 if (ArePrioritiesSame(target_intersection, other_intersection)) {
424 result = target_intersection;
425 } else {
426 for (const auto& prioritized_device : target_intersection) {
427 result.push_back(std::make_pair(prioritized_device.first, 0));
428 }
429 DeviceSet::SortPrioritizedDeviceTypeVector(&result);
430 }
431 }
432
433 if (result.empty()) {
434 return false;
435 }
436 supported_device_types_ = result;
437 return true;
438 }
439
AssignDevice(const Node & node)440 Status Member::AssignDevice(const Node& node) {
441 if (node.assigned_device_name_index() == assigned_device_name_index_) {
442 return Status::OK();
443 }
444
445 DeviceNameUtils::ParsedName parsed;
446 DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
447 Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
448 if (!s.ok()) {
449 return errors::Internal(
450 "Constraining by assigned device should not cause an error. Original "
451 "root's assigned device name: ",
452 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
453 " node's assigned device name \"", node.assigned_device_name(),
454 ". Error: ", s.error_message());
455 }
456 s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
457 if (!s.ok()) {
458 return errors::Internal(
459 "Constraining by assigned device should not cause an error. Original "
460 "root's resource device name: ",
461 DeviceNameUtils::ParsedNameToString(resource_device_name_),
462 " node's assigned device name \"", node.assigned_device_name(),
463 ". Error: ", s.error_message());
464 }
465 s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
466 if (!s.ok()) {
467 return errors::Internal(
468 "Constraining by assigned device should not cause an error. Original "
469 "root's requested device name: \"",
470 DeviceNameUtils::ParsedNameToString(requested_device_name_),
471 "\", node's assigned device name \"", node.assigned_device_name(),
472 "\". Error: ", s.error_message());
473 }
474
475 assigned_device_name_index_ = node.assigned_device_name_index();
476 // Clear cached possible_devices, if any.
477 possible_devices_.clear();
478 return Status::OK();
479 }
480
MaybeExcludeXlaDevices()481 void Member::MaybeExcludeXlaDevices() {
482 for (const auto& parsed_name :
483 {requested_device_name_, assigned_device_name_, resource_device_name_}) {
484 if (parsed_name.has_type && IsXlaDevice(parsed_name.type)) {
485 return;
486 }
487 }
488
489 PrioritizedDeviceTypeVector non_xla_types;
490 absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
491 [&](const std::pair<DeviceType, int32>& entry) {
492 return !IsXlaDevice(entry.first.type_string());
493 });
494
495 // TODO(b/141216278) Remove all XLA device types from the supported device
496 // types if the node has no requested/assigned/resource XLA device.
497 if (!non_xla_types.empty() &&
498 non_xla_types.size() < supported_device_types_.size()) {
499 supported_device_types_ = std::move(non_xla_types);
500 }
501 }
502
LimitToPossibleDevices(const PossibleDevices & devices,bool allow_soft_placement)503 Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
504 bool allow_soft_placement) {
505 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
506 &requested_device_name_, devices.requested_device_name,
507 allow_soft_placement));
508 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
509 &resource_device_name_, devices.resource_device_name));
510 MergeSupportedDevices(devices.device_types);
511 return Status::OK();
512 }
513
DebugString() const514 string Member::DebugString() const {
515 return absl::StrCat(
516 "Member(assigned_device_name_index_=", assigned_device_name_index_,
517 " requested_device_name_='",
518 DeviceNameUtils::ParsedNameToString(requested_device_name_),
519 "' assigned_device_name_='",
520 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
521 "' resource_device_name_='",
522 DeviceNameUtils::ParsedNameToString(resource_device_name_),
523 "' supported_device_types_=[",
524 absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
525 ", "),
526 "] possible_devices_=[",
527 absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
528 }
529
GetSoftDeviceName() const530 DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
531 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
532 if (!assigned_device_name_.has_type) {
533 soft_device_name.type.clear();
534 soft_device_name.has_type = false;
535 }
536 if (!assigned_device_name_.has_id) {
537 soft_device_name.has_id = false;
538 }
539 return soft_device_name;
540 }
541
GetPreferredSoftDeviceName() const542 DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
543 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
544 if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
545 soft_device_name.type.clear();
546 soft_device_name.has_type = false;
547 }
548 if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
549 soft_device_name.has_id = false;
550 }
551 return soft_device_name;
552 }
553
554 // Returns ParsedName whose address space (i.e. job, replica, task) identifies
555 // the address space directly accessible by the local process. If the address
556 // space is fully specified and it is exactly the same as the address space
557 // of a device, then all kernels of that device should be registered in the
558 // local process.
LocalAddressSpec(const Device * client_device,const Device * default_local_device)559 static const DeviceNameUtils::ParsedName LocalAddressSpec(
560 const Device* client_device, const Device* default_local_device) {
561 if (client_device != nullptr) {
562 return DeviceNameUtils::AddressSpace(client_device->parsed_name());
563 }
564
565 if (default_local_device != nullptr) {
566 return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
567 }
568
569 // TODO(b/139617593) Return the name of the first local device in device_set_
570 // once we can trust the output of Device::IsLocal().
571 return DeviceNameUtils::ParsedName();
572 }
573
ColocationGraph(const Graph * graph,const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)574 ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
575 const FunctionLibraryDefinition* flib_def,
576 const DeviceSet* device_set,
577 const Device* default_local_device,
578 bool allow_soft_placement,
579 bool log_device_placement)
580 : graph_(*graph),
581 stack_(stack),
582 flib_def_(*flib_def),
583 inspecting_placer_(stack, flib_def, device_set, default_local_device,
584 allow_soft_placement, log_device_placement),
585 inspection_required_checker_(graph, flib_def),
586 device_set_(*device_set),
587 device_types_(device_set->PrioritizedDeviceTypeList()),
588 local_address_spec_(
589 LocalAddressSpec(device_set->client_device(), default_local_device)),
590 default_local_device_(default_local_device),
591 allow_soft_placement_(allow_soft_placement),
592 log_device_placement_(log_device_placement) {
593 members_.resize(graph_.num_node_ids());
594 }
595
596 // Adds each node of the Graph to this ColocationGraph as a singleton.
597 //
598 // NOTE: The implementation assumes that the ids of nodes passed to
599 // this method are dense and zero-based; the memory used will be linear in
600 // the largest node ID.
601 // NOTE: If this method returns an error, *this is left in an undefined
602 // state.
ColocateAllNodes()603 Status ColocationGraph::ColocateAllNodes() {
604 // This maps from a colocation group identifier to the 'root' of that
605 // colocation group. Note that the keys in this map are StringPiece; the
606 // actual strings are stored under the NodeDef. The lifetime of this map
607 // is limited to this ColocateAllNodes() method, and no part of the
608 // NodeDef trees are changed during the lifetime of this method, so using
609 // StringPiece as a key is safe.
610 //
611 // Also, as a further optimization, we remove the "loc:@" prefix from
612 // "class" attribute values, when they are used as keys in this table.
613 // This allows us to use StringPiece values that refer to substrings of
614 // 'string' values stored in NodeDef attribute lists, as well as StringPiece
615 // values that refer to 'string' values from NodeDef::name(), without
616 // performing any string allocations.
617 std::unordered_map<StringPiece, const Node*, StringPieceHasher>
618 colocation_group_root;
619
620 for (const Node* node : graph_.op_nodes()) {
621 // When adding the node, identify whether it is part of a colocation
622 // group.
623
624 // This code is effectively the equivalent of GetNodeAttr() for a string
625 // array, but it avoids all internal allocations (the allocation of the
626 // backing store of the std::vector<string> as well as the copies of the
627 // strings within it). Instead, we combine the query of the colocation
628 // attribute with the calls to ColocateNodeToGroup.
629 const AttrValue* attr_value =
630 node->attrs().Find(kColocationAttrNameStringPiece);
631 if (attr_value != nullptr) {
632 if (attr_value->has_list()) {
633 for (const string& class_spec : attr_value->list().s()) {
634 StringPiece spec(class_spec);
635 if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
636 TF_RETURN_IF_ERROR(
637 ColocateNodeToGroup(&colocation_group_root, node, spec));
638 }
639 }
640 } else if (!attr_value->s().empty()) {
641 LOG(ERROR) << "The value for colocation attribute '_class' must be a "
642 "list of strings, not a single string: "
643 << node->DebugString();
644 }
645 }
646
647 // Each node belongs to a colocation group with the node's name.
648 TF_RETURN_IF_ERROR(
649 ColocateNodeToGroup(&colocation_group_root, node, node->name()));
650 }
651
652 return Status::OK();
653 }
654
ColocateResourceOrRefEdge(const Node * src,const Node * dst)655 Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
656 const Node* dst) {
657 // Colocate `src` and `dst` to maintain the invariant that nodes
658 // connected by reference edges are colocated.
659 int src_root_id = FindAndUpdateRoot(src->id());
660 int dst_root_id = FindAndUpdateRoot(dst->id());
661 auto& src_root = members_[src_root_id];
662 auto& dst_root = members_[dst_root_id];
663
664 TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
665 *src, src_root, *dst, log_device_placement_));
666 Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
667 if (!status.ok()) {
668 return AttachDef(
669 errors::InvalidArgument("Nodes were connected by a "
670 "reference connection (requiring them to "
671 "be on the same device), but the two nodes "
672 "were assigned two different devices: ",
673 status.error_message()),
674 *dst);
675 }
676 return Status::OK();
677 }
678
ColocateResourceAndRefEdges(std::unordered_set<Node * > * inspection_required)679 Status ColocationGraph::ColocateResourceAndRefEdges(
680 std::unordered_set<Node*>* inspection_required) {
681 // If `node` has an input edge with reference type, add an edge from the
682 // source of that edge to `node`.
683 for (const Edge* edge : graph_.edges()) {
684 if (edge->IsControlEdge()) {
685 continue;
686 }
687 Node* src = edge->src();
688 Node* dst = edge->dst();
689 bool needs_inspection;
690 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
691 *src, &needs_inspection));
692 if (needs_inspection) {
693 inspection_required->insert(src);
694 continue;
695 }
696 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
697 *dst, &needs_inspection));
698 if (needs_inspection) {
699 inspection_required->insert(dst);
700 continue;
701 }
702
703 DataType input_type = dst->input_type(edge->dst_input());
704
705 // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
706 // This is needed to get around the issue in b/135705778.
707 if (input_type == DT_VARIANT &&
708 data::DatasetOpKernel::IsDatasetOp(&src->op_def()) &&
709 data::DatasetOpKernel::IsDatasetOp(&dst->op_def())) {
710 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
711 continue;
712 }
713
714 // Even though we can look inside function calling ops, we make an exception
715 // here mostly for performance reasons. Looking inside function calling ops
716 // is extra overhead. It is only necessary when they return resources. When
717 // they don't, we don't look inside them and make this exception here.
718 // Looking inside, could potentially enable us to make better placement
719 // decisions. It might be worth doing at some point.
720 if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
721 !IsExemptFromResourceInputColocation(dst)) {
722 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
723 }
724 }
725
726 return Status::OK();
727 }
728
AddInspectionConstraints(const std::unordered_set<Node * > & inspection_required)729 Status ColocationGraph::AddInspectionConstraints(
730 const std::unordered_set<Node*>& inspection_required) {
731 for (Node* node : inspection_required) {
732 IOColocationGroups groups;
733 TF_RETURN_IF_ERROR(
734 inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
735 VLOG(2) << "Computed IOColocationGroups for node " << node->name()
736 << ":\n\t" << groups.DebugString();
737 TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
738 }
739 return Status::OK();
740 }
741
Initialize()742 Status ColocationGraph::Initialize() {
743 TF_RETURN_IF_ERROR(InitializeMembers());
744
745 std::unordered_set<Node*> inspection_required;
746 TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
747 TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
748 TF_RETURN_IF_ERROR(ColocateAllNodes());
749
750 for (Node* node : graph_.op_nodes()) {
751 int root_id = FindAndUpdateRoot(node->id());
752 members_[root_id].MaybeExcludeXlaDevices();
753 }
754
755 return Status::OK();
756 }
757
758 // pair containing a node and whether this node has a resource input
759 // from the node requiring placer inspection.
760 using NodeAndBool = std::pair<const Node*, bool>;
761
762 namespace {
763
764 // Returns a vector of node names from `nodes`.
NodeAndBoolToString(const std::vector<NodeAndBool> & nodes)765 std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
766 std::vector<string> v;
767 v.reserve(nodes.size());
768 for (const NodeAndBool& node_and_bool : nodes) {
769 v.push_back(node_and_bool.first->name());
770 }
771 return v;
772 }
773
774 // Given a node requiring placer inspection and its IOColocationGroups,
775 // computes `group_nodes`.
776 // group_nodes[i] contains the nodes that are members of colocation
777 // group i. These nodes are inputs or outputs of `node`.
778 // group_nodes[i][j] is a pair containing a node and whether this node
779 // has a resource input from `node`.
780 // Note:
781 // The same node can be added multiple times to the same group.
782 // The same node can be added to multiple groups.
GetGroupNodes(const IOColocationGroups & groups,const Node & node,std::vector<std::vector<NodeAndBool>> * group_nodes)783 Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
784 std::vector<std::vector<NodeAndBool>>* group_nodes) {
785 group_nodes->reserve(groups.group_devices.size());
786 for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
787 const Node* src;
788 TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
789 int group_id = groups.input_groups[arg_idx];
790 (*group_nodes)[group_id].emplace_back(src, false);
791 }
792
793 for (const Edge* edge : node.out_edges()) {
794 if (edge->IsControlEdge()) {
795 continue;
796 }
797
798 int group_id = groups.output_groups[edge->src_output()];
799 (*group_nodes)[group_id].emplace_back(
800 edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
801 }
802
803 if (VLOG_IS_ON(2)) {
804 VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
805 for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
806 VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
807 << "]";
808 }
809 }
810 return Status::OK();
811 }
812
813 } // namespace
814
ApplyIOColocationGroups(const IOColocationGroups & groups,const Node & node)815 Status ColocationGraph::ApplyIOColocationGroups(
816 const IOColocationGroups& groups, const Node& node) {
817 if (groups.input_groups.size() != node.num_inputs()) {
818 return errors::Internal(
819 "Cannot apply input/output device constraints to node ",
820 node.DebugString(), " because input_groups.size() (",
821 groups.input_groups.size(),
822 ") is different from number of inputs into the op node (",
823 node.num_inputs(), ")");
824 }
825 if (groups.output_groups.size() != node.num_outputs()) {
826 return errors::Internal(
827 "Cannot apply input/output device constraints to node ",
828 node.DebugString(), " because output_groups.size() (",
829 groups.output_groups.size(),
830 ") is different from number of outputs into the op node (",
831 node.num_outputs(), ")");
832 }
833
834 // group_nodes[i] contains the nodes that are members of colocation
835 // group i. These nodes are inputs or outputs of `node`.
836 // group_nodes[i][j] is a pair containing the node and whether this node
837 // has a resource input from `node`.
838 // The same node can be added multiple times to the same group.
839 // The same node can be added to multiple groups.
840 // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
841 std::vector<std::vector<NodeAndBool>> group_nodes(
842 groups.group_devices.size());
843 TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
844
845 // Colocate nodes in each group
846 for (const std::vector<NodeAndBool>& nodes : group_nodes) {
847 for (int i = 1; i < nodes.size(); ++i) {
848 VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
849 << nodes[i].first->name() << "\"";
850 if (nodes[i].second) {
851 TF_RETURN_IF_ERROR(
852 ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
853 } else {
854 TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
855 }
856 }
857 }
858
859 // Limit devices in each group
860 for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
861 // Nothing to do for empty groups. Groups can be empty if some output
862 // of an op is not used.
863 if (group_nodes[group_id].empty()) {
864 continue;
865 }
866 const Node* group_node = group_nodes[group_id][0].first;
867 const PossibleDevices& possible_devices = groups.group_devices[group_id];
868 TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
869 }
870
871 return Status::OK();
872 }
873
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)874 Status ColocationGraph::ColocateNodeToGroup(
875 std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
876 colocation_group_root,
877 const Node* node, StringPiece colocation_group) {
878 const Node*& root_node = (*colocation_group_root)[colocation_group];
879 if (root_node == nullptr) {
880 // This is the first node of the colocation group, so
881 // designate this node as the 'root' of that colocation group.
882 root_node = node;
883 } else {
884 // Try to colocate the node with the root. If there is an
885 // error, return it.
886 Status s = ColocateNodes(*node, *root_node);
887 if (!s.ok()) {
888 if (!allow_soft_placement_) {
889 return AttachDef(s, *node);
890 }
891 if (log_device_placement_) {
892 LOG(INFO) << "Ignoring request to colocate node '" << node->name()
893 << "' with nodes in colocation group '" << colocation_group
894 << "' because soft placement is on and an attempt at doing "
895 "so resulted in the following error: "
896 << AttachDef(s, *node).ToString();
897 }
898 }
899 }
900 return Status::OK();
901 }
902
903 // Merge the (possibly disjoint) sets containing nodes "x" and
904 // "y". Returns OK if the all nodes in the union of these sets can
905 // be placed on the same device type.
906 //
907 // NOTE: If this method returns an error, *this is left in an undefined
908 // state.
ColocateNodes(const Node & x,const Node & y)909 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
910 int x_root = FindAndUpdateRoot(x.id());
911 int y_root = FindAndUpdateRoot(y.id());
912 return ColocateNodes(x, x_root, y, y_root);
913 }
914
915 // This overload of ColocateNodes() allows a caller to provide the root node
916 // ids for the two nodes. For large graphs, this noticeably reduces the
917 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)918 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
919 int y_root) {
920 if (x_root == y_root) {
921 return Status::OK();
922 }
923
924 Member* new_root_member;
925 Member* old_root_member;
926 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
927 /*dry_run=*/true);
928
929 // Merge the partial device specifications, and ensure that they are
930 // compatible. NULL options_ is treated as allowing soft placement.
931 // If there is an error, nothing is modified.
932 // TODO(mrry): Consider enriching the error message by pointing
933 // out which nodes have the explicit partial device
934 // specifications that caused this conflict.
935 Status s = new_root_member->MergeDeviceNames(*old_root_member,
936 allow_soft_placement_);
937 if (!s.ok()) {
938 return errors::InvalidArgument(
939 "Cannot colocate nodes ",
940 errors::FormatColocationNodeForError(x.name()), " and ",
941 errors::FormatColocationNodeForError(y.name()), ": ",
942 s.error_message());
943 }
944
945 // Ensure that the common root has at least one supported device
946 // type, by computing the intersection of
947 // new_root_member.supported_device_types and
948 // old_root_member.supported_device_types.
949 if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
950 return errors::InvalidArgument(
951 "Cannot colocate nodes ",
952 errors::FormatColocationNodeForError(x.name()), " and ",
953 errors::FormatColocationNodeForError(y.name()),
954 " because no device type supports both of those nodes and the "
955 "other nodes colocated with them.",
956 DebugInfo(x_root), DebugInfo(y_root));
957 }
958
959 // All error checks are done, merge the colocation graphs.
960 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
961 /*dry_run=*/false);
962 return Status::OK();
963 }
964
LimitToAssignedDevice(const Node & node)965 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
966 if (node.assigned_device_name_index() < 0) {
967 return errors::Internal(
968 "Expected an assigned node as argument to LimitToAssignedDevice but "
969 "got: ",
970 node.DebugString());
971 }
972 int root = FindAndUpdateRoot(node.id());
973 Member& root_member = members_[root];
974 return root_member.AssignDevice(node);
975 }
976
GetSoftDeviceCandidates(const Node & node,const Member & root_member,int root_id,std::vector<Device * > * possible_devices)977 void ColocationGraph::GetSoftDeviceCandidates(
978 const Node& node, const Member& root_member, int root_id,
979 std::vector<Device*>* possible_devices) {
980 // Try to find supported devices that don't violate resource devices.
981 // The soft_device_name is the same as the requested device name
982 // without specifying the device type or ID (if assigned and requested
983 // devices does not specify them).
984 DeviceNameUtils::ParsedName soft_device_name =
985 root_member.GetPreferredSoftDeviceName();
986 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
987 if (!possible_devices->empty()) {
988 *possible_devices = FilterSupportedDevices(
989 *possible_devices, root_member.supported_device_types(),
990 default_local_device_);
991 }
992
993 if (!possible_devices->empty()) {
994 return;
995 }
996
997 // TODO(iga): Disallow changing resource devices when this ColocationGraph
998 // is for :
999 // - a function called by an op requiring deep inspection, or
1000 // - a graph containing ops requiring inspection.
1001 // It is fairly tricky to make changing resource devices in presence of
1002 // ops requiring inspection work correctly. One thing it would require is to
1003 // communicate these "resource movement" decisions across Placer instances.
1004
1005 // Failed to find supported devices that don't violate resource devices.
1006 // Try finding some devices that violated resource devices.
1007 // If we succceed, we will log a warning below.
1008 soft_device_name = root_member.GetSoftDeviceName();
1009 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1010 if (!possible_devices->empty()) {
1011 *possible_devices = FilterSupportedDevices(
1012 *possible_devices, root_member.supported_device_types(),
1013 default_local_device_);
1014 }
1015
1016 if (!possible_devices->empty()) {
1017 LOG(WARNING)
1018 << "Failed to place the graph without changing the devices of some "
1019 "resources. Some of the operations (that had to be colocated with "
1020 "resource generating operations) are not supported on the "
1021 "resources' devices. Current candidate devices are [\n "
1022 << absl::StrJoin(DevicesToString(*possible_devices), "\n ")
1023 << "].\nSee below for details of this colocation group:"
1024 << DebugInfo(root_id);
1025 }
1026 }
1027
LimitToPossibleDevices(const Node & node,const PossibleDevices & devices)1028 Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1029 const PossibleDevices& devices) {
1030 int root = FindAndUpdateRoot(node.id());
1031 Member& root_member = members_[root];
1032 return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1033 }
1034
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)1035 Status ColocationGraph::GetDevicesForNode(
1036 Node* node, const std::vector<Device*>** possible_devices) {
1037 *possible_devices = nullptr;
1038 const int node_root = FindAndUpdateRoot(node->id());
1039 if (!members_[node_root].possible_devices().empty()) {
1040 *possible_devices = &members_[node_root].possible_devices();
1041 return Status::OK();
1042 }
1043
1044 Member& root_member = members_[node_root];
1045
1046 // We have not yet computed the possible devices for the
1047 // colocated node set containing 'node', so we do so now using the
1048 // constraints on the root node.
1049
1050 // "devices" will contain the set of feasible placements for the
1051 // colocated node set containing 'node'.
1052 // NOTE: Basing possible device computation on requested device name
1053 // is guaranteed to respect the assigned and resource device names because
1054 // requested device is always a specialization of both.
1055 std::vector<Device*> devices;
1056 if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1057 // The root node has a (possibly partial) device
1058 // specification, so enumerate the physical devices that
1059 // conform to it.
1060 device_set_.FindMatchingDevices(root_member.requested_device_name(),
1061 &devices);
1062
1063 if (!devices.empty()) {
1064 // Filter devices into those that are compatible with the root
1065 // node (and its children).
1066 devices = FilterSupportedDevices(
1067 devices, root_member.supported_device_types(), default_local_device_);
1068 }
1069
1070 // Perform soft placement if allow_soft_placement_ is set.
1071 if (devices.empty() && allow_soft_placement_) {
1072 GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1073 }
1074
1075 if (devices.empty()) {
1076 // Return an error when a physical device that matches an explicit
1077 // device specification is not found. This ensures that we don't
1078 // assign a node to GPU when the user wanted to force it on CPU.
1079 string debug_info = DebugInfo(node_root);
1080
1081 DeviceNameUtils::ParsedName specified_device_name;
1082 if (DeviceNameUtils::ParseFullName(node->requested_device(),
1083 &specified_device_name) &&
1084 specified_device_name == root_member.requested_device_name()) {
1085 // The specified device and merged set device match, and
1086 // will appear in the GraphDef (for debugging), so just
1087 // print the specified device.
1088 std::vector<Device*> devices_matching_nodedef;
1089 device_set_.FindMatchingDevices(specified_device_name,
1090 &devices_matching_nodedef);
1091 if (devices_matching_nodedef.empty()) {
1092 // Sometimes it is almost impossible to understand the problem
1093 // without a list of available devices.
1094 std::vector<string> device_names;
1095 for (const Device* device : device_set_.devices()) {
1096 device_names.push_back(device->name());
1097 }
1098 std::sort(device_names.begin(), device_names.end());
1099
1100 string gpu_msg = "";
1101 if (!IsGoogleCudaEnabled() &&
1102 absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1103 gpu_msg =
1104 " The requested device appears to be a GPU, but CUDA is not "
1105 "enabled.";
1106 }
1107
1108 return errors::InvalidArgument(
1109 errors::FormatNodeNameForError(node->name()),
1110 " was explicitly assigned to ", node->requested_device(),
1111 " but available devices are [ ",
1112 absl::StrJoin(device_names, ", "), " ]. Make sure ",
1113 "the device specification refers to a valid device.", gpu_msg);
1114 } else if (specified_device_name.has_type) {
1115 return errors::InvalidArgument(
1116 "Could not satisfy explicit device specification '",
1117 node->requested_device(), "' because no supported kernel for ",
1118 specified_device_name.type, " devices is available.", debug_info,
1119 "\nOp: ", node->type_string(),
1120 "\nNode attrs: ", node->attrs().DebugString(),
1121 "\nRegistered kernels:\n",
1122 KernelsRegisteredForOp(node->type_string()));
1123 } else {
1124 return errors::InvalidArgument(
1125 "Could not satisfy explicit device specification '",
1126 node->requested_device(), debug_info);
1127 }
1128 } else {
1129 // The specified device may be a valid device but the
1130 // merged set device is different, so print both.
1131 // TODO(b/129057603): There are many possibilities at this point.
1132 // Provide good error messages.
1133 return errors::InvalidArgument(
1134 "Could not satisfy explicit device specification '",
1135 node->requested_device(), "' because the node ",
1136 errors::FormatColocationNodeForError(node->name()),
1137 " was colocated with a group of nodes that ",
1138 "required incompatible device '",
1139 DeviceNameUtils::ParsedNameToString(
1140 root_member.requested_device_name()),
1141 "'. All available devices [",
1142 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1143 debug_info);
1144 }
1145 }
1146 } else {
1147 // The device is completely unspecified, so enumerate the devices that
1148 // support all of the nodes in the set.
1149 if (device_set_.devices().empty()) {
1150 return errors::Internal("No devices are registered");
1151 }
1152 devices = FilterSupportedDevices(device_set_.devices(),
1153 root_member.supported_device_types(),
1154 default_local_device_);
1155
1156 if (devices.empty()) {
1157 return errors::InvalidArgument(
1158 "Node had no OpKernel registered to support this operation: ",
1159 "Operation was ", node->type_string(), " and inputs were [",
1160 DataTypeVectorString(node->input_types()), "].\n",
1161 DebugInfo(node_root));
1162 }
1163 }
1164
1165 // Cache the result of the possible devices for this node group.
1166 root_member.set_possible_devices(std::move(devices));
1167 *possible_devices = &root_member.possible_devices();
1168 return Status::OK();
1169 }
1170
InitializeMembers()1171 Status ColocationGraph::InitializeMembers() {
1172 for (Node* node : graph_.op_nodes()) {
1173 Status status = InitializeMember(*node, &members_[node->id()]);
1174 if (!status.ok()) {
1175 return AttachDef(status, *node);
1176 }
1177 }
1178 return Status::OK();
1179 }
1180
DebugString() const1181 string ColocationGraph::DebugString() const {
1182 std::unordered_set<int> roots;
1183 std::vector<string> root_strings;
1184 for (const Node* node : graph_.nodes()) {
1185 if (!node->IsOp()) {
1186 continue;
1187 }
1188 int node_root = FindRoot(node->id());
1189 if (roots.count(node_root) == 0) {
1190 root_strings.push_back(DebugInfo(node_root));
1191 roots.insert(node_root);
1192 }
1193 }
1194 return absl::StrJoin(root_strings, "\n");
1195 }
1196
1197 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root) const1198 string ColocationGraph::DebugInfo(const int node_root) const {
1199 string text(
1200 "\nColocation Debug Info:\n"
1201 "Colocation group had the following types and supported devices: ");
1202
1203 // If this node is part of a colocation group, then we want to
1204 // collect the mapping of ops to supported devices, so that
1205 // the user can see why an unsatisfiable placement occurred.
1206
1207 std::unordered_map<string, string> type_to_devices;
1208 std::vector<const Node*> colocation_nodes;
1209 int num_nodes_found = 0;
1210
1211 for (const Node* node : graph_.nodes()) {
1212 if (!node->IsOp()) {
1213 continue;
1214 }
1215 int id = node->id();
1216 if (FindRoot(id) != node_root) {
1217 continue;
1218 }
1219 ++num_nodes_found;
1220 colocation_nodes.push_back(node);
1221
1222 PrioritizedDeviceTypeVector supported_types;
1223 SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1224 &local_address_spec_)
1225 .IgnoreError();
1226 string devices_registered;
1227 for (const auto& device_type : supported_types) {
1228 strings::StrAppend(&devices_registered,
1229 DeviceTypeString(device_type.first), " ");
1230 }
1231
1232 const string& op_type = node->type_string();
1233 type_to_devices[op_type] = std::move(devices_registered);
1234 }
1235 strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1236
1237 for (const auto& td : type_to_devices) {
1238 strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1239 }
1240 strings::StrAppend(&text,
1241 "\n\nColocation members, user-requested devices, and "
1242 "framework assigned devices, if any:");
1243 for (const Node* node : colocation_nodes) {
1244 strings::StrAppend(&text, "\n ", node->name(), " (", node->type_string(),
1245 ") ", node->requested_device());
1246 if (node->has_assigned_device_name()) {
1247 strings::StrAppend(
1248 &text, " framework assigned device=", node->assigned_device_name());
1249 }
1250 }
1251 strings::StrAppend(&text, "\n");
1252
1253 if (num_nodes_found <= 0) {
1254 text.clear();
1255 }
1256 return text;
1257 }
1258
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,Member * member)1259 Status ColocationGraph::InitializeMemberWithAssignedDevice(
1260 const string& assigned_device_name, const string& node_type,
1261 Member* member) {
1262 // This node has already been assigned to a device, so we
1263 // respect this placement, after sanity-checking it.
1264 // NOTE: Since any assignment must have been performed by
1265 // the TensorFlow runtime, we consider errors in this branch to
1266 // be INTERNAL.
1267 TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1268
1269 // Since assigned device must be a full specification, do extra checks.
1270 const Device* assigned_device =
1271 device_set_.FindDeviceByName(assigned_device_name);
1272 if (assigned_device == nullptr) {
1273 // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1274 // calls when they are supported.
1275 return errors::Internal(
1276 "Assigned device '", assigned_device_name,
1277 "' does not match any device. This error can happen when one attempts "
1278 "to run a tf.function with resource inputs residing on remote devices. "
1279 "This use case is currently not supported. Here are the devices "
1280 "available on this machine: [",
1281 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1282 "If you are seeing this error when running using a tf.Session, set "
1283 "experimental.share_cluster_devices_in_session to true in the "
1284 "tf.ConfigProto.");
1285 }
1286
1287 for (const auto& d : member->supported_device_types()) {
1288 if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
1289 return Status::OK();
1290 }
1291 }
1292
1293 return errors::Internal("Assigned device '", assigned_device_name,
1294 "' does not have registered OpKernel support "
1295 "for ",
1296 node_type);
1297 }
1298
InitializeMember(const Node & node,Member * member)1299 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1300 TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1301 node, device_types_, &local_address_spec_));
1302
1303 if (node.has_assigned_device_name()) {
1304 TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1305 node.assigned_device_name(), node.type_string(), member));
1306 } else {
1307 // This node has not yet been assigned to a device, so we
1308 // calculate any constraints due to the set of registered
1309 // kernels and any (partial) user-provided device specification
1310 // in the NodeDef.
1311
1312 // If no kernels are registered for this op type, fail with an error.
1313 if (member->supported_device_types().empty()) {
1314 std::set<string> registered_device_types;
1315 for (Device* d : device_set_.devices()) {
1316 registered_device_types.insert(d->device_type());
1317 }
1318 return errors::InvalidArgument(
1319 "No OpKernel was registered to support Op '", node.type_string(),
1320 "' used by ", errors::FormatNodeNameForError(node.name()),
1321 " with these attrs: [", node.attrs().DebugString(),
1322 "]\n"
1323 "Registered devices: [",
1324 absl::StrJoin(registered_device_types, ", "), "]\n",
1325 "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1326 }
1327
1328 // If the NodeDef contains a device, then we interpret it as a
1329 // (partial) device specification.
1330 if (!node.requested_device().empty()) {
1331 if (IsRefOrResourceGeneratorNode(node)) {
1332 // Treat requested device on resource generating nodes as assigned
1333 // device so that we don't override it.
1334 TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1335 } else {
1336 // The user has specified a device in the NodeDef, try to find a
1337 // valid device matching their specification in the set of
1338 // devices.
1339 // NOTE: The full name may specify a device that is not in
1340 // n.supported_device_types(), but we check that in AssignDevice().
1341 TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1342 }
1343 }
1344 }
1345 return Status::OK();
1346 }
1347
1348 // Returns a list of devices having type in supported_device_types. The
1349 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_local_device)1350 /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1351 const std::vector<Device*>& devices,
1352 const PrioritizedDeviceTypeVector& supported_device_types,
1353 const Device* default_local_device) {
1354 Device* filtered_default_device = nullptr;
1355 PrioritizedDeviceVector prioritized_filtered_devices;
1356 for (const auto& supported_device_type : supported_device_types) {
1357 for (Device* device : devices) {
1358 if (DeviceType(device->attributes().device_type()) ==
1359 supported_device_type.first) {
1360 if (default_local_device &&
1361 (device == default_local_device ||
1362 // TODO(nareshmodi, fishx): At times the device pointer in the
1363 // device set is different to the one passed in as the default
1364 // device. Figure out why this might be.
1365 device->name() == default_local_device->name())) {
1366 filtered_default_device = device;
1367 } else {
1368 prioritized_filtered_devices.emplace_back(
1369 device, supported_device_type.second);
1370 }
1371 }
1372 }
1373 }
1374 DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1375
1376 std::vector<Device*> filtered_devices;
1377 if (filtered_default_device != nullptr) {
1378 filtered_devices.emplace_back(filtered_default_device);
1379 }
1380 for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1381 filtered_devices.push_back(prioritized_filtered_device.first);
1382 }
1383 return filtered_devices;
1384 }
1385
1386 } // namespace tensorflow
1387