1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17
18 #include <memory>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/common_runtime/composite_device.h"
30 #include "tensorflow/core/common_runtime/device.h"
31 #include "tensorflow/core/common_runtime/device_set.h"
32 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
33 #include "tensorflow/core/common_runtime/inspecting_placer.h"
34 #include "tensorflow/core/common_runtime/partitioning_utils.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/attr_value_util.h"
37 #include "tensorflow/core/framework/dataset.h"
38 #include "tensorflow/core/framework/device_attributes.pb.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/graph/algorithm.h"
45 #include "tensorflow/core/graph/graph_node_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/util/device_name_utils.h"
51 #include "tensorflow/core/util/dump_graph.h"
52 #include "tensorflow/core/util/port.h"
53
54 namespace tensorflow {
55
56 namespace {
57
58 // We hoist the conversion from C-style string literal to StringPiece here,
59 // so that we can avoid the many repeated calls to strlen().
60 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
61 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
62
63 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)64 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
65 std::vector<string> v;
66 v.reserve(devices.size());
67 for (Device* d : devices) {
68 v.push_back(d->name());
69 }
70 return v;
71 }
72
73 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)74 std::vector<string> DeviceTypeAndPriorityToString(
75 const PrioritizedDeviceTypeVector& devices) {
76 std::vector<string> v;
77 v.reserve(devices.size());
78 for (const std::pair<DeviceType, int32>& device_and_type : devices) {
79 v.push_back(DeviceTypeString(device_and_type.first));
80 }
81 return v;
82 }
83
IsRefOrResource(DataType data_type)84 bool IsRefOrResource(DataType data_type) {
85 return IsRefType(data_type) || data_type == DT_RESOURCE;
86 }
87
88 // While Placer can override requested device on ops processing
89 // resources, i.e. node that take (and potentially return) a resource,
90 // it must not override requested device on ops generating a resource,
91 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
92 // output nodes.
IsRefOrResourceGeneratorNode(const Node & node)93 bool IsRefOrResourceGeneratorNode(const Node& node) {
94 return node.num_inputs() == 0 && node.num_outputs() == 1 &&
95 IsRefOrResource(node.output_type(0));
96 }
97
IsExemptFromResourceInputColocation(const Node * node)98 bool IsExemptFromResourceInputColocation(const Node* node) {
99 // Note: Partitioned function calls, which place and partition their
100 // function bodies, are exempt from this check: they forward resource and
101 // ref inputs to operations that are appropriately placed, instead of
102 // dereferencing them.
103 const string& op_type = node->op_def().name();
104 auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
105 return exempt_ops.find(op_type) != exempt_ops.end();
106 }
107
HasPriorities(const PrioritizedDeviceTypeVector & device_types)108 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
109 for (const auto& prioritized_device_type : device_types) {
110 if (prioritized_device_type.second != 0) return true;
111 }
112 return false;
113 }
114
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)115 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
116 const PrioritizedDeviceTypeVector& b_types) {
117 if (a_types.size() != b_types.size()) {
118 return false;
119 }
120 for (int i = 0; i < a_types.size(); ++i) {
121 if (a_types[i].first != b_types[i].first) {
122 return false;
123 }
124 }
125 return true;
126 }
127
IsXlaDevice(absl::string_view device_type)128 bool IsXlaDevice(absl::string_view device_type) {
129 if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
130 device_type == "XLA_TPU_JIT") {
131 // Symbolic XLA device.
132 return true;
133 }
134
135 return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
136 device_type == "TPU");
137 }
138
IsCompositeDevice(absl::string_view device_type)139 bool IsCompositeDevice(absl::string_view device_type) {
140 return device_type == kCompositeDeviceType;
141 }
142
143 } // namespace
144
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types,const DeviceNameUtils::ParsedName * local_address_spec)145 Status Member::SetParentAndSupportedDevices(
146 const Node& node, const std::vector<DeviceType>& types,
147 const DeviceNameUtils::ParsedName* local_address_spec) {
148 int id = node.id();
149 if (id < 0) {
150 return errors::Internal("Placer should not be creating a Member for node: ",
151 node.DebugString());
152 }
153 parent_ = id;
154 return SupportedDeviceTypesForNode(
155 types, node.def(), &supported_device_types_, local_address_spec);
156 }
157
SetAssignedDeviceName(const string & device_name)158 Status Member::SetAssignedDeviceName(const string& device_name) {
159 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
160 return errors::Internal(
161 "Setting assigned device name when there is a requested device set "
162 "is unsupported");
163 }
164 if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
165 return errors::Internal("Malformed assigned device '", device_name, "'");
166 }
167 // Set requested device to assigned_device to maintain the invariant that
168 // requested is a specialization of assigned.
169 requested_device_name_ = assigned_device_name_;
170 return Status::OK();
171 }
172
SetResourceDeviceName(const Node & node)173 Status Member::SetResourceDeviceName(const Node& node) {
174 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
175 return errors::Internal(
176 "Setting resource device name when there is a requested device set "
177 "is unsupported");
178 }
179
180 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
181 &resource_device_name_)) {
182 return errors::InvalidArgument("Malformed device specification '",
183 node.requested_device(),
184 "' in node: ", node.DebugString());
185 }
186
187 // Set requested device to resource device to maintain the invariant that
188 // requested is a specialization of resource.
189 requested_device_name_ = resource_device_name_;
190 return Status::OK();
191 }
192
SetRequestedDeviceName(const Node & node)193 Status Member::SetRequestedDeviceName(const Node& node) {
194 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
195 return errors::Internal(
196 "Setting requested device name when there is an assigned device set "
197 "is unsupported");
198 }
199 if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
200 return errors::Internal(
201 "Setting requested device name when there is a resource device set "
202 "is unsupported");
203 }
204 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
205 &requested_device_name_)) {
206 return errors::InvalidArgument("Malformed device specification '",
207 node.requested_device(),
208 "' in node: ", node.DebugString());
209 }
210 return Status::OK();
211 }
212
FillPossibleDevices(PossibleDevices * possible_device) const213 Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
214 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
215 return errors::Internal(
216 "Cannot fill PossibleDevices from a member that has non-empty assigned "
217 "device. Did we start assigning devices to functions called by deep "
218 "ops? ",
219 DebugString());
220 }
221 possible_device->requested_device_name = requested_device_name_;
222 possible_device->resource_device_name = resource_device_name_;
223 possible_device->device_types = supported_device_types_;
224 return Status::OK();
225 }
226
IsEdgeFromCompositeDeviceToPhysicalDevice(const Member & src_root) const227 bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice(
228 const Member& src_root) const {
229 auto compatible_edge_from_composite_device_to_physical_device =
230 [](const DeviceNameUtils::ParsedName& src_device,
231 const DeviceNameUtils::ParsedName& dst_device) -> bool {
232 return src_device.has_type && dst_device.has_type &&
233 IsCompositeDevice(src_device.type) &&
234 !IsCompositeDevice(dst_device.type);
235 };
236 if (compatible_edge_from_composite_device_to_physical_device(
237 src_root.assigned_device_name_, assigned_device_name_) ||
238 compatible_edge_from_composite_device_to_physical_device(
239 src_root.resource_device_name_, resource_device_name_) ||
240 compatible_edge_from_composite_device_to_physical_device(
241 src_root.requested_device_name_, requested_device_name_)) {
242 return true;
243 }
244 return false;
245 }
246
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)247 Status Member::EnsureCompatibilityAcrossResourceEdge(
248 const Node& src, const Member& src_root,
249 const Node& dst, /*dst_root is this*/
250 bool log_device_placement) {
251 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
252 assigned_device_name_)) {
253 return errors::InvalidArgument(
254 "Cannot place the graph because a reference or resource edge "
255 "connects colocation groups with incompatible assigned devices: ",
256 DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
257 " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
258 ". The edge src node is ", src.name(), " , and the dst node is ",
259 dst.name());
260 }
261
262 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
263 resource_device_name_)) {
264 return errors::InvalidArgument(
265 "Cannot place the graph because a reference or resource edge "
266 "connects colocation groups with incompatible resource devices: ",
267 DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
268 " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
269 ". The edge src node is ", src.name(), " , and the dst node is ",
270 dst.name());
271 }
272
273 if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
274 requested_device_name_)) {
275 return Status::OK();
276 }
277
278 // If we are here, assigned and resource devices are compatible but requested
279 // ones are not. We will be overriding the requested device for destination
280 // node, but need to preserve the invariant that it will be a specialization
281 // of the assigned and resource devices.
282 if (log_device_placement) {
283 LOG(INFO) << "Ignoring device specification "
284 << DeviceNameUtils::ParsedNameToString(requested_device_name_)
285 << " for node '" << dst.name()
286 << "' because the input edge from '" << src.name()
287 << "' is a reference connection and already has a device "
288 "field set to "
289 << DeviceNameUtils::ParsedNameToString(
290 src_root.requested_device_name_);
291 }
292 requested_device_name_ = src_root.requested_device_name_;
293 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
294 assigned_device_name_);
295 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
296 resource_device_name_);
297 return Status::OK();
298 }
299
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)300 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
301 Member** new_root, Member** old_root, bool dry_run) {
302 Member& x_root_member = (*tree)[x_root];
303 Member& y_root_member = (*tree)[y_root];
304
305 // Merge the sets by setting the parent pointer of the smaller tree's root
306 // node to point to the root of the larger tree. Together with path
307 // compression in ColocationGraph::FindRoot, this ensures that we do not
308 // experience pathological performance on graphs such as chains.
309 int new_root_id, old_root_id;
310 if (x_root_member.rank_ < y_root_member.rank_) {
311 // The tree rooted at x_root is shallower, so connect it to
312 // y_root. The rank of y_root is unchanged because its new
313 // child has strictly less rank.
314 if (!dry_run) {
315 x_root_member.parent_ = y_root;
316 }
317 new_root_id = y_root;
318 old_root_id = x_root;
319 } else if (x_root_member.rank_ > y_root_member.rank_) {
320 // The tree rooted at y_root is shallower, so connect it to
321 // x_root. The rank of x_root is unchanged because its new
322 // child has strictly less rank.
323 if (!dry_run) {
324 y_root_member.parent_ = x_root;
325 }
326 new_root_id = x_root;
327 old_root_id = y_root;
328 } else {
329 if (!dry_run) {
330 // Both trees have the same rank, so break the tie by choosing
331 // x_root as the new root.
332 y_root_member.parent_ = x_root;
333 // Increment the rank of the tree rooted at x_root, because it
334 // is now strictly deeper than before.
335 ++x_root_member.rank_;
336 }
337 new_root_id = x_root;
338 old_root_id = y_root;
339 }
340
341 *new_root = &(*tree)[new_root_id];
342 *old_root = &(*tree)[old_root_id];
343 }
344
345 // tree is non-const because we can change some `parent` pointers in some
346 // members for more efficient future lookups. The vector itself is not
347 // changed.
FindAndUpdateRoot(std::vector<Member> * tree,int node_id)348 int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
349 Member& member = (*tree)[node_id];
350 if (member.parent_ == node_id) {
351 // member.parent is the root of this disjoint tree. Do nothing.
352 } else {
353 member.parent_ = FindAndUpdateRoot(tree, member.parent_);
354 }
355 // Now it is guaranteed that member.parent is the root of this disjoint
356 // tree.
357 return member.parent_;
358 }
359
FindRoot(const std::vector<Member> & tree,int node_id)360 int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
361 const Member& member = tree[node_id];
362 if (member.parent_ == node_id) {
363 return member.parent_;
364 }
365 return FindRoot(tree, member.parent_);
366 }
367
MergeDeviceNames(const Member & other,bool allow_soft_placement)368 Status Member::MergeDeviceNames(const Member& other,
369 bool allow_soft_placement) {
370 // Assuming the "requested is a specialization of assigned and resource
371 // devices" invariant holds for this and `other`, it will hold after the
372 // merges below.
373 DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
374 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
375 &assigned_device_name_copy, other.assigned_device_name_));
376
377 DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
378 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
379 &resource_device_name_copy, other.resource_device_name_));
380
381 DeviceNameUtils::ParsedName requested_device_name_copy =
382 requested_device_name_;
383 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
384 &requested_device_name_copy, other.requested_device_name_,
385 allow_soft_placement));
386
387 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
388 assigned_device_name_copy);
389 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
390 resource_device_name_copy);
391
392 // We checked for all errors, now change the devices.
393 assigned_device_name_ = assigned_device_name_copy;
394 resource_device_name_ = resource_device_name_copy;
395 requested_device_name_ = requested_device_name_copy;
396 return Status::OK();
397 }
398
399 // Updates this to contain the intersection of the device types in
400 // this and "other".
MergeSupportedDevices(const Member & other)401 bool Member::MergeSupportedDevices(const Member& other) {
402 return MergeSupportedDevices(other.supported_device_types_);
403 }
404
MergeSupportedDevices(const PrioritizedDeviceTypeVector & other_devices)405 bool Member::MergeSupportedDevices(
406 const PrioritizedDeviceTypeVector& other_devices) {
407 // Generate intersection with priorities.
408 // Each vector contains the same device types but with different priorities.
409 // The priorities are taken from the corresponding source vector.
410 PrioritizedDeviceTypeVector target_intersection;
411 PrioritizedDeviceTypeVector other_intersection;
412
413 for (const auto& prioritized_device_type : supported_device_types_) {
414 bool found = false;
415 for (const auto& other_prioritized_device_type : other_devices) {
416 if (prioritized_device_type.first ==
417 other_prioritized_device_type.first) {
418 found = true;
419 other_intersection.push_back(other_prioritized_device_type);
420 break;
421 }
422 }
423 if (found) {
424 target_intersection.push_back(prioritized_device_type);
425 }
426 }
427
428 DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
429 DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
430
431 PrioritizedDeviceTypeVector result;
432
433 bool is_target_prioritized = HasPriorities(target_intersection);
434 bool is_other_prioritized = HasPriorities(other_intersection);
435 if (!is_target_prioritized && !is_other_prioritized) {
436 // If neither are prioritized then we just return the original i.e. target
437 // prioritization.
438 result = target_intersection;
439 } else if (is_target_prioritized && !is_other_prioritized) {
440 // If only one is prioritized, then we respect priorities of that in the
441 // intersection.
442 result = target_intersection;
443 } else if (!is_target_prioritized && is_other_prioritized) {
444 result = other_intersection;
445 } else {
446 // If both have priorities and agree then we go with that. If the
447 // prioritization order is different, then we just fallback to the default
448 // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
449 // merged priorities to 0, so that downstream merges work correctly as well.
450 if (ArePrioritiesSame(target_intersection, other_intersection)) {
451 result = target_intersection;
452 } else {
453 for (const auto& prioritized_device : target_intersection) {
454 result.push_back(std::make_pair(prioritized_device.first, 0));
455 }
456 DeviceSet::SortPrioritizedDeviceTypeVector(&result);
457 }
458 }
459
460 if (result.empty()) {
461 return false;
462 }
463 supported_device_types_ = result;
464 return true;
465 }
466
AssignDevice(const Node & node)467 Status Member::AssignDevice(const Node& node) {
468 if (node.assigned_device_name_index() == assigned_device_name_index_) {
469 return Status::OK();
470 }
471
472 DeviceNameUtils::ParsedName parsed;
473 DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
474 Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
475 if (!s.ok()) {
476 return errors::Internal(
477 "Constraining by assigned device should not cause an error. Original "
478 "root's assigned device name: ",
479 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
480 " node's assigned device name \"", node.assigned_device_name(),
481 ". Error: ", s.error_message());
482 }
483 s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
484 if (!s.ok()) {
485 return errors::Internal(
486 "Constraining by assigned device should not cause an error. Original "
487 "root's resource device name: ",
488 DeviceNameUtils::ParsedNameToString(resource_device_name_),
489 " node's assigned device name \"", node.assigned_device_name(),
490 ". Error: ", s.error_message());
491 }
492 s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
493 if (!s.ok()) {
494 return errors::Internal(
495 "Constraining by assigned device should not cause an error. Original "
496 "root's requested device name: \"",
497 DeviceNameUtils::ParsedNameToString(requested_device_name_),
498 "\", node's assigned device name \"", node.assigned_device_name(),
499 "\". Error: ", s.error_message());
500 }
501
502 assigned_device_name_index_ = node.assigned_device_name_index();
503 // Clear cached possible_devices, if any.
504 possible_devices_.clear();
505 return Status::OK();
506 }
507
MaybeExcludeXlaDevices()508 void Member::MaybeExcludeXlaDevices() {
509 for (const auto& parsed_name :
510 {requested_device_name_, assigned_device_name_, resource_device_name_}) {
511 // Don't exculde XLA devices from supported devices if member is explicitly
512 // assigned to a CompositeDevice.
513 if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) ||
514 IsCompositeDevice(parsed_name.type))) {
515 return;
516 }
517 }
518
519 PrioritizedDeviceTypeVector non_xla_types;
520 absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
521 [&](const std::pair<DeviceType, int32>& entry) {
522 return !IsXlaDevice(entry.first.type_string());
523 });
524
525 // TODO(b/141216278) Remove all XLA device types from the supported device
526 // types if the node has no requested/assigned/resource XLA device.
527 if (!non_xla_types.empty() &&
528 non_xla_types.size() < supported_device_types_.size()) {
529 supported_device_types_ = std::move(non_xla_types);
530 }
531 }
532
LimitToPossibleDevices(const PossibleDevices & devices,bool allow_soft_placement)533 Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
534 bool allow_soft_placement) {
535 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
536 &requested_device_name_, devices.requested_device_name,
537 allow_soft_placement));
538 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
539 &resource_device_name_, devices.resource_device_name));
540 MergeSupportedDevices(devices.device_types);
541 return Status::OK();
542 }
543
DebugString() const544 string Member::DebugString() const {
545 return absl::StrCat(
546 "Member(assigned_device_name_index_=", assigned_device_name_index_,
547 " requested_device_name_='",
548 DeviceNameUtils::ParsedNameToString(requested_device_name_),
549 "' assigned_device_name_='",
550 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
551 "' resource_device_name_='",
552 DeviceNameUtils::ParsedNameToString(resource_device_name_),
553 "' supported_device_types_=[",
554 absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
555 ", "),
556 "] possible_devices_=[",
557 absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
558 }
559
GetSoftDeviceName() const560 DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
561 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
562 if (!assigned_device_name_.has_type) {
563 soft_device_name.type.clear();
564 soft_device_name.has_type = false;
565 }
566 if (!assigned_device_name_.has_id) {
567 soft_device_name.has_id = false;
568 }
569 return soft_device_name;
570 }
571
GetPreferredSoftDeviceName() const572 DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
573 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
574 if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
575 soft_device_name.type.clear();
576 soft_device_name.has_type = false;
577 }
578 if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
579 soft_device_name.has_id = false;
580 }
581 return soft_device_name;
582 }
583
584 // Returns ParsedName whose address space (i.e. job, replica, task) identifies
585 // the address space directly accessible by the local process. If the address
586 // space is fully specified and it is exactly the same as the address space
587 // of a device, then all kernels of that device should be registered in the
588 // local process.
LocalAddressSpec(const Device * client_device,const Device * default_local_device)589 static const DeviceNameUtils::ParsedName LocalAddressSpec(
590 const Device* client_device, const Device* default_local_device) {
591 if (client_device != nullptr) {
592 return DeviceNameUtils::AddressSpace(client_device->parsed_name());
593 }
594
595 if (default_local_device != nullptr) {
596 return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
597 }
598
599 // TODO(b/139617593) Return the name of the first local device in device_set_
600 // once we can trust the output of Device::IsLocal().
601 return DeviceNameUtils::ParsedName();
602 }
603
ColocationGraph(const Graph * graph,const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)604 ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
605 const FunctionLibraryDefinition* flib_def,
606 const DeviceSet* device_set,
607 const Device* default_local_device,
608 bool allow_soft_placement,
609 bool log_device_placement)
610 : graph_(*graph),
611 stack_(stack),
612 inspecting_placer_(stack, flib_def, device_set, default_local_device,
613 allow_soft_placement, log_device_placement),
614 inspection_required_checker_(graph, flib_def),
615 device_set_(*device_set),
616 device_types_(device_set->PrioritizedDeviceTypeList()),
617 local_address_spec_(
618 LocalAddressSpec(device_set->client_device(), default_local_device)),
619 default_local_device_(default_local_device),
620 allow_soft_placement_(allow_soft_placement),
621 log_device_placement_(log_device_placement) {
622 members_.resize(graph_.num_node_ids());
623 }
624
625 // Adds each node of the Graph to this ColocationGraph as a singleton.
626 //
627 // NOTE: The implementation assumes that the ids of nodes passed to
628 // this method are dense and zero-based; the memory used will be linear in
629 // the largest node ID.
630 // NOTE: If this method returns an error, *this is left in an undefined
631 // state.
ColocateAllNodes()632 Status ColocationGraph::ColocateAllNodes() {
633 // This maps from a colocation group identifier to the 'root' of that
634 // colocation group. Note that the keys in this map are StringPiece; the
635 // actual strings are stored under the NodeDef. The lifetime of this map
636 // is limited to this ColocateAllNodes() method, and no part of the
637 // NodeDef trees are changed during the lifetime of this method, so using
638 // StringPiece as a key is safe.
639 //
640 // Also, as a further optimization, we remove the "loc:@" prefix from
641 // "class" attribute values, when they are used as keys in this table.
642 // This allows us to use StringPiece values that refer to substrings of
643 // 'string' values stored in NodeDef attribute lists, as well as StringPiece
644 // values that refer to 'string' values from NodeDef::name(), without
645 // performing any string allocations.
646 std::unordered_map<StringPiece, const Node*, StringPieceHasher>
647 colocation_group_root;
648
649 for (const Node* node : graph_.op_nodes()) {
650 // When adding the node, identify whether it is part of a colocation
651 // group.
652
653 // This code is effectively the equivalent of GetNodeAttr() for a string
654 // array, but it avoids all internal allocations (the allocation of the
655 // backing store of the std::vector<string> as well as the copies of the
656 // strings within it). Instead, we combine the query of the colocation
657 // attribute with the calls to ColocateNodeToGroup.
658 const AttrValue* attr_value =
659 node->attrs().Find(kColocationAttrNameStringPiece);
660 if (attr_value != nullptr) {
661 if (attr_value->has_list()) {
662 for (const string& class_spec : attr_value->list().s()) {
663 StringPiece spec(class_spec);
664 if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
665 TF_RETURN_IF_ERROR(
666 ColocateNodeToGroup(&colocation_group_root, node, spec));
667 }
668 }
669 } else if (!attr_value->s().empty()) {
670 LOG(ERROR) << "The value for colocation attribute '_class' must be a "
671 "list of strings, not a single string: "
672 << node->DebugString();
673 }
674 }
675
676 // Each node belongs to a colocation group with the node's name.
677 TF_RETURN_IF_ERROR(
678 ColocateNodeToGroup(&colocation_group_root, node, node->name()));
679 }
680
681 return Status::OK();
682 }
683
ColocateResourceOrRefEdge(const Node * src,const Node * dst)684 Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
685 const Node* dst) {
686 // Colocate `src` and `dst` to maintain the invariant that nodes
687 // connected by reference edges are colocated.
688 int src_root_id = FindAndUpdateRoot(src->id());
689 int dst_root_id = FindAndUpdateRoot(dst->id());
690 auto& src_root = members_[src_root_id];
691 auto& dst_root = members_[dst_root_id];
692
693 if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) {
694 // If the src root is assigned to a composite device and the dst root is
695 // assigned to a physical device, don't colocate the dst root with the src
696 // root.
697 return Status::OK();
698 }
699 TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
700 *src, src_root, *dst, log_device_placement_));
701 Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
702 if (!status.ok()) {
703 return AttachDef(
704 errors::InvalidArgument(
705 "Nodes were connected by a reference or resource connection "
706 "(requiring them to be on the same device), but the two nodes "
707 "were assigned two different devices: ",
708 status.error_message()),
709 *dst);
710 }
711 return Status::OK();
712 }
713
ColocateResourceAndRefEdges(std::unordered_set<Node * > * inspection_required)714 Status ColocationGraph::ColocateResourceAndRefEdges(
715 std::unordered_set<Node*>* inspection_required) {
716 // If `node` has an input edge with reference type, add an edge from the
717 // source of that edge to `node`.
718 for (const Edge* edge : graph_.edges()) {
719 if (edge->IsControlEdge()) {
720 continue;
721 }
722 Node* src = edge->src();
723 Node* dst = edge->dst();
724 bool needs_inspection;
725 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
726 *src, &needs_inspection));
727 if (needs_inspection) {
728 inspection_required->insert(src);
729 continue;
730 }
731 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
732 *dst, &needs_inspection));
733 if (needs_inspection) {
734 inspection_required->insert(dst);
735 continue;
736 }
737
738 DataType input_type = dst->input_type(edge->dst_input());
739
740 // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
741 // This is needed to get around the issue in b/135705778.
742 if (input_type == DT_VARIANT &&
743 data::DatasetOpKernel::IsDatasetOp(&src->op_def()) &&
744 data::DatasetOpKernel::IsDatasetOp(&dst->op_def())) {
745 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
746 continue;
747 }
748
749 // Even though we can look inside function calling ops, we make an exception
750 // here mostly for performance reasons. Looking inside function calling ops
751 // is extra overhead. It is only necessary when they return resources. When
752 // they don't, we don't look inside them and make this exception here.
753 // Looking inside, could potentially enable us to make better placement
754 // decisions. It might be worth doing at some point.
755 if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
756 !IsExemptFromResourceInputColocation(dst)) {
757 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
758 }
759 }
760
761 return Status::OK();
762 }
763
764 namespace {
765 // Returns tensor list element data type, if the node is one of the ops that
766 // operate with TensorLists. Otherwise returns DT_INVALID.
GetElementDataType(const Node & node)767 DataType GetElementDataType(const Node& node) {
768 static absl::flat_hash_set<std::string>* tensor_list_ops =
769 new absl::flat_hash_set<std::string>(
770 {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
771 "TensorListSplit", "TensorListScatter", "TensorListScatterV2",
772 "TensorListScatterIntoExistingList", "TensorListPushBack",
773 "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
774 "TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
775 "TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
776
777 if (tensor_list_ops->contains(node.type_string())) {
778 DataType element_type;
779 if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
780 return element_type;
781 }
782 }
783
784 return DT_INVALID;
785 }
786 } // namespace
787
AddHostOnlyDataTypesConstraints()788 Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
789 auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
790
791 auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
792 return entry.first == DEVICE_CPU;
793 };
794
795 for (Node* node : graph_.nodes()) {
796 // Skip nodes that do not have DT_VARIANT inputs.
797 if (absl::c_none_of(node->input_types(), is_variant)) continue;
798
799 // Skip nodes that can't be placed on GPU anyway.
800 Member& root = members_[FindAndUpdateRoot(node->id())];
801 if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue;
802
803 // Stop DFS traversal when found the underlying data type of a variant.
804 absl::optional<bool> is_host_data_type;
805
806 auto edge_filter = [&](const Edge& edge) -> bool {
807 // We already found the underlying data type.
808 if (is_host_data_type.has_value()) return false;
809
810 // Otherwise follow only DT_VARIANT data edges.
811 auto edge_dtype = [&]() -> DataType {
812 return edge.src()->output_type(edge.src_output());
813 };
814 return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
815 };
816
817 auto enter = [&](Node* n) -> void {
818 DataType element_type = GetElementDataType(*n);
819 // To handle nested lists continue traversal after finding a TensorList
820 // operation that uses DT_VARIANT for element type.
821 if (element_type == DT_INVALID || element_type == DT_VARIANT) return;
822 is_host_data_type = DataTypeAlwaysOnHost(element_type);
823 };
824
825 ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
826 /*stable_comparator=*/nullptr, edge_filter);
827
828 if (is_host_data_type.has_value() && *is_host_data_type) {
829 VLOG(2) << "Limit node possible devices to CPU only, because it has a "
830 "DT_VARIANT input with host-only underlying data type: "
831 << "node=" << node->name();
832
833 // Restrict possible device types to CPU only.
834 PossibleDevices possible_devices;
835 absl::c_copy_if(root.supported_device_types(),
836 std::back_inserter(possible_devices.device_types),
837 is_cpu_device);
838
839 TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
840 possible_devices, /*allow_soft_placement=*/false));
841 }
842 }
843
844 return Status::OK();
845 }
846
AddInspectionConstraints(const std::unordered_set<Node * > & inspection_required)847 Status ColocationGraph::AddInspectionConstraints(
848 const std::unordered_set<Node*>& inspection_required) {
849 for (Node* node : inspection_required) {
850 IOColocationGroups groups;
851 TF_RETURN_IF_ERROR(
852 inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
853 VLOG(2) << "Computed IOColocationGroups for node " << node->name()
854 << ":\n\t" << groups.DebugString();
855 TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
856 }
857 return Status::OK();
858 }
859
Initialize()860 Status ColocationGraph::Initialize() {
861 TF_RETURN_IF_ERROR(InitializeMembers());
862
863 std::unordered_set<Node*> inspection_required;
864 TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
865 TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
866 TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
867 TF_RETURN_IF_ERROR(ColocateAllNodes());
868
869 for (Node* node : graph_.op_nodes()) {
870 int root_id = FindAndUpdateRoot(node->id());
871 members_[root_id].MaybeExcludeXlaDevices();
872 }
873
874 return Status::OK();
875 }
876
877 // pair containing a node and whether this node has a resource input
878 // from the node requiring placer inspection.
879 using NodeAndBool = std::pair<const Node*, bool>;
880
881 namespace {
882
883 // Returns a vector of node names from `nodes`.
NodeAndBoolToString(const std::vector<NodeAndBool> & nodes)884 std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
885 std::vector<string> v;
886 v.reserve(nodes.size());
887 for (const NodeAndBool& node_and_bool : nodes) {
888 v.push_back(node_and_bool.first->name());
889 }
890 return v;
891 }
892
893 // Given a node requiring placer inspection and its IOColocationGroups,
894 // computes `group_nodes`.
895 // group_nodes[i] contains the nodes that are members of colocation
896 // group i. These nodes are inputs or outputs of `node`.
897 // group_nodes[i][j] is a pair containing a node and whether this node
898 // has a resource input from `node`.
899 // Note:
900 // The same node can be added multiple times to the same group.
901 // The same node can be added to multiple groups.
GetGroupNodes(const IOColocationGroups & groups,const Node & node,std::vector<std::vector<NodeAndBool>> * group_nodes)902 Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
903 std::vector<std::vector<NodeAndBool>>* group_nodes) {
904 group_nodes->reserve(groups.group_devices.size());
905 for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
906 const Node* src;
907 TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
908 int group_id = groups.input_groups[arg_idx];
909 (*group_nodes)[group_id].emplace_back(src, false);
910 }
911
912 for (const Edge* edge : node.out_edges()) {
913 if (edge->IsControlEdge()) {
914 continue;
915 }
916
917 int group_id = groups.output_groups[edge->src_output()];
918 (*group_nodes)[group_id].emplace_back(
919 edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
920 }
921
922 if (VLOG_IS_ON(2)) {
923 VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
924 for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
925 VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
926 << "]";
927 }
928 }
929 return Status::OK();
930 }
931
932 // Returns whether the device_type in `device_attributes` is supported.
IsSupportedDeviceType(const DeviceAttributes & device_attributes,const DeviceType & supported_type)933 bool IsSupportedDeviceType(const DeviceAttributes& device_attributes,
934 const DeviceType& supported_type) {
935 if (DeviceType(device_attributes.device_type()) == supported_type) {
936 return true;
937 }
938 return IsCompositeDevice(device_attributes.device_type());
939 }
940
941 } // namespace
942
ApplyIOColocationGroups(const IOColocationGroups & groups,const Node & node)943 Status ColocationGraph::ApplyIOColocationGroups(
944 const IOColocationGroups& groups, const Node& node) {
945 if (groups.input_groups.size() != node.num_inputs()) {
946 return errors::Internal(
947 "Cannot apply input/output device constraints to node ",
948 node.DebugString(), " because input_groups.size() (",
949 groups.input_groups.size(),
950 ") is different from number of inputs into the op node (",
951 node.num_inputs(), ")");
952 }
953 if (groups.output_groups.size() != node.num_outputs()) {
954 return errors::Internal(
955 "Cannot apply input/output device constraints to node ",
956 node.DebugString(), " because output_groups.size() (",
957 groups.output_groups.size(),
958 ") is different from number of outputs into the op node (",
959 node.num_outputs(), ")");
960 }
961
962 // group_nodes[i] contains the nodes that are members of colocation
963 // group i. These nodes are inputs or outputs of `node`.
964 // group_nodes[i][j] is a pair containing the node and whether this node
965 // has a resource input from `node`.
966 // The same node can be added multiple times to the same group.
967 // The same node can be added to multiple groups.
968 // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
969 std::vector<std::vector<NodeAndBool>> group_nodes(
970 groups.group_devices.size());
971 TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
972
973 // Colocate nodes in each group
974 for (const std::vector<NodeAndBool>& nodes : group_nodes) {
975 for (int i = 1; i < nodes.size(); ++i) {
976 VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
977 << nodes[i].first->name() << "\"";
978 if (nodes[i].second) {
979 TF_RETURN_IF_ERROR(
980 ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
981 } else {
982 TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
983 }
984 }
985 }
986
987 // Limit devices in each group
988 for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
989 // Nothing to do for empty groups. Groups can be empty if some output
990 // of an op is not used.
991 if (group_nodes[group_id].empty()) {
992 continue;
993 }
994 const Node* group_node = group_nodes[group_id][0].first;
995 const PossibleDevices& possible_devices = groups.group_devices[group_id];
996 TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
997 }
998
999 return Status::OK();
1000 }
1001
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)1002 Status ColocationGraph::ColocateNodeToGroup(
1003 std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
1004 colocation_group_root,
1005 const Node* node, StringPiece colocation_group) {
1006 const Node*& root_node = (*colocation_group_root)[colocation_group];
1007 if (root_node == nullptr) {
1008 // This is the first node of the colocation group, so
1009 // designate this node as the 'root' of that colocation group.
1010 root_node = node;
1011 } else {
1012 // Try to colocate the node with the root. If there is an
1013 // error, return it.
1014 Status s = ColocateNodes(*node, *root_node);
1015 if (!s.ok()) {
1016 if (!allow_soft_placement_) {
1017 return AttachDef(s, *node);
1018 }
1019 if (log_device_placement_) {
1020 LOG(INFO) << "Ignoring request to colocate node '" << node->name()
1021 << "' with nodes in colocation group '" << colocation_group
1022 << "' because soft placement is on and an attempt at doing "
1023 "so resulted in the following error: "
1024 << AttachDef(s, *node).ToString();
1025 }
1026 }
1027 }
1028 return Status::OK();
1029 }
1030
1031 // Merge the (possibly disjoint) sets containing nodes "x" and
1032 // "y". Returns OK if the all nodes in the union of these sets can
1033 // be placed on the same device type.
1034 //
1035 // NOTE: If this method returns an error, *this is left in an undefined
1036 // state.
ColocateNodes(const Node & x,const Node & y)1037 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
1038 int x_root = FindAndUpdateRoot(x.id());
1039 int y_root = FindAndUpdateRoot(y.id());
1040 return ColocateNodes(x, x_root, y, y_root);
1041 }
1042
1043 // This overload of ColocateNodes() allows a caller to provide the root node
1044 // ids for the two nodes. For large graphs, this noticeably reduces the
1045 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)1046 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
1047 int y_root) {
1048 if (x_root == y_root) {
1049 return Status::OK();
1050 }
1051
1052 Member* new_root_member;
1053 Member* old_root_member;
1054 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1055 /*dry_run=*/true);
1056
1057 // Merge the partial device specifications, and ensure that they are
1058 // compatible. NULL options_ is treated as allowing soft placement.
1059 // If there is an error, nothing is modified.
1060 // TODO(mrry): Consider enriching the error message by pointing
1061 // out which nodes have the explicit partial device
1062 // specifications that caused this conflict.
1063 Status s = new_root_member->MergeDeviceNames(*old_root_member,
1064 allow_soft_placement_);
1065 if (!s.ok()) {
1066 return errors::InvalidArgument(
1067 "Cannot colocate nodes ",
1068 errors::FormatColocationNodeForError(x.name()), " and ",
1069 errors::FormatColocationNodeForError(y.name()), ": ",
1070 s.error_message());
1071 }
1072
1073 // Ensure that the common root has at least one supported device
1074 // type, by computing the intersection of
1075 // new_root_member.supported_device_types and
1076 // old_root_member.supported_device_types.
1077 if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
1078 return errors::InvalidArgument(
1079 "Cannot colocate nodes ",
1080 errors::FormatColocationNodeForError(x.name()), " and ",
1081 errors::FormatColocationNodeForError(y.name()),
1082 " because no device type supports both of those nodes and the "
1083 "other nodes colocated with them.",
1084 DebugInfo(x_root), DebugInfo(y_root));
1085 }
1086
1087 // All error checks are done, merge the colocation graphs.
1088 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1089 /*dry_run=*/false);
1090 return Status::OK();
1091 }
1092
LimitToAssignedDevice(const Node & node)1093 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
1094 if (node.assigned_device_name_index() < 0) {
1095 return errors::Internal(
1096 "Expected an assigned node as argument to LimitToAssignedDevice but "
1097 "got: ",
1098 node.DebugString());
1099 }
1100 int root = FindAndUpdateRoot(node.id());
1101 Member& root_member = members_[root];
1102 return root_member.AssignDevice(node);
1103 }
1104
GetSoftDeviceCandidates(const Node & node,const Member & root_member,int root_id,std::vector<Device * > * possible_devices)1105 void ColocationGraph::GetSoftDeviceCandidates(
1106 const Node& node, const Member& root_member, int root_id,
1107 std::vector<Device*>* possible_devices) {
1108 // Try to find supported devices that don't violate resource devices.
1109 // The soft_device_name is the same as the requested device name
1110 // without specifying the device type or ID (if assigned and requested
1111 // devices does not specify them).
1112 DeviceNameUtils::ParsedName soft_device_name =
1113 root_member.GetPreferredSoftDeviceName();
1114 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1115 if (!possible_devices->empty()) {
1116 *possible_devices = FilterSupportedDevices(
1117 *possible_devices, root_member.supported_device_types(),
1118 default_local_device_);
1119 }
1120
1121 if (!possible_devices->empty()) {
1122 return;
1123 }
1124
1125 // TODO(iga): Disallow changing resource devices when this ColocationGraph
1126 // is for :
1127 // - a function called by an op requiring deep inspection, or
1128 // - a graph containing ops requiring inspection.
1129 // It is fairly tricky to make changing resource devices in presence of
1130 // ops requiring inspection work correctly. One thing it would require is to
1131 // communicate these "resource movement" decisions across Placer instances.
1132
1133 // Failed to find supported devices that don't violate resource devices.
1134 // Try finding some devices that violated resource devices.
1135 // If we succeed, we will log a warning below.
1136 soft_device_name = root_member.GetSoftDeviceName();
1137 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1138 if (!possible_devices->empty()) {
1139 *possible_devices = FilterSupportedDevices(
1140 *possible_devices, root_member.supported_device_types(),
1141 default_local_device_);
1142 }
1143
1144 if (!possible_devices->empty()) {
1145 LOG(WARNING)
1146 << "Failed to place the graph without changing the devices of some "
1147 "resources. Some of the operations (that had to be colocated with "
1148 "resource generating operations) are not supported on the "
1149 "resources' devices. Current candidate devices are [\n "
1150 << absl::StrJoin(DevicesToString(*possible_devices), "\n ")
1151 << "].\nSee below for details of this colocation group:"
1152 << DebugInfo(root_id);
1153 }
1154 }
1155
LimitToPossibleDevices(const Node & node,const PossibleDevices & devices)1156 Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1157 const PossibleDevices& devices) {
1158 int root = FindAndUpdateRoot(node.id());
1159 Member& root_member = members_[root];
1160 return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1161 }
1162
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)1163 Status ColocationGraph::GetDevicesForNode(
1164 Node* node, const std::vector<Device*>** possible_devices) {
1165 *possible_devices = nullptr;
1166 const int node_root = FindAndUpdateRoot(node->id());
1167 if (!members_[node_root].possible_devices().empty()) {
1168 *possible_devices = &members_[node_root].possible_devices();
1169 return Status::OK();
1170 }
1171
1172 Member& root_member = members_[node_root];
1173
1174 // We have not yet computed the possible devices for the
1175 // colocated node set containing 'node', so we do so now using the
1176 // constraints on the root node.
1177
1178 // "devices" will contain the set of feasible placements for the
1179 // colocated node set containing 'node'.
1180 // NOTE: Basing possible device computation on requested device name
1181 // is guaranteed to respect the assigned and resource device names because
1182 // requested device is always a specialization of both.
1183 std::vector<Device*> devices;
1184 if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1185 // The root node has a (possibly partial) device
1186 // specification, so enumerate the physical devices that
1187 // conform to it.
1188 device_set_.FindMatchingDevices(root_member.requested_device_name(),
1189 &devices);
1190
1191 if (!devices.empty()) {
1192 // Filter devices into those that are compatible with the root
1193 // node (and its children).
1194 devices = FilterSupportedDevices(
1195 devices, root_member.supported_device_types(), default_local_device_);
1196 }
1197
1198 // Perform soft placement if allow_soft_placement_ is set.
1199 if (devices.empty() && allow_soft_placement_) {
1200 GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1201 }
1202
1203 if (devices.empty()) {
1204 // Return an error when a physical device that matches an explicit
1205 // device specification is not found. This ensures that we don't
1206 // assign a node to GPU when the user wanted to force it on CPU.
1207 string debug_info = DebugInfo(node_root);
1208
1209 DeviceNameUtils::ParsedName specified_device_name;
1210 if (DeviceNameUtils::ParseFullName(node->requested_device(),
1211 &specified_device_name) &&
1212 specified_device_name == root_member.requested_device_name()) {
1213 // The specified device and merged set device match, and
1214 // will appear in the GraphDef (for debugging), so just
1215 // print the specified device.
1216 std::vector<Device*> devices_matching_nodedef;
1217 device_set_.FindMatchingDevices(specified_device_name,
1218 &devices_matching_nodedef);
1219 if (devices_matching_nodedef.empty()) {
1220 // Sometimes it is almost impossible to understand the problem
1221 // without a list of available devices.
1222 std::vector<string> device_names;
1223 for (const Device* device : device_set_.devices()) {
1224 device_names.push_back(device->name());
1225 }
1226 std::sort(device_names.begin(), device_names.end());
1227
1228 string gpu_msg = "";
1229 if (!IsGoogleCudaEnabled() &&
1230 absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1231 gpu_msg =
1232 " The requested device appears to be a GPU, but CUDA is not "
1233 "enabled.";
1234 }
1235
1236 return errors::InvalidArgument(
1237 errors::FormatNodeNameForError(node->name()),
1238 " was explicitly assigned to ", node->requested_device(),
1239 " but available devices are [ ",
1240 absl::StrJoin(device_names, ", "), " ]. Make sure ",
1241 "the device specification refers to a valid device.", gpu_msg);
1242 } else if (specified_device_name.has_type) {
1243 return errors::InvalidArgument(
1244 "Could not satisfy explicit device specification '",
1245 node->requested_device(), "' because no supported kernel for ",
1246 specified_device_name.type, " devices is available.", debug_info,
1247 "\nOp: ", node->type_string(),
1248 "\nNode attrs: ", node->attrs().DebugString(),
1249 "\nRegistered kernels:\n",
1250 KernelsRegisteredForOp(node->type_string()));
1251 } else {
1252 return errors::InvalidArgument(
1253 "Could not satisfy explicit device specification '",
1254 node->requested_device(), debug_info);
1255 }
1256 } else {
1257 // The specified device may be a valid device but the
1258 // merged set device is different, so print both.
1259 // TODO(b/129057603): There are many possibilities at this point.
1260 // Provide good error messages.
1261 return errors::InvalidArgument(
1262 "Could not satisfy explicit device specification '",
1263 node->requested_device(), "' because the node ",
1264 errors::FormatColocationNodeForError(node->name()),
1265 " was colocated with a group of nodes that ",
1266 "required incompatible device '",
1267 DeviceNameUtils::ParsedNameToString(
1268 root_member.requested_device_name()),
1269 "'. All available devices [",
1270 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1271 debug_info);
1272 }
1273 }
1274 } else {
1275 // The device is completely unspecified, so enumerate the devices that
1276 // support all of the nodes in the set.
1277 if (device_set_.devices().empty()) {
1278 return errors::Internal("No devices are registered");
1279 }
1280 devices = FilterSupportedDevices(device_set_.devices(),
1281 root_member.supported_device_types(),
1282 default_local_device_);
1283
1284 if (devices.empty()) {
1285 return errors::InvalidArgument(
1286 "Node had no OpKernel registered to support this operation: ",
1287 "Operation was ", node->type_string(), " and inputs were [",
1288 DataTypeVectorString(node->input_types()), "].\n",
1289 DebugInfo(node_root));
1290 }
1291 }
1292
1293 // Cache the result of the possible devices for this node group.
1294 root_member.set_possible_devices(std::move(devices));
1295 *possible_devices = &root_member.possible_devices();
1296 return Status::OK();
1297 }
1298
InitializeMembers()1299 Status ColocationGraph::InitializeMembers() {
1300 for (Node* node : graph_.op_nodes()) {
1301 Status status = InitializeMember(*node, &members_[node->id()]);
1302 if (!status.ok()) {
1303 return AttachDef(status, *node);
1304 }
1305 }
1306 return Status::OK();
1307 }
1308
DebugString() const1309 string ColocationGraph::DebugString() const {
1310 std::unordered_set<int> roots;
1311 std::vector<string> root_strings;
1312 for (const Node* node : graph_.nodes()) {
1313 if (!node->IsOp()) {
1314 continue;
1315 }
1316 int node_root = FindRoot(node->id());
1317 if (roots.count(node_root) == 0) {
1318 root_strings.push_back(DebugInfo(node_root));
1319 roots.insert(node_root);
1320 }
1321 }
1322 return absl::StrJoin(root_strings, "\n");
1323 }
1324
1325 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root) const1326 string ColocationGraph::DebugInfo(const int node_root) const {
1327 string text(
1328 "\nColocation Debug Info:\n"
1329 "Colocation group had the following types and supported devices: ");
1330
1331 // If this node is part of a colocation group, then we want to
1332 // collect the mapping of ops to supported devices, so that
1333 // the user can see why an unsatisfiable placement occurred.
1334
1335 std::unordered_map<string, string> type_to_devices;
1336 std::vector<const Node*> colocation_nodes;
1337 int num_nodes_found = 0;
1338
1339 for (const Node* node : graph_.nodes()) {
1340 if (!node->IsOp()) {
1341 continue;
1342 }
1343 int id = node->id();
1344 if (FindRoot(id) != node_root) {
1345 continue;
1346 }
1347 ++num_nodes_found;
1348 colocation_nodes.push_back(node);
1349
1350 PrioritizedDeviceTypeVector supported_types;
1351 SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1352 &local_address_spec_)
1353 .IgnoreError();
1354 string devices_registered;
1355 for (const auto& device_type : supported_types) {
1356 strings::StrAppend(&devices_registered,
1357 DeviceTypeString(device_type.first), " ");
1358 }
1359
1360 const string& op_type = node->type_string();
1361 type_to_devices[op_type] = std::move(devices_registered);
1362 }
1363 strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1364
1365 for (const auto& td : type_to_devices) {
1366 strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1367 }
1368 strings::StrAppend(&text,
1369 "\n\nColocation members, user-requested devices, and "
1370 "framework assigned devices, if any:");
1371 for (const Node* node : colocation_nodes) {
1372 strings::StrAppend(&text, "\n ", node->name(), " (", node->type_string(),
1373 ") ", node->requested_device());
1374 if (node->has_assigned_device_name()) {
1375 strings::StrAppend(
1376 &text, " framework assigned device=", node->assigned_device_name());
1377 }
1378 }
1379 strings::StrAppend(&text, "\n");
1380
1381 if (num_nodes_found <= 0) {
1382 text.clear();
1383 }
1384 return text;
1385 }
1386
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,Member * member)1387 Status ColocationGraph::InitializeMemberWithAssignedDevice(
1388 const string& assigned_device_name, const string& node_type,
1389 Member* member) {
1390 // This node has already been assigned to a device, so we
1391 // respect this placement, after sanity-checking it.
1392 // NOTE: Since any assignment must have been performed by
1393 // the TensorFlow runtime, we consider errors in this branch to
1394 // be INTERNAL.
1395 TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1396
1397 // Since assigned device must be a full specification, do extra checks.
1398 const Device* assigned_device =
1399 device_set_.FindDeviceByName(assigned_device_name);
1400 if (assigned_device == nullptr) {
1401 // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1402 // calls when they are supported.
1403 return errors::Internal(
1404 "Assigned device '", assigned_device_name,
1405 "' does not match any device. This error can happen when one attempts "
1406 "to run a tf.function with resource inputs residing on remote devices. "
1407 "This use case is currently not supported. Here are the devices "
1408 "available on this machine: [",
1409 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1410 "If you are seeing this error when running using a tf.Session, set "
1411 "share_cluster_devices_in_session to true in the tf.ConfigProto.");
1412 }
1413
1414 for (const auto& d : member->supported_device_types()) {
1415 if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) {
1416 return Status::OK();
1417 }
1418 }
1419
1420 return errors::Internal("Assigned device '", assigned_device_name,
1421 "' does not have registered OpKernel support "
1422 "for ",
1423 node_type);
1424 }
1425
InitializeMember(const Node & node,Member * member)1426 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1427 TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1428 node, device_types_, &local_address_spec_));
1429
1430 if (node.has_assigned_device_name()) {
1431 TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1432 node.assigned_device_name(), node.type_string(), member));
1433 } else {
1434 // This node has not yet been assigned to a device, so we
1435 // calculate any constraints due to the set of registered
1436 // kernels and any (partial) user-provided device specification
1437 // in the NodeDef.
1438
1439 // If no kernels are registered for this op type, fail with an error.
1440 if (member->supported_device_types().empty()) {
1441 std::set<string> registered_device_types;
1442 for (Device* d : device_set_.devices()) {
1443 registered_device_types.insert(d->device_type());
1444 }
1445 return errors::InvalidArgument(
1446 "No OpKernel was registered to support Op '", node.type_string(),
1447 "' used by ", errors::FormatNodeNameForError(node.name()),
1448 " with these attrs: [", node.attrs().DebugString(),
1449 "]\n"
1450 "Registered devices: [",
1451 absl::StrJoin(registered_device_types, ", "), "]\n",
1452 "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1453 }
1454
1455 // If the NodeDef contains a device, then we interpret it as a
1456 // (partial) device specification.
1457 if (!node.requested_device().empty()) {
1458 if (IsRefOrResourceGeneratorNode(node)) {
1459 // Treat requested device on resource generating nodes as assigned
1460 // device so that we don't override it.
1461 TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1462 } else {
1463 // The user has specified a device in the NodeDef, try to find a
1464 // valid device matching their specification in the set of
1465 // devices.
1466 // NOTE: The full name may specify a device that is not in
1467 // n.supported_device_types(), but we check that in AssignDevice().
1468 TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1469 }
1470 }
1471 }
1472 return Status::OK();
1473 }
1474
1475 // Returns a list of devices having type in supported_device_types. The
1476 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_local_device)1477 /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1478 const std::vector<Device*>& devices,
1479 const PrioritizedDeviceTypeVector& supported_device_types,
1480 const Device* default_local_device) {
1481 Device* filtered_default_device = nullptr;
1482 PrioritizedDeviceVector prioritized_filtered_devices;
1483 for (const auto& supported_device_type : supported_device_types) {
1484 for (Device* device : devices) {
1485 if (IsSupportedDeviceType(device->attributes(),
1486 supported_device_type.first)) {
1487 if (default_local_device &&
1488 (device == default_local_device ||
1489 // TODO(nareshmodi, fishx): At times the device pointer in the
1490 // device set is different to the one passed in as the default
1491 // device. Figure out why this might be.
1492 device->name() == default_local_device->name())) {
1493 filtered_default_device = device;
1494 } else {
1495 prioritized_filtered_devices.emplace_back(
1496 device, supported_device_type.second);
1497 }
1498 }
1499 }
1500 }
1501 DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1502
1503 std::vector<Device*> filtered_devices;
1504 if (filtered_default_device != nullptr) {
1505 filtered_devices.emplace_back(filtered_default_device);
1506 }
1507 for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1508 filtered_devices.push_back(prioritized_filtered_device.first);
1509 }
1510 return filtered_devices;
1511 }
1512
1513 } // namespace tensorflow
1514