1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
17
18 #include <functional>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_format.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
24
25 namespace tensorflow {
26 namespace grappler {
27 namespace graph_analyzer {
28
29 //=== Subgraph::Identity
30
Identity(InitializerList init)31 Subgraph::Identity::Identity(InitializerList init) {
32 for (auto element : init) {
33 insert(element);
34 }
35 }
36
operator <(const Identity & other) const37 bool Subgraph::Identity::operator<(const Identity& other) const {
38 // Shorter sets go first.
39 if (this->size() < other.size()) {
40 return true;
41 }
42 if (this->size() > other.size()) {
43 return false;
44 }
45 for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
46 ++lit, ++rit) {
47 if (*lit < *rit) {
48 return true;
49 }
50 if (*lit > *rit) {
51 return false;
52 }
53 }
54 return false; // Equal.
55 }
56
operator ==(const Identity & other) const57 bool Subgraph::Identity::operator==(const Identity& other) const {
58 if (this->size() != other.size()) {
59 return false;
60 }
61 for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
62 ++lit, ++rit) {
63 if (*lit != *rit) {
64 return false;
65 }
66 }
67 return true; // Equal.
68 }
69
Hash() const70 size_t Subgraph::Identity::Hash() const {
71 std::hash<const GenNode*> hasher;
72 size_t result = 0;
73 for (auto ptr : *this) {
74 CombineHash(hasher(ptr), &result);
75 }
76 return result;
77 }
78
Dump()79 string Subgraph::Dump() {
80 // TODO(babkin): this is simplified for now.
81 std::vector<string> nodes;
82 for (const auto& n : id_) {
83 if (specific_) {
84 nodes.emplace_back(absl::StrFormat("%s(%s)", n->opcode(), n->name()));
85 } else {
86 nodes.emplace_back(n->opcode());
87 }
88 }
89 std::sort(nodes.begin(), nodes.end());
90
91 return absl::StrFormat("%d: ", collation_count_) + absl::StrJoin(nodes, ", ");
92 }
93
ExtractForSignature(SigNodeMap * result)94 void Subgraph::ExtractForSignature(SigNodeMap* result) {
95 // Mapping of nodes from the original graph to the new one.
96 SigNode::TranslationMap full_to_new;
97
98 for (auto node : id_) {
99 auto newnode_ref = absl::make_unique<SigNode>(node->node_def());
100 auto newnode = newnode_ref.get();
101 (*result)[node->name()] = std::move(newnode_ref);
102 full_to_new[node] = newnode;
103 }
104
105 for (const auto& mapping : full_to_new) {
106 mapping.second->CopyLinks(*mapping.first, full_to_new);
107 }
108 }
109
110 //=== Subgraph
111
Subgraph(const Identity & parent_id,GenNode * add_node)112 Subgraph::Subgraph(const Identity& parent_id, GenNode* add_node)
113 : id_(parent_id) {
114 id_.insert(add_node);
115 hash_ = id_.Hash();
116 }
117
118 //=== SubgraphIterator
119
SubgraphIterator(const Subgraph::Identity * id)120 SubgraphIterator::SubgraphIterator(const Subgraph::Identity* id)
121 : id_(id), id_it_(id_->begin()) {
122 if (!id_->empty()) {
123 link_map_it_ = (*id_it_)->links().begin();
124 // In case if the node has no links.
125 while (link_map_it_ == (*id_it_)->links().end()) {
126 if (++id_it_ == id_->end()) {
127 return;
128 }
129 link_map_it_ = (*id_it_)->links().begin();
130 }
131 link_idx_ = 0;
132 // The LinkTargetVector should never be empty but just in case safeguard
133 // against that too.
134 PropagateNext();
135 }
136 }
137
Next()138 bool SubgraphIterator::Next() {
139 if (AtEnd()) {
140 return false;
141 }
142 ++link_idx_;
143 return PropagateNext();
144 }
145
NextIfSamePort()146 bool SubgraphIterator::NextIfSamePort() {
147 if (AtEnd()) {
148 return false;
149 }
150 const int64 link_map_it_second_size = link_map_it_->second.size();
151 if (link_idx_ + 1 < link_map_it_second_size) {
152 ++link_idx_;
153 return true;
154 } else {
155 return false;
156 }
157 }
158
SkipPort()159 void SubgraphIterator::SkipPort() {
160 if (AtEnd()) {
161 return;
162 }
163 link_idx_ = link_map_it_->second.size() - 1;
164 }
165
SkipNode()166 void SubgraphIterator::SkipNode() {
167 if (AtEnd()) {
168 return;
169 }
170 for (auto next = link_map_it_; next != (*id_it_)->links().end(); ++next) {
171 link_map_it_ = next;
172 }
173 link_idx_ = link_map_it_->second.size() - 1;
174 }
175
PropagateNext()176 bool SubgraphIterator::PropagateNext() {
177 // Loops are used to skip over the empty entries.
178 const int64 link_map_it_second_size = link_map_it_->second.size();
179 while (link_idx_ >= link_map_it_second_size) {
180 ++link_map_it_;
181 while (link_map_it_ == (*id_it_)->links().end()) {
182 if (++id_it_ == id_->end()) {
183 return false;
184 }
185 link_map_it_ = (*id_it_)->links().begin();
186 }
187 link_idx_ = 0;
188 }
189 return true;
190 }
191
operator ==(const SubgraphIterator & other) const192 bool SubgraphIterator::operator==(const SubgraphIterator& other) const {
193 if (id_ != other.id_) {
194 return false;
195 }
196 if (id_it_ != other.id_it_) {
197 return false;
198 }
199 // When AtEnd(), the rest of the fields are not valid.
200 if (AtEnd()) {
201 return true;
202 }
203 if (link_map_it_ != other.link_map_it_) {
204 return false;
205 }
206 if (link_idx_ != other.link_idx_) {
207 return false;
208 }
209 return true;
210 }
211
212 //=== SubgraphPtrSet
213
ExtendParent(const Subgraph::Identity & parent_id,GenNode * node)214 Subgraph* SubgraphPtrSet::ExtendParent(const Subgraph::Identity& parent_id,
215 GenNode* node) {
216 if (parent_id.find(node) != parent_id.end()) {
217 // This was another link to the node that is already in the parent.
218 return nullptr;
219 }
220
221 // Constructing an object just to check that an equivalent one is already
222 // present is kind of ugly but storing the references rather than the objects
223 // in the set avoids the need to make the object copyable.
224 auto sg = absl::make_unique<Subgraph>(parent_id, node);
225 if (find(sg) != end()) {
226 // This subgraph was already found by extending from a different path.
227 return nullptr;
228 }
229
230 Subgraph* ptr = sg.get();
231 insert(std::move(sg));
232 return ptr;
233 }
234
235 } // end namespace graph_analyzer
236 } // end namespace grappler
237 } // end namespace tensorflow
238