1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17
18 #include <memory>
19 #include <set>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_set.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 #include "tensorflow/core/util/dump_graph.h"
38 #include "tensorflow/core/util/port.h"
39
40 namespace tensorflow {
41
42 namespace {
43
44 // We hoist the conversion from C-style string literal to StringPiece here,
45 // so that we can avoid the many repeated calls to strlen().
46 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
47 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
48
49 // Returns a list of devices having type in supported_device_types. The
50 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_device)51 std::vector<Device*> FilterSupportedDevices(
52 const std::vector<Device*>& devices,
53 const PrioritizedDeviceTypeVector& supported_device_types,
54 const Device* default_device) {
55 Device* filtered_default_device = nullptr;
56 std::vector<std::pair<Device*, int32>> prioritized_filtered_devices;
57 for (const auto& supported_device_type : supported_device_types) {
58 for (Device* device : devices) {
59 if (DeviceType(device->attributes().device_type()) ==
60 supported_device_type.first) {
61 if (device == default_device) {
62 filtered_default_device = device;
63 } else {
64 prioritized_filtered_devices.emplace_back(
65 device, supported_device_type.second);
66 }
67 }
68 }
69 }
70
71 auto device_sort = [](const std::pair<Device*, int32>& a,
72 const std::pair<Device*, int32>& b) {
73 if (a.second != b.second) {
74 return a.second > b.second;
75 }
76
77 auto a_priority =
78 DeviceSet::DeviceTypeOrder(DeviceType(a.first->device_type()));
79 auto b_priority =
80 DeviceSet::DeviceTypeOrder(DeviceType(b.first->device_type()));
81 // First sort by prioritized device type (higher is preferred) and
82 // then by device name (lexicographically).
83 if (a_priority != b_priority) {
84 return a_priority > b_priority;
85 }
86 return StringPiece(a.first->name()) < StringPiece(b.first->name());
87 };
88 std::sort(prioritized_filtered_devices.begin(),
89 prioritized_filtered_devices.end(), device_sort);
90
91 std::vector<Device*> filtered_devices;
92 if (filtered_default_device != nullptr) {
93 filtered_devices.emplace_back(filtered_default_device);
94 }
95 for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
96 filtered_devices.push_back(prioritized_filtered_device.first);
97 }
98 return filtered_devices;
99 }
100
101 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)102 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
103 std::vector<string> v;
104 v.reserve(devices.size());
105 for (Device* d : devices) {
106 v.push_back(d->name());
107 }
108 return v;
109 }
110
111 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)112 std::vector<string> DeviceTypeAndPriorityToString(
113 const PrioritizedDeviceTypeVector& devices) {
114 std::vector<string> v;
115 v.reserve(devices.size());
116 for (const std::pair<DeviceType, int32>& device_and_type : devices) {
117 v.push_back(DeviceTypeString(device_and_type.first));
118 }
119 return v;
120 }
121
122 // While Placer can override requested device on ops processing
123 // resources, i.e. node that take (and potentially return) a resource,
124 // it must not override requested device on ops generating a resource,
125 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
126 // output nodes.
IsResourceGeneratorNode(const Node & node)127 bool IsResourceGeneratorNode(const Node& node) {
128 return node.num_inputs() == 0 && node.num_outputs() == 1 &&
129 (IsRefType(node.output_type(0)) || node.output_type(0) == DT_RESOURCE);
130 }
131
IsExemptFromResourceInputColocation(const Node * node)132 bool IsExemptFromResourceInputColocation(const Node* node) {
133 // Note: Partitioned function calls, which place and partition their
134 // function bodies, are exempt from this check: they forward resource and
135 // ref inputs to operations that are appropriately placed, instead of
136 // dereferencing them.
137 const string& op_type = node->op_def().name();
138 return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
139 }
140
HasPriorities(const PrioritizedDeviceTypeVector & device_types)141 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
142 for (const auto& prioritized_device_type : device_types) {
143 if (prioritized_device_type.second != 0) return true;
144 }
145 return false;
146 }
147
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)148 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
149 const PrioritizedDeviceTypeVector& b_types) {
150 if (a_types.size() != b_types.size()) {
151 return false;
152 }
153 for (int i = 0; i < a_types.size(); ++i) {
154 if (a_types[i].first != b_types[i].first) {
155 return false;
156 }
157 }
158 return true;
159 }
160
161 } // namespace
162
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types)163 Status Member::SetParentAndSupportedDevices(
164 const Node& node, const std::vector<DeviceType>& types) {
165 int id = node.id();
166 if (id < 0) {
167 return errors::Internal("Placer should not be creating a Member for node: ",
168 node.DebugString());
169 }
170 parent_ = id;
171 return SupportedDeviceTypesForNode(types, node.def(),
172 &supported_device_types_);
173 }
174
SetAssignedDeviceName(const string & device_name)175 Status Member::SetAssignedDeviceName(const string& device_name) {
176 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
177 return errors::Internal(
178 "Setting assigned device name when there is a requested device set "
179 "is unsupported");
180 }
181 if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
182 return errors::Internal("Malformed assigned device '", device_name, "'");
183 }
184 // Set requested device to assigned_device to maintain the invariant that
185 // requested is a specialization of assigned.
186 requested_device_name_ = assigned_device_name_;
187 return Status::OK();
188 }
189
SetRequestedDeviceName(const Node & node)190 Status Member::SetRequestedDeviceName(const Node& node) {
191 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
192 &requested_device_name_)) {
193 return errors::InvalidArgument("Malformed device specification '",
194 node.requested_device(),
195 "' in node: ", node.DebugString());
196 }
197 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
198 return errors::Internal(
199 "Setting requested device name when there is an assigned device set "
200 "is unsupported");
201 }
202 return Status::OK();
203 }
204
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)205 Status Member::EnsureCompatibilityAcrossResourceEdge(
206 const Node& src, const Member& src_root,
207 const Node& dst, /*dst_root is this*/
208 bool log_device_placement) {
209 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
210 assigned_device_name_)) {
211 return errors::InvalidArgument(
212 "Cannot place the graph because a reference or resource edge "
213 "connects colocation groups with incompatible assigned devices: ",
214 DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
215 " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
216 ". The edge src node is ", src.name(), " , and the dst node is ",
217 dst.name());
218 }
219
220 if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
221 requested_device_name_)) {
222 return Status::OK();
223 }
224
225 // If we are here, assigned devices are compatible but requested ones are
226 // not. We will be overriding the requested device for destination node, but
227 // need to preserve the invariant that it will be a specialization of
228 // the assigned device.
229 if (log_device_placement) {
230 LOG(INFO) << "Ignoring device specification "
231 << DeviceNameUtils::ParsedNameToString(requested_device_name_)
232 << " for node '" << dst.name()
233 << "' because the input edge from '" << src.name()
234 << "' is a reference connection and already has a device "
235 "field set to "
236 << DeviceNameUtils::ParsedNameToString(
237 src_root.requested_device_name_);
238 }
239 requested_device_name_ = src_root.requested_device_name_;
240 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
241 assigned_device_name_);
242 return Status::OK();
243 }
244
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)245 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
246 Member** new_root, Member** old_root, bool dry_run) {
247 Member& x_root_member = (*tree)[x_root];
248 Member& y_root_member = (*tree)[y_root];
249
250 // Merge the sets by setting the parent pointer of the smaller tree's root
251 // node to point to the root of the larger tree. Together with path
252 // compression in ColocationGraph::FindRoot, this ensures that we do not
253 // experience pathological performance on graphs such as chains.
254 int new_root_id, old_root_id;
255 if (x_root_member.rank_ < y_root_member.rank_) {
256 // The tree rooted at x_root is shallower, so connect it to
257 // y_root. The rank of y_root is unchanged because its new
258 // child has strictly less rank.
259 if (!dry_run) {
260 x_root_member.parent_ = y_root;
261 }
262 new_root_id = y_root;
263 old_root_id = x_root;
264 } else if (x_root_member.rank_ > y_root_member.rank_) {
265 // The tree rooted at y_root is shallower, so connect it to
266 // x_root. The rank of x_root is unchanged because its new
267 // child has strictly less rank.
268 if (!dry_run) {
269 y_root_member.parent_ = x_root;
270 }
271 new_root_id = x_root;
272 old_root_id = y_root;
273 } else {
274 if (!dry_run) {
275 // Both trees have the same rank, so break the tie by choosing
276 // x_root as the new root.
277 y_root_member.parent_ = x_root;
278 // Increment the rank of the tree rooted at x_root, because it
279 // is now strictly deeper than before.
280 ++x_root_member.rank_;
281 }
282 new_root_id = x_root;
283 old_root_id = y_root;
284 }
285
286 *new_root = &(*tree)[new_root_id];
287 *old_root = &(*tree)[old_root_id];
288 }
289
290 // tree is non-const because we can change some `parent` pointers in some
291 // members for more efficient future lookups. The vector itself is not
292 // changed.
FindRoot(std::vector<Member> * tree,int node_id)293 int Member::FindRoot(std::vector<Member>* tree, int node_id) {
294 Member& member = (*tree)[node_id];
295 if (member.parent_ == node_id) {
296 // member.parent is the root of this disjoint tree. Do nothing.
297 } else {
298 member.parent_ = FindRoot(tree, member.parent_);
299 }
300 // Now it is guaranteed that member.parent is the root of this disjoint
301 // tree.
302 return member.parent_;
303 }
304
MergeDeviceNames(const Member & other,bool allow_soft_placement)305 Status Member::MergeDeviceNames(const Member& other,
306 bool allow_soft_placement) {
307 // Assuming the "requested is a specialization of assigned" invariant holds
308 // for this and `other`, it will hold after the two merges below.
309 DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
310 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
311 &assigned_device_name_copy, other.assigned_device_name_));
312
313 DeviceNameUtils::ParsedName requested_device_name_copy =
314 requested_device_name_;
315 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
316 &requested_device_name_copy, other.requested_device_name_,
317 allow_soft_placement));
318
319 // We checked for all errors, now change the devices.
320 assigned_device_name_ = assigned_device_name_copy;
321 requested_device_name_ = requested_device_name_copy;
322 return Status::OK();
323 }
324
325 // Updates this to contain the intersection of the device types in
326 // this and "other".
MergeSupportedDevices(const Member & other)327 bool Member::MergeSupportedDevices(const Member& other) {
328 // Generate intersection with priorities.
329 // Each vector contains the same device types but with different priorities.
330 // The priorities are taken from the corresponding source vector.
331 PrioritizedDeviceTypeVector target_intersection;
332 PrioritizedDeviceTypeVector other_intersection;
333 for (const auto& prioritized_device_type : supported_device_types_) {
334 bool found = false;
335 for (const auto& other_prioritized_device_type :
336 other.supported_device_types_) {
337 if (prioritized_device_type.first ==
338 other_prioritized_device_type.first) {
339 found = true;
340 other_intersection.push_back(other_prioritized_device_type);
341 break;
342 }
343 }
344 if (found) {
345 target_intersection.push_back(prioritized_device_type);
346 }
347 }
348
349 // Sort the devices by priority order.
350 auto device_sort = [](const std::pair<DeviceType, int32>& a,
351 const std::pair<DeviceType, int32>& b) {
352 // First look at set priorities.
353 if (a.second != b.second) {
354 return a.second > b.second;
355 }
356 // Then fallback to default priorities.
357 auto a_priority = DeviceSet::DeviceTypeOrder(a.first);
358 auto b_priority = DeviceSet::DeviceTypeOrder(b.first);
359 if (a_priority != b_priority) {
360 return a_priority > b_priority;
361 }
362 // Finally just look at the Device type strings.
363 return a.first.type_string() < b.first.type_string();
364 };
365
366 std::sort(target_intersection.begin(), target_intersection.end(),
367 device_sort);
368 std::sort(other_intersection.begin(), other_intersection.end(), device_sort);
369
370 PrioritizedDeviceTypeVector result;
371
372 bool is_target_prioritized = HasPriorities(target_intersection);
373 bool is_other_prioritized = HasPriorities(other_intersection);
374 if (!is_target_prioritized && !is_other_prioritized) {
375 // If neither are prioritized then we just return the original i.e. target
376 // prioritization.
377 result = target_intersection;
378 } else if (is_target_prioritized && !is_other_prioritized) {
379 // If only one is prioritized, then we respect priorities of that in the
380 // intersection.
381 result = target_intersection;
382 } else if (!is_target_prioritized && is_other_prioritized) {
383 result = other_intersection;
384 } else {
385 // If both have priorities and agree then we go with that. If the
386 // prioritization order is different, then we just fallback to the default
387 // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
388 // merged priorities to 0, so that downstream merges work correctly as well.
389 if (ArePrioritiesSame(target_intersection, other_intersection)) {
390 result = target_intersection;
391 } else {
392 for (const auto& prioritized_device : target_intersection) {
393 result.push_back(std::make_pair(prioritized_device.first, 0));
394 }
395 std::sort(result.begin(), result.end(), device_sort);
396 }
397 }
398
399 if (result.empty()) {
400 return false;
401 }
402 supported_device_types_ = result;
403 return true;
404 }
405
AssignDevice(const Node & node,bool allow_soft_placement)406 Status Member::AssignDevice(const Node& node, bool allow_soft_placement) {
407 if (node.assigned_device_name_index() == assigned_device_name_index_) {
408 return Status::OK();
409 }
410
411 DeviceNameUtils::ParsedName parsed;
412 DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
413 Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed,
414 allow_soft_placement);
415 if (!s.ok()) {
416 return errors::Internal(
417 "Constraining by assigned device should not cause an error. Original "
418 "root's assigned device name: ",
419 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
420 " node's assigned device name \"", node.assigned_device_name(),
421 ". Error: ", s.error_message());
422 }
423 s = DeviceNameUtils::MergeDevNames(&requested_device_name_, parsed,
424 allow_soft_placement);
425 if (!s.ok()) {
426 return errors::Internal(
427 "Constraining by assigned device should not cause an error. Original "
428 "root's requested device name: \"",
429 DeviceNameUtils::ParsedNameToString(requested_device_name_),
430 "\", node's assigned device name \"", node.assigned_device_name(),
431 "\". Error: ", s.error_message());
432 }
433
434 assigned_device_name_index_ = node.assigned_device_name_index();
435 // Clear cached possible_devices, if any.
436 possible_devices_.clear();
437 return Status::OK();
438 }
DebugString()439 string Member::DebugString() {
440 return absl::StrCat(
441 "Member(assigned_device_name_index_=", assigned_device_name_index_,
442 " requested_device_name_=",
443 DeviceNameUtils::ParsedNameToString(requested_device_name_),
444 " assigned_device_name_=",
445 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
446 " supported_device_types_=[",
447 absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
448 ", "),
449 "] possible_devices_=[",
450 absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
451 }
ColocationGraph(const Graph * graph,const DeviceSet * device_set,const Device * default_device,bool allow_soft_placement,bool log_device_placement)452 ColocationGraph::ColocationGraph(const Graph* graph,
453 const DeviceSet* device_set,
454 const Device* default_device,
455 bool allow_soft_placement,
456 bool log_device_placement)
457 : graph_(graph),
458 device_set_(device_set),
459 device_types_(device_set->PrioritizedDeviceTypeList()),
460 default_device_(default_device),
461 allow_soft_placement_(allow_soft_placement),
462 log_device_placement_(log_device_placement) {
463 members_.resize(graph->num_node_ids());
464 }
465
466 // Adds each node of the Graph to this ColocationGraph as a singleton.
467 //
468 // NOTE: The implementation assumes that the ids of nodes passed to
469 // this method are dense and zero-based; the memory used will be linear in
470 // the largest node ID.
471 // NOTE: If this method returns an error, *this is left in an undefined
472 // state.
ColocateAllNodes()473 Status ColocationGraph::ColocateAllNodes() {
474 // This maps from a colocation group identifier to the 'root' of that
475 // colocation group. Note that the keys in this map are StringPiece; the
476 // actual strings are stored under the NodeDef. The lifetime of this map
477 // is limited to this ColocateAllNodes() method, and no part of the
478 // NodeDef trees are changed during the lifetime of this method, so using
479 // StringPiece as a key is safe.
480 //
481 // Also, as a further optimization, we remove the "loc:@" prefix from
482 // "class" attribute values, when they are used as keys in this table.
483 // This allows us to use StringPiece values that refer to substrings of
484 // 'string' values stored in NodeDef attribute lists, as well as StringPiece
485 // values that refer to 'string' values from NodeDef::name(), without
486 // performing any string allocations.
487 std::unordered_map<StringPiece, const Node*, StringPieceHasher>
488 colocation_group_root;
489
490 for (const Node* node : graph_->op_nodes()) {
491 // When adding the node, identify whether it is part of a colocation
492 // group.
493
494 // This code is effectively the equivalent of GetNodeAttr() for a string
495 // array, but it avoids all internal allocations (the allocation of the
496 // backing store of the std::vector<string> as well as the copies of the
497 // strings within it). Instead, we combine the query of the colocation
498 // attribute with the calls to ColocateNodeToGroup.
499 bool found_spec = false;
500 const AttrValue* attr_value =
501 node->attrs().Find(kColocationAttrNameStringPiece);
502 if (attr_value != nullptr && attr_value->has_list()) {
503 for (const string& class_spec : attr_value->list().s()) {
504 StringPiece spec(class_spec);
505 if (str_util::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
506 found_spec = true;
507 TF_RETURN_IF_ERROR(
508 ColocateNodeToGroup(&colocation_group_root, node, spec));
509 }
510 }
511 }
512
513 // TODO(iga): Even when the node has a spec, we need to colocate the
514 // node to its "name group" because other nodes can still use
515 // "loc:@<this_node_name>" in their colocation specs.
516 if (!found_spec) {
517 // If the node does not specify a colocation group, then use the
518 // name of this node as the colocation group.
519 TF_RETURN_IF_ERROR(
520 ColocateNodeToGroup(&colocation_group_root, node, node->name()));
521 }
522 }
523
524 return Status::OK();
525 }
526
ColocateResourceOrRefEdge(Node * src,Node * dst)527 Status ColocationGraph::ColocateResourceOrRefEdge(Node* src, Node* dst) {
528 // Colocate `src` and `dst` to maintain the invariant that nodes
529 // connected by reference edges are colocated.
530 int src_root_id = FindRoot(src->id());
531 int dst_root_id = FindRoot(dst->id());
532 auto& src_root = members_[src_root_id];
533 auto& dst_root = members_[dst_root_id];
534
535 TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
536 *src, src_root, *dst, log_device_placement_));
537 Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
538 if (!status.ok()) {
539 return AttachDef(
540 errors::InvalidArgument("Nodes were connected by a "
541 "reference connection (requiring them to "
542 "be on the same device), but the two nodes "
543 "were assigned two different devices: ",
544 status.error_message()),
545 *dst);
546 }
547 return Status::OK();
548 }
549
ColocateResourceAndRefEdges()550 Status ColocationGraph::ColocateResourceAndRefEdges() {
551 // Enumerate the constraint edges, and use them to update the disjoint
552 // node set.
553 // If `node` has an input edge with reference type, add an edge from the
554 // source of that edge to `node`.
555 for (const Edge* edge : graph_->edges()) {
556 if (edge->IsControlEdge()) {
557 continue;
558 }
559 Node* src = edge->src();
560 Node* dst = edge->dst();
561 DataType input_type = dst->input_type(edge->dst_input());
562 if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
563 !IsExemptFromResourceInputColocation(dst)) {
564 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
565 }
566 }
567 return Status::OK();
568 }
569
Initialize()570 Status ColocationGraph::Initialize() {
571 TF_RETURN_IF_ERROR(InitializeMembers());
572 TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges());
573 TF_RETURN_IF_ERROR(ColocateAllNodes());
574 return Status::OK();
575 }
576
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)577 Status ColocationGraph::ColocateNodeToGroup(
578 std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
579 colocation_group_root,
580 const Node* node, StringPiece colocation_group) {
581 const Node*& root_node = (*colocation_group_root)[colocation_group];
582 if (root_node == nullptr) {
583 // This is the first node of the colocation group, so
584 // designate this node as the 'root' of that colocation group.
585 root_node = node;
586 } else {
587 // Try to colocate the node with the root. If there is an
588 // error, return it.
589 Status s = ColocateNodes(*node, *root_node);
590 if (!s.ok()) {
591 if (!allow_soft_placement_) {
592 return AttachDef(s, *node);
593 }
594 if (log_device_placement_) {
595 LOG(INFO) << "Ignoring request to colocate node '" << node->name()
596 << "' with nodes in colocation group '" << colocation_group
597 << "' because soft placement is on and an attempt at doing "
598 "so resulted in the following error: "
599 << AttachDef(s, *node).ToString();
600 }
601 }
602 }
603 return Status::OK();
604 }
605
606 // Merge the (possibly disjoint) sets containing nodes "x" and
607 // "y". Returns OK if the all nodes in the union of these sets can
608 // be placed on the same device type.
609 //
610 // NOTE: If this method returns an error, *this is left in an undefined
611 // state.
ColocateNodes(const Node & x,const Node & y)612 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
613 int x_root = FindRoot(x.id());
614 int y_root = FindRoot(y.id());
615 return ColocateNodes(x, x_root, y, y_root);
616 }
617
618 // This overload of ColocateNodes() allows a caller to provide the root node
619 // ids for the two nodes. For large graphs, this noticeably reduces the
620 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)621 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
622 int y_root) {
623 if (x_root == y_root) {
624 return Status::OK();
625 }
626
627 Member* new_root_member;
628 Member* old_root_member;
629 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
630 /*dry_run=*/true);
631
632 // Merge the partial device specifications, and ensure that they are
633 // compatible. NULL options_ is treated as allowing soft placement.
634 // If there is an error, nothing is modified.
635 // TODO(mrry): Consider enriching the error message by pointing
636 // out which nodes have the explicit partial device
637 // specifications that caused this conflict.
638 Status s = new_root_member->MergeDeviceNames(*old_root_member,
639 allow_soft_placement_);
640 if (!s.ok()) {
641 return errors::InvalidArgument(
642 "Cannot colocate nodes ",
643 errors::FormatColocationNodeForError(x.name()), " and ",
644 errors::FormatColocationNodeForError(y.name()), ": ",
645 s.error_message());
646 }
647
648 // Ensure that the common root has at least one supported device
649 // type, by computing the intersection of
650 // new_root_member.supported_device_types and
651 // old_root_member.supported_device_types.
652 if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
653 return errors::InvalidArgument(
654 "Cannot colocate nodes ",
655 errors::FormatColocationNodeForError(x.name()), " and ",
656 errors::FormatColocationNodeForError(y.name()),
657 " because no device type supports both of those nodes and the "
658 "other nodes colocated with them.",
659 DebugInfo(x_root), DebugInfo(y_root));
660 }
661
662 // All error checks are done, merge the colocation graphs.
663 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
664 /*dry_run=*/false);
665 return Status::OK();
666 }
667
668 // Limits the possible devices of `node`'s colocation group to the device
669 // to which `node` is assigned. This makes sure that all nodes in this
670 // colocation group will be assigned to the same device. Without this
671 // explicit restriction, heuristics can choose a different possible device
672 // for other nodes in the group.
LimitToAssignedDevice(const Node & node)673 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
674 if (node.assigned_device_name_index() < 0) {
675 return errors::Internal(
676 "Expected an assigned node as argument to LimitToAssignedDevice but "
677 "got: ",
678 node.DebugString());
679 }
680 int root = FindRoot(node.id());
681 Member& root_member = members_[root];
682 return root_member.AssignDevice(node, allow_soft_placement_);
683 }
684
685 // For the given node, subject to the constraints previously given
686 // to this ColocationGraph, set its assigned_device_name. Returns OK
687 // if a satisfying device can be found, otherwise an error.
688 //
689 // Note: This method returns a pointer to a field within members_.
690 // The caller must not use the returned pointer after there is any possibility
691 // that the members_[i].possible_devices field has been modified.
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)692 Status ColocationGraph::GetDevicesForNode(
693 Node* node, const std::vector<Device*>** possible_devices) {
694 *possible_devices = nullptr;
695 const int node_root = FindRoot(node->id());
696 if (!members_[node_root].possible_devices().empty()) {
697 *possible_devices = &members_[node_root].possible_devices();
698 return Status::OK();
699 }
700
701 // We have not yet computed the possible devices for the
702 // colocated node set containing 'node', so we do so now using the
703 // constraints on the root node.
704
705 // "devices" will contain the set of feasible placements for the
706 // colocated node set containing 'node'.
707 std::vector<Device*> devices;
708 if (DeviceNameUtils::HasSomeDetails(
709 members_[node_root].requested_device_name())) {
710 // The root node has a (possibly partial) device
711 // specification, so enumerate the physical devices that
712 // conform to it.
713 device_set_->FindMatchingDevices(
714 members_[node_root].requested_device_name(), &devices);
715
716 if (!devices.empty()) {
717 // Filter devices into those that are compatible with the root
718 // node (and its children).
719 devices = FilterSupportedDevices(
720 devices, members_[node_root].supported_device_types(),
721 default_device_);
722 }
723
724 // Perform soft placement if allow_soft_placement_ is set.
725 if (devices.empty() && allow_soft_placement_) {
726 // The soft_device_name is the same as the node's device name
727 // without specifying the device type or ID.
728 DeviceNameUtils::ParsedName soft_device_name =
729 members_[node_root].requested_device_name();
730 soft_device_name.type.clear();
731 soft_device_name.has_type = false;
732 soft_device_name.has_id = false;
733 device_set_->FindMatchingDevices(soft_device_name, &devices);
734 if (!devices.empty()) {
735 devices = FilterSupportedDevices(
736 devices, members_[node_root].supported_device_types(),
737 default_device_);
738 }
739 }
740
741 if (devices.empty()) {
742 // Return an error when a physical device that matches an explicit
743 // device specification is not found. This ensures that we don't
744 // assign a node to GPU when the user wanted to force it on CPU.
745 string debug_info = DebugInfo(node_root);
746
747 DeviceNameUtils::ParsedName specified_device_name;
748 if (DeviceNameUtils::ParseFullName(node->requested_device(),
749 &specified_device_name) &&
750 specified_device_name ==
751 members_[node_root].requested_device_name()) {
752 // The specified device and merged set device match, and
753 // will appear in the GraphDef (for debugging), so just
754 // print the specified device.
755 std::vector<Device*> devices_matching_nodedef;
756 device_set_->FindMatchingDevices(specified_device_name,
757 &devices_matching_nodedef);
758 if (devices_matching_nodedef.empty()) {
759 // Sometimes it is almost impossible to understand the problem
760 // without a list of available devices.
761 std::vector<string> device_names;
762 for (const Device* device : device_set_->devices()) {
763 device_names.push_back(device->name());
764 }
765 std::sort(device_names.begin(), device_names.end());
766
767 string gpu_msg = "";
768 if (!IsGoogleCudaEnabled() &&
769 str_util::Lowercase(specified_device_name.type) == "gpu") {
770 gpu_msg =
771 " The requested device appears to be a GPU, but CUDA is not "
772 "enabled.";
773 }
774
775 return errors::InvalidArgument(
776 errors::FormatNodeNameForError(node->name()),
777 "was explicitly assigned to ", node->requested_device(),
778 " but available devices are [ ",
779 str_util::Join(device_names, ", "), " ]. Make sure ",
780 "the device specification refers to a valid device.", gpu_msg);
781 } else if (specified_device_name.has_type) {
782 return errors::InvalidArgument(
783 "Could not satisfy explicit device specification '",
784 node->requested_device(), "' because no supported kernel for ",
785 specified_device_name.type, " devices is available.", debug_info,
786 "\nRegistered kernels:\n",
787 KernelsRegisteredForOp(node->type_string()));
788 } else {
789 return errors::InvalidArgument(
790 "Could not satisfy explicit device specification '",
791 node->requested_device(), debug_info);
792 }
793 } else {
794 // The specified device may be a valid device but the
795 // merged set device is different, so print both.
796 return errors::InvalidArgument(
797 "Could not satisfy explicit device specification '",
798 node->requested_device(), "' because the node ",
799 errors::FormatColocationNodeForError(node->name()),
800 " was colocated with a group of nodes that ",
801 "required incompatible device '",
802 DeviceNameUtils::ParsedNameToString(
803 members_[node_root].requested_device_name()),
804 "'", debug_info);
805 }
806 }
807 } else {
808 // The device is completely unspecified, so enumerate the devices that
809 // support all of the nodes in the set.
810 if (device_set_->devices().empty()) {
811 return errors::Internal("No devices are registered");
812 }
813 devices = FilterSupportedDevices(
814 device_set_->devices(), members_[node_root].supported_device_types(),
815 default_device_);
816
817 if (devices.empty()) {
818 return errors::InvalidArgument(
819 "Node had no OpKernel registered to support this operation: ",
820 "Operation was ", node->type_string(), " and inputs were ",
821 DataTypeVectorString(node->input_types()), DebugInfo(node_root));
822 }
823 }
824
825 // Cache the result of the possible devices for this node group.
826 members_[node_root].set_possible_devices(std::move(devices));
827 *possible_devices = &members_[node_root].possible_devices();
828 return Status::OK();
829 }
830
InitializeMembers()831 Status ColocationGraph::InitializeMembers() {
832 for (Node* node : graph_->op_nodes()) {
833 Status status = InitializeMember(*node, &members_[node->id()]);
834 if (!status.ok()) {
835 return AttachDef(status, *node);
836 }
837 }
838 return Status::OK();
839 }
840
DebugString()841 string ColocationGraph::DebugString() {
842 std::unordered_set<int> roots;
843 std::vector<string> root_strings;
844 for (const Node* node : graph_->nodes()) {
845 if (!node->IsOp()) {
846 continue;
847 }
848 int node_root = FindRoot(node->id());
849 if (roots.count(node_root) == 0) {
850 root_strings.push_back(DebugInfo(node_root));
851 roots.insert(node_root);
852 }
853 }
854 return absl::StrJoin(root_strings, "\n");
855 }
856
857 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root)858 string ColocationGraph::DebugInfo(const int node_root) {
859 string text(
860 "\nColocation Debug Info:\n"
861 "Colocation group had the following types and devices: ");
862
863 // If this node is part of a colocation group, then we want to
864 // collect the mapping of ops to supported devices, so that
865 // the user can see why an unsatisfiable placement occurred.
866
867 std::unordered_map<string, string> type_to_devices;
868 std::vector<const Node*> colocation_nodes;
869 int num_nodes_found = 0;
870
871 for (const Node* node : graph_->nodes()) {
872 if (!node->IsOp()) {
873 continue;
874 }
875 int id = node->id();
876 if (FindRoot(id) != node_root) {
877 continue;
878 }
879 ++num_nodes_found;
880 colocation_nodes.push_back(node);
881 const string& op_type = node->type_string();
882 string devices_registered;
883 for (const auto& device_type : members_[id].supported_device_types()) {
884 strings::StrAppend(&devices_registered,
885 DeviceTypeString(device_type.first), " ");
886 }
887
888 type_to_devices[op_type] = std::move(devices_registered);
889 }
890
891 for (const auto& td : type_to_devices) {
892 strings::StrAppend(&text, "\n", td.first, ": ", td.second);
893 }
894 strings::StrAppend(&text,
895 "\n\nColocation members and user-requested devices:");
896 for (const Node* node : colocation_nodes) {
897 strings::StrAppend(&text, "\n ", node->name(), " (", node->type_string(),
898 ") ", node->requested_device());
899 }
900 strings::StrAppend(&text, "\n");
901
902 if (num_nodes_found <= 0) {
903 text.clear();
904 }
905 return text;
906 }
907
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,bool must_be_full_name,Member * member)908 Status ColocationGraph::InitializeMemberWithAssignedDevice(
909 const string& assigned_device_name, const string& node_type,
910 bool must_be_full_name, Member* member) {
911 // This node has already been assigned to a device, so we
912 // respect this placement, after sanity-checking it.
913 // NOTE: Since any assignment must have been performed by
914 // the TensorFlow runtime, we consider errors in this branch to
915 // be INTERNAL.
916 TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
917 if (!must_be_full_name) {
918 return Status::OK();
919 }
920 // Since assigned device must be a full specification, do extra checks.
921 const Device* assigned_device =
922 device_set_->FindDeviceByName(assigned_device_name);
923 if (assigned_device == nullptr) {
924 return errors::Internal("Assigned device '", assigned_device_name,
925 "' does not match any device");
926 }
927
928 for (const auto& d : member->supported_device_types()) {
929 if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
930 return Status::OK();
931 }
932 }
933
934 return errors::Internal("Assigned device '", assigned_device_name,
935 "' does not have registered OpKernel support "
936 "for ",
937 node_type);
938 }
939
InitializeMember(const Node & node,Member * member)940 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
941 TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(node, device_types_));
942
943 if (node.has_assigned_device_name()) {
944 TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
945 node.assigned_device_name(), node.type_string(), true, member));
946 } else {
947 // This node has not yet been assigned to a device, so we
948 // calculate any constraints due to the set of registered
949 // kernels and any (partial) user-provided device specification
950 // in the NodeDef.
951
952 // If no kernels are registered for this op type, fail with an error.
953 if (member->supported_device_types().empty()) {
954 std::set<string> registered_device_types;
955 for (Device* d : device_set_->devices()) {
956 registered_device_types.insert(d->device_type());
957 }
958 std::vector<string> attr_key_vals;
959 for (const auto& it : node.attrs()) {
960 const string& name = it.first;
961 const AttrValue& attr_value = it.second;
962 attr_key_vals.push_back(
963 strings::StrCat(name, "=", SummarizeAttrValue(attr_value)));
964 }
965 return errors::InvalidArgument(
966 "No OpKernel was registered to support Op '", node.type_string(),
967 "' used by ", errors::FormatNodeNameForError(node.name()),
968 "with these attrs: [", str_util::Join(attr_key_vals, ", "),
969 "]\n"
970 "Registered devices: [",
971 str_util::Join(registered_device_types, ", "), "]\n",
972 "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
973 }
974
975 // If the NodeDef contains a device, then we interpret it as a
976 // (partial) device specification.
977 if (!node.requested_device().empty()) {
978 if (IsResourceGeneratorNode(node)) {
979 // Treat requested device on resource generating nodes as assigned
980 // device so that we don't override it.
981 TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
982 node.requested_device(), node.type_string(), false, member));
983 } else {
984 // The user has specified a device in the NodeDef, try to find a
985 // valid device matching their specification in the set of
986 // devices.
987 // NOTE: The full name may specify a device that is not in
988 // n.supported_device_types(), but we check that in AssignDevice().
989 TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
990 }
991 }
992 }
993 return Status::OK();
994 }
995
996 } // namespace tensorflow
997