• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
17 
18 #include <functional>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_format.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 namespace graph_analyzer {
28 
29 //=== Subgraph::Identity
30 
Identity(InitializerList init)31 Subgraph::Identity::Identity(InitializerList init) {
32   for (auto element : init) {
33     insert(element);
34   }
35 }
36 
operator <(const Identity & other) const37 bool Subgraph::Identity::operator<(const Identity& other) const {
38   // Shorter sets go first.
39   if (this->size() < other.size()) {
40     return true;
41   }
42   if (this->size() > other.size()) {
43     return false;
44   }
45   for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
46        ++lit, ++rit) {
47     if (*lit < *rit) {
48       return true;
49     }
50     if (*lit > *rit) {
51       return false;
52     }
53   }
54   return false;  // Equal.
55 }
56 
operator ==(const Identity & other) const57 bool Subgraph::Identity::operator==(const Identity& other) const {
58   if (this->size() != other.size()) {
59     return false;
60   }
61   for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
62        ++lit, ++rit) {
63     if (*lit != *rit) {
64       return false;
65     }
66   }
67   return true;  // Equal.
68 }
69 
Hash() const70 size_t Subgraph::Identity::Hash() const {
71   std::hash<const GenNode*> hasher;
72   size_t result = 0;
73   for (auto ptr : *this) {
74     CombineHash(hasher(ptr), &result);
75   }
76   return result;
77 }
78 
Dump()79 string Subgraph::Dump() {
80   // TODO(babkin): this is simplified for now.
81   std::vector<string> nodes;
82   for (const auto& n : id_) {
83     if (specific_) {
84       nodes.emplace_back(absl::StrFormat("%s(%s)", n->opcode(), n->name()));
85     } else {
86       nodes.emplace_back(n->opcode());
87     }
88   }
89   std::sort(nodes.begin(), nodes.end());
90 
91   return absl::StrFormat("%d: ", collation_count_) + absl::StrJoin(nodes, ", ");
92 }
93 
ExtractForSignature(SigNodeMap * result)94 void Subgraph::ExtractForSignature(SigNodeMap* result) {
95   // Mapping of nodes from the original graph to the new one.
96   SigNode::TranslationMap full_to_new;
97 
98   for (auto node : id_) {
99     auto newnode_ref = absl::make_unique<SigNode>(node->node_def());
100     auto newnode = newnode_ref.get();
101     (*result)[node->name()] = std::move(newnode_ref);
102     full_to_new[node] = newnode;
103   }
104 
105   for (const auto& mapping : full_to_new) {
106     mapping.second->CopyLinks(*mapping.first, full_to_new);
107   }
108 }
109 
110 //=== Subgraph
111 
Subgraph(const Identity & parent_id,GenNode * add_node)112 Subgraph::Subgraph(const Identity& parent_id, GenNode* add_node)
113     : id_(parent_id) {
114   id_.insert(add_node);
115   hash_ = id_.Hash();
116 }
117 
118 //=== SubgraphIterator
119 
SubgraphIterator(const Subgraph::Identity * id)120 SubgraphIterator::SubgraphIterator(const Subgraph::Identity* id)
121     : id_(id), id_it_(id_->begin()) {
122   if (!id_->empty()) {
123     link_map_it_ = (*id_it_)->links().begin();
124     // In case if the node has no links.
125     while (link_map_it_ == (*id_it_)->links().end()) {
126       if (++id_it_ == id_->end()) {
127         return;
128       }
129       link_map_it_ = (*id_it_)->links().begin();
130     }
131     link_idx_ = 0;
132     // The LinkTargetVector should never be empty but just in case safeguard
133     // against that too.
134     PropagateNext();
135   }
136 }
137 
Next()138 bool SubgraphIterator::Next() {
139   if (AtEnd()) {
140     return false;
141   }
142   ++link_idx_;
143   return PropagateNext();
144 }
145 
NextIfSamePort()146 bool SubgraphIterator::NextIfSamePort() {
147   if (AtEnd()) {
148     return false;
149   }
150   const int64 link_map_it_second_size = link_map_it_->second.size();
151   if (link_idx_ + 1 < link_map_it_second_size) {
152     ++link_idx_;
153     return true;
154   } else {
155     return false;
156   }
157 }
158 
SkipPort()159 void SubgraphIterator::SkipPort() {
160   if (AtEnd()) {
161     return;
162   }
163   link_idx_ = link_map_it_->second.size() - 1;
164 }
165 
SkipNode()166 void SubgraphIterator::SkipNode() {
167   if (AtEnd()) {
168     return;
169   }
170   for (auto next = link_map_it_; next != (*id_it_)->links().end(); ++next) {
171     link_map_it_ = next;
172   }
173   link_idx_ = link_map_it_->second.size() - 1;
174 }
175 
PropagateNext()176 bool SubgraphIterator::PropagateNext() {
177   // Loops are used to skip over the empty entries.
178   const int64 link_map_it_second_size = link_map_it_->second.size();
179   while (link_idx_ >= link_map_it_second_size) {
180     ++link_map_it_;
181     while (link_map_it_ == (*id_it_)->links().end()) {
182       if (++id_it_ == id_->end()) {
183         return false;
184       }
185       link_map_it_ = (*id_it_)->links().begin();
186     }
187     link_idx_ = 0;
188   }
189   return true;
190 }
191 
operator ==(const SubgraphIterator & other) const192 bool SubgraphIterator::operator==(const SubgraphIterator& other) const {
193   if (id_ != other.id_) {
194     return false;
195   }
196   if (id_it_ != other.id_it_) {
197     return false;
198   }
199   // When AtEnd(), the rest of the fields are not valid.
200   if (AtEnd()) {
201     return true;
202   }
203   if (link_map_it_ != other.link_map_it_) {
204     return false;
205   }
206   if (link_idx_ != other.link_idx_) {
207     return false;
208   }
209   return true;
210 }
211 
212 //=== SubgraphPtrSet
213 
ExtendParent(const Subgraph::Identity & parent_id,GenNode * node)214 Subgraph* SubgraphPtrSet::ExtendParent(const Subgraph::Identity& parent_id,
215                                        GenNode* node) {
216   if (parent_id.find(node) != parent_id.end()) {
217     // This was another link to the node that is already in the parent.
218     return nullptr;
219   }
220 
221   // Constructing an object just to check that an equivalent one is already
222   // present is kind of ugly but storing the references rather than the objects
223   // in the set avoids the need to make the object copyable.
224   auto sg = absl::make_unique<Subgraph>(parent_id, node);
225   if (find(sg) != end()) {
226     // This subgraph was already found by extending from a different path.
227     return nullptr;
228   }
229 
230   Subgraph* ptr = sg.get();
231   insert(std::move(sg));
232   return ptr;
233 }
234 
235 }  // end namespace graph_analyzer
236 }  // end namespace grappler
237 }  // end namespace tensorflow
238