1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/colocation.h"
17
18 #include <cstring>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/grappler/utils.h"
22
23 namespace tensorflow {
24 namespace grappler {
25
26 namespace {
27
28 // Find root node of the colocation group.
29 // The map is mapping from one node name to its parent. node_name is the
30 // starting node to search. By iteratively following the path from child to
31 // parent, we can find the root node for the colocation group that node_name
32 // belongs to.
GetColocationGroupRoot(std::unordered_map<string,string> * map,const string & node_name)33 string GetColocationGroupRoot(std::unordered_map<string, string>* map,
34 const string& node_name) {
35 if (map->find(node_name) == map->end()) {
36 // If node_name is not in the map, we create a new root node which points
37 // to itself.
38 map->insert({node_name, node_name});
39 return node_name;
40 }
41 string cur = node_name;
42 while ((*map)[cur] != cur) {
43 // Backtracing the map until we reach the root node.
44 cur = (*map)[cur];
45 }
46 return cur;
47 }
48
49 // Merge two colocation groups into one.
50 // left and right is the root node of two colocation groups respectively.
MergeColocationGroup(std::unordered_map<string,string> * map,const string & left,const string & right)51 void MergeColocationGroup(std::unordered_map<string, string>* map,
52 const string& left, const string& right) {
53 // Do nothing if left or right node is not in the map.
54 if (map->find(left) == map->end() || map->find(right) == map->end()) {
55 return;
56 }
57 if (left != right) {
58 // Make the right node a child of the left node, which merges the two
59 // groups.
60 map->at(right) = left;
61 }
62 }
63 } // namespace
64
65 // Use of disjoint set algorithm to build the colocation groups from the input
66 // graph. The core data structure in use is a hash map from one node to its
67 // parent node. Whenever we see two nodes colocate with each other, we merge
68 // their colocation groups together. After we traverse all colocation pairs
69 // in the graph, we will have several disjoint sets. Then we pick the root node
70 // of each disjoint set as the representative node, and let all other nodes in
71 // the group colocate with the representative node.
ReassignColocation(GraphDef * graph)72 void ReassignColocation(GraphDef* graph) {
73 constexpr char kClassAttr[] = "_class";
74 constexpr char kColocPrefix[] = "loc:@";
75
76 // A hashmap that maps from a node name to its parent node name.
77 std::unordered_map<string, string> coloc_groups;
78 NodeMap node_map(graph);
79 for (const auto& node : graph->node()) {
80 auto iter = node.attr().find(kClassAttr);
81 if (iter != node.attr().end() && iter->second.has_list()) {
82 for (const auto& str : iter->second.list().s()) {
83 size_t pos = str.find(kColocPrefix);
84 if (pos == 0) {
85 // After we find a colocation, update the colocation groups.
86 string colocate_node = str.substr(pos + strlen(kColocPrefix));
87 MergeColocationGroup(
88 &coloc_groups, GetColocationGroupRoot(&coloc_groups, node.name()),
89 GetColocationGroupRoot(&coloc_groups, colocate_node));
90 }
91 }
92 }
93 }
94
95 // We use the root node of each colocation groups as its representative
96 // node. For each node in one group, colocate with the representative node
97 // if the node is in the graph.
98 for (const auto& pair : coloc_groups) {
99 if (pair.first != pair.second) {
100 // This is a child node.
101 NodeDef* node = node_map.GetNode(pair.first);
102 if (node) {
103 // Colocate this node with the root node.
104 AttrValue new_value;
105 new_value.mutable_list()->add_s(
106 kColocPrefix + GetColocationGroupRoot(&coloc_groups, pair.first));
107 node->mutable_attr()->erase(kClassAttr);
108 node->mutable_attr()->insert({kClassAttr, new_value});
109 }
110 } else {
111 // This is a root node. Clear the _class attribute.
112 NodeDef* node = node_map.GetNode(pair.first);
113 if (node) { // root node should always exist in the graph as guaranteed
114 // by order of merging. Just put check here to ensure safety.
115 node->mutable_attr()->erase(kClassAttr);
116 }
117 }
118 }
119 }
120
121 } // namespace grappler
122 } // namespace tensorflow
123