• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/graph_view.h"
17 
18 #include <utility>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/graph/tensor_id.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/graph_view_internal.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/util/device_name_utils.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 namespace utils {
35 
FaninView(NodeView * node_view,int index)36 FaninView::FaninView(NodeView* node_view, int index)
37     : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
38                             index) {}
39 
FanoutView(NodeView * node_view,int index)40 FanoutView::FanoutView(NodeView* node_view, int index)
41     : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
42                             index) {}
43 
node() const44 const NodeDef* NodeView::node() const {
45   return &graph_view_->graph()->node(node_index_);
46 }
47 
HasFanin(const FanoutView & fanin) const48 bool NodeView::HasFanin(const FanoutView& fanin) const {
49   if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) {
50     return false;
51   }
52   return fanins_set_.contains(
53       {&graph_view_->graph_->node(fanin.node_index_), fanin.index()});
54 }
55 
HasFanout(const FaninView & fanout) const56 bool NodeView::HasFanout(const FaninView& fanout) const {
57   if (fanout.index() < Graph::kControlSlot ||
58       graph_view_ != fanout.graph_view_) {
59     return false;
60   }
61   NodeView* view = fanout.node_view();
62   if (view == nullptr) {
63     return false;
64   } else if (fanout.index() == Graph::kControlSlot) {
65     return view->fanins_set_.contains({this->node(), Graph::kControlSlot});
66   } else if (fanout.index() >= static_cast<int>(view->regular_fanins_.size())) {
67     return false;
68   }
69   return view->regular_fanins_[fanout.index()].node_index_ == node_index_;
70 }
71 
GetMissingFanin() const72 inline const FanoutView& NodeView::GetMissingFanin() const {
73   return graph_view_->missing_fanin_;
74 }
75 
GetMissingFanout() const76 inline const std::vector<FaninView>& NodeView::GetMissingFanout() const {
77   return graph_view_->missing_fanout_;
78 }
79 
80 namespace {
81 const char kGraphViewError[] = "GraphView::GraphView error: ";
82 }  // namespace
83 
GraphView(const GraphDef * graph,Status * status)84 GraphView::GraphView(const GraphDef* graph, Status* status)
85     : GraphViewInternal(graph) {
86   const int num_nodes = graph->node_size();
87   node_index_by_name_.reserve(num_nodes);
88   nodes_.reserve(num_nodes);
89   for (const NodeDef& node : graph->node()) {
90     if (!AddUniqueNodeInternal(&node)) {
91       *status = errors::InvalidArgument(
92           kGraphViewError, "graph has multiple nodes with the name '",
93           node.name(), "'.");
94       Reset();
95       return;
96     }
97   }
98   Status s;
99   for (NodeView& node_view : nodes_) {
100     s = CheckAndAddFaninsInternal(&node_view);
101     if (!s.ok()) {
102       *status = s;
103       Reset();
104       return;
105     }
106   }
107   *status = Status::OK();
108 }
109 
AddUniqueNodeInternal(const NodeDef * node)110 bool GraphView::AddUniqueNodeInternal(const NodeDef* node) {
111   const int node_index = node_index_by_name_.size();
112   auto it = node_index_by_name_.emplace(node->name(), node_index);
113   if (it.second) {
114     nodes_.emplace_back(this, node_index);
115     return true;
116   }
117   return false;
118 }
119 
CheckAndAddFaninsInternal(NodeView * node_view)120 Status GraphView::CheckAndAddFaninsInternal(NodeView* node_view) {
121   bool has_observed_control = false;
122   const NodeDef* node = node_view->node();
123   const string& node_name = node->name();
124   const int node_index = node_view->node_index_;
125   node_view->fanins_set_.reserve(node->input_size());
126   for (const string& input : node->input()) {
127     TensorId fanin_id = ParseTensorName(input);
128     if (fanin_id.node() == node_name) {
129       return errors::InvalidArgument(kGraphViewError, "node '", node_name,
130                                      "' has self cycle fanin '", input, "'.");
131     }
132     bool is_control = IsTensorIdControl(fanin_id);
133     if (!is_control && has_observed_control) {
134       return errors::InvalidArgument(kGraphViewError, "node '", node_name,
135                                      "' has regular fanin '", input,
136                                      "' after controlling fanins.");
137     }
138     auto it = node_index_by_name_.find(fanin_id.node());
139     if (it == node_index_by_name_.end()) {
140       return errors::InvalidArgument(kGraphViewError, "node '", node_name,
141                                      "' has missing fanin '", input, "'.");
142     }
143     const int fanin_node_index = it->second;
144     NodeView& fanin_node_view = nodes_[fanin_node_index];
145 
146     if (is_control) {
147       fanin_node_view.controlled_fanouts_.emplace_back(this, node_index,
148                                                        Graph::kControlSlot);
149       node_view->controlling_fanins_.emplace_back(this, fanin_node_index,
150                                                   Graph::kControlSlot);
151       node_view->fanins_set_.emplace(fanin_node_view.node(),
152                                      Graph::kControlSlot);
153       has_observed_control = true;
154     } else {
155       int fanin_node_view_regular_fanouts_by_port_size =
156           fanin_node_view.regular_fanouts_by_port_.size();
157       if (fanin_node_view_regular_fanouts_by_port_size < fanin_id.index() + 1) {
158         fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1);
159       }
160       fanin_node_view.regular_fanouts_by_port_[fanin_id.index()].emplace_back(
161           this, node_index, node_view->regular_fanins_.size());
162       ++fanin_node_view.num_regular_fanouts_;
163       node_view->regular_fanins_.emplace_back(this, fanin_node_index,
164                                               fanin_id.index());
165       node_view->fanins_set_.emplace(fanin_node_view.node(), fanin_id.index());
166     }
167   }
168   return Status::OK();
169 }
170 
MutableFaninView(MutableNodeView * node_view,int index)171 MutableFaninView::MutableFaninView(MutableNodeView* node_view, int index)
172     : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
173                             index) {}
174 
MutableFanoutView(MutableNodeView * node_view,int index)175 MutableFanoutView::MutableFanoutView(MutableNodeView* node_view, int index)
176     : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
177                             index) {}
178 
node() const179 NodeDef* MutableNodeView::node() const {
180   return graph_view_->graph()->mutable_node(node_index_);
181 }
182 
HasFanin(const MutableFanoutView & fanin) const183 bool MutableNodeView::HasFanin(const MutableFanoutView& fanin) const {
184   if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) {
185     return false;
186   }
187   return fanins_count_.contains(
188       {&graph_view_->graph_->node(fanin.node_index_), fanin.index()});
189 }
190 
HasFanout(const MutableFaninView & fanout) const191 bool MutableNodeView::HasFanout(const MutableFaninView& fanout) const {
192   if (fanout.index() < Graph::kControlSlot ||
193       graph_view_ != fanout.graph_view_) {
194     return false;
195   }
196   MutableNodeView* view = fanout.node_view();
197   if (view == nullptr) {
198     return false;
199   } else if (fanout.index() == Graph::kControlSlot) {
200     return view->fanins_count_.contains({this->node(), Graph::kControlSlot});
201   } else if (fanout.index() >= static_cast<int>(view->regular_fanins_.size())) {
202     return false;
203   }
204   return view->regular_fanins_[fanout.index()].node_index_ == node_index_;
205 }
206 
GetMissingFanin() const207 const MutableFanoutView& MutableNodeView::GetMissingFanin() const {
208   return graph_view_->missing_fanin_;
209 }
210 
GetMissingFanout() const211 const std::vector<MutableFaninView>& MutableNodeView::GetMissingFanout() const {
212   return graph_view_->missing_fanout_;
213 }
214 
215 namespace {
216 const char kMutationAddNodeError[] = "Mutation::AddNode error: ";
217 
IsTensorIdRegular(const TensorId & tensor_id)218 bool IsTensorIdRegular(const TensorId& tensor_id) {
219   return tensor_id.index() >= 0;
220 }
221 }  // namespace
222 
Mutation(MutableGraphView * graph_view)223 Mutation::Mutation(MutableGraphView* graph_view) : graph_view_(graph_view) {}
224 
AddNode(NodeDef && node,Status * status)225 MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) {
226   bool has_observed_control = false;
227   const string& node_name = node.name();
228   std::vector<SafeTensorId> regular_fanins;
229   absl::flat_hash_set<string> controlling_fanins;
230   const int num_fanins = node.input_size();
231   for (int i = 0; i < num_fanins; ++i) {
232     const string& input = node.input(i);
233     TensorId fanin_id = ParseTensorName(input);
234     if (fanin_id.node() == node_name) {
235       *status =
236           errors::InvalidArgument(kMutationAddNodeError, "node '", node_name,
237                                   "' has self cycle fanin '", input, "'.");
238       return MutationNewNode(this, mutation_counter_, internal::kMissingIndex);
239     }
240     bool is_control = IsTensorIdControl(fanin_id);
241     if (is_control) {
242       has_observed_control = true;
243       controlling_fanins.emplace(fanin_id.node());
244     } else if (has_observed_control) {
245       *status = errors::InvalidArgument(kMutationAddNodeError, "node '",
246                                         node_name, "' has regular fanin '",
247                                         input, "' after controlling fanins.");
248       return MutationNewNode(this, mutation_counter_, internal::kMissingIndex);
249     } else {
250       regular_fanins.emplace_back(fanin_id);
251     }
252   }
253 
254   node.mutable_input()->Clear();
255   new_nodes_.emplace_back(graph_view_, std::move(node));
256   MutationNewNodeHolder& mutation_node = new_nodes_.back();
257   mutation_node.regular_fanins = std::move(regular_fanins);
258   mutation_node.num_regular_fanins = mutation_node.regular_fanins.size();
259   mutation_node.controlling_fanins = std::move(controlling_fanins);
260   *status = Status::OK();
261   return MutationNewNode(this, mutation_counter_, new_nodes_.size() - 1);
262 }
263 
AddMutation(MutableNodeView * node,std::function<bool (MutableNodeViewDiff *)> mutate_fn)264 void Mutation::AddMutation(
265     MutableNodeView* node,
266     std::function<bool(MutableNodeViewDiff*)> mutate_fn) {
267   DCHECK(node->graph_view_ == graph_view_);
268   if (node->update_index_ == internal::kMissingIndex) {
269     MutableNodeViewDiff diff(graph_view_, node->node_index_);
270     // If mutation is a no-op return and do not add it to the `updated_nodes_`.
271     if (!mutate_fn(&diff)) return;
272     node->update_index_ = updated_nodes_.size();
273     updated_nodes_.push_back(std::move(diff));
274   } else if (!removed_nodes_.contains(node->node_index_)) {
275     MutableNodeViewDiff& diff = updated_nodes_[node->update_index_];
276     mutate_fn(&diff);
277   }
278 }
279 
RemoveNode(MutableNodeView * node)280 void Mutation::RemoveNode(MutableNodeView* node) {
281   auto& update_index = node->update_index_;
282   if (update_index != internal::kMissingIndex) {
283     int updated_nodes_size = updated_nodes_.size();
284     if (update_index < updated_nodes_size - 1) {
285       graph_view_->nodes_[updated_nodes_.back().node_index].update_index_ =
286           update_index;
287       std::swap(updated_nodes_[update_index], updated_nodes_.back());
288     }
289     updated_nodes_.pop_back();
290     update_index = internal::kMissingIndex;
291   }
292   removed_nodes_.insert(node->node_index_);
293 }
294 
UpdateNodeName(MutableNodeView * node,absl::string_view name)295 void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
296   AddMutation(node, [name](MutableNodeViewDiff* diff) {
297     return internal::UpdateName(diff, name);
298   });
299 }
300 
UpdateNodeName(const MutationNewNode & node,absl::string_view name)301 void Mutation::UpdateNodeName(const MutationNewNode& node,
302                               absl::string_view name) {
303   DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
304   internal::UpdateName(&new_nodes_[node.index_], name);
305 }
306 
UpdateNodeOp(MutableNodeView * node,absl::string_view op)307 void Mutation::UpdateNodeOp(MutableNodeView* node, absl::string_view op) {
308   AddMutation(node, [op](MutableNodeViewDiff* diff) {
309     return internal::UpdateOp(diff, op);
310   });
311 }
312 
UpdateNodeOp(const MutationNewNode & node,absl::string_view op)313 void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) {
314   DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
315   internal::UpdateOp(&new_nodes_[node.index_], op);
316 }
317 
UpdateNodeDevice(MutableNodeView * node,absl::string_view device)318 void Mutation::UpdateNodeDevice(MutableNodeView* node,
319                                 absl::string_view device) {
320   AddMutation(node, [device](MutableNodeViewDiff* diff) {
321     return internal::UpdateDevice(diff, device);
322   });
323 }
324 
UpdateNodeDevice(const MutationNewNode & node,absl::string_view device)325 void Mutation::UpdateNodeDevice(const MutationNewNode& node,
326                                 absl::string_view device) {
327   DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
328   internal::UpdateDevice(&new_nodes_[node.index_], device);
329 }
330 
AddOrUpdateRegularFanin(MutableNodeView * node,int index,const TensorId & fanin)331 void Mutation::AddOrUpdateRegularFanin(MutableNodeView* node, int index,
332                                        const TensorId& fanin) {
333   AddMutation(node, [index, fanin](MutableNodeViewDiff* diff) {
334     return internal::AddOrUpdateRegularFanin(diff, index, fanin);
335   });
336 }
337 
AddOrUpdateRegularFanin(const MutationNewNode & node,int index,const TensorId & fanin)338 void Mutation::AddOrUpdateRegularFanin(const MutationNewNode& node, int index,
339                                        const TensorId& fanin) {
340   DCHECK(node.mutation_ == this &&
341          node.mutation_counter_ == mutation_counter_ && index >= 0 &&
342          IsTensorIdRegular(fanin));
343   internal::AddOrUpdateRegularFanin(&new_nodes_[node.index_], index, fanin);
344 }
345 
RemoveRegularFanin(MutableNodeView * node,int index)346 void Mutation::RemoveRegularFanin(MutableNodeView* node, int index) {
347   AddMutation(node, [index](MutableNodeViewDiff* diff) {
348     return internal::RemoveRegularFanin(diff, index);
349   });
350 }
351 
RemoveRegularFanin(const MutationNewNode & node,int index)352 void Mutation::RemoveRegularFanin(const MutationNewNode& node, int index) {
353   DCHECK(node.mutation_ == this &&
354          node.mutation_counter_ == mutation_counter_ && index >= 0);
355   internal::RemoveRegularFanin(&new_nodes_[node.index_], index);
356 }
357 
AddControllingFanin(MutableNodeView * node,absl::string_view fanin_node_name)358 void Mutation::AddControllingFanin(MutableNodeView* node,
359                                    absl::string_view fanin_node_name) {
360   AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) {
361     auto it = node->controlling_fanins_index_.find(fanin_node_name);
362     const int control_index = it != node->controlling_fanins_index_.end()
363                                   ? it->second
364                                   : internal::kMissingIndex;
365     return internal::AddControllingFanin(diff, control_index, fanin_node_name);
366   });
367 }
368 
AddControllingFanin(const MutationNewNode & node,absl::string_view fanin_node_name)369 void Mutation::AddControllingFanin(const MutationNewNode& node,
370                                    absl::string_view fanin_node_name) {
371   DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
372   internal::AddControllingFanin(&new_nodes_[node.index_], fanin_node_name);
373 }
374 
RemoveControllingFanin(MutableNodeView * node,absl::string_view fanin_node_name)375 void Mutation::RemoveControllingFanin(MutableNodeView* node,
376                                       absl::string_view fanin_node_name) {
377   AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) {
378     auto it = node->controlling_fanins_index_.find(fanin_node_name);
379     const int control_index = it != node->controlling_fanins_index_.end()
380                                   ? it->second
381                                   : internal::kMissingIndex;
382     return internal::RemoveControllingFanin(diff, control_index,
383                                             fanin_node_name);
384   });
385 }
386 
RemoveControllingFanin(const MutationNewNode & node,absl::string_view fanin_node_name)387 void Mutation::RemoveControllingFanin(const MutationNewNode& node,
388                                       absl::string_view fanin_node_name) {
389   DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
390   internal::RemoveControllingFanin(&new_nodes_[node.index_], fanin_node_name);
391 }
392 
AddOrUpdateNodeAttr(MutableNodeView * node,absl::string_view attr_name,const AttrValue & attr_value)393 void Mutation::AddOrUpdateNodeAttr(MutableNodeView* node,
394                                    absl::string_view attr_name,
395                                    const AttrValue& attr_value) {
396   AddMutation(node, [attr_name, attr_value](MutableNodeViewDiff* diff) {
397     return internal::AddOrUpdateAttribute(diff, attr_name, attr_value);
398   });
399 }
400 
AddOrUpdateNodeAttr(const MutationNewNode & node,absl::string_view attr_name,const AttrValue & attr_value)401 void Mutation::AddOrUpdateNodeAttr(const MutationNewNode& node,
402                                    absl::string_view attr_name,
403                                    const AttrValue& attr_value) {
404   DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
405   internal::AddOrUpdateAttribute(&new_nodes_[node.index_], attr_name,
406                                  attr_value);
407 }
408 
RemoveNodeAttr(MutableNodeView * node,absl::string_view attr_name)409 void Mutation::RemoveNodeAttr(MutableNodeView* node,
410                               absl::string_view attr_name) {
411   AddMutation(node, [attr_name](MutableNodeViewDiff* diff) {
412     return internal::RemoveAttribute(diff, attr_name);
413   });
414 }
415 
RemoveNodeAttr(const MutationNewNode & node,absl::string_view attr_name)416 void Mutation::RemoveNodeAttr(const MutationNewNode& node,
417                               absl::string_view attr_name) {
418   DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
419   internal::RemoveAttribute(&new_nodes_[node.index_], attr_name);
420 }
421 
ResetInternal()422 void Mutation::ResetInternal() {
423   updated_nodes_.clear();
424   removed_nodes_.clear();
425   new_nodes_.clear();
426 }
427 
Reset()428 void Mutation::Reset() {
429   for (const auto& update : updated_nodes_) {
430     graph_view_->nodes_[update.node_index].update_index_ =
431         internal::kMissingIndex;
432   }
433   ResetInternal();
434 }
435 
Apply()436 Status Mutation::Apply() { return graph_view_->ApplyMutationInternal(); }
437 
438 namespace {
439 const char kMutableGraphViewError[] =
440     "MutableGraphView::MutableGraphView error: ";
441 
442 const char kMutableGraphViewApplyError[] = "Mutation::Apply error: ";
443 
IncrementFaninCount(absl::flat_hash_map<internal::NodeDefAndPortIndex,int> * fanins_count,const internal::NodeDefAndPortIndex & fanin)444 inline void IncrementFaninCount(
445     absl::flat_hash_map<internal::NodeDefAndPortIndex, int>* fanins_count,
446     const internal::NodeDefAndPortIndex& fanin) {
447   ++(*fanins_count)[fanin];
448 }
449 
DecrementFaninCount(absl::flat_hash_map<internal::NodeDefAndPortIndex,int> * fanins_count,const internal::NodeDefAndPortIndex & fanin)450 inline void DecrementFaninCount(
451     absl::flat_hash_map<internal::NodeDefAndPortIndex, int>* fanins_count,
452     const internal::NodeDefAndPortIndex& fanin) {
453   auto it = fanins_count->find(fanin);
454   if (it != fanins_count->end()) {
455     if (it->second <= 1) {
456       fanins_count->erase(it);
457     } else {
458       --it->second;
459     }
460   }
461 }
462 }  // namespace
463 
MutableGraphView(GraphDef * graph,Status * status)464 MutableGraphView::MutableGraphView(GraphDef* graph, Status* status)
465     : GraphViewInternal(graph), mutation_(Mutation(this)) {
466   const int num_nodes = graph->node_size();
467   node_index_by_name_.reserve(num_nodes);
468   nodes_.reserve(num_nodes);
469   for (NodeDef& node : *graph->mutable_node()) {
470     if (!AddUniqueNodeInternal(&node)) {
471       *status = errors::InvalidArgument(
472           kMutableGraphViewError, "graph has multiple nodes with the name '",
473           node.name(), "'.");
474       Reset();
475       return;
476     }
477   }
478   std::vector<std::vector<TensorId>> fanins;
479   Status s = CheckFaninsInternal(&fanins);
480   if (!s.ok()) {
481     *status = s;
482     Reset();
483     return;
484   }
485   AddFaninsInternal(&fanins);
486   mutation_.ResetInternal();
487   *status = Status::OK();
488 }
489 
GetMutationBuilder()490 Mutation* MutableGraphView::GetMutationBuilder() { return &mutation_; }
491 
AddUniqueNodeInternal(NodeDef * node)492 bool MutableGraphView::AddUniqueNodeInternal(NodeDef* node) {
493   const int node_index = node_index_by_name_.size();
494   auto it = node_index_by_name_.emplace(node->name(), node_index);
495   if (it.second) {
496     nodes_.emplace_back(this, node_index);
497     return true;
498   }
499   return false;
500 }
501 
CheckFaninsInternal(std::vector<std::vector<TensorId>> * fanins)502 Status MutableGraphView::CheckFaninsInternal(
503     std::vector<std::vector<TensorId>>* fanins) {
504   const int num_nodes = nodes_.size();
505   fanins->reserve(num_nodes);
506   for (int i = 0; i < num_nodes; ++i) {
507     bool has_observed_control = false;
508     const NodeDef* node = nodes_[i].node();
509     const string& node_name = node->name();
510     std::vector<TensorId> node_fanins;
511     node_fanins.reserve(node->input_size());
512     for (const string& input : node->input()) {
513       TensorId fanin_id = ParseTensorName(input);
514       if (fanin_id.node() == node_name) {
515         return errors::InvalidArgument(kMutableGraphViewError, "node '",
516                                        node_name, "' has self cycle fanin '",
517                                        input, "'.");
518       }
519       bool is_control = IsTensorIdControl(fanin_id);
520       if (!is_control && has_observed_control) {
521         return errors::InvalidArgument(kMutableGraphViewError, "node '",
522                                        node_name, "' has regular fanin '",
523                                        input, "' after controlling fanins.");
524       }
525       if (!node_index_by_name_.contains(fanin_id.node())) {
526         return errors::InvalidArgument(kMutableGraphViewError, "node '",
527                                        node_name, "' has missing fanin '",
528                                        input, "'.");
529       }
530       if (is_control) {
531         has_observed_control = true;
532       }
533       node_fanins.push_back(std::move(fanin_id));
534     }
535     fanins->push_back(std::move(node_fanins));
536   }
537   return Status::OK();
538 }
539 
AddFaninsInternal(std::vector<std::vector<TensorId>> * fanins)540 void MutableGraphView::AddFaninsInternal(
541     std::vector<std::vector<TensorId>>* fanins) {
542   const int num_nodes = nodes_.size();
543   for (int i = 0; i < num_nodes; ++i) {
544     MutableNodeView& node_view = nodes_[i];
545     NodeDef* node = node_view.node();
546     std::vector<TensorId>& node_fanins = fanins->at(i);
547     absl::flat_hash_set<absl::string_view> observed_controls;
548     int pos = 0;
549     const int last_idx = node_fanins.size() - 1;
550     int last_pos = last_idx;
551     node_view.fanins_count_.reserve(node->input_size());
552     node_view.controlling_fanins_index_.reserve(node->input_size());
553     while (pos <= last_pos) {
554       const TensorId& fanin_id = node_fanins[pos];
555       bool is_control = IsTensorIdControl(fanin_id);
556       const int fanin_node_index = node_index_by_name_[fanin_id.node()];
557       MutableNodeView& fanin_node_view = nodes_[fanin_node_index];
558 
559       if (is_control) {
560         if (gtl::InsertIfNotPresent(&observed_controls, fanin_id.node())) {
561           fanin_node_view.controlled_fanouts_.emplace_back(
562               this, i, Graph::kControlSlot,
563               node_view.controlling_fanins_.size());
564           node_view.controlling_fanins_.emplace_back(
565               this, fanin_node_index, Graph::kControlSlot,
566               fanin_node_view.controlled_fanouts_.size() - 1);
567           IncrementFaninCount(
568               &node_view.fanins_count_,
569               {&graph_->node(fanin_node_index), Graph::kControlSlot});
570           node_view.controlling_fanins_index_.emplace(
571               fanin_id.node(), pos - node_view.NumRegularFanins());
572           ++pos;
573         } else {
574           node->mutable_input()->SwapElements(pos, last_pos);
575           std::swap(node_fanins[pos], node_fanins[last_pos]);
576           --last_pos;
577         }
578       } else {
579         int fanin_node_view_regular_fanouts_by_port_size =
580             fanin_node_view.regular_fanouts_by_port_.size();
581         if (fanin_node_view_regular_fanouts_by_port_size <
582             fanin_id.index() + 1) {
583           fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1);
584         }
585         auto& fanin_regular_fanouts =
586             fanin_node_view.regular_fanouts_by_port_[fanin_id.index()];
587         fanin_regular_fanouts.emplace_back(this, i,
588                                            node_view.regular_fanins_.size(),
589                                            node_view.regular_fanins_.size());
590         ++fanin_node_view.num_regular_fanouts_;
591         node_view.regular_fanins_.emplace_back(
592             this, fanin_node_index, fanin_id.index(),
593             fanin_regular_fanouts.size() - 1);
594         IncrementFaninCount(
595             &node_view.fanins_count_,
596             {&graph_->node(fanin_node_index), fanin_id.index()});
597         ++pos;
598       }
599     }
600     if (last_pos < last_idx) {
601       node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
602     }
603   }
604 }
605 
GetNodeNamesAndPartitionUpdatedNodes(absl::flat_hash_map<absl::string_view,int> * node_names,std::vector<RenamedOrOverwrittenNode> * renamed_nodes,std::vector<int> * inplace_nodes,std::vector<int> * empty_diff_node_indices)606 Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes(
607     absl::flat_hash_map<absl::string_view, int>* node_names,
608     std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
609     std::vector<int>* inplace_nodes,
610     std::vector<int>* empty_diff_node_indices) {
611   // For all nodes to be removed and renamed, mark their original names as
612   // missing and put associated node index in graph.
613   for (const auto& diff : mutation_.updated_nodes_) {
614     if (diff.update_name) {
615       const int index = diff.node_index;
616       const string& node_name = nodes_[index].GetName();
617       node_names->emplace(node_name, index);
618     }
619   }
620 
621   for (int node_index : mutation_.removed_nodes_) {
622     const string& node_name = nodes_[node_index].GetName();
623     node_names->emplace(node_name, node_index);
624   }
625 
626   auto name_conflict = [](const absl::string_view node_name) {
627     return errors::InvalidArgument(kMutableGraphViewApplyError,
628                                    "multiple nodes with the name: '", node_name,
629                                    "' exists in Mutation.");
630   };
631 
632   // Partition updated nodes by if they will be renamed or not.
633   const int num_updated_nodes = mutation_.updated_nodes_.size();
634   renamed_nodes->reserve(num_updated_nodes);
635   inplace_nodes->reserve(num_updated_nodes);
636   empty_diff_node_indices->reserve(num_updated_nodes);
637   for (int i = 0; i < num_updated_nodes; ++i) {
638     auto& diff = mutation_.updated_nodes_[i];
639     if (internal::IsEmpty(&diff)) {
640       empty_diff_node_indices->emplace_back(diff.node_index);
641       continue;
642     }
643     // Get name of updated node after potential mutation.
644     const string& node_name =
645         diff.update_name ? diff.name : nodes_[diff.node_index].GetName();
646     auto it = node_names->insert({node_name, internal::kNodeNamePresent});
647     if (!it.second) {
648       if (it.first->second == internal::kNodeNamePresent) {
649         // Another node in the mutation is already using this name, which will
650         // result in a conflict.
651         return name_conflict(node_name);
652       } else {
653         // Mark name as present (node was marked missing from either being
654         // removed or renamed).
655         it.first->second = internal::kNodeNamePresent;
656       }
657     }
658     if (diff.update_name) {
659       // Lookup new name of node in current graph. If a node has such name,
660       // store its index for later lookups as this node will be overwritten.
661       auto node_name_it = node_index_by_name_.find(node_name);
662       const int overwritten_node_index =
663           node_name_it != node_index_by_name_.end() ? node_name_it->second
664                                                     : internal::kMissingIndex;
665       renamed_nodes->emplace_back(i, overwritten_node_index);
666     } else {
667       inplace_nodes->push_back(i);
668     }
669   }
670 
671   // Get names of new nodes after potential mutation.
672   for (const auto& new_node : mutation_.new_nodes_) {
673     const string& node_name = new_node.node.name();
674     auto it = node_names->insert({node_name, internal::kNodeNamePresent});
675     if (it.second) {
676       continue;
677     }
678     if (it.first->second == internal::kNodeNamePresent) {
679       // Another node in the mutation is already using this name, which will
680       // result in a conflict.
681       return name_conflict(node_name);
682     } else {
683       // Mark name as present (node was marked missing from either being removed
684       // or renamed).
685       it.first->second = internal::kNodeNamePresent;
686     }
687   }
688 
689   return Status::OK();
690 }
691 
RemovedOrMissingNodeFanoutsWellFormed(const absl::flat_hash_map<absl::string_view,int> & node_names,const std::vector<RenamedOrOverwrittenNode> & renamed_nodes)692 Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
693     const absl::flat_hash_map<absl::string_view, int>& node_names,
694     const std::vector<RenamedOrOverwrittenNode>& renamed_nodes) {
695   auto bad_fanout = [](absl::string_view fanout_node_name,
696                        absl::string_view node_name) {
697     return errors::InvalidArgument(
698         kMutableGraphViewApplyError, "fanout '", fanout_node_name,
699         "' exist for missing node '", node_name, "'.");
700   };
701 
702   // Lookup nodes to be overwritten.
703   std::vector<bool> overwritten_nodes(NumNodes());
704   for (auto& renamed_node : renamed_nodes) {
705     if (renamed_node.overwritten_node_index_ == internal::kMissingIndex) {
706       continue;
707     }
708     overwritten_nodes[renamed_node.overwritten_node_index_] = true;
709   }
710 
711   // Check if removed nodes and previous state of renamed nodes have no fanouts.
712   for (const auto& node_name_state : node_names) {
713     if (node_name_state.second == internal::kNodeNamePresent) {
714       continue;
715     }
716     const MutableNodeView& node_view = nodes_[node_name_state.second];
717     for (const auto& regular_fanouts : node_view.GetRegularFanouts()) {
718       for (const auto& regular_fanout : regular_fanouts) {
719         // Check all fanouts of a single port.
720         MutableNodeView* fanout_view = regular_fanout.node_view();
721         if (fanout_view->update_index_ == internal::kMissingIndex) {
722           if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
723             // Fanout node will be removed, this can be ignored.
724             continue;
725           } else if (!overwritten_nodes[fanout_view->node_index_]) {
726             // Fanout is not updated or removed/overwritten.
727             return bad_fanout(fanout_view->GetName(), node_name_state.first);
728           }
729         } else {
730           auto& diff = mutation_.updated_nodes_[fanout_view->update_index_];
731           const int last_index = fanout_view->NumRegularFanins() -
732                                  diff.num_regular_inputs_to_remove - 1;
733           if (regular_fanout.index() > last_index) {
734             // Fanin of fanout is removed, this can be ignored.
735             continue;
736           }
737           // Check if fanin is updated.
738           if (diff.regular_inputs_to_update.find(regular_fanout.index()) ==
739               diff.regular_inputs_to_update.end()) {
740             return bad_fanout(fanout_view->GetName(), node_name_state.first);
741           }
742         }
743       }
744     }
745     for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
746       MutableNodeView* fanout_view = controlled_fanout.node_view();
747       if (fanout_view->update_index_ == internal::kMissingIndex) {
748         if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
749           // Fanout node will be removed, this can be ignored.
750           continue;
751         } else if (!overwritten_nodes[fanout_view->node_index_]) {
752           // Fanout is not updated or removed/overwritten.
753           return bad_fanout(fanout_view->GetName(), node_name_state.first);
754         }
755       } else {
756         auto& diff = mutation_.updated_nodes_[fanout_view->update_index_];
757         // Check if controlling fanin is removed.
758         if (diff.controlling_inputs_to_remove.find(
759                 controlled_fanout.fanin_index_) ==
760             diff.controlling_inputs_to_remove.end()) {
761           return bad_fanout(fanout_view->GetName(), node_name_state.first);
762         }
763       }
764     }
765   }
766 
767   return Status::OK();
768 }
769 
CheckNodeNamesAndFanins(const absl::flat_hash_map<absl::string_view,int> & node_names,const std::vector<RenamedOrOverwrittenNode> & renamed_nodes,const std::vector<int> & inplace_nodes)770 Status MutableGraphView::CheckNodeNamesAndFanins(
771     const absl::flat_hash_map<absl::string_view, int>& node_names,
772     const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
773     const std::vector<int>& inplace_nodes) {
774   // Check if removed/missing node fanouts are valid.
775   TF_RETURN_IF_ERROR(
776       RemovedOrMissingNodeFanoutsWellFormed(node_names, renamed_nodes));
777 
778   // Check if updated nodes and their fanins are valid.
779   for (auto& inplace_node : inplace_nodes) {
780     auto& diff = mutation_.updated_nodes_[inplace_node];
781     if (!internal::IsWellFormed(&diff, node_names)) {
782       return errors::InvalidArgument(
783           kMutableGraphViewApplyError, "inplace updated node '",
784           nodes_[diff.node_index].GetName(), "' is ill-formed.");
785     }
786   }
787   for (auto& renamed_node : renamed_nodes) {
788     auto& diff = mutation_.updated_nodes_[renamed_node.renamed_update_index_];
789     if (!internal::IsWellFormed(&diff, node_names)) {
790       return errors::InvalidArgument(
791           kMutableGraphViewApplyError, "renamed updated node '", diff.name,
792           "' ('", nodes_[diff.node_index].GetName(), "') is ill-formed.");
793     }
794   }
795 
796   // Check if new nodes and their fanins are valid.
797   for (auto& new_node : mutation_.new_nodes_) {
798     if (!internal::IsWellFormed(&new_node, node_names)) {
799       return errors::InvalidArgument(kMutableGraphViewApplyError, "new node '",
800                                      new_node.node.name(), "' is ill-formed.");
801     }
802   }
803 
804   return Status::OK();
805 }
806 
CheckKernelRegisteredForNodes()807 Status MutableGraphView::CheckKernelRegisteredForNodes() {
808   Status s;
809   for (auto& diff : mutation_.updated_nodes_) {
810     if (internal::IsEmpty(&diff)) {
811       continue;
812     }
813 
814     NodeDef* node = nodes_[diff.node_index].node();
815     diff.processed_attrs =
816         AttrValueMap(node->attr().begin(), node->attr().end());
817     for (const auto& attr_to_remove : diff.attrs_to_remove) {
818       (*diff.processed_attrs).erase(attr_to_remove);
819     }
820     for (const auto& attr_to_add : diff.attrs_to_add) {
821       gtl::InsertOrUpdate(&(*diff.processed_attrs), attr_to_add.first,
822                           attr_to_add.second);
823     }
824     const string& device = diff.update_device ? diff.device : node->device();
825     DeviceNameUtils::ParsedName name;
826     if (device.empty() || !DeviceNameUtils::ParseFullName(device, &name) ||
827         !name.has_type) {
828       continue;
829     }
830     s = IsKernelRegisteredForNode(diff.update_name ? diff.name : node->name(),
831                                   node->has_experimental_debug_info(),
832                                   node->experimental_debug_info(),
833                                   diff.update_op ? diff.op : node->op(), device,
834                                   AttrSlice(&(*diff.processed_attrs)));
835     if (!s.ok()) {
836       LOG(WARNING) << s.error_message();
837     }
838   }
839   for (const auto& new_node_holder : mutation_.new_nodes_) {
840     const auto& new_node_def = new_node_holder.node;
841     DeviceNameUtils::ParsedName name;
842     if (new_node_def.device().empty() ||
843         !DeviceNameUtils::ParseFullName(new_node_def.device(), &name) ||
844         !name.has_type) {
845       continue;
846     }
847     s = IsKernelRegisteredForNode(new_node_def);
848     if (!s.ok()) {
849       LOG(WARNING) << s.error_message();
850     }
851   }
852   return Status::OK();
853 }
854 
855 template <typename T>
ReplaceNodeFanouts(MutableNodeView * node,T * fanouts)856 void MutableGraphView::ReplaceNodeFanouts(MutableNodeView* node, T* fanouts) {
857   node->num_regular_fanouts_ = fanouts->num_regular_fanouts_;
858   node->regular_fanouts_by_port_ = std::move(fanouts->regular_fanouts_by_port_);
859   for (int i = 0, i_max = node->regular_fanouts_by_port_.size(); i < i_max;
860        ++i) {
861     for (int j = 0, j_max = node->regular_fanouts_by_port_[i].size(); j < j_max;
862          ++j) {
863       auto& fanout = node->regular_fanouts_by_port_[i][j];
864       auto* fanout_node_view = fanout.node_view();
865       auto& fanout_fanin = fanout_node_view->regular_fanins_[fanout.index()];
866       auto* fanout_fanins_count = &fanout_node_view->fanins_count_;
867       DecrementFaninCount(
868           fanout_fanins_count,
869           {&graph_->node(fanout_fanin.node_index_), fanout_fanin.index()});
870       fanout_fanin.node_index_ = node->node_index_;
871       IncrementFaninCount(
872           fanout_fanins_count,
873           {&graph_->node(node->node_index_), fanout_fanin.index()});
874     }
875   }
876   node->controlled_fanouts_ = std::move(fanouts->controlled_fanouts_);
877   for (int i = 0, i_max = node->controlled_fanouts_.size(); i < i_max; ++i) {
878     auto& fanout = node->controlled_fanouts_[i];
879     auto* fanout_node_view = fanout.node_view();
880     auto& fanout_fanin =
881         fanout_node_view->controlling_fanins_[fanout.fanin_index_];
882     auto* fanout_fanins_count = &fanout_node_view->fanins_count_;
883     DecrementFaninCount(
884         fanout_fanins_count,
885         {&graph_->node(fanout_fanin.node_index_), Graph::kControlSlot});
886     fanout_fanin.node_index_ = node->node_index_;
887     fanout_fanin.fanout_index_ = i;
888     IncrementFaninCount(fanout_fanins_count, {&graph_->node(node->node_index_),
889                                               Graph::kControlSlot});
890   }
891 }
892 
FixRenamedNodes(std::vector<RenamedOrOverwrittenNode> * renamed_nodes,absl::flat_hash_map<string,NodeViewFanouts> * renamed_fanouts,std::vector<bool> * overwritten_name_removed_nodes)893 void MutableGraphView::FixRenamedNodes(
894     std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
895     absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
896     std::vector<bool>* overwritten_name_removed_nodes) {
897   // Extract all renamed node fanouts.
898   renamed_fanouts->reserve(renamed_nodes->size());
899   for (auto& renamed : *renamed_nodes) {
900     auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_];
901     // Remove node index by name from graph.
902     node_index_by_name_.erase(nodes_[diff.node_index].GetName());
903     MutableNodeView& renamed_node = nodes_[diff.node_index];
904     renamed_fanouts->try_emplace(
905         renamed_node.GetName(),
906         std::move(renamed_node.regular_fanouts_by_port_),
907         renamed_node.num_regular_fanouts_,
908         std::move(renamed_node.controlled_fanouts_));
909   }
910 
911   // Replace renamed node fanouts with fanouts associated with updated name.
912   for (auto& renamed : *renamed_nodes) {
913     auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_];
914     MutableNodeView& renamed_node = nodes_[diff.node_index];
915     auto fanouts_it = renamed_fanouts->find(diff.name);
916     if (fanouts_it != renamed_fanouts->end()) {
917       // Another renamed node's fanout.
918       auto& fanouts = fanouts_it->second;
919       ReplaceNodeFanouts(&renamed_node, &fanouts);
920       renamed_fanouts->erase(fanouts_it);
921       // Node to be overwritten is being renamed, so it won't be overwritten.
922       renamed.overwritten_node_index_ = internal::kMissingIndex;
923     } else if (renamed.overwritten_node_index_ != internal::kMissingIndex) {
924       // Existing node in graph.
925       MutableNodeView& node_to_overwrite =
926           nodes_[renamed.overwritten_node_index_];
927       ReplaceNodeFanouts(&renamed_node, &node_to_overwrite);
928       node_index_by_name_.erase(node_to_overwrite.GetName());
929       if (mutation_.removed_nodes_.contains(node_to_overwrite.node_index_)) {
930         (*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true;
931       }
932     } else {
933       // No existing fanouts.
934       renamed_node.num_regular_fanouts_ = 0;
935     }
936 
937     // Update node name.
938     renamed_node.node()->set_name(diff.name);
939     diff.update_name = false;
940     diff.name.clear();
941     // Rehash renamed nodes with updated name.
942     node_index_by_name_.emplace(renamed_node.GetName(), diff.node_index);
943   }
944 }
945 
AddNewNodes(absl::flat_hash_map<string,NodeViewFanouts> * renamed_fanouts,std::vector<int> * new_node_indices)946 void MutableGraphView::AddNewNodes(
947     absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
948     std::vector<int>* new_node_indices) {
949   new_node_indices->reserve(mutation_.new_nodes_.size());
950   for (auto& new_node : mutation_.new_nodes_) {
951     int node_index;
952     auto graph_it = node_index_by_name_.find(new_node.node.name());
953     if (graph_it != node_index_by_name_.end()) {
954       // Overwrite existing node.
955       node_index = graph_it->second;
956       MutableNodeView& node_view = nodes_[node_index];
957       RemoveAllFaninFanoutInternal(&node_view);
958       auto* node_def = graph_->mutable_node(node_index);
959       node_def->mutable_op()->swap(*new_node.node.mutable_op());
960       node_def->mutable_device()->swap(*new_node.node.mutable_device());
961       node_def->mutable_input()->Clear();
962       node_def->mutable_attr()->swap(*new_node.node.mutable_attr());
963       mutation_.removed_nodes_.erase(node_index);
964     } else {
965       // New node.
966       auto* new_node_def = graph_->add_node();
967       *new_node_def = std::move(new_node.node);
968       node_index = nodes_.size();
969       nodes_.emplace_back(this, node_index);
970       MutableNodeView& new_node_view = nodes_.back();
971       auto it = renamed_fanouts->find(new_node_view.GetName());
972       if (it != renamed_fanouts->end()) {
973         // Reuse fanouts of renamed node.
974         NodeViewFanouts& fanouts = it->second;
975         ReplaceNodeFanouts(&new_node_view, &fanouts);
976         renamed_fanouts->erase(it);
977       }
978       node_index_by_name_.emplace(new_node_view.GetName(), node_index);
979     }
980     new_node_indices->emplace_back(node_index);
981   }
982 }
983 
FixRenamedFanouts(const absl::flat_hash_map<string,NodeViewFanouts> & renamed_fanouts)984 void MutableGraphView::FixRenamedFanouts(
985     const absl::flat_hash_map<string, NodeViewFanouts>& renamed_fanouts) {
986   // Leftover fanouts in renamed_fanouts are due to nodes not existing anymore
987   // or a node being renamed without another node taking its place. For these
988   // leftover fanouts, mark their respective fanin fanout_index_ to
989   // internal::kMissingIndex as an indicator so when it comes to updating or
990   // removing fanins inplace, nodes with the same index don't get affected and
991   // other fanouts are accidentally removed.
992   for (auto& renamed_fanout : renamed_fanouts) {
993     for (auto& regular_fanouts :
994          renamed_fanout.second.regular_fanouts_by_port_) {
995       for (auto& fanout : regular_fanouts) {
996         auto* fanout_node_view = fanout.node_view();
997         auto& fanin = fanout_node_view->regular_fanins_[fanout.index()];
998         fanout_node_view->fanins_count_.erase(
999             {fanin.node_view()->node(), fanin.index()});
1000         fanin.fanout_index_ = internal::kMissingIndex;
1001       }
1002     }
1003     for (auto& fanout : renamed_fanout.second.controlled_fanouts_) {
1004       auto* fanout_node_view = fanout.node_view();
1005       auto& fanin = fanout_node_view->controlling_fanins_[fanout.fanin_index_];
1006       fanout_node_view->fanins_count_.erase(
1007           {fanin.node_view()->node(), Graph::kControlSlot});
1008       fanout_node_view->controlling_fanins_index_.erase(renamed_fanout.first);
1009       fanin.fanout_index_ = internal::kMissingIndex;
1010     }
1011   }
1012 }
1013 
RemoveRegularFaninFanoutInternal(MutableNodeView * node_view,int i)1014 inline void MutableGraphView::RemoveRegularFaninFanoutInternal(
1015     MutableNodeView* node_view, int i) {
1016   MutableFanoutView& fanin = node_view->regular_fanins_[i];
1017   // Fanin was marked as removed via FixRenamedFanouts.
1018   if (fanin.fanout_index_ == internal::kMissingIndex) {
1019     return;
1020   }
1021 
1022   DecrementFaninCount(&node_view->fanins_count_,
1023                       {&graph_->node(fanin.node_index_), fanin.index()});
1024   auto* fanin_node_view = fanin.node_view();
1025   auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin.index()];
1026   int fanouts_size = fanouts.size();
1027   if (fanin.fanout_index_ < fanouts_size - 1) {
1028     // Swap fanout with last fanout in vector, and update it's associated fanin
1029     // index.
1030     MutableFaninView& last_fanout = fanouts.back();
1031     last_fanout.node_view()
1032         ->regular_fanins_[last_fanout.index()]
1033         .fanout_index_ = fanin.fanout_index_;
1034     std::swap(last_fanout, fanouts[fanin.fanout_index_]);
1035   }
1036   // Remove fanout.
1037   fanouts.pop_back();
1038   --fanin.node_view()->num_regular_fanouts_;
1039 
1040   // Resize fanouts. Fanouts may not be removed sequentially in relation to
1041   // output port, so trailing empty output ports may be left behind. It is
1042   // necessary to loop through all of the output ports to determine the maximum
1043   // output port before resizing.
1044   int last_fanout_index = fanin_node_view->regular_fanouts_by_port_.size();
1045   for (int i = fanin_node_view->regular_fanouts_by_port_.size() - 1; i >= 0;
1046        --i) {
1047     if (fanin_node_view->regular_fanouts_by_port_[i].empty()) {
1048       last_fanout_index = i;
1049     } else {
1050       break;
1051     }
1052   }
1053   int fanin_node_view_regular_fanouts_by_port_size =
1054       fanin_node_view->regular_fanouts_by_port_.size();
1055   if (last_fanout_index < fanin_node_view_regular_fanouts_by_port_size) {
1056     fanin_node_view->regular_fanouts_by_port_.resize(last_fanout_index);
1057   }
1058 }
1059 
AddRegularFaninInternal(MutableNodeView * node_view,const SafeTensorId & fanin_id)1060 inline void MutableGraphView::AddRegularFaninInternal(
1061     MutableNodeView* node_view, const SafeTensorId& fanin_id) {
1062   MutableNodeView* fanin_node_view = GetNode(fanin_id.node());
1063   // Resize fanouts to include new output port index.
1064   int fanin_node_view_regular_fanouts_by_port_size =
1065       fanin_node_view->regular_fanouts_by_port_.size();
1066   if (fanin_node_view_regular_fanouts_by_port_size < fanin_id.index() + 1) {
1067     fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1);
1068   }
1069 
1070   // Add node as fanout to fanin.
1071   auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()];
1072   fanouts.emplace_back(this, node_view->node_index(),
1073                        node_view->regular_fanins_.size(),
1074                        node_view->regular_fanins_.size());
1075   ++fanin_node_view->num_regular_fanouts_;
1076 
1077   // Add fanin to node.
1078   node_view->regular_fanins_.emplace_back(this, fanin_node_view->node_index(),
1079                                           fanin_id.index(), fanouts.size() - 1);
1080   IncrementFaninCount(
1081       &node_view->fanins_count_,
1082       {&graph_->node(fanin_node_view->node_index()), fanin_id.index()});
1083 }
1084 
UpdateRegularFaninInternal(MutableNodeView * node_view,const int i,const SafeTensorId & fanin_id)1085 inline void MutableGraphView::UpdateRegularFaninInternal(
1086     MutableNodeView* node_view, const int i, const SafeTensorId& fanin_id) {
1087   // Remove fanin.
1088   RemoveRegularFaninFanoutInternal(node_view, i);
1089 
1090   MutableNodeView* fanin_node_view = GetNode(fanin_id.node());
1091   // Resize fanouts to include new output port index.
1092   int fanin_node_view_regular_fanouts_by_port_size =
1093       fanin_node_view->regular_fanouts_by_port_.size();
1094   if (fanin_node_view_regular_fanouts_by_port_size < fanin_id.index() + 1) {
1095     fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1);
1096   }
1097 
1098   // Add node as fanout to fanin.
1099   auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()];
1100   fanouts.emplace_back(this, node_view->node_index(), i, i);
1101   ++fanin_node_view->num_regular_fanouts_;
1102 
1103   // Replace fanin in node.
1104   node_view->regular_fanins_[i] =
1105       MutableFanoutView(this, fanin_node_view->node_index(), fanin_id.index(),
1106                         fanouts.size() - 1);
1107   IncrementFaninCount(
1108       &node_view->fanins_count_,
1109       {&graph_->node(fanin_node_view->node_index()), fanin_id.index()});
1110 }
1111 
RemoveControllingFaninFanoutInternal(MutableNodeView * node_view,int i)1112 inline void MutableGraphView::RemoveControllingFaninFanoutInternal(
1113     MutableNodeView* node_view, int i) {
1114   auto& control_to_remove = node_view->controlling_fanins_[i];
1115   if (control_to_remove.fanout_index_ != internal::kMissingIndex) {
1116     // Update internal state associated with node.
1117     node_view->fanins_count_.erase(
1118         {control_to_remove.node_view()->node(), Graph::kControlSlot});
1119     node_view->controlling_fanins_index_.erase(
1120         control_to_remove.node_view()->GetName());
1121 
1122     // Remove controlled fanout from controlling fanin, via swapping last
1123     // controlled fanout in controlling fanin with controlled fanout to be
1124     // removed.
1125     auto* control_to_remove_view = control_to_remove.node_view();
1126     int control_to_remove_view_controlled_fanouts_size =
1127         control_to_remove_view->controlled_fanouts_.size();
1128     if (control_to_remove.fanout_index_ <
1129         control_to_remove_view_controlled_fanouts_size - 1) {
1130       auto& control_to_remove_view_last_control =
1131           control_to_remove_view->controlled_fanouts_.back();
1132       control_to_remove_view_last_control.node_view()
1133           ->controlling_fanins_[control_to_remove_view_last_control
1134                                     .fanin_index_]
1135           .fanout_index_ = control_to_remove.fanout_index_;
1136       std::swap(control_to_remove_view_last_control,
1137                 control_to_remove_view
1138                     ->controlled_fanouts_[control_to_remove.fanout_index_]);
1139     }
1140     control_to_remove_view->controlled_fanouts_.pop_back();
1141   }
1142 }
1143 
RemoveControllingFaninInternal(MutableNodeView * node_view,const std::set<int> & indices_to_remove)1144 inline void MutableGraphView::RemoveControllingFaninInternal(
1145     MutableNodeView* node_view, const std::set<int>& indices_to_remove) {
1146   const int num_regular_fanins = node_view->NumRegularFanins();
1147   auto* mutable_input = node_view->node()->mutable_input();
1148   // Iterate in descending order so indices stay consistent.
1149   for (auto rit = indices_to_remove.rbegin(); rit != indices_to_remove.rend();
1150        ++rit) {
1151     const int control_index = *rit;
1152     RemoveControllingFaninFanoutInternal(node_view, control_index);
1153 
1154     // Swap last controlling fanin in node with controlling fanin to be removed.
1155     int node_view_controlling_fanins_size =
1156         node_view->controlling_fanins_.size();
1157     if (control_index < node_view_controlling_fanins_size - 1) {
1158       auto& last_control = node_view->controlling_fanins_.back();
1159       auto* last_control_view = last_control.node_view();
1160       last_control_view->controlled_fanouts_[last_control.fanout_index_]
1161           .fanin_index_ = control_index;
1162       node_view->controlling_fanins_index_.find(last_control_view->GetName())
1163           ->second = control_index;
1164       mutable_input->SwapElements(
1165           num_regular_fanins + control_index,
1166           num_regular_fanins + node_view->NumControllingFanins() - 1);
1167       std::swap(last_control, node_view->controlling_fanins_[control_index]);
1168     }
1169     mutable_input->RemoveLast();
1170     node_view->controlling_fanins_.pop_back();
1171   }
1172 }
1173 
AddControllingFaninInternal(MutableNodeView * node_view,absl::string_view fanin_node_name)1174 inline void MutableGraphView::AddControllingFaninInternal(
1175     MutableNodeView* node_view, absl::string_view fanin_node_name) {
1176   NodeDef* node = node_view->node();
1177   // Add controlling fanin to NodeDef.
1178   node->add_input(AsControlDependency(string(fanin_node_name)));
1179   MutableNodeView* fanin_node_view = GetNode(fanin_node_name);
1180   const int index = node_view->controlling_fanins_.size();
1181   fanin_node_view->controlled_fanouts_.emplace_back(
1182       this, node_view->node_index(), Graph::kControlSlot, index);
1183   node_view->controlling_fanins_.emplace_back(
1184       this, fanin_node_view->node_index(), Graph::kControlSlot,
1185       fanin_node_view->controlled_fanouts_.size() - 1);
1186   IncrementFaninCount(
1187       &node_view->fanins_count_,
1188       {&graph_->node(fanin_node_view->node_index()), Graph::kControlSlot});
1189   // Parse new fanin string for node name.
1190   TensorId tensor_id = ParseTensorName(node->input(node->input_size() - 1));
1191   node_view->controlling_fanins_index_.emplace(tensor_id.node(), index);
1192 }
1193 
ApplyNodeUpdates()1194 void MutableGraphView::ApplyNodeUpdates() {
1195   for (auto& diff : mutation_.updated_nodes_) {
1196     if (internal::IsEmpty(&diff)) {
1197       continue;
1198     }
1199     MutableNodeView& node_view = nodes_[diff.node_index];
1200     diff.node_index = internal::kMissingIndex;
1201     // Clean up node view.
1202     node_view.update_index_ = internal::kMissingIndex;
1203 
1204     NodeDef* node_def = node_view.node();
1205 
1206     // Set updated fields and attributes of node.
1207     if (diff.update_op) {
1208       node_def->set_op(diff.op);
1209     }
1210     if (diff.update_device) {
1211       node_def->set_device(diff.device);
1212     }
1213     node_def->mutable_attr()->swap((*diff.processed_attrs));
1214 
1215     // Updated fanins. Only one of `regular_inputs_to_remove_` or
1216     // `regular_inputs_to_add_` can be set.
1217     if (diff.num_regular_inputs_to_remove > 0) {
1218       // Truncate trailing regular fanins.
1219       const int first_index =
1220           node_view.NumRegularFanins() - diff.num_regular_inputs_to_remove;
1221       for (int i = first_index; i < node_view.NumRegularFanins(); ++i) {
1222         RemoveRegularFaninFanoutInternal(&node_view, i);
1223       }
1224       node_view.regular_fanins_.resize(first_index);
1225       node_def->mutable_input()->DeleteSubrange(
1226           node_view.NumRegularFanins(), diff.num_regular_inputs_to_remove);
1227     } else if (diff.num_regular_inputs_to_add > 0) {
1228       // Append regular fanins.
1229       node_def->mutable_input()->Reserve(node_def->mutable_input()->size() +
1230                                          diff.num_regular_inputs_to_add);
1231       int curr_index = node_view.NumRegularFanins();
1232       int curr_control_start = curr_index;
1233       for (const SafeTensorId& fanin : diff.regular_inputs_to_add) {
1234         AddRegularFaninInternal(&node_view, fanin);
1235         node_def->add_input(SafeTensorIdToString(fanin));
1236         node_def->mutable_input()->SwapElements(curr_index,
1237                                                 node_def->input_size() - 1);
1238         if (curr_control_start == curr_index) {
1239           curr_control_start = node_def->input_size() - 1;
1240         }
1241         ++curr_index;
1242       }
1243       // Rotate shifted controlling fanins to match up with
1244       // `node_view.controlling_fanins_` as `num_regular_inputs_to_add_` may not
1245       // be a multiple of `num_regular_inputs_to_add_`. This is to prevent
1246       // rehashing controlling fanins in `node_view.controlling_fanins_index_`.
1247       if (node_view.NumControllingFanins() > 1 &&
1248           curr_control_start != node_view.NumRegularFanins()) {
1249         std::rotate(
1250             node_def->mutable_input()->begin() + node_view.NumRegularFanins(),
1251             node_def->mutable_input()->begin() + curr_control_start,
1252             node_def->mutable_input()->end());
1253       }
1254     }
1255 
1256     for (const auto& update_fanin : diff.regular_inputs_to_update) {
1257       UpdateRegularFaninInternal(&node_view, update_fanin.first,
1258                                  update_fanin.second);
1259       node_def->set_input(update_fanin.first,
1260                           SafeTensorIdToString(update_fanin.second));
1261     }
1262 
1263     RemoveControllingFaninInternal(&node_view,
1264                                    diff.controlling_inputs_to_remove);
1265 
1266     node_def->mutable_input()->Reserve(node_def->mutable_input()->size() +
1267                                        diff.controlling_inputs_to_add.size());
1268     for (const auto& control_to_add : diff.controlling_inputs_to_add) {
1269       AddControllingFaninInternal(&node_view, control_to_add);
1270     }
1271   }
1272 }
1273 
SetNewNodesFanins(const std::vector<int> & new_node_indices)1274 void MutableGraphView::SetNewNodesFanins(
1275     const std::vector<int>& new_node_indices) {
1276   auto new_node = mutation_.new_nodes_.begin();
1277   for (const int new_node_index : new_node_indices) {
1278     MutableNodeView& new_node_view = nodes_[new_node_index];
1279     NodeDef* new_node_def = new_node_view.node();
1280     new_node_def->mutable_input()->Reserve(new_node->num_regular_fanins +
1281                                            new_node->controlling_fanins.size());
1282     for (const SafeTensorId& fanin : new_node->regular_fanins) {
1283       AddRegularFaninInternal(&new_node_view, fanin);
1284       new_node_def->add_input(SafeTensorIdToString(fanin));
1285     }
1286     for (const string& control_to_add : new_node->controlling_fanins) {
1287       AddControllingFaninInternal(&new_node_view, control_to_add);
1288     }
1289     ++new_node;
1290   }
1291 }
1292 
RemoveAllFaninFanoutInternal(MutableNodeView * node_view)1293 inline void MutableGraphView::RemoveAllFaninFanoutInternal(
1294     MutableNodeView* node_view) {
1295   const int num_regular_fanins = node_view->NumRegularFanins();
1296   for (int i = 0; i < num_regular_fanins; ++i) {
1297     RemoveRegularFaninFanoutInternal(node_view, i);
1298   }
1299   std::vector<MutableFanoutView>().swap(node_view->regular_fanins_);
1300   const int num_controlling_fanins = node_view->NumControllingFanins();
1301   for (int i = 0; i < num_controlling_fanins; ++i) {
1302     RemoveControllingFaninFanoutInternal(node_view, i);
1303   }
1304   std::vector<MutableFanoutView>().swap(node_view->controlling_fanins_);
1305 }
1306 
RemoveNodesInternal(const std::vector<RenamedOrOverwrittenNode> & renamed_nodes,const std::vector<bool> & overwritten_name_removed_nodes)1307 void MutableGraphView::RemoveNodesInternal(
1308     const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
1309     const std::vector<bool>& overwritten_name_removed_nodes) {
1310   // Get all nodes overwritten by renamed nodes and remove their fanins.
1311   std::vector<int> overwritten_nodes;
1312   overwritten_nodes.reserve(renamed_nodes.size());
1313   for (const auto& renamed : renamed_nodes) {
1314     if (renamed.overwritten_node_index_ != internal::kMissingIndex) {
1315       auto& node = nodes_[renamed.overwritten_node_index_];
1316       RemoveAllFaninFanoutInternal(&node);
1317       overwritten_nodes.emplace_back(renamed.overwritten_node_index_);
1318     }
1319   }
1320 
1321   // Get all nodes explicitly marked for removal and remove their fanins.
1322   std::vector<int> node_indices_to_remove;
1323   node_indices_to_remove.reserve(mutation_.updated_nodes_.size() +
1324                                  overwritten_nodes.size());
1325   for (int node_index : mutation_.removed_nodes_) {
1326     auto& node = nodes_[node_index];
1327     RemoveAllFaninFanoutInternal(&node);
1328     node_indices_to_remove.push_back(node_index);
1329     if (!overwritten_name_removed_nodes[node_index]) {
1330       node_index_by_name_.erase(node.GetName());
1331     }
1332   }
1333   node_indices_to_remove.insert(node_indices_to_remove.end(),
1334                                 overwritten_nodes.begin(),
1335                                 overwritten_nodes.end());
1336   std::set<int> sorted_node_indices_to_remove(node_indices_to_remove.begin(),
1337                                               node_indices_to_remove.end());
1338 
1339   // Iterate in descending order so indices stay consistent.
1340   for (auto rit = sorted_node_indices_to_remove.rbegin();
1341        rit != sorted_node_indices_to_remove.rend(); ++rit) {
1342     const int removed_node_index = *rit;
1343     MutableNodeView& last_node = nodes_.back();
1344     if (last_node.node_index_ > removed_node_index) {
1345       last_node.node_index_ = removed_node_index;
1346       for (auto& regular_fanin : last_node.regular_fanins_) {
1347         // Update fanouts of regular fanins with new index.
1348         regular_fanin.node_view()
1349             ->regular_fanouts_by_port_[regular_fanin.index()]
1350                                       [regular_fanin.fanout_index_]
1351             .node_index_ = removed_node_index;
1352       }
1353       for (auto& controlling_fanin : last_node.controlling_fanins_) {
1354         // Update fanouts of controlling fanins with new index.
1355         controlling_fanin.node_view()
1356             ->controlled_fanouts_[controlling_fanin.fanout_index_]
1357             .node_index_ = removed_node_index;
1358       }
1359       for (auto& regular_fanouts : last_node.regular_fanouts_by_port_) {
1360         for (auto& regular_fanout : regular_fanouts) {
1361           // Update fanins of regular fanouts.
1362           MutableNodeView* fanout_node_view = regular_fanout.node_view();
1363           fanout_node_view->regular_fanins_[regular_fanout.fanin_index_]
1364               .node_index_ = removed_node_index;
1365         }
1366       }
1367       for (auto& controlled_fanout : last_node.controlled_fanouts_) {
1368         // Update fanins of controlled fanouts.
1369         MutableNodeView* fanout_node_view = controlled_fanout.node_view();
1370         fanout_node_view->controlling_fanins_[controlled_fanout.fanin_index_]
1371             .node_index_ = removed_node_index;
1372       }
1373 
1374       const int last_node_index = nodes_.size() - 1;
1375       std::swap(nodes_[last_node_index], nodes_[removed_node_index]);
1376       graph()->mutable_node()->SwapElements(last_node_index,
1377                                             removed_node_index);
1378       node_index_by_name_.find(nodes_[removed_node_index].GetName())->second =
1379           removed_node_index;
1380     }
1381     nodes_.pop_back();
1382   }
1383   if (!sorted_node_indices_to_remove.empty()) {
1384     const int current_size = graph()->node_size();
1385     const int num_to_remove = sorted_node_indices_to_remove.size();
1386     graph()->mutable_node()->DeleteSubrange(current_size - num_to_remove,
1387                                             num_to_remove);
1388   }
1389 }
1390 
1391 namespace {
1392 constexpr int kTopologicalSortDone = -1;
1393 
1394 const char kMutableGraphViewSortTopologicallyError[] =
1395     "MutableGraphView::SortTopologically error: ";
1396 
1397 // TraversalState is an enum representing the state of a node when it is being
1398 // traversed via DFS.
1399 enum TraversalState : uint8_t { PENDING, PROCESSING, PROCESSED };
1400 
1401 // RecursionStackState is an enum representing the recursion stack state
1402 // when using DFS iteratively. `ENTER` is the state representing entering into
1403 // a recursive call, while `EXIT` is the state representing exiting a
1404 // recursive call.
1405 enum RecursionStackState : bool { ENTER, EXIT };
1406 
1407 // RecursionStackEntry is a helper struct representing an instance of a
1408 // recursive call in the iterative DFS simulating a recursive ordering.
1409 struct RecursionStackEntry {
RecursionStackEntrytensorflow::grappler::utils::__anond34f60600f11::RecursionStackEntry1410   RecursionStackEntry(int node_index, RecursionStackState recursion_state)
1411       : node_index(node_index), recursion_state(recursion_state) {}
1412 
1413   const int node_index;
1414   const RecursionStackState recursion_state;
1415 };
1416 
1417 // Edge is a helper struct representing an edge in the graph.
1418 struct Edge {
Edgetensorflow::grappler::utils::__anond34f60600f11::Edge1419   Edge(int from, int to) : from(from), to(to) {}
1420 
1421   const int from;
1422   const int to;
1423 };
1424 }  // namespace
1425 
SortTopologically(bool ignore_cycles,absl::Span<const TopologicalDependency> extra_dependencies)1426 Status MutableGraphView::SortTopologically(
1427     bool ignore_cycles,
1428     absl::Span<const TopologicalDependency> extra_dependencies) {
1429   if (!mutation_.updated_nodes_.empty() || !mutation_.new_nodes_.empty()) {
1430     // Cannot sort when there is an active mutation due to indices possibly
1431     // being changed or invalidated.
1432     return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
1433                                    "active mutation exists.");
1434   }
1435 
1436   const int num_nodes = nodes_.size();
1437 
1438   // Group extra dependencies by `from` node.
1439   absl::flat_hash_map<int, std::vector<int>> extra_dependencies_by_parent;
1440   for (const auto& extra_dependency : extra_dependencies) {
1441     if (extra_dependency.graph_view_ != this ||
1442         extra_dependency.from_ == extra_dependency.to_ ||
1443         extra_dependency.from_ < 0 || extra_dependency.from_ >= num_nodes ||
1444         extra_dependency.to_ < 0 || extra_dependency.to_ >= num_nodes) {
1445       return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
1446                                      "invalid extra dependencies.");
1447     }
1448     extra_dependencies_by_parent[extra_dependency.from_].push_back(
1449         extra_dependency.to_);
1450   }
1451 
1452   // Reversed colored post-order DFS traversal. This does not fail on cycles,
1453   // but there are no guarantees on ordering within a cycle.
1454   std::vector<TraversalState> traversal_state(num_nodes, PENDING);
1455   int curr_pos = num_nodes - 1;
1456   std::vector<int> order(num_nodes);
1457   std::vector<Edge> edges_in_cycle;
1458 
1459   auto push_onto_stack = [this](
1460                              const int curr_index, const int fanout_index,
1461                              std::vector<RecursionStackEntry>* recursion_stack,
1462                              std::vector<TraversalState>* traversal_state,
1463                              std::vector<Edge>* edges_in_cycle) {
1464     // Ignore NextIteration -> Merge connections to break control flow cycles.
1465     if (IsNextIteration(graph_->node(curr_index)) &&
1466         IsMerge(graph_->node(fanout_index))) {
1467       return;
1468     }
1469     auto& fanout_traversal_state = (*traversal_state)[fanout_index];
1470     if (fanout_traversal_state == PROCESSING) {
1471       // Cycle detected.
1472       edges_in_cycle->push_back({curr_index, fanout_index});
1473     } else if (fanout_traversal_state == PENDING) {
1474       // Unvisited node, simply add to stack for future traversal.
1475       recursion_stack->push_back({fanout_index, ENTER});
1476     }
1477   };
1478 
1479   auto process_fanouts = [this, &extra_dependencies_by_parent,
1480                           &push_onto_stack](
1481                              const int curr_index,
1482                              std::vector<RecursionStackEntry>* recursion_stack,
1483                              std::vector<TraversalState>* traversal_state,
1484                              std::vector<Edge>* edges_in_cycle) {
1485     const auto& node_view = nodes_[curr_index];
1486     // Regular fanouts.
1487     for (const auto& regular_fanouts_port_i : node_view.GetRegularFanouts()) {
1488       for (const auto& regular_fanout : regular_fanouts_port_i) {
1489         push_onto_stack(curr_index, regular_fanout.node_index_, recursion_stack,
1490                         traversal_state, edges_in_cycle);
1491       }
1492     }
1493     // Controlled fanouts.
1494     for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
1495       push_onto_stack(curr_index, controlled_fanout.node_index_,
1496                       recursion_stack, traversal_state, edges_in_cycle);
1497     }
1498     // Extra dependencies.
1499     auto it = extra_dependencies_by_parent.find(curr_index);
1500     if (it != extra_dependencies_by_parent.end()) {
1501       for (const auto& extra_fanout : it->second) {
1502         push_onto_stack(curr_index, extra_fanout, recursion_stack,
1503                         traversal_state, edges_in_cycle);
1504       }
1505     }
1506   };
1507 
1508   auto reversed_postorder_dfs =
1509       [&process_fanouts](const MutableNodeView& root_node_view,
1510                          std::vector<int>* order,
1511                          std::vector<TraversalState>* traversal_state,
1512                          int* curr_pos, std::vector<Edge>* edges_in_cycle) {
1513         std::vector<RecursionStackEntry> recursion_stack;
1514         // Add the root to stack to start the traversal.
1515         const int root_index = root_node_view.node_index_;
1516         auto& root_traversal_state = (*traversal_state)[root_index];
1517         if (root_traversal_state == PENDING) {
1518           recursion_stack.push_back({root_index, ENTER});
1519         }
1520         while (!recursion_stack.empty()) {
1521           auto curr_entry = recursion_stack.back();
1522           recursion_stack.pop_back();
1523           const int curr_index = curr_entry.node_index;
1524           auto& curr_traversal_state = (*traversal_state)[curr_index];
1525           if (curr_traversal_state == PROCESSED) {
1526             // Node already processed which can be ignored.
1527             continue;
1528           } else if (curr_entry.recursion_state == EXIT) {
1529             // Node from recursion stack where all fanouts were visited.
1530             // Instead of adding node index to a vector, simply set what its
1531             // index would be, so there will not be a need for inversion later
1532             // on. The value set is in decending order so the reversed
1533             // post-order is returned.
1534             (*order)[curr_index] = *curr_pos;
1535             curr_traversal_state = PROCESSED;
1536             --(*curr_pos);
1537           } else {
1538             // Process current node and fanouts.
1539             curr_traversal_state = PROCESSING;
1540             recursion_stack.push_back({curr_index, EXIT});
1541             process_fanouts(curr_index, &recursion_stack, traversal_state,
1542                             edges_in_cycle);
1543           }
1544         }
1545       };
1546 
1547   // Determine sources to start DFS (nodes with no inputs) and unique fanout
1548   // nodes.
1549   for (int i = num_nodes - 1; i >= 0; --i) {
1550     auto& node = nodes_[i];
1551     if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
1552       reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
1553                              &edges_in_cycle);
1554     }
1555   }
1556 
1557   if (!ignore_cycles && !edges_in_cycle.empty()) {
1558     std::vector<string> edges_formatted;
1559     edges_formatted.reserve(edges_in_cycle.size());
1560     for (const auto& edge : edges_in_cycle) {
1561       edges_formatted.push_back(
1562           absl::StrCat("'", graph_->node(edge.from).name(), "' -> '",
1563                        graph_->node(edge.to).name(), "'"));
1564     }
1565     const string edges_str =
1566         absl::StrCat("{", absl::StrJoin(edges_formatted, ", "), "}");
1567     return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
1568                                    "detected edge(s) creating cycle(s) ",
1569                                    edges_str, ".");
1570   }
1571   if (curr_pos != kTopologicalSortDone) {
1572     // Not all nodes were processed.
1573     if (!ignore_cycles) {
1574       return errors::InvalidArgument(
1575           kMutableGraphViewSortTopologicallyError,
1576           "was not able to sort all nodes topologically.");
1577     }
1578     // Otherwise process all nodes regardless of cycles.
1579     for (const auto& node : nodes_) {
1580       reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
1581                              &edges_in_cycle);
1582     }
1583   }
1584 
1585   // Permute nodes by reversed post-order DFS.
1586   std::vector<MutableNodeView> permuted_nodes(num_nodes);
1587   for (int i = 0; i < num_nodes; ++i) {
1588     permuted_nodes[order[i]] = std::move(nodes_[i]);
1589   }
1590   nodes_.swap(permuted_nodes);
1591 
1592   // Fix up indices of MutableNodeViews.
1593   for (MutableNodeView& node_view : nodes_) {
1594     const int prev_node_index = node_view.node_index_;
1595     if (prev_node_index != order[prev_node_index]) {
1596       const string& node_name = graph_->node(prev_node_index).name();
1597       node_view.node_index_ = order[prev_node_index];
1598       node_index_by_name_.find(node_name)->second = node_view.node_index_;
1599     }
1600     for (MutableFanoutView& regular_fanin : node_view.regular_fanins_) {
1601       regular_fanin.node_index_ = order[regular_fanin.node_index_];
1602     }
1603     for (MutableFanoutView& controlling_fanin : node_view.controlling_fanins_) {
1604       controlling_fanin.node_index_ = order[controlling_fanin.node_index_];
1605     }
1606     for (std::vector<MutableFaninView>& regular_fanouts_port_i :
1607          node_view.regular_fanouts_by_port_) {
1608       for (MutableFaninView& regular_fanout : regular_fanouts_port_i) {
1609         regular_fanout.node_index_ = order[regular_fanout.node_index_];
1610       }
1611     }
1612     for (MutableFaninView& controlled_fanout : node_view.controlled_fanouts_) {
1613       controlled_fanout.node_index_ = order[controlled_fanout.node_index_];
1614     }
1615   }
1616 
1617   // Permute graph NodeDefs.
1618   PermuteNodesInPlace(graph_, &order, /*invert_permutation=*/false);
1619 
1620   return Status::OK();
1621 }
1622 
ValidateInternal(absl::flat_hash_map<absl::string_view,int> * node_names,std::vector<RenamedOrOverwrittenNode> * renamed_nodes,std::vector<int> * inplace_nodes,std::vector<int> * empty_diff_node_indices)1623 inline Status MutableGraphView::ValidateInternal(
1624     absl::flat_hash_map<absl::string_view, int>* node_names,
1625     std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
1626     std::vector<int>* inplace_nodes,
1627     std::vector<int>* empty_diff_node_indices) {
1628   // Get node names and partition updated_nodes_ by if they are renamed or not,
1629   // skipping empty MutableNodeViewDiff.
1630   TF_RETURN_IF_ERROR(GetNodeNamesAndPartitionUpdatedNodes(
1631       node_names, renamed_nodes, inplace_nodes, empty_diff_node_indices));
1632 
1633   // Check existence of fanins and validity (i.e. no self loops).
1634   TF_RETURN_IF_ERROR(
1635       CheckNodeNamesAndFanins(*node_names, *renamed_nodes, *inplace_nodes));
1636 
1637   // Check if nodes after mutation have kernels registered.
1638   TF_RETURN_IF_ERROR(CheckKernelRegisteredForNodes());
1639 
1640   return Status::OK();
1641 }
1642 
ApplyMutationInternal()1643 Status MutableGraphView::ApplyMutationInternal() {
1644   // Node name -> node index mapping. If a node index is -1, the associated node
1645   // with key node name exists. Otherwise the node index is the node's index in
1646   // the graph.
1647   absl::flat_hash_map<absl::string_view, int> node_names;
1648   // Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are
1649   // renamed (and possibly have other fields mutated).
1650   std::vector<RenamedOrOverwrittenNode> renamed_nodes;
1651   // Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are
1652   // not renamed but have fields mutated.
1653   std::vector<int> inplace_nodes;
1654   // Indices of nodes in graph where MutableNodeViewDiff are empty.
1655   // `update_index_` of nodes associated to empty MutableNodeViewDiff should be
1656   // cleared after validation success.
1657   std::vector<int> empty_diff_node_indices;
1658 
1659   // Check if this mutation is valid before applying, and partition
1660   // updated_nodes_ into inplace mutated nodes and renamed nodes.
1661   TF_RETURN_IF_ERROR(ValidateInternal(
1662       &node_names, &renamed_nodes, &inplace_nodes, &empty_diff_node_indices));
1663 
1664   // Clear `update_index_` of MutableNodeView with empty associated
1665   // MutableNodeViewDiff.
1666   for (const int empty_diff_node_index : empty_diff_node_indices) {
1667     nodes_[empty_diff_node_index].update_index_ = internal::kMissingIndex;
1668   }
1669 
1670   // Node name and associated fanouts.
1671   absl::flat_hash_map<string, NodeViewFanouts> renamed_fanouts;
1672   // Removed nodes where name was overwritten by a renamed node.
1673   std::vector<bool> overwritten_name_removed_nodes(nodes_.size());
1674   // Fix renaming of existing nodes by swapping fanouts and rehashing names.
1675   // This will also overwrite removed or unmodified nodes.
1676   FixRenamedNodes(&renamed_nodes, &renamed_fanouts,
1677                   &overwritten_name_removed_nodes);
1678 
1679   // Indices of nodes in graph where new nodes were inserted/appended. These
1680   // will be corresponding to `new_nodes_` in order.
1681   std::vector<int> new_node_indices;
1682   // Add new nodes, overwriting removed or unmodified nodes.
1683   AddNewNodes(&renamed_fanouts, &new_node_indices);
1684 
1685   // For abandoned fanouts, mark their respective fanins so the original node
1686   // associated will not have their fanouts removed and be left in an
1687   // inconsistent state.
1688   FixRenamedFanouts(renamed_fanouts);
1689 
1690   // Apply mutations to updated nodes (renamed nodes are treated as inplace
1691   // nodes as they have already been renamed). Removed nodes are ignored.
1692   ApplyNodeUpdates();
1693 
1694   // Set fanins of new nodes.
1695   SetNewNodesFanins(new_node_indices);
1696 
1697   // Remove overwritten nodes and updated nodes set to be removed.
1698   RemoveNodesInternal(renamed_nodes, overwritten_name_removed_nodes);
1699 
1700   mutation_.ResetInternal();
1701 
1702   mutation_.mutation_counter_++;
1703 
1704   return Status::OK();
1705 }
1706 
1707 }  // namespace utils
1708 }  // namespace grappler
1709 }  // namespace tensorflow
1710