1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17
18 #include <memory>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/common_runtime/composite_device.h"
30 #include "tensorflow/core/common_runtime/device.h"
31 #include "tensorflow/core/common_runtime/device_set.h"
32 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
33 #include "tensorflow/core/common_runtime/inspecting_placer.h"
34 #include "tensorflow/core/common_runtime/partitioning_utils.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/attr_value_util.h"
37 #include "tensorflow/core/framework/dataset.h"
38 #include "tensorflow/core/framework/device_attributes.pb.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/graph_node_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/util/device_name_utils.h"
51 #include "tensorflow/core/util/dump_graph.h"
52 #include "tensorflow/core/util/port.h"
53
54 namespace tensorflow {
55
56 namespace {
57
58 // We hoist the conversion from C-style string literal to StringPiece here,
59 // so that we can avoid the many repeated calls to strlen().
60 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
61 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
62
63 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)64 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
65 std::vector<string> v;
66 v.reserve(devices.size());
67 for (Device* d : devices) {
68 v.push_back(d->name());
69 }
70 return v;
71 }
72
73 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)74 std::vector<string> DeviceTypeAndPriorityToString(
75 const PrioritizedDeviceTypeVector& devices) {
76 std::vector<string> v;
77 v.reserve(devices.size());
78 for (const std::pair<DeviceType, int32>& device_and_type : devices) {
79 v.push_back(DeviceTypeString(device_and_type.first));
80 }
81 return v;
82 }
83
IsRefOrResource(DataType data_type)84 bool IsRefOrResource(DataType data_type) {
85 return IsRefType(data_type) || data_type == DT_RESOURCE;
86 }
87
88 // While Placer can override requested device on ops processing
89 // resources, i.e. node that take (and potentially return) a resource,
90 // it must not override requested device on ops generating a resource,
91 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
92 // output nodes.
IsRefOrResourceGeneratorNode(const Node & node)93 bool IsRefOrResourceGeneratorNode(const Node& node) {
94 return node.num_inputs() == 0 && node.num_outputs() == 1 &&
95 IsRefOrResource(node.output_type(0));
96 }
97
IsExemptFromResourceInputColocation(const Node * node)98 bool IsExemptFromResourceInputColocation(const Node* node) {
99 // Note: Partitioned function calls, which place and partition their
100 // function bodies, are exempt from this check: they forward resource and
101 // ref inputs to operations that are appropriately placed, instead of
102 // dereferencing them.
103 const string& op_type = node->op_def().name();
104 auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
105 return exempt_ops.find(op_type) != exempt_ops.end();
106 }
107
HasPriorities(const PrioritizedDeviceTypeVector & device_types)108 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
109 for (const auto& prioritized_device_type : device_types) {
110 if (prioritized_device_type.second != 0) return true;
111 }
112 return false;
113 }
114
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)115 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
116 const PrioritizedDeviceTypeVector& b_types) {
117 if (a_types.size() != b_types.size()) {
118 return false;
119 }
120 for (int i = 0; i < a_types.size(); ++i) {
121 if (a_types[i].first != b_types[i].first) {
122 return false;
123 }
124 }
125 return true;
126 }
127
IsXlaDevice(absl::string_view device_type)128 bool IsXlaDevice(absl::string_view device_type) {
129 if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
130 device_type == "XLA_TPU_JIT") {
131 // Symbolic XLA device.
132 return true;
133 }
134
135 return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
136 device_type == "TPU");
137 }
138
IsCompositeDevice(absl::string_view device_type)139 bool IsCompositeDevice(absl::string_view device_type) {
140 return device_type == kCompositeDeviceType;
141 }
142
143 } // namespace
144
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types,const DeviceNameUtils::ParsedName * local_address_spec)145 Status Member::SetParentAndSupportedDevices(
146 const Node& node, const std::vector<DeviceType>& types,
147 const DeviceNameUtils::ParsedName* local_address_spec) {
148 int id = node.id();
149 if (id < 0) {
150 return errors::Internal("Placer should not be creating a Member for node: ",
151 node.DebugString());
152 }
153 parent_ = id;
154 return SupportedDeviceTypesForNode(
155 types, node.def(), &supported_device_types_, local_address_spec);
156 }
157
SetAssignedDeviceName(const string & device_name)158 Status Member::SetAssignedDeviceName(const string& device_name) {
159 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
160 return errors::Internal(
161 "Setting assigned device name when there is a requested device set "
162 "is unsupported");
163 }
164 if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
165 return errors::Internal("Malformed assigned device '", device_name, "'");
166 }
167 // Set requested device to assigned_device to maintain the invariant that
168 // requested is a specialization of assigned.
169 requested_device_name_ = assigned_device_name_;
170 return Status::OK();
171 }
172
SetResourceDeviceName(const Node & node)173 Status Member::SetResourceDeviceName(const Node& node) {
174 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
175 return errors::Internal(
176 "Setting resource device name when there is a requested device set "
177 "is unsupported");
178 }
179
180 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
181 &resource_device_name_)) {
182 return errors::InvalidArgument("Malformed device specification '",
183 node.requested_device(),
184 "' in node: ", node.DebugString());
185 }
186
187 // Set requested device to resource device to maintain the invariant that
188 // requested is a specialization of resource.
189 requested_device_name_ = resource_device_name_;
190 return Status::OK();
191 }
192
SetRequestedDeviceName(const Node & node)193 Status Member::SetRequestedDeviceName(const Node& node) {
194 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
195 return errors::Internal(
196 "Setting requested device name when there is an assigned device set "
197 "is unsupported");
198 }
199 if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
200 return errors::Internal(
201 "Setting requested device name when there is a resource device set "
202 "is unsupported");
203 }
204 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
205 &requested_device_name_)) {
206 return errors::InvalidArgument("Malformed device specification '",
207 node.requested_device(),
208 "' in node: ", node.DebugString());
209 }
210 return Status::OK();
211 }
212
FillPossibleDevices(PossibleDevices * possible_device) const213 Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
214 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
215 return errors::Internal(
216 "Cannot fill PossibleDevices from a member that has non-empty assigned "
217 "device. Did we start assigning devices to functions called by deep "
218 "ops? ",
219 DebugString());
220 }
221 possible_device->requested_device_name = requested_device_name_;
222 possible_device->resource_device_name = resource_device_name_;
223 possible_device->device_types = supported_device_types_;
224 return Status::OK();
225 }
226
IsEdgeFromCompositeDeviceToPhysicalDevice(const Member & src_root) const227 bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice(
228 const Member& src_root) const {
229 auto compatible_edge_from_composite_device_to_physical_device =
230 [](const DeviceNameUtils::ParsedName& src_device,
231 const DeviceNameUtils::ParsedName& dst_device) -> bool {
232 return src_device.has_type && dst_device.has_type &&
233 IsCompositeDevice(src_device.type) &&
234 !IsCompositeDevice(dst_device.type);
235 };
236 if (compatible_edge_from_composite_device_to_physical_device(
237 src_root.assigned_device_name_, assigned_device_name_) ||
238 compatible_edge_from_composite_device_to_physical_device(
239 src_root.resource_device_name_, resource_device_name_) ||
240 compatible_edge_from_composite_device_to_physical_device(
241 src_root.requested_device_name_, requested_device_name_)) {
242 return true;
243 }
244 return false;
245 }
246
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)247 Status Member::EnsureCompatibilityAcrossResourceEdge(
248 const Node& src, const Member& src_root,
249 const Node& dst, /*dst_root is this*/
250 bool log_device_placement) {
251 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
252 assigned_device_name_)) {
253 return errors::InvalidArgument(
254 "Cannot place the graph because a reference or resource edge "
255 "connects colocation groups with incompatible assigned devices: ",
256 DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
257 " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
258 ". The edge src node is ", src.name(), " , and the dst node is ",
259 dst.name());
260 }
261
262 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
263 resource_device_name_)) {
264 return errors::InvalidArgument(
265 "Cannot place the graph because a reference or resource edge "
266 "connects colocation groups with incompatible resource devices: ",
267 DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
268 " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
269 ". The edge src node is ", src.name(), " , and the dst node is ",
270 dst.name());
271 }
272
273 if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
274 requested_device_name_)) {
275 return Status::OK();
276 }
277
278 // If we are here, assigned and resource devices are compatible but requested
279 // ones are not. We will be overriding the requested device for destination
280 // node, but need to preserve the invariant that it will be a specialization
281 // of the assigned and resource devices.
282 if (log_device_placement) {
283 LOG(INFO) << "Ignoring device specification "
284 << DeviceNameUtils::ParsedNameToString(requested_device_name_)
285 << " for node '" << dst.name()
286 << "' because the input edge from '" << src.name()
287 << "' is a reference connection and already has a device "
288 "field set to "
289 << DeviceNameUtils::ParsedNameToString(
290 src_root.requested_device_name_);
291 }
292 requested_device_name_ = src_root.requested_device_name_;
293 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
294 assigned_device_name_);
295 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
296 resource_device_name_);
297 return Status::OK();
298 }
299
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)300 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
301 Member** new_root, Member** old_root, bool dry_run) {
302 Member& x_root_member = (*tree)[x_root];
303 Member& y_root_member = (*tree)[y_root];
304
305 // Merge the sets by setting the parent pointer of the smaller tree's root
306 // node to point to the root of the larger tree. Together with path
307 // compression in ColocationGraph::FindRoot, this ensures that we do not
308 // experience pathological performance on graphs such as chains.
309 int new_root_id, old_root_id;
310 if (x_root_member.rank_ < y_root_member.rank_) {
311 // The tree rooted at x_root is shallower, so connect it to
312 // y_root. The rank of y_root is unchanged because its new
313 // child has strictly less rank.
314 if (!dry_run) {
315 x_root_member.parent_ = y_root;
316 }
317 new_root_id = y_root;
318 old_root_id = x_root;
319 } else if (x_root_member.rank_ > y_root_member.rank_) {
320 // The tree rooted at y_root is shallower, so connect it to
321 // x_root. The rank of x_root is unchanged because its new
322 // child has strictly less rank.
323 if (!dry_run) {
324 y_root_member.parent_ = x_root;
325 }
326 new_root_id = x_root;
327 old_root_id = y_root;
328 } else {
329 if (!dry_run) {
330 // Both trees have the same rank, so break the tie by choosing
331 // x_root as the new root.
332 y_root_member.parent_ = x_root;
333 // Increment the rank of the tree rooted at x_root, because it
334 // is now strictly deeper than before.
335 ++x_root_member.rank_;
336 }
337 new_root_id = x_root;
338 old_root_id = y_root;
339 }
340
341 *new_root = &(*tree)[new_root_id];
342 *old_root = &(*tree)[old_root_id];
343 }
344
345 // tree is non-const because we can change some `parent` pointers in some
346 // members for more efficient future lookups. The vector itself is not
347 // changed.
FindAndUpdateRoot(std::vector<Member> * tree,int node_id)348 int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
349 Member& member = (*tree)[node_id];
350 if (member.parent_ == node_id) {
351 // member.parent is the root of this disjoint tree. Do nothing.
352 } else {
353 member.parent_ = FindAndUpdateRoot(tree, member.parent_);
354 }
355 // Now it is guaranteed that member.parent is the root of this disjoint
356 // tree.
357 return member.parent_;
358 }
359
FindRoot(const std::vector<Member> & tree,int node_id)360 int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
361 const Member& member = tree[node_id];
362 if (member.parent_ == node_id) {
363 return member.parent_;
364 }
365 return FindRoot(tree, member.parent_);
366 }
367
MergeDeviceNames(const Member & other,bool allow_soft_placement)368 Status Member::MergeDeviceNames(const Member& other,
369 bool allow_soft_placement) {
370 // Assuming the "requested is a specialization of assigned and resource
371 // devices" invariant holds for this and `other`, it will hold after the
372 // merges below.
373 DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
374 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
375 &assigned_device_name_copy, other.assigned_device_name_));
376
377 DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
378 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
379 &resource_device_name_copy, other.resource_device_name_));
380
381 DeviceNameUtils::ParsedName requested_device_name_copy =
382 requested_device_name_;
383 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
384 &requested_device_name_copy, other.requested_device_name_,
385 allow_soft_placement));
386
387 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
388 assigned_device_name_copy);
389 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
390 resource_device_name_copy);
391
392 // We checked for all errors, now change the devices.
393 assigned_device_name_ = assigned_device_name_copy;
394 resource_device_name_ = resource_device_name_copy;
395 requested_device_name_ = requested_device_name_copy;
396 return Status::OK();
397 }
398
399 // Updates this to contain the intersection of the device types in
400 // this and "other".
MergeSupportedDevices(const Member & other)401 bool Member::MergeSupportedDevices(const Member& other) {
402 return MergeSupportedDevices(other.supported_device_types_);
403 }
404
MergeSupportedDevices(const PrioritizedDeviceTypeVector & other_devices)405 bool Member::MergeSupportedDevices(
406 const PrioritizedDeviceTypeVector& other_devices) {
407 // Generate intersection with priorities.
408 // Each vector contains the same device types but with different priorities.
409 // The priorities are taken from the corresponding source vector.
410 PrioritizedDeviceTypeVector target_intersection;
411 PrioritizedDeviceTypeVector other_intersection;
412
413 for (const auto& prioritized_device_type : supported_device_types_) {
414 bool found = false;
415 for (const auto& other_prioritized_device_type : other_devices) {
416 if (prioritized_device_type.first ==
417 other_prioritized_device_type.first) {
418 found = true;
419 other_intersection.push_back(other_prioritized_device_type);
420 break;
421 }
422 }
423 if (found) {
424 target_intersection.push_back(prioritized_device_type);
425 }
426 }
427
428 DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
429 DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
430
431 PrioritizedDeviceTypeVector result;
432
433 bool is_target_prioritized = HasPriorities(target_intersection);
434 bool is_other_prioritized = HasPriorities(other_intersection);
435 if (!is_target_prioritized && !is_other_prioritized) {
436 // If neither are prioritized then we just return the original i.e. target
437 // prioritization.
438 result = target_intersection;
439 } else if (is_target_prioritized && !is_other_prioritized) {
440 // If only one is prioritized, then we respect priorities of that in the
441 // intersection.
442 result = target_intersection;
443 } else if (!is_target_prioritized && is_other_prioritized) {
444 result = other_intersection;
445 } else {
446 // If both have priorities and agree then we go with that. If the
447 // prioritization order is different, then we just fallback to the default
448 // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
449 // merged priorities to 0, so that downstream merges work correctly as well.
450 if (ArePrioritiesSame(target_intersection, other_intersection)) {
451 result = target_intersection;
452 } else {
453 for (const auto& prioritized_device : target_intersection) {
454 result.push_back(std::make_pair(prioritized_device.first, 0));
455 }
456 DeviceSet::SortPrioritizedDeviceTypeVector(&result);
457 }
458 }
459
460 if (result.empty()) {
461 return false;
462 }
463 supported_device_types_ = result;
464 return true;
465 }
466
AssignDevice(const Node & node)467 Status Member::AssignDevice(const Node& node) {
468 if (node.assigned_device_name_index() == assigned_device_name_index_) {
469 return Status::OK();
470 }
471
472 DeviceNameUtils::ParsedName parsed;
473 DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
474 Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
475 if (!s.ok()) {
476 return errors::Internal(
477 "Constraining by assigned device should not cause an error. Original "
478 "root's assigned device name: ",
479 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
480 " node's assigned device name \"", node.assigned_device_name(),
481 ". Error: ", s.error_message());
482 }
483 s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
484 if (!s.ok()) {
485 return errors::Internal(
486 "Constraining by assigned device should not cause an error. Original "
487 "root's resource device name: ",
488 DeviceNameUtils::ParsedNameToString(resource_device_name_),
489 " node's assigned device name \"", node.assigned_device_name(),
490 ". Error: ", s.error_message());
491 }
492 s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
493 if (!s.ok()) {
494 return errors::Internal(
495 "Constraining by assigned device should not cause an error. Original "
496 "root's requested device name: \"",
497 DeviceNameUtils::ParsedNameToString(requested_device_name_),
498 "\", node's assigned device name \"", node.assigned_device_name(),
499 "\". Error: ", s.error_message());
500 }
501
502 assigned_device_name_index_ = node.assigned_device_name_index();
503 // Clear cached possible_devices, if any.
504 possible_devices_.clear();
505 return Status::OK();
506 }
507
MaybeExcludeXlaDevices()508 void Member::MaybeExcludeXlaDevices() {
509 for (const auto& parsed_name :
510 {requested_device_name_, assigned_device_name_, resource_device_name_}) {
511 // Don't exculde XLA devices from supported devices if member is explicitly
512 // assigned to a CompositeDevice.
513 if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) ||
514 IsCompositeDevice(parsed_name.type))) {
515 return;
516 }
517 }
518
519 PrioritizedDeviceTypeVector non_xla_types;
520 absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
521 [&](const std::pair<DeviceType, int32>& entry) {
522 return !IsXlaDevice(entry.first.type_string());
523 });
524
525 // TODO(b/141216278) Remove all XLA device types from the supported device
526 // types if the node has no requested/assigned/resource XLA device.
527 if (!non_xla_types.empty() &&
528 non_xla_types.size() < supported_device_types_.size()) {
529 supported_device_types_ = std::move(non_xla_types);
530 }
531 }
532
LimitToPossibleDevices(const PossibleDevices & devices,bool allow_soft_placement)533 Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
534 bool allow_soft_placement) {
535 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
536 &requested_device_name_, devices.requested_device_name,
537 allow_soft_placement));
538 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
539 &resource_device_name_, devices.resource_device_name));
540 MergeSupportedDevices(devices.device_types);
541 return Status::OK();
542 }
543
DebugString() const544 string Member::DebugString() const {
545 return absl::StrCat(
546 "Member(assigned_device_name_index_=", assigned_device_name_index_,
547 " requested_device_name_='",
548 DeviceNameUtils::ParsedNameToString(requested_device_name_),
549 "' assigned_device_name_='",
550 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
551 "' resource_device_name_='",
552 DeviceNameUtils::ParsedNameToString(resource_device_name_),
553 "' supported_device_types_=[",
554 absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
555 ", "),
556 "] possible_devices_=[",
557 absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
558 }
559
GetSoftDeviceName() const560 DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
561 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
562 if (!assigned_device_name_.has_type) {
563 soft_device_name.type.clear();
564 soft_device_name.has_type = false;
565 }
566 if (!assigned_device_name_.has_id) {
567 soft_device_name.has_id = false;
568 }
569 return soft_device_name;
570 }
571
GetPreferredSoftDeviceName() const572 DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
573 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
574 if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
575 soft_device_name.type.clear();
576 soft_device_name.has_type = false;
577 }
578 if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
579 soft_device_name.has_id = false;
580 }
581 return soft_device_name;
582 }
583
584 // Returns ParsedName whose address space (i.e. job, replica, task) identifies
585 // the address space directly accessible by the local process. If the address
586 // space is fully specified and it is exactly the same as the address space
587 // of a device, then all kernels of that device should be registered in the
588 // local process.
LocalAddressSpec(const Device * client_device,const Device * default_local_device)589 static const DeviceNameUtils::ParsedName LocalAddressSpec(
590 const Device* client_device, const Device* default_local_device) {
591 if (client_device != nullptr) {
592 return DeviceNameUtils::AddressSpace(client_device->parsed_name());
593 }
594
595 if (default_local_device != nullptr) {
596 return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
597 }
598
599 // TODO(b/139617593) Return the name of the first local device in device_set_
600 // once we can trust the output of Device::IsLocal().
601 return DeviceNameUtils::ParsedName();
602 }
603
ColocationGraph(const Graph * graph,const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)604 ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
605 const FunctionLibraryDefinition* flib_def,
606 const DeviceSet* device_set,
607 const Device* default_local_device,
608 bool allow_soft_placement,
609 bool log_device_placement)
610 : graph_(*graph),
611 stack_(stack),
612 flib_def_(*flib_def),
613 inspecting_placer_(stack, flib_def, device_set, default_local_device,
614 allow_soft_placement, log_device_placement),
615 inspection_required_checker_(graph, flib_def),
616 device_set_(*device_set),
617 device_types_(device_set->PrioritizedDeviceTypeList()),
618 local_address_spec_(
619 LocalAddressSpec(device_set->client_device(), default_local_device)),
620 default_local_device_(default_local_device),
621 allow_soft_placement_(allow_soft_placement),
622 log_device_placement_(log_device_placement) {
623 members_.resize(graph_.num_node_ids());
624 }
625
626 // Adds each node of the Graph to this ColocationGraph as a singleton.
627 //
628 // NOTE: The implementation assumes that the ids of nodes passed to
629 // this method are dense and zero-based; the memory used will be linear in
630 // the largest node ID.
631 // NOTE: If this method returns an error, *this is left in an undefined
632 // state.
ColocateAllNodes()633 Status ColocationGraph::ColocateAllNodes() {
634 // This maps from a colocation group identifier to the 'root' of that
635 // colocation group. Note that the keys in this map are StringPiece; the
636 // actual strings are stored under the NodeDef. The lifetime of this map
637 // is limited to this ColocateAllNodes() method, and no part of the
638 // NodeDef trees are changed during the lifetime of this method, so using
639 // StringPiece as a key is safe.
640 //
641 // Also, as a further optimization, we remove the "loc:@" prefix from
642 // "class" attribute values, when they are used as keys in this table.
643 // This allows us to use StringPiece values that refer to substrings of
644 // 'string' values stored in NodeDef attribute lists, as well as StringPiece
645 // values that refer to 'string' values from NodeDef::name(), without
646 // performing any string allocations.
647 std::unordered_map<StringPiece, const Node*, StringPieceHasher>
648 colocation_group_root;
649
650 for (const Node* node : graph_.op_nodes()) {
651 // When adding the node, identify whether it is part of a colocation
652 // group.
653
654 // This code is effectively the equivalent of GetNodeAttr() for a string
655 // array, but it avoids all internal allocations (the allocation of the
656 // backing store of the std::vector<string> as well as the copies of the
657 // strings within it). Instead, we combine the query of the colocation
658 // attribute with the calls to ColocateNodeToGroup.
659 const AttrValue* attr_value =
660 node->attrs().Find(kColocationAttrNameStringPiece);
661 if (attr_value != nullptr) {
662 if (attr_value->has_list()) {
663 for (const string& class_spec : attr_value->list().s()) {
664 StringPiece spec(class_spec);
665 if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
666 TF_RETURN_IF_ERROR(
667 ColocateNodeToGroup(&colocation_group_root, node, spec));
668 }
669 }
670 } else if (!attr_value->s().empty()) {
671 LOG(ERROR) << "The value for colocation attribute '_class' must be a "
672 "list of strings, not a single string: "
673 << node->DebugString();
674 }
675 }
676
677 // Each node belongs to a colocation group with the node's name.
678 TF_RETURN_IF_ERROR(
679 ColocateNodeToGroup(&colocation_group_root, node, node->name()));
680 }
681
682 return Status::OK();
683 }
684
ColocateResourceOrRefEdge(const Node * src,const Node * dst)685 Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
686 const Node* dst) {
687 // Colocate `src` and `dst` to maintain the invariant that nodes
688 // connected by reference edges are colocated.
689 int src_root_id = FindAndUpdateRoot(src->id());
690 int dst_root_id = FindAndUpdateRoot(dst->id());
691 auto& src_root = members_[src_root_id];
692 auto& dst_root = members_[dst_root_id];
693
694 if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) {
695 // If the src root is assigned to a composite device and the dst root is
696 // assigned to a physical device, don't colocate the dst root with the src
697 // root.
698 return Status::OK();
699 }
700 TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
701 *src, src_root, *dst, log_device_placement_));
702 Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
703 if (!status.ok()) {
704 return AttachDef(
705 errors::InvalidArgument(
706 "Nodes were connected by a reference or resource connection "
707 "(requiring them to be on the same device), but the two nodes "
708 "were assigned two different devices: ",
709 status.error_message()),
710 *dst);
711 }
712 return Status::OK();
713 }
714
ColocateResourceAndRefEdges(std::unordered_set<Node * > * inspection_required)715 Status ColocationGraph::ColocateResourceAndRefEdges(
716 std::unordered_set<Node*>* inspection_required) {
717 // If `node` has an input edge with reference type, add an edge from the
718 // source of that edge to `node`.
719 for (const Edge* edge : graph_.edges()) {
720 if (edge->IsControlEdge()) {
721 continue;
722 }
723 Node* src = edge->src();
724 Node* dst = edge->dst();
725 bool needs_inspection;
726 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
727 *src, &needs_inspection));
728 if (needs_inspection) {
729 inspection_required->insert(src);
730 continue;
731 }
732 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
733 *dst, &needs_inspection));
734 if (needs_inspection) {
735 inspection_required->insert(dst);
736 continue;
737 }
738
739 DataType input_type = dst->input_type(edge->dst_input());
740
741 // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
742 // This is needed to get around the issue in b/135705778.
743 if (input_type == DT_VARIANT &&
744 data::DatasetOpKernel::IsDatasetOp(&src->op_def()) &&
745 data::DatasetOpKernel::IsDatasetOp(&dst->op_def())) {
746 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
747 continue;
748 }
749
750 // Even though we can look inside function calling ops, we make an exception
751 // here mostly for performance reasons. Looking inside function calling ops
752 // is extra overhead. It is only necessary when they return resources. When
753 // they don't, we don't look inside them and make this exception here.
754 // Looking inside, could potentially enable us to make better placement
755 // decisions. It might be worth doing at some point.
756 if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
757 !IsExemptFromResourceInputColocation(dst)) {
758 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
759 }
760 }
761
762 return Status::OK();
763 }
764
765 namespace {
766 // Returns tensor list element data type, if the node is one of the ops that
767 // operate with TensorLists. Otherwise returns DT_INVALID.
GetElementDataType(const Node & node)768 DataType GetElementDataType(const Node& node) {
769 static absl::flat_hash_set<std::string>* tensor_list_ops =
770 new absl::flat_hash_set<std::string>(
771 {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
772 "TensorListSplit", "TensorListScatter", "TensorListScatterV2",
773 "TensorListScatterIntoExistingList", "TensorListPushBack",
774 "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
775 "TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
776 "TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
777
778 if (tensor_list_ops->contains(node.type_string())) {
779 DataType element_type;
780 if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
781 return element_type;
782 }
783 }
784
785 return DT_INVALID;
786 }
787 } // namespace
788
AddHostOnlyDataTypesConstraints()789 Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
790 auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
791
792 auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
793 return entry.first == DEVICE_CPU;
794 };
795
796 for (Node* node : graph_.nodes()) {
797 // Skip nodes that do not have DT_VARIANT inputs.
798 if (absl::c_none_of(node->input_types(), is_variant)) continue;
799
800 // Skip nodes that can't be placed on GPU anyway.
801 Member& root = members_[FindAndUpdateRoot(node->id())];
802 if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue;
803
804 // Stop DFS traversal when found the underlying data type of a variant.
805 absl::optional<bool> is_host_data_type;
806
807 auto edge_filter = [&](const Edge& edge) -> bool {
808 // We already found the underlying data type.
809 if (is_host_data_type.has_value()) return false;
810
811 // Otherwise follow only DT_VARIANT data edges.
812 auto edge_dtype = [&]() -> DataType {
813 return edge.src()->output_type(edge.src_output());
814 };
815 return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
816 };
817
818 auto enter = [&](Node* n) -> void {
819 DataType element_type = GetElementDataType(*n);
820 // To handle nested lists continue traversal after finding a TensorList
821 // operation that uses DT_VARIANT for element type.
822 if (element_type == DT_INVALID || element_type == DT_VARIANT) return;
823 is_host_data_type = DataTypeAlwaysOnHost(element_type);
824 };
825
826 ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
827 /*stable_comparator=*/nullptr, edge_filter);
828
829 if (is_host_data_type.has_value() && *is_host_data_type) {
830 VLOG(2) << "Limit node possible devices to CPU only, because it has a "
831 "DT_VARIANT input with host-only underlying data type: "
832 << "node=" << node->name();
833
834 // Restrict possible device types to CPU only.
835 PossibleDevices possible_devices;
836 absl::c_copy_if(root.supported_device_types(),
837 std::back_inserter(possible_devices.device_types),
838 is_cpu_device);
839
840 TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
841 possible_devices, /*allow_soft_placement=*/false));
842 }
843 }
844
845 return Status::OK();
846 }
847
AddInspectionConstraints(const std::unordered_set<Node * > & inspection_required)848 Status ColocationGraph::AddInspectionConstraints(
849 const std::unordered_set<Node*>& inspection_required) {
850 for (Node* node : inspection_required) {
851 IOColocationGroups groups;
852 TF_RETURN_IF_ERROR(
853 inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
854 VLOG(2) << "Computed IOColocationGroups for node " << node->name()
855 << ":\n\t" << groups.DebugString();
856 TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
857 }
858 return Status::OK();
859 }
860
Initialize()861 Status ColocationGraph::Initialize() {
862 TF_RETURN_IF_ERROR(InitializeMembers());
863
864 std::unordered_set<Node*> inspection_required;
865 TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
866 TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
867 TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
868 TF_RETURN_IF_ERROR(ColocateAllNodes());
869
870 for (Node* node : graph_.op_nodes()) {
871 int root_id = FindAndUpdateRoot(node->id());
872 members_[root_id].MaybeExcludeXlaDevices();
873 }
874
875 return Status::OK();
876 }
877
878 // pair containing a node and whether this node has a resource input
879 // from the node requiring placer inspection.
880 using NodeAndBool = std::pair<const Node*, bool>;
881
882 namespace {
883
884 // Returns a vector of node names from `nodes`.
NodeAndBoolToString(const std::vector<NodeAndBool> & nodes)885 std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
886 std::vector<string> v;
887 v.reserve(nodes.size());
888 for (const NodeAndBool& node_and_bool : nodes) {
889 v.push_back(node_and_bool.first->name());
890 }
891 return v;
892 }
893
894 // Given a node requiring placer inspection and its IOColocationGroups,
895 // computes `group_nodes`.
896 // group_nodes[i] contains the nodes that are members of colocation
897 // group i. These nodes are inputs or outputs of `node`.
898 // group_nodes[i][j] is a pair containing a node and whether this node
899 // has a resource input from `node`.
900 // Note:
901 // The same node can be added multiple times to the same group.
902 // The same node can be added to multiple groups.
GetGroupNodes(const IOColocationGroups & groups,const Node & node,std::vector<std::vector<NodeAndBool>> * group_nodes)903 Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
904 std::vector<std::vector<NodeAndBool>>* group_nodes) {
905 group_nodes->reserve(groups.group_devices.size());
906 for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
907 const Node* src;
908 TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
909 int group_id = groups.input_groups[arg_idx];
910 (*group_nodes)[group_id].emplace_back(src, false);
911 }
912
913 for (const Edge* edge : node.out_edges()) {
914 if (edge->IsControlEdge()) {
915 continue;
916 }
917
918 int group_id = groups.output_groups[edge->src_output()];
919 (*group_nodes)[group_id].emplace_back(
920 edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
921 }
922
923 if (VLOG_IS_ON(2)) {
924 VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
925 for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
926 VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
927 << "]";
928 }
929 }
930 return Status::OK();
931 }
932
933 // Returns whether the device_type in `device_attributes` is supported.
IsSupportedDeviceType(const DeviceAttributes & device_attributes,const DeviceType & supported_type)934 bool IsSupportedDeviceType(const DeviceAttributes& device_attributes,
935 const DeviceType& supported_type) {
936 if (DeviceType(device_attributes.device_type()) == supported_type) {
937 return true;
938 }
939 return IsCompositeDevice(device_attributes.device_type());
940 }
941
942 } // namespace
943
ApplyIOColocationGroups(const IOColocationGroups & groups,const Node & node)944 Status ColocationGraph::ApplyIOColocationGroups(
945 const IOColocationGroups& groups, const Node& node) {
946 if (groups.input_groups.size() != node.num_inputs()) {
947 return errors::Internal(
948 "Cannot apply input/output device constraints to node ",
949 node.DebugString(), " because input_groups.size() (",
950 groups.input_groups.size(),
951 ") is different from number of inputs into the op node (",
952 node.num_inputs(), ")");
953 }
954 if (groups.output_groups.size() != node.num_outputs()) {
955 return errors::Internal(
956 "Cannot apply input/output device constraints to node ",
957 node.DebugString(), " because output_groups.size() (",
958 groups.output_groups.size(),
959 ") is different from number of outputs into the op node (",
960 node.num_outputs(), ")");
961 }
962
963 // group_nodes[i] contains the nodes that are members of colocation
964 // group i. These nodes are inputs or outputs of `node`.
965 // group_nodes[i][j] is a pair containing the node and whether this node
966 // has a resource input from `node`.
967 // The same node can be added multiple times to the same group.
968 // The same node can be added to multiple groups.
969 // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
970 std::vector<std::vector<NodeAndBool>> group_nodes(
971 groups.group_devices.size());
972 TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
973
974 // Colocate nodes in each group
975 for (const std::vector<NodeAndBool>& nodes : group_nodes) {
976 for (int i = 1; i < nodes.size(); ++i) {
977 VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
978 << nodes[i].first->name() << "\"";
979 if (nodes[i].second) {
980 TF_RETURN_IF_ERROR(
981 ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
982 } else {
983 TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
984 }
985 }
986 }
987
988 // Limit devices in each group
989 for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
990 // Nothing to do for empty groups. Groups can be empty if some output
991 // of an op is not used.
992 if (group_nodes[group_id].empty()) {
993 continue;
994 }
995 const Node* group_node = group_nodes[group_id][0].first;
996 const PossibleDevices& possible_devices = groups.group_devices[group_id];
997 TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
998 }
999
1000 return Status::OK();
1001 }
1002
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)1003 Status ColocationGraph::ColocateNodeToGroup(
1004 std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
1005 colocation_group_root,
1006 const Node* node, StringPiece colocation_group) {
1007 const Node*& root_node = (*colocation_group_root)[colocation_group];
1008 if (root_node == nullptr) {
1009 // This is the first node of the colocation group, so
1010 // designate this node as the 'root' of that colocation group.
1011 root_node = node;
1012 } else {
1013 // Try to colocate the node with the root. If there is an
1014 // error, return it.
1015 Status s = ColocateNodes(*node, *root_node);
1016 if (!s.ok()) {
1017 if (!allow_soft_placement_) {
1018 return AttachDef(s, *node);
1019 }
1020 if (log_device_placement_) {
1021 LOG(INFO) << "Ignoring request to colocate node '" << node->name()
1022 << "' with nodes in colocation group '" << colocation_group
1023 << "' because soft placement is on and an attempt at doing "
1024 "so resulted in the following error: "
1025 << AttachDef(s, *node).ToString();
1026 }
1027 }
1028 }
1029 return Status::OK();
1030 }
1031
1032 // Merge the (possibly disjoint) sets containing nodes "x" and
1033 // "y". Returns OK if the all nodes in the union of these sets can
1034 // be placed on the same device type.
1035 //
1036 // NOTE: If this method returns an error, *this is left in an undefined
1037 // state.
ColocateNodes(const Node & x,const Node & y)1038 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
1039 int x_root = FindAndUpdateRoot(x.id());
1040 int y_root = FindAndUpdateRoot(y.id());
1041 return ColocateNodes(x, x_root, y, y_root);
1042 }
1043
1044 // This overload of ColocateNodes() allows a caller to provide the root node
1045 // ids for the two nodes. For large graphs, this noticeably reduces the
1046 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)1047 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
1048 int y_root) {
1049 if (x_root == y_root) {
1050 return Status::OK();
1051 }
1052
1053 Member* new_root_member;
1054 Member* old_root_member;
1055 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1056 /*dry_run=*/true);
1057
1058 // Merge the partial device specifications, and ensure that they are
1059 // compatible. NULL options_ is treated as allowing soft placement.
1060 // If there is an error, nothing is modified.
1061 // TODO(mrry): Consider enriching the error message by pointing
1062 // out which nodes have the explicit partial device
1063 // specifications that caused this conflict.
1064 Status s = new_root_member->MergeDeviceNames(*old_root_member,
1065 allow_soft_placement_);
1066 if (!s.ok()) {
1067 return errors::InvalidArgument(
1068 "Cannot colocate nodes ",
1069 errors::FormatColocationNodeForError(x.name()), " and ",
1070 errors::FormatColocationNodeForError(y.name()), ": ",
1071 s.error_message());
1072 }
1073
1074 // Ensure that the common root has at least one supported device
1075 // type, by computing the intersection of
1076 // new_root_member.supported_device_types and
1077 // old_root_member.supported_device_types.
1078 if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
1079 return errors::InvalidArgument(
1080 "Cannot colocate nodes ",
1081 errors::FormatColocationNodeForError(x.name()), " and ",
1082 errors::FormatColocationNodeForError(y.name()),
1083 " because no device type supports both of those nodes and the "
1084 "other nodes colocated with them.",
1085 DebugInfo(x_root), DebugInfo(y_root));
1086 }
1087
1088 // All error checks are done, merge the colocation graphs.
1089 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1090 /*dry_run=*/false);
1091 return Status::OK();
1092 }
1093
LimitToAssignedDevice(const Node & node)1094 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
1095 if (node.assigned_device_name_index() < 0) {
1096 return errors::Internal(
1097 "Expected an assigned node as argument to LimitToAssignedDevice but "
1098 "got: ",
1099 node.DebugString());
1100 }
1101 int root = FindAndUpdateRoot(node.id());
1102 Member& root_member = members_[root];
1103 return root_member.AssignDevice(node);
1104 }
1105
GetSoftDeviceCandidates(const Node & node,const Member & root_member,int root_id,std::vector<Device * > * possible_devices)1106 void ColocationGraph::GetSoftDeviceCandidates(
1107 const Node& node, const Member& root_member, int root_id,
1108 std::vector<Device*>* possible_devices) {
1109 // Try to find supported devices that don't violate resource devices.
1110 // The soft_device_name is the same as the requested device name
1111 // without specifying the device type or ID (if assigned and requested
1112 // devices does not specify them).
1113 DeviceNameUtils::ParsedName soft_device_name =
1114 root_member.GetPreferredSoftDeviceName();
1115 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1116 if (!possible_devices->empty()) {
1117 *possible_devices = FilterSupportedDevices(
1118 *possible_devices, root_member.supported_device_types(),
1119 default_local_device_);
1120 }
1121
1122 if (!possible_devices->empty()) {
1123 return;
1124 }
1125
1126 // TODO(iga): Disallow changing resource devices when this ColocationGraph
1127 // is for :
1128 // - a function called by an op requiring deep inspection, or
1129 // - a graph containing ops requiring inspection.
1130 // It is fairly tricky to make changing resource devices in presence of
1131 // ops requiring inspection work correctly. One thing it would require is to
1132 // communicate these "resource movement" decisions across Placer instances.
1133
1134 // Failed to find supported devices that don't violate resource devices.
1135 // Try finding some devices that violated resource devices.
1136 // If we succeed, we will log a warning below.
1137 soft_device_name = root_member.GetSoftDeviceName();
1138 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1139 if (!possible_devices->empty()) {
1140 *possible_devices = FilterSupportedDevices(
1141 *possible_devices, root_member.supported_device_types(),
1142 default_local_device_);
1143 }
1144
1145 if (!possible_devices->empty()) {
1146 LOG(WARNING)
1147 << "Failed to place the graph without changing the devices of some "
1148 "resources. Some of the operations (that had to be colocated with "
1149 "resource generating operations) are not supported on the "
1150 "resources' devices. Current candidate devices are [\n "
1151 << absl::StrJoin(DevicesToString(*possible_devices), "\n ")
1152 << "].\nSee below for details of this colocation group:"
1153 << DebugInfo(root_id);
1154 }
1155 }
1156
LimitToPossibleDevices(const Node & node,const PossibleDevices & devices)1157 Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1158 const PossibleDevices& devices) {
1159 int root = FindAndUpdateRoot(node.id());
1160 Member& root_member = members_[root];
1161 return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1162 }
1163
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)1164 Status ColocationGraph::GetDevicesForNode(
1165 Node* node, const std::vector<Device*>** possible_devices) {
1166 *possible_devices = nullptr;
1167 const int node_root = FindAndUpdateRoot(node->id());
1168 if (!members_[node_root].possible_devices().empty()) {
1169 *possible_devices = &members_[node_root].possible_devices();
1170 return Status::OK();
1171 }
1172
1173 Member& root_member = members_[node_root];
1174
1175 // We have not yet computed the possible devices for the
1176 // colocated node set containing 'node', so we do so now using the
1177 // constraints on the root node.
1178
1179 // "devices" will contain the set of feasible placements for the
1180 // colocated node set containing 'node'.
1181 // NOTE: Basing possible device computation on requested device name
1182 // is guaranteed to respect the assigned and resource device names because
1183 // requested device is always a specialization of both.
1184 std::vector<Device*> devices;
1185 if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1186 // The root node has a (possibly partial) device
1187 // specification, so enumerate the physical devices that
1188 // conform to it.
1189 device_set_.FindMatchingDevices(root_member.requested_device_name(),
1190 &devices);
1191
1192 if (!devices.empty()) {
1193 // Filter devices into those that are compatible with the root
1194 // node (and its children).
1195 devices = FilterSupportedDevices(
1196 devices, root_member.supported_device_types(), default_local_device_);
1197 }
1198
1199 // Perform soft placement if allow_soft_placement_ is set.
1200 if (devices.empty() && allow_soft_placement_) {
1201 GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1202 }
1203
1204 if (devices.empty()) {
1205 // Return an error when a physical device that matches an explicit
1206 // device specification is not found. This ensures that we don't
1207 // assign a node to GPU when the user wanted to force it on CPU.
1208 string debug_info = DebugInfo(node_root);
1209
1210 DeviceNameUtils::ParsedName specified_device_name;
1211 if (DeviceNameUtils::ParseFullName(node->requested_device(),
1212 &specified_device_name) &&
1213 specified_device_name == root_member.requested_device_name()) {
1214 // The specified device and merged set device match, and
1215 // will appear in the GraphDef (for debugging), so just
1216 // print the specified device.
1217 std::vector<Device*> devices_matching_nodedef;
1218 device_set_.FindMatchingDevices(specified_device_name,
1219 &devices_matching_nodedef);
1220 if (devices_matching_nodedef.empty()) {
1221 // Sometimes it is almost impossible to understand the problem
1222 // without a list of available devices.
1223 std::vector<string> device_names;
1224 for (const Device* device : device_set_.devices()) {
1225 device_names.push_back(device->name());
1226 }
1227 std::sort(device_names.begin(), device_names.end());
1228
1229 string gpu_msg = "";
1230 if (!IsGoogleCudaEnabled() &&
1231 absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1232 gpu_msg =
1233 " The requested device appears to be a GPU, but CUDA is not "
1234 "enabled.";
1235 }
1236
1237 return errors::InvalidArgument(
1238 errors::FormatNodeNameForError(node->name()),
1239 " was explicitly assigned to ", node->requested_device(),
1240 " but available devices are [ ",
1241 absl::StrJoin(device_names, ", "), " ]. Make sure ",
1242 "the device specification refers to a valid device.", gpu_msg);
1243 } else if (specified_device_name.has_type) {
1244 return errors::InvalidArgument(
1245 "Could not satisfy explicit device specification '",
1246 node->requested_device(), "' because no supported kernel for ",
1247 specified_device_name.type, " devices is available.", debug_info,
1248 "\nOp: ", node->type_string(),
1249 "\nNode attrs: ", node->attrs().DebugString(),
1250 "\nRegistered kernels:\n",
1251 KernelsRegisteredForOp(node->type_string()));
1252 } else {
1253 return errors::InvalidArgument(
1254 "Could not satisfy explicit device specification '",
1255 node->requested_device(), debug_info);
1256 }
1257 } else {
1258 // The specified device may be a valid device but the
1259 // merged set device is different, so print both.
1260 // TODO(b/129057603): There are many possibilities at this point.
1261 // Provide good error messages.
1262 return errors::InvalidArgument(
1263 "Could not satisfy explicit device specification '",
1264 node->requested_device(), "' because the node ",
1265 errors::FormatColocationNodeForError(node->name()),
1266 " was colocated with a group of nodes that ",
1267 "required incompatible device '",
1268 DeviceNameUtils::ParsedNameToString(
1269 root_member.requested_device_name()),
1270 "'. All available devices [",
1271 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1272 debug_info);
1273 }
1274 }
1275 } else {
1276 // The device is completely unspecified, so enumerate the devices that
1277 // support all of the nodes in the set.
1278 if (device_set_.devices().empty()) {
1279 return errors::Internal("No devices are registered");
1280 }
1281 devices = FilterSupportedDevices(device_set_.devices(),
1282 root_member.supported_device_types(),
1283 default_local_device_);
1284
1285 if (devices.empty()) {
1286 return errors::InvalidArgument(
1287 "Node had no OpKernel registered to support this operation: ",
1288 "Operation was ", node->type_string(), " and inputs were [",
1289 DataTypeVectorString(node->input_types()), "].\n",
1290 DebugInfo(node_root));
1291 }
1292 }
1293
1294 // Cache the result of the possible devices for this node group.
1295 root_member.set_possible_devices(std::move(devices));
1296 *possible_devices = &root_member.possible_devices();
1297 return Status::OK();
1298 }
1299
InitializeMembers()1300 Status ColocationGraph::InitializeMembers() {
1301 for (Node* node : graph_.op_nodes()) {
1302 Status status = InitializeMember(*node, &members_[node->id()]);
1303 if (!status.ok()) {
1304 return AttachDef(status, *node);
1305 }
1306 }
1307 return Status::OK();
1308 }
1309
DebugString() const1310 string ColocationGraph::DebugString() const {
1311 std::unordered_set<int> roots;
1312 std::vector<string> root_strings;
1313 for (const Node* node : graph_.nodes()) {
1314 if (!node->IsOp()) {
1315 continue;
1316 }
1317 int node_root = FindRoot(node->id());
1318 if (roots.count(node_root) == 0) {
1319 root_strings.push_back(DebugInfo(node_root));
1320 roots.insert(node_root);
1321 }
1322 }
1323 return absl::StrJoin(root_strings, "\n");
1324 }
1325
1326 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root) const1327 string ColocationGraph::DebugInfo(const int node_root) const {
1328 string text(
1329 "\nColocation Debug Info:\n"
1330 "Colocation group had the following types and supported devices: ");
1331
1332 // If this node is part of a colocation group, then we want to
1333 // collect the mapping of ops to supported devices, so that
1334 // the user can see why an unsatisfiable placement occurred.
1335
1336 std::unordered_map<string, string> type_to_devices;
1337 std::vector<const Node*> colocation_nodes;
1338 int num_nodes_found = 0;
1339
1340 for (const Node* node : graph_.nodes()) {
1341 if (!node->IsOp()) {
1342 continue;
1343 }
1344 int id = node->id();
1345 if (FindRoot(id) != node_root) {
1346 continue;
1347 }
1348 ++num_nodes_found;
1349 colocation_nodes.push_back(node);
1350
1351 PrioritizedDeviceTypeVector supported_types;
1352 SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1353 &local_address_spec_)
1354 .IgnoreError();
1355 string devices_registered;
1356 for (const auto& device_type : supported_types) {
1357 strings::StrAppend(&devices_registered,
1358 DeviceTypeString(device_type.first), " ");
1359 }
1360
1361 const string& op_type = node->type_string();
1362 type_to_devices[op_type] = std::move(devices_registered);
1363 }
1364 strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1365
1366 for (const auto& td : type_to_devices) {
1367 strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1368 }
1369 strings::StrAppend(&text,
1370 "\n\nColocation members, user-requested devices, and "
1371 "framework assigned devices, if any:");
1372 for (const Node* node : colocation_nodes) {
1373 strings::StrAppend(&text, "\n ", node->name(), " (", node->type_string(),
1374 ") ", node->requested_device());
1375 if (node->has_assigned_device_name()) {
1376 strings::StrAppend(
1377 &text, " framework assigned device=", node->assigned_device_name());
1378 }
1379 }
1380 strings::StrAppend(&text, "\n");
1381
1382 if (num_nodes_found <= 0) {
1383 text.clear();
1384 }
1385 return text;
1386 }
1387
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,Member * member)1388 Status ColocationGraph::InitializeMemberWithAssignedDevice(
1389 const string& assigned_device_name, const string& node_type,
1390 Member* member) {
1391 // This node has already been assigned to a device, so we
1392 // respect this placement, after sanity-checking it.
1393 // NOTE: Since any assignment must have been performed by
1394 // the TensorFlow runtime, we consider errors in this branch to
1395 // be INTERNAL.
1396 TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1397
1398 // Since assigned device must be a full specification, do extra checks.
1399 const Device* assigned_device =
1400 device_set_.FindDeviceByName(assigned_device_name);
1401 if (assigned_device == nullptr) {
1402 // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1403 // calls when they are supported.
1404 return errors::Internal(
1405 "Assigned device '", assigned_device_name,
1406 "' does not match any device. This error can happen when one attempts "
1407 "to run a tf.function with resource inputs residing on remote devices. "
1408 "This use case is currently not supported. Here are the devices "
1409 "available on this machine: [",
1410 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1411 "If you are seeing this error when running using a tf.Session, set "
1412 "share_cluster_devices_in_session to true in the tf.ConfigProto.");
1413 }
1414
1415 for (const auto& d : member->supported_device_types()) {
1416 if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) {
1417 return Status::OK();
1418 }
1419 }
1420
1421 return errors::Internal("Assigned device '", assigned_device_name,
1422 "' does not have registered OpKernel support "
1423 "for ",
1424 node_type);
1425 }
1426
InitializeMember(const Node & node,Member * member)1427 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1428 TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1429 node, device_types_, &local_address_spec_));
1430
1431 if (node.has_assigned_device_name()) {
1432 TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1433 node.assigned_device_name(), node.type_string(), member));
1434 } else {
1435 // This node has not yet been assigned to a device, so we
1436 // calculate any constraints due to the set of registered
1437 // kernels and any (partial) user-provided device specification
1438 // in the NodeDef.
1439
1440 // If no kernels are registered for this op type, fail with an error.
1441 if (member->supported_device_types().empty()) {
1442 std::set<string> registered_device_types;
1443 for (Device* d : device_set_.devices()) {
1444 registered_device_types.insert(d->device_type());
1445 }
1446 return errors::InvalidArgument(
1447 "No OpKernel was registered to support Op '", node.type_string(),
1448 "' used by ", errors::FormatNodeNameForError(node.name()),
1449 " with these attrs: [", node.attrs().DebugString(),
1450 "]\n"
1451 "Registered devices: [",
1452 absl::StrJoin(registered_device_types, ", "), "]\n",
1453 "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1454 }
1455
1456 // If the NodeDef contains a device, then we interpret it as a
1457 // (partial) device specification.
1458 if (!node.requested_device().empty()) {
1459 if (IsRefOrResourceGeneratorNode(node)) {
1460 // Treat requested device on resource generating nodes as assigned
1461 // device so that we don't override it.
1462 TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1463 } else {
1464 // The user has specified a device in the NodeDef, try to find a
1465 // valid device matching their specification in the set of
1466 // devices.
1467 // NOTE: The full name may specify a device that is not in
1468 // n.supported_device_types(), but we check that in AssignDevice().
1469 TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1470 }
1471 }
1472 }
1473 return Status::OK();
1474 }
1475
1476 // Returns a list of devices having type in supported_device_types. The
1477 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_local_device)1478 /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1479 const std::vector<Device*>& devices,
1480 const PrioritizedDeviceTypeVector& supported_device_types,
1481 const Device* default_local_device) {
1482 Device* filtered_default_device = nullptr;
1483 PrioritizedDeviceVector prioritized_filtered_devices;
1484 for (const auto& supported_device_type : supported_device_types) {
1485 for (Device* device : devices) {
1486 if (IsSupportedDeviceType(device->attributes(),
1487 supported_device_type.first)) {
1488 if (default_local_device &&
1489 (device == default_local_device ||
1490 // TODO(nareshmodi, fishx): At times the device pointer in the
1491 // device set is different to the one passed in as the default
1492 // device. Figure out why this might be.
1493 device->name() == default_local_device->name())) {
1494 filtered_default_device = device;
1495 } else {
1496 prioritized_filtered_devices.emplace_back(
1497 device, supported_device_type.second);
1498 }
1499 }
1500 }
1501 }
1502 DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1503
1504 std::vector<Device*> filtered_devices;
1505 if (filtered_default_device != nullptr) {
1506 filtered_devices.emplace_back(filtered_default_device);
1507 }
1508 for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1509 filtered_devices.push_back(prioritized_filtered_device.first);
1510 }
1511 return filtered_devices;
1512 }
1513
1514 } // namespace tensorflow
1515