1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/graph_view.h"
17
18 #include <utility>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/graph/tensor_id.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/grappler/utils/graph_view_internal.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/util/device_name_utils.h"
31
32 namespace tensorflow {
33 namespace grappler {
34 namespace utils {
35
FaninView(NodeView * node_view,int index)36 FaninView::FaninView(NodeView* node_view, int index)
37 : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
38 index) {}
39
FanoutView(NodeView * node_view,int index)40 FanoutView::FanoutView(NodeView* node_view, int index)
41 : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
42 index) {}
43
node() const44 const NodeDef* NodeView::node() const {
45 return &graph_view_->graph()->node(node_index_);
46 }
47
HasFanin(const FanoutView & fanin) const48 bool NodeView::HasFanin(const FanoutView& fanin) const {
49 if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) {
50 return false;
51 }
52 return fanins_set_.contains(
53 {&graph_view_->graph_->node(fanin.node_index_), fanin.index()});
54 }
55
HasFanout(const FaninView & fanout) const56 bool NodeView::HasFanout(const FaninView& fanout) const {
57 if (fanout.index() < Graph::kControlSlot ||
58 graph_view_ != fanout.graph_view_) {
59 return false;
60 }
61 NodeView* view = fanout.node_view();
62 if (view == nullptr) {
63 return false;
64 } else if (fanout.index() == Graph::kControlSlot) {
65 return view->fanins_set_.contains({this->node(), Graph::kControlSlot});
66 } else if (fanout.index() >= static_cast<int>(view->regular_fanins_.size())) {
67 return false;
68 }
69 return view->regular_fanins_[fanout.index()].node_index_ == node_index_;
70 }
71
GetMissingFanin() const72 inline const FanoutView& NodeView::GetMissingFanin() const {
73 return graph_view_->missing_fanin_;
74 }
75
GetMissingFanout() const76 inline const std::vector<FaninView>& NodeView::GetMissingFanout() const {
77 return graph_view_->missing_fanout_;
78 }
79
80 namespace {
81 const char kGraphViewError[] = "GraphView::GraphView error: ";
82 } // namespace
83
GraphView(const GraphDef * graph,Status * status)84 GraphView::GraphView(const GraphDef* graph, Status* status)
85 : GraphViewInternal(graph) {
86 const int num_nodes = graph->node_size();
87 node_index_by_name_.reserve(num_nodes);
88 nodes_.reserve(num_nodes);
89 for (const NodeDef& node : graph->node()) {
90 if (!AddUniqueNodeInternal(&node)) {
91 *status = errors::InvalidArgument(
92 kGraphViewError, "graph has multiple nodes with the name '",
93 node.name(), "'.");
94 Reset();
95 return;
96 }
97 }
98 Status s;
99 for (NodeView& node_view : nodes_) {
100 s = CheckAndAddFaninsInternal(&node_view);
101 if (!s.ok()) {
102 *status = s;
103 Reset();
104 return;
105 }
106 }
107 *status = Status::OK();
108 }
109
AddUniqueNodeInternal(const NodeDef * node)110 bool GraphView::AddUniqueNodeInternal(const NodeDef* node) {
111 const int node_index = node_index_by_name_.size();
112 auto it = node_index_by_name_.emplace(node->name(), node_index);
113 if (it.second) {
114 nodes_.emplace_back(this, node_index);
115 return true;
116 }
117 return false;
118 }
119
CheckAndAddFaninsInternal(NodeView * node_view)120 Status GraphView::CheckAndAddFaninsInternal(NodeView* node_view) {
121 bool has_observed_control = false;
122 const NodeDef* node = node_view->node();
123 const string& node_name = node->name();
124 const int node_index = node_view->node_index_;
125 node_view->fanins_set_.reserve(node->input_size());
126 for (const string& input : node->input()) {
127 TensorId fanin_id = ParseTensorName(input);
128 if (fanin_id.node() == node_name) {
129 return errors::InvalidArgument(kGraphViewError, "node '", node_name,
130 "' has self cycle fanin '", input, "'.");
131 }
132 bool is_control = IsTensorIdControl(fanin_id);
133 if (!is_control && has_observed_control) {
134 return errors::InvalidArgument(kGraphViewError, "node '", node_name,
135 "' has regular fanin '", input,
136 "' after controlling fanins.");
137 }
138 auto it = node_index_by_name_.find(fanin_id.node());
139 if (it == node_index_by_name_.end()) {
140 return errors::InvalidArgument(kGraphViewError, "node '", node_name,
141 "' has missing fanin '", input, "'.");
142 }
143 const int fanin_node_index = it->second;
144 NodeView& fanin_node_view = nodes_[fanin_node_index];
145
146 if (is_control) {
147 fanin_node_view.controlled_fanouts_.emplace_back(this, node_index,
148 Graph::kControlSlot);
149 node_view->controlling_fanins_.emplace_back(this, fanin_node_index,
150 Graph::kControlSlot);
151 node_view->fanins_set_.emplace(fanin_node_view.node(),
152 Graph::kControlSlot);
153 has_observed_control = true;
154 } else {
155 int fanin_node_view_regular_fanouts_by_port_size =
156 fanin_node_view.regular_fanouts_by_port_.size();
157 if (fanin_node_view_regular_fanouts_by_port_size < fanin_id.index() + 1) {
158 fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1);
159 }
160 fanin_node_view.regular_fanouts_by_port_[fanin_id.index()].emplace_back(
161 this, node_index, node_view->regular_fanins_.size());
162 ++fanin_node_view.num_regular_fanouts_;
163 node_view->regular_fanins_.emplace_back(this, fanin_node_index,
164 fanin_id.index());
165 node_view->fanins_set_.emplace(fanin_node_view.node(), fanin_id.index());
166 }
167 }
168 return Status::OK();
169 }
170
MutableFaninView(MutableNodeView * node_view,int index)171 MutableFaninView::MutableFaninView(MutableNodeView* node_view, int index)
172 : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
173 index) {}
174
MutableFanoutView(MutableNodeView * node_view,int index)175 MutableFanoutView::MutableFanoutView(MutableNodeView* node_view, int index)
176 : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
177 index) {}
178
node() const179 NodeDef* MutableNodeView::node() const {
180 return graph_view_->graph()->mutable_node(node_index_);
181 }
182
HasFanin(const MutableFanoutView & fanin) const183 bool MutableNodeView::HasFanin(const MutableFanoutView& fanin) const {
184 if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) {
185 return false;
186 }
187 return fanins_count_.contains(
188 {&graph_view_->graph_->node(fanin.node_index_), fanin.index()});
189 }
190
HasFanout(const MutableFaninView & fanout) const191 bool MutableNodeView::HasFanout(const MutableFaninView& fanout) const {
192 if (fanout.index() < Graph::kControlSlot ||
193 graph_view_ != fanout.graph_view_) {
194 return false;
195 }
196 MutableNodeView* view = fanout.node_view();
197 if (view == nullptr) {
198 return false;
199 } else if (fanout.index() == Graph::kControlSlot) {
200 return view->fanins_count_.contains({this->node(), Graph::kControlSlot});
201 } else if (fanout.index() >= static_cast<int>(view->regular_fanins_.size())) {
202 return false;
203 }
204 return view->regular_fanins_[fanout.index()].node_index_ == node_index_;
205 }
206
GetMissingFanin() const207 const MutableFanoutView& MutableNodeView::GetMissingFanin() const {
208 return graph_view_->missing_fanin_;
209 }
210
GetMissingFanout() const211 const std::vector<MutableFaninView>& MutableNodeView::GetMissingFanout() const {
212 return graph_view_->missing_fanout_;
213 }
214
215 namespace {
216 const char kMutationAddNodeError[] = "Mutation::AddNode error: ";
217
IsTensorIdRegular(const TensorId & tensor_id)218 bool IsTensorIdRegular(const TensorId& tensor_id) {
219 return tensor_id.index() >= 0;
220 }
221 } // namespace
222
Mutation(MutableGraphView * graph_view)223 Mutation::Mutation(MutableGraphView* graph_view) : graph_view_(graph_view) {}
224
AddNode(NodeDef && node,Status * status)225 MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) {
226 bool has_observed_control = false;
227 const string& node_name = node.name();
228 std::vector<SafeTensorId> regular_fanins;
229 absl::flat_hash_set<string> controlling_fanins;
230 const int num_fanins = node.input_size();
231 for (int i = 0; i < num_fanins; ++i) {
232 const string& input = node.input(i);
233 TensorId fanin_id = ParseTensorName(input);
234 if (fanin_id.node() == node_name) {
235 *status =
236 errors::InvalidArgument(kMutationAddNodeError, "node '", node_name,
237 "' has self cycle fanin '", input, "'.");
238 return MutationNewNode(this, mutation_counter_, internal::kMissingIndex);
239 }
240 bool is_control = IsTensorIdControl(fanin_id);
241 if (is_control) {
242 has_observed_control = true;
243 controlling_fanins.emplace(fanin_id.node());
244 } else if (has_observed_control) {
245 *status = errors::InvalidArgument(kMutationAddNodeError, "node '",
246 node_name, "' has regular fanin '",
247 input, "' after controlling fanins.");
248 return MutationNewNode(this, mutation_counter_, internal::kMissingIndex);
249 } else {
250 regular_fanins.emplace_back(fanin_id);
251 }
252 }
253
254 node.mutable_input()->Clear();
255 new_nodes_.emplace_back(graph_view_, std::move(node));
256 MutationNewNodeHolder& mutation_node = new_nodes_.back();
257 mutation_node.regular_fanins = std::move(regular_fanins);
258 mutation_node.num_regular_fanins = mutation_node.regular_fanins.size();
259 mutation_node.controlling_fanins = std::move(controlling_fanins);
260 *status = Status::OK();
261 return MutationNewNode(this, mutation_counter_, new_nodes_.size() - 1);
262 }
263
AddMutation(MutableNodeView * node,std::function<bool (MutableNodeViewDiff *)> mutate_fn)264 void Mutation::AddMutation(
265 MutableNodeView* node,
266 std::function<bool(MutableNodeViewDiff*)> mutate_fn) {
267 DCHECK(node->graph_view_ == graph_view_);
268 if (node->update_index_ == internal::kMissingIndex) {
269 MutableNodeViewDiff diff(graph_view_, node->node_index_);
270 // If mutation is a no-op return and do not add it to the `updated_nodes_`.
271 if (!mutate_fn(&diff)) return;
272 node->update_index_ = updated_nodes_.size();
273 updated_nodes_.push_back(std::move(diff));
274 } else if (!removed_nodes_.contains(node->node_index_)) {
275 MutableNodeViewDiff& diff = updated_nodes_[node->update_index_];
276 mutate_fn(&diff);
277 }
278 }
279
RemoveNode(MutableNodeView * node)280 void Mutation::RemoveNode(MutableNodeView* node) {
281 auto& update_index = node->update_index_;
282 if (update_index != internal::kMissingIndex) {
283 int updated_nodes_size = updated_nodes_.size();
284 if (update_index < updated_nodes_size - 1) {
285 graph_view_->nodes_[updated_nodes_.back().node_index].update_index_ =
286 update_index;
287 std::swap(updated_nodes_[update_index], updated_nodes_.back());
288 }
289 updated_nodes_.pop_back();
290 update_index = internal::kMissingIndex;
291 }
292 removed_nodes_.insert(node->node_index_);
293 }
294
UpdateNodeName(MutableNodeView * node,absl::string_view name)295 void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
296 AddMutation(node, [name](MutableNodeViewDiff* diff) {
297 return internal::UpdateName(diff, name);
298 });
299 }
300
UpdateNodeName(const MutationNewNode & node,absl::string_view name)301 void Mutation::UpdateNodeName(const MutationNewNode& node,
302 absl::string_view name) {
303 DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
304 internal::UpdateName(&new_nodes_[node.index_], name);
305 }
306
UpdateNodeOp(MutableNodeView * node,absl::string_view op)307 void Mutation::UpdateNodeOp(MutableNodeView* node, absl::string_view op) {
308 AddMutation(node, [op](MutableNodeViewDiff* diff) {
309 return internal::UpdateOp(diff, op);
310 });
311 }
312
UpdateNodeOp(const MutationNewNode & node,absl::string_view op)313 void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) {
314 DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
315 internal::UpdateOp(&new_nodes_[node.index_], op);
316 }
317
UpdateNodeDevice(MutableNodeView * node,absl::string_view device)318 void Mutation::UpdateNodeDevice(MutableNodeView* node,
319 absl::string_view device) {
320 AddMutation(node, [device](MutableNodeViewDiff* diff) {
321 return internal::UpdateDevice(diff, device);
322 });
323 }
324
UpdateNodeDevice(const MutationNewNode & node,absl::string_view device)325 void Mutation::UpdateNodeDevice(const MutationNewNode& node,
326 absl::string_view device) {
327 DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
328 internal::UpdateDevice(&new_nodes_[node.index_], device);
329 }
330
AddOrUpdateRegularFanin(MutableNodeView * node,int index,const TensorId & fanin)331 void Mutation::AddOrUpdateRegularFanin(MutableNodeView* node, int index,
332 const TensorId& fanin) {
333 AddMutation(node, [index, fanin](MutableNodeViewDiff* diff) {
334 return internal::AddOrUpdateRegularFanin(diff, index, fanin);
335 });
336 }
337
AddOrUpdateRegularFanin(const MutationNewNode & node,int index,const TensorId & fanin)338 void Mutation::AddOrUpdateRegularFanin(const MutationNewNode& node, int index,
339 const TensorId& fanin) {
340 DCHECK(node.mutation_ == this &&
341 node.mutation_counter_ == mutation_counter_ && index >= 0 &&
342 IsTensorIdRegular(fanin));
343 internal::AddOrUpdateRegularFanin(&new_nodes_[node.index_], index, fanin);
344 }
345
RemoveRegularFanin(MutableNodeView * node,int index)346 void Mutation::RemoveRegularFanin(MutableNodeView* node, int index) {
347 AddMutation(node, [index](MutableNodeViewDiff* diff) {
348 return internal::RemoveRegularFanin(diff, index);
349 });
350 }
351
RemoveRegularFanin(const MutationNewNode & node,int index)352 void Mutation::RemoveRegularFanin(const MutationNewNode& node, int index) {
353 DCHECK(node.mutation_ == this &&
354 node.mutation_counter_ == mutation_counter_ && index >= 0);
355 internal::RemoveRegularFanin(&new_nodes_[node.index_], index);
356 }
357
AddControllingFanin(MutableNodeView * node,absl::string_view fanin_node_name)358 void Mutation::AddControllingFanin(MutableNodeView* node,
359 absl::string_view fanin_node_name) {
360 AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) {
361 auto it = node->controlling_fanins_index_.find(fanin_node_name);
362 const int control_index = it != node->controlling_fanins_index_.end()
363 ? it->second
364 : internal::kMissingIndex;
365 return internal::AddControllingFanin(diff, control_index, fanin_node_name);
366 });
367 }
368
AddControllingFanin(const MutationNewNode & node,absl::string_view fanin_node_name)369 void Mutation::AddControllingFanin(const MutationNewNode& node,
370 absl::string_view fanin_node_name) {
371 DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
372 internal::AddControllingFanin(&new_nodes_[node.index_], fanin_node_name);
373 }
374
RemoveControllingFanin(MutableNodeView * node,absl::string_view fanin_node_name)375 void Mutation::RemoveControllingFanin(MutableNodeView* node,
376 absl::string_view fanin_node_name) {
377 AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) {
378 auto it = node->controlling_fanins_index_.find(fanin_node_name);
379 const int control_index = it != node->controlling_fanins_index_.end()
380 ? it->second
381 : internal::kMissingIndex;
382 return internal::RemoveControllingFanin(diff, control_index,
383 fanin_node_name);
384 });
385 }
386
RemoveControllingFanin(const MutationNewNode & node,absl::string_view fanin_node_name)387 void Mutation::RemoveControllingFanin(const MutationNewNode& node,
388 absl::string_view fanin_node_name) {
389 DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
390 internal::RemoveControllingFanin(&new_nodes_[node.index_], fanin_node_name);
391 }
392
AddOrUpdateNodeAttr(MutableNodeView * node,absl::string_view attr_name,const AttrValue & attr_value)393 void Mutation::AddOrUpdateNodeAttr(MutableNodeView* node,
394 absl::string_view attr_name,
395 const AttrValue& attr_value) {
396 AddMutation(node, [attr_name, attr_value](MutableNodeViewDiff* diff) {
397 return internal::AddOrUpdateAttribute(diff, attr_name, attr_value);
398 });
399 }
400
AddOrUpdateNodeAttr(const MutationNewNode & node,absl::string_view attr_name,const AttrValue & attr_value)401 void Mutation::AddOrUpdateNodeAttr(const MutationNewNode& node,
402 absl::string_view attr_name,
403 const AttrValue& attr_value) {
404 DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
405 internal::AddOrUpdateAttribute(&new_nodes_[node.index_], attr_name,
406 attr_value);
407 }
408
RemoveNodeAttr(MutableNodeView * node,absl::string_view attr_name)409 void Mutation::RemoveNodeAttr(MutableNodeView* node,
410 absl::string_view attr_name) {
411 AddMutation(node, [attr_name](MutableNodeViewDiff* diff) {
412 return internal::RemoveAttribute(diff, attr_name);
413 });
414 }
415
RemoveNodeAttr(const MutationNewNode & node,absl::string_view attr_name)416 void Mutation::RemoveNodeAttr(const MutationNewNode& node,
417 absl::string_view attr_name) {
418 DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
419 internal::RemoveAttribute(&new_nodes_[node.index_], attr_name);
420 }
421
ResetInternal()422 void Mutation::ResetInternal() {
423 updated_nodes_.clear();
424 removed_nodes_.clear();
425 new_nodes_.clear();
426 }
427
Reset()428 void Mutation::Reset() {
429 for (const auto& update : updated_nodes_) {
430 graph_view_->nodes_[update.node_index].update_index_ =
431 internal::kMissingIndex;
432 }
433 ResetInternal();
434 }
435
Apply()436 Status Mutation::Apply() { return graph_view_->ApplyMutationInternal(); }
437
438 namespace {
439 const char kMutableGraphViewError[] =
440 "MutableGraphView::MutableGraphView error: ";
441
442 const char kMutableGraphViewApplyError[] = "Mutation::Apply error: ";
443
IncrementFaninCount(absl::flat_hash_map<internal::NodeDefAndPortIndex,int> * fanins_count,const internal::NodeDefAndPortIndex & fanin)444 inline void IncrementFaninCount(
445 absl::flat_hash_map<internal::NodeDefAndPortIndex, int>* fanins_count,
446 const internal::NodeDefAndPortIndex& fanin) {
447 ++(*fanins_count)[fanin];
448 }
449
DecrementFaninCount(absl::flat_hash_map<internal::NodeDefAndPortIndex,int> * fanins_count,const internal::NodeDefAndPortIndex & fanin)450 inline void DecrementFaninCount(
451 absl::flat_hash_map<internal::NodeDefAndPortIndex, int>* fanins_count,
452 const internal::NodeDefAndPortIndex& fanin) {
453 auto it = fanins_count->find(fanin);
454 if (it != fanins_count->end()) {
455 if (it->second <= 1) {
456 fanins_count->erase(it);
457 } else {
458 --it->second;
459 }
460 }
461 }
462 } // namespace
463
MutableGraphView(GraphDef * graph,Status * status)464 MutableGraphView::MutableGraphView(GraphDef* graph, Status* status)
465 : GraphViewInternal(graph), mutation_(Mutation(this)) {
466 const int num_nodes = graph->node_size();
467 node_index_by_name_.reserve(num_nodes);
468 nodes_.reserve(num_nodes);
469 for (NodeDef& node : *graph->mutable_node()) {
470 if (!AddUniqueNodeInternal(&node)) {
471 *status = errors::InvalidArgument(
472 kMutableGraphViewError, "graph has multiple nodes with the name '",
473 node.name(), "'.");
474 Reset();
475 return;
476 }
477 }
478 std::vector<std::vector<TensorId>> fanins;
479 Status s = CheckFaninsInternal(&fanins);
480 if (!s.ok()) {
481 *status = s;
482 Reset();
483 return;
484 }
485 AddFaninsInternal(&fanins);
486 mutation_.ResetInternal();
487 *status = Status::OK();
488 }
489
GetMutationBuilder()490 Mutation* MutableGraphView::GetMutationBuilder() { return &mutation_; }
491
AddUniqueNodeInternal(NodeDef * node)492 bool MutableGraphView::AddUniqueNodeInternal(NodeDef* node) {
493 const int node_index = node_index_by_name_.size();
494 auto it = node_index_by_name_.emplace(node->name(), node_index);
495 if (it.second) {
496 nodes_.emplace_back(this, node_index);
497 return true;
498 }
499 return false;
500 }
501
CheckFaninsInternal(std::vector<std::vector<TensorId>> * fanins)502 Status MutableGraphView::CheckFaninsInternal(
503 std::vector<std::vector<TensorId>>* fanins) {
504 const int num_nodes = nodes_.size();
505 fanins->reserve(num_nodes);
506 for (int i = 0; i < num_nodes; ++i) {
507 bool has_observed_control = false;
508 const NodeDef* node = nodes_[i].node();
509 const string& node_name = node->name();
510 std::vector<TensorId> node_fanins;
511 node_fanins.reserve(node->input_size());
512 for (const string& input : node->input()) {
513 TensorId fanin_id = ParseTensorName(input);
514 if (fanin_id.node() == node_name) {
515 return errors::InvalidArgument(kMutableGraphViewError, "node '",
516 node_name, "' has self cycle fanin '",
517 input, "'.");
518 }
519 bool is_control = IsTensorIdControl(fanin_id);
520 if (!is_control && has_observed_control) {
521 return errors::InvalidArgument(kMutableGraphViewError, "node '",
522 node_name, "' has regular fanin '",
523 input, "' after controlling fanins.");
524 }
525 if (!node_index_by_name_.contains(fanin_id.node())) {
526 return errors::InvalidArgument(kMutableGraphViewError, "node '",
527 node_name, "' has missing fanin '",
528 input, "'.");
529 }
530 if (is_control) {
531 has_observed_control = true;
532 }
533 node_fanins.push_back(std::move(fanin_id));
534 }
535 fanins->push_back(std::move(node_fanins));
536 }
537 return Status::OK();
538 }
539
AddFaninsInternal(std::vector<std::vector<TensorId>> * fanins)540 void MutableGraphView::AddFaninsInternal(
541 std::vector<std::vector<TensorId>>* fanins) {
542 const int num_nodes = nodes_.size();
543 for (int i = 0; i < num_nodes; ++i) {
544 MutableNodeView& node_view = nodes_[i];
545 NodeDef* node = node_view.node();
546 std::vector<TensorId>& node_fanins = fanins->at(i);
547 absl::flat_hash_set<absl::string_view> observed_controls;
548 int pos = 0;
549 const int last_idx = node_fanins.size() - 1;
550 int last_pos = last_idx;
551 node_view.fanins_count_.reserve(node->input_size());
552 node_view.controlling_fanins_index_.reserve(node->input_size());
553 while (pos <= last_pos) {
554 const TensorId& fanin_id = node_fanins[pos];
555 bool is_control = IsTensorIdControl(fanin_id);
556 const int fanin_node_index = node_index_by_name_[fanin_id.node()];
557 MutableNodeView& fanin_node_view = nodes_[fanin_node_index];
558
559 if (is_control) {
560 if (gtl::InsertIfNotPresent(&observed_controls, fanin_id.node())) {
561 fanin_node_view.controlled_fanouts_.emplace_back(
562 this, i, Graph::kControlSlot,
563 node_view.controlling_fanins_.size());
564 node_view.controlling_fanins_.emplace_back(
565 this, fanin_node_index, Graph::kControlSlot,
566 fanin_node_view.controlled_fanouts_.size() - 1);
567 IncrementFaninCount(
568 &node_view.fanins_count_,
569 {&graph_->node(fanin_node_index), Graph::kControlSlot});
570 node_view.controlling_fanins_index_.emplace(
571 fanin_id.node(), pos - node_view.NumRegularFanins());
572 ++pos;
573 } else {
574 node->mutable_input()->SwapElements(pos, last_pos);
575 std::swap(node_fanins[pos], node_fanins[last_pos]);
576 --last_pos;
577 }
578 } else {
579 int fanin_node_view_regular_fanouts_by_port_size =
580 fanin_node_view.regular_fanouts_by_port_.size();
581 if (fanin_node_view_regular_fanouts_by_port_size <
582 fanin_id.index() + 1) {
583 fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1);
584 }
585 auto& fanin_regular_fanouts =
586 fanin_node_view.regular_fanouts_by_port_[fanin_id.index()];
587 fanin_regular_fanouts.emplace_back(this, i,
588 node_view.regular_fanins_.size(),
589 node_view.regular_fanins_.size());
590 ++fanin_node_view.num_regular_fanouts_;
591 node_view.regular_fanins_.emplace_back(
592 this, fanin_node_index, fanin_id.index(),
593 fanin_regular_fanouts.size() - 1);
594 IncrementFaninCount(
595 &node_view.fanins_count_,
596 {&graph_->node(fanin_node_index), fanin_id.index()});
597 ++pos;
598 }
599 }
600 if (last_pos < last_idx) {
601 node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
602 }
603 }
604 }
605
GetNodeNamesAndPartitionUpdatedNodes(absl::flat_hash_map<absl::string_view,int> * node_names,std::vector<RenamedOrOverwrittenNode> * renamed_nodes,std::vector<int> * inplace_nodes,std::vector<int> * empty_diff_node_indices)606 Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes(
607 absl::flat_hash_map<absl::string_view, int>* node_names,
608 std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
609 std::vector<int>* inplace_nodes,
610 std::vector<int>* empty_diff_node_indices) {
611 // For all nodes to be removed and renamed, mark their original names as
612 // missing and put associated node index in graph.
613 for (const auto& diff : mutation_.updated_nodes_) {
614 if (diff.update_name) {
615 const int index = diff.node_index;
616 const string& node_name = nodes_[index].GetName();
617 node_names->emplace(node_name, index);
618 }
619 }
620
621 for (int node_index : mutation_.removed_nodes_) {
622 const string& node_name = nodes_[node_index].GetName();
623 node_names->emplace(node_name, node_index);
624 }
625
626 auto name_conflict = [](const absl::string_view node_name) {
627 return errors::InvalidArgument(kMutableGraphViewApplyError,
628 "multiple nodes with the name: '", node_name,
629 "' exists in Mutation.");
630 };
631
632 // Partition updated nodes by if they will be renamed or not.
633 const int num_updated_nodes = mutation_.updated_nodes_.size();
634 renamed_nodes->reserve(num_updated_nodes);
635 inplace_nodes->reserve(num_updated_nodes);
636 empty_diff_node_indices->reserve(num_updated_nodes);
637 for (int i = 0; i < num_updated_nodes; ++i) {
638 auto& diff = mutation_.updated_nodes_[i];
639 if (internal::IsEmpty(&diff)) {
640 empty_diff_node_indices->emplace_back(diff.node_index);
641 continue;
642 }
643 // Get name of updated node after potential mutation.
644 const string& node_name =
645 diff.update_name ? diff.name : nodes_[diff.node_index].GetName();
646 auto it = node_names->insert({node_name, internal::kNodeNamePresent});
647 if (!it.second) {
648 if (it.first->second == internal::kNodeNamePresent) {
649 // Another node in the mutation is already using this name, which will
650 // result in a conflict.
651 return name_conflict(node_name);
652 } else {
653 // Mark name as present (node was marked missing from either being
654 // removed or renamed).
655 it.first->second = internal::kNodeNamePresent;
656 }
657 }
658 if (diff.update_name) {
659 // Lookup new name of node in current graph. If a node has such name,
660 // store its index for later lookups as this node will be overwritten.
661 auto node_name_it = node_index_by_name_.find(node_name);
662 const int overwritten_node_index =
663 node_name_it != node_index_by_name_.end() ? node_name_it->second
664 : internal::kMissingIndex;
665 renamed_nodes->emplace_back(i, overwritten_node_index);
666 } else {
667 inplace_nodes->push_back(i);
668 }
669 }
670
671 // Get names of new nodes after potential mutation.
672 for (const auto& new_node : mutation_.new_nodes_) {
673 const string& node_name = new_node.node.name();
674 auto it = node_names->insert({node_name, internal::kNodeNamePresent});
675 if (it.second) {
676 continue;
677 }
678 if (it.first->second == internal::kNodeNamePresent) {
679 // Another node in the mutation is already using this name, which will
680 // result in a conflict.
681 return name_conflict(node_name);
682 } else {
683 // Mark name as present (node was marked missing from either being removed
684 // or renamed).
685 it.first->second = internal::kNodeNamePresent;
686 }
687 }
688
689 return Status::OK();
690 }
691
RemovedOrMissingNodeFanoutsWellFormed(const absl::flat_hash_map<absl::string_view,int> & node_names,const std::vector<RenamedOrOverwrittenNode> & renamed_nodes)692 Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
693 const absl::flat_hash_map<absl::string_view, int>& node_names,
694 const std::vector<RenamedOrOverwrittenNode>& renamed_nodes) {
695 auto bad_fanout = [](absl::string_view fanout_node_name,
696 absl::string_view node_name) {
697 return errors::InvalidArgument(
698 kMutableGraphViewApplyError, "fanout '", fanout_node_name,
699 "' exist for missing node '", node_name, "'.");
700 };
701
702 // Lookup nodes to be overwritten.
703 std::vector<bool> overwritten_nodes(NumNodes());
704 for (auto& renamed_node : renamed_nodes) {
705 if (renamed_node.overwritten_node_index_ == internal::kMissingIndex) {
706 continue;
707 }
708 overwritten_nodes[renamed_node.overwritten_node_index_] = true;
709 }
710
711 // Check if removed nodes and previous state of renamed nodes have no fanouts.
712 for (const auto& node_name_state : node_names) {
713 if (node_name_state.second == internal::kNodeNamePresent) {
714 continue;
715 }
716 const MutableNodeView& node_view = nodes_[node_name_state.second];
717 for (const auto& regular_fanouts : node_view.GetRegularFanouts()) {
718 for (const auto& regular_fanout : regular_fanouts) {
719 // Check all fanouts of a single port.
720 MutableNodeView* fanout_view = regular_fanout.node_view();
721 if (fanout_view->update_index_ == internal::kMissingIndex) {
722 if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
723 // Fanout node will be removed, this can be ignored.
724 continue;
725 } else if (!overwritten_nodes[fanout_view->node_index_]) {
726 // Fanout is not updated or removed/overwritten.
727 return bad_fanout(fanout_view->GetName(), node_name_state.first);
728 }
729 } else {
730 auto& diff = mutation_.updated_nodes_[fanout_view->update_index_];
731 const int last_index = fanout_view->NumRegularFanins() -
732 diff.num_regular_inputs_to_remove - 1;
733 if (regular_fanout.index() > last_index) {
734 // Fanin of fanout is removed, this can be ignored.
735 continue;
736 }
737 // Check if fanin is updated.
738 if (diff.regular_inputs_to_update.find(regular_fanout.index()) ==
739 diff.regular_inputs_to_update.end()) {
740 return bad_fanout(fanout_view->GetName(), node_name_state.first);
741 }
742 }
743 }
744 }
745 for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
746 MutableNodeView* fanout_view = controlled_fanout.node_view();
747 if (fanout_view->update_index_ == internal::kMissingIndex) {
748 if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
749 // Fanout node will be removed, this can be ignored.
750 continue;
751 } else if (!overwritten_nodes[fanout_view->node_index_]) {
752 // Fanout is not updated or removed/overwritten.
753 return bad_fanout(fanout_view->GetName(), node_name_state.first);
754 }
755 } else {
756 auto& diff = mutation_.updated_nodes_[fanout_view->update_index_];
757 // Check if controlling fanin is removed.
758 if (diff.controlling_inputs_to_remove.find(
759 controlled_fanout.fanin_index_) ==
760 diff.controlling_inputs_to_remove.end()) {
761 return bad_fanout(fanout_view->GetName(), node_name_state.first);
762 }
763 }
764 }
765 }
766
767 return Status::OK();
768 }
769
CheckNodeNamesAndFanins(const absl::flat_hash_map<absl::string_view,int> & node_names,const std::vector<RenamedOrOverwrittenNode> & renamed_nodes,const std::vector<int> & inplace_nodes)770 Status MutableGraphView::CheckNodeNamesAndFanins(
771 const absl::flat_hash_map<absl::string_view, int>& node_names,
772 const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
773 const std::vector<int>& inplace_nodes) {
774 // Check if removed/missing node fanouts are valid.
775 TF_RETURN_IF_ERROR(
776 RemovedOrMissingNodeFanoutsWellFormed(node_names, renamed_nodes));
777
778 // Check if updated nodes and their fanins are valid.
779 for (auto& inplace_node : inplace_nodes) {
780 auto& diff = mutation_.updated_nodes_[inplace_node];
781 if (!internal::IsWellFormed(&diff, node_names)) {
782 return errors::InvalidArgument(
783 kMutableGraphViewApplyError, "inplace updated node '",
784 nodes_[diff.node_index].GetName(), "' is ill-formed.");
785 }
786 }
787 for (auto& renamed_node : renamed_nodes) {
788 auto& diff = mutation_.updated_nodes_[renamed_node.renamed_update_index_];
789 if (!internal::IsWellFormed(&diff, node_names)) {
790 return errors::InvalidArgument(
791 kMutableGraphViewApplyError, "renamed updated node '", diff.name,
792 "' ('", nodes_[diff.node_index].GetName(), "') is ill-formed.");
793 }
794 }
795
796 // Check if new nodes and their fanins are valid.
797 for (auto& new_node : mutation_.new_nodes_) {
798 if (!internal::IsWellFormed(&new_node, node_names)) {
799 return errors::InvalidArgument(kMutableGraphViewApplyError, "new node '",
800 new_node.node.name(), "' is ill-formed.");
801 }
802 }
803
804 return Status::OK();
805 }
806
CheckKernelRegisteredForNodes()807 Status MutableGraphView::CheckKernelRegisteredForNodes() {
808 Status s;
809 for (auto& diff : mutation_.updated_nodes_) {
810 if (internal::IsEmpty(&diff)) {
811 continue;
812 }
813
814 NodeDef* node = nodes_[diff.node_index].node();
815 diff.processed_attrs =
816 AttrValueMap(node->attr().begin(), node->attr().end());
817 for (const auto& attr_to_remove : diff.attrs_to_remove) {
818 (*diff.processed_attrs).erase(attr_to_remove);
819 }
820 for (const auto& attr_to_add : diff.attrs_to_add) {
821 gtl::InsertOrUpdate(&(*diff.processed_attrs), attr_to_add.first,
822 attr_to_add.second);
823 }
824 const string& device = diff.update_device ? diff.device : node->device();
825 DeviceNameUtils::ParsedName name;
826 if (device.empty() || !DeviceNameUtils::ParseFullName(device, &name) ||
827 !name.has_type) {
828 continue;
829 }
830 s = IsKernelRegisteredForNode(diff.update_name ? diff.name : node->name(),
831 node->has_experimental_debug_info(),
832 node->experimental_debug_info(),
833 diff.update_op ? diff.op : node->op(), device,
834 AttrSlice(&(*diff.processed_attrs)));
835 if (!s.ok()) {
836 LOG(WARNING) << s.error_message();
837 }
838 }
839 for (const auto& new_node_holder : mutation_.new_nodes_) {
840 const auto& new_node_def = new_node_holder.node;
841 DeviceNameUtils::ParsedName name;
842 if (new_node_def.device().empty() ||
843 !DeviceNameUtils::ParseFullName(new_node_def.device(), &name) ||
844 !name.has_type) {
845 continue;
846 }
847 s = IsKernelRegisteredForNode(new_node_def);
848 if (!s.ok()) {
849 LOG(WARNING) << s.error_message();
850 }
851 }
852 return Status::OK();
853 }
854
855 template <typename T>
ReplaceNodeFanouts(MutableNodeView * node,T * fanouts)856 void MutableGraphView::ReplaceNodeFanouts(MutableNodeView* node, T* fanouts) {
857 node->num_regular_fanouts_ = fanouts->num_regular_fanouts_;
858 node->regular_fanouts_by_port_ = std::move(fanouts->regular_fanouts_by_port_);
859 for (int i = 0, i_max = node->regular_fanouts_by_port_.size(); i < i_max;
860 ++i) {
861 for (int j = 0, j_max = node->regular_fanouts_by_port_[i].size(); j < j_max;
862 ++j) {
863 auto& fanout = node->regular_fanouts_by_port_[i][j];
864 auto* fanout_node_view = fanout.node_view();
865 auto& fanout_fanin = fanout_node_view->regular_fanins_[fanout.index()];
866 auto* fanout_fanins_count = &fanout_node_view->fanins_count_;
867 DecrementFaninCount(
868 fanout_fanins_count,
869 {&graph_->node(fanout_fanin.node_index_), fanout_fanin.index()});
870 fanout_fanin.node_index_ = node->node_index_;
871 IncrementFaninCount(
872 fanout_fanins_count,
873 {&graph_->node(node->node_index_), fanout_fanin.index()});
874 }
875 }
876 node->controlled_fanouts_ = std::move(fanouts->controlled_fanouts_);
877 for (int i = 0, i_max = node->controlled_fanouts_.size(); i < i_max; ++i) {
878 auto& fanout = node->controlled_fanouts_[i];
879 auto* fanout_node_view = fanout.node_view();
880 auto& fanout_fanin =
881 fanout_node_view->controlling_fanins_[fanout.fanin_index_];
882 auto* fanout_fanins_count = &fanout_node_view->fanins_count_;
883 DecrementFaninCount(
884 fanout_fanins_count,
885 {&graph_->node(fanout_fanin.node_index_), Graph::kControlSlot});
886 fanout_fanin.node_index_ = node->node_index_;
887 fanout_fanin.fanout_index_ = i;
888 IncrementFaninCount(fanout_fanins_count, {&graph_->node(node->node_index_),
889 Graph::kControlSlot});
890 }
891 }
892
FixRenamedNodes(std::vector<RenamedOrOverwrittenNode> * renamed_nodes,absl::flat_hash_map<string,NodeViewFanouts> * renamed_fanouts,std::vector<bool> * overwritten_name_removed_nodes)893 void MutableGraphView::FixRenamedNodes(
894 std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
895 absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
896 std::vector<bool>* overwritten_name_removed_nodes) {
897 // Extract all renamed node fanouts.
898 renamed_fanouts->reserve(renamed_nodes->size());
899 for (auto& renamed : *renamed_nodes) {
900 auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_];
901 // Remove node index by name from graph.
902 node_index_by_name_.erase(nodes_[diff.node_index].GetName());
903 MutableNodeView& renamed_node = nodes_[diff.node_index];
904 renamed_fanouts->try_emplace(
905 renamed_node.GetName(),
906 std::move(renamed_node.regular_fanouts_by_port_),
907 renamed_node.num_regular_fanouts_,
908 std::move(renamed_node.controlled_fanouts_));
909 }
910
911 // Replace renamed node fanouts with fanouts associated with updated name.
912 for (auto& renamed : *renamed_nodes) {
913 auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_];
914 MutableNodeView& renamed_node = nodes_[diff.node_index];
915 auto fanouts_it = renamed_fanouts->find(diff.name);
916 if (fanouts_it != renamed_fanouts->end()) {
917 // Another renamed node's fanout.
918 auto& fanouts = fanouts_it->second;
919 ReplaceNodeFanouts(&renamed_node, &fanouts);
920 renamed_fanouts->erase(fanouts_it);
921 // Node to be overwritten is being renamed, so it won't be overwritten.
922 renamed.overwritten_node_index_ = internal::kMissingIndex;
923 } else if (renamed.overwritten_node_index_ != internal::kMissingIndex) {
924 // Existing node in graph.
925 MutableNodeView& node_to_overwrite =
926 nodes_[renamed.overwritten_node_index_];
927 ReplaceNodeFanouts(&renamed_node, &node_to_overwrite);
928 node_index_by_name_.erase(node_to_overwrite.GetName());
929 if (mutation_.removed_nodes_.contains(node_to_overwrite.node_index_)) {
930 (*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true;
931 }
932 } else {
933 // No existing fanouts.
934 renamed_node.num_regular_fanouts_ = 0;
935 }
936
937 // Update node name.
938 renamed_node.node()->set_name(diff.name);
939 diff.update_name = false;
940 diff.name.clear();
941 // Rehash renamed nodes with updated name.
942 node_index_by_name_.emplace(renamed_node.GetName(), diff.node_index);
943 }
944 }
945
AddNewNodes(absl::flat_hash_map<string,NodeViewFanouts> * renamed_fanouts,std::vector<int> * new_node_indices)946 void MutableGraphView::AddNewNodes(
947 absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
948 std::vector<int>* new_node_indices) {
949 new_node_indices->reserve(mutation_.new_nodes_.size());
950 for (auto& new_node : mutation_.new_nodes_) {
951 int node_index;
952 auto graph_it = node_index_by_name_.find(new_node.node.name());
953 if (graph_it != node_index_by_name_.end()) {
954 // Overwrite existing node.
955 node_index = graph_it->second;
956 MutableNodeView& node_view = nodes_[node_index];
957 RemoveAllFaninFanoutInternal(&node_view);
958 auto* node_def = graph_->mutable_node(node_index);
959 node_def->mutable_op()->swap(*new_node.node.mutable_op());
960 node_def->mutable_device()->swap(*new_node.node.mutable_device());
961 node_def->mutable_input()->Clear();
962 node_def->mutable_attr()->swap(*new_node.node.mutable_attr());
963 mutation_.removed_nodes_.erase(node_index);
964 } else {
965 // New node.
966 auto* new_node_def = graph_->add_node();
967 *new_node_def = std::move(new_node.node);
968 node_index = nodes_.size();
969 nodes_.emplace_back(this, node_index);
970 MutableNodeView& new_node_view = nodes_.back();
971 auto it = renamed_fanouts->find(new_node_view.GetName());
972 if (it != renamed_fanouts->end()) {
973 // Reuse fanouts of renamed node.
974 NodeViewFanouts& fanouts = it->second;
975 ReplaceNodeFanouts(&new_node_view, &fanouts);
976 renamed_fanouts->erase(it);
977 }
978 node_index_by_name_.emplace(new_node_view.GetName(), node_index);
979 }
980 new_node_indices->emplace_back(node_index);
981 }
982 }
983
FixRenamedFanouts(const absl::flat_hash_map<string,NodeViewFanouts> & renamed_fanouts)984 void MutableGraphView::FixRenamedFanouts(
985 const absl::flat_hash_map<string, NodeViewFanouts>& renamed_fanouts) {
986 // Leftover fanouts in renamed_fanouts are due to nodes not existing anymore
987 // or a node being renamed without another node taking its place. For these
988 // leftover fanouts, mark their respective fanin fanout_index_ to
989 // internal::kMissingIndex as an indicator so when it comes to updating or
990 // removing fanins inplace, nodes with the same index don't get affected and
991 // other fanouts are accidentally removed.
992 for (auto& renamed_fanout : renamed_fanouts) {
993 for (auto& regular_fanouts :
994 renamed_fanout.second.regular_fanouts_by_port_) {
995 for (auto& fanout : regular_fanouts) {
996 auto* fanout_node_view = fanout.node_view();
997 auto& fanin = fanout_node_view->regular_fanins_[fanout.index()];
998 fanout_node_view->fanins_count_.erase(
999 {fanin.node_view()->node(), fanin.index()});
1000 fanin.fanout_index_ = internal::kMissingIndex;
1001 }
1002 }
1003 for (auto& fanout : renamed_fanout.second.controlled_fanouts_) {
1004 auto* fanout_node_view = fanout.node_view();
1005 auto& fanin = fanout_node_view->controlling_fanins_[fanout.fanin_index_];
1006 fanout_node_view->fanins_count_.erase(
1007 {fanin.node_view()->node(), Graph::kControlSlot});
1008 fanout_node_view->controlling_fanins_index_.erase(renamed_fanout.first);
1009 fanin.fanout_index_ = internal::kMissingIndex;
1010 }
1011 }
1012 }
1013
RemoveRegularFaninFanoutInternal(MutableNodeView * node_view,int i)1014 inline void MutableGraphView::RemoveRegularFaninFanoutInternal(
1015 MutableNodeView* node_view, int i) {
1016 MutableFanoutView& fanin = node_view->regular_fanins_[i];
1017 // Fanin was marked as removed via FixRenamedFanouts.
1018 if (fanin.fanout_index_ == internal::kMissingIndex) {
1019 return;
1020 }
1021
1022 DecrementFaninCount(&node_view->fanins_count_,
1023 {&graph_->node(fanin.node_index_), fanin.index()});
1024 auto* fanin_node_view = fanin.node_view();
1025 auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin.index()];
1026 int fanouts_size = fanouts.size();
1027 if (fanin.fanout_index_ < fanouts_size - 1) {
1028 // Swap fanout with last fanout in vector, and update it's associated fanin
1029 // index.
1030 MutableFaninView& last_fanout = fanouts.back();
1031 last_fanout.node_view()
1032 ->regular_fanins_[last_fanout.index()]
1033 .fanout_index_ = fanin.fanout_index_;
1034 std::swap(last_fanout, fanouts[fanin.fanout_index_]);
1035 }
1036 // Remove fanout.
1037 fanouts.pop_back();
1038 --fanin.node_view()->num_regular_fanouts_;
1039
1040 // Resize fanouts. Fanouts may not be removed sequentially in relation to
1041 // output port, so trailing empty output ports may be left behind. It is
1042 // necessary to loop through all of the output ports to determine the maximum
1043 // output port before resizing.
1044 int last_fanout_index = fanin_node_view->regular_fanouts_by_port_.size();
1045 for (int i = fanin_node_view->regular_fanouts_by_port_.size() - 1; i >= 0;
1046 --i) {
1047 if (fanin_node_view->regular_fanouts_by_port_[i].empty()) {
1048 last_fanout_index = i;
1049 } else {
1050 break;
1051 }
1052 }
1053 int fanin_node_view_regular_fanouts_by_port_size =
1054 fanin_node_view->regular_fanouts_by_port_.size();
1055 if (last_fanout_index < fanin_node_view_regular_fanouts_by_port_size) {
1056 fanin_node_view->regular_fanouts_by_port_.resize(last_fanout_index);
1057 }
1058 }
1059
AddRegularFaninInternal(MutableNodeView * node_view,const SafeTensorId & fanin_id)1060 inline void MutableGraphView::AddRegularFaninInternal(
1061 MutableNodeView* node_view, const SafeTensorId& fanin_id) {
1062 MutableNodeView* fanin_node_view = GetNode(fanin_id.node());
1063 // Resize fanouts to include new output port index.
1064 int fanin_node_view_regular_fanouts_by_port_size =
1065 fanin_node_view->regular_fanouts_by_port_.size();
1066 if (fanin_node_view_regular_fanouts_by_port_size < fanin_id.index() + 1) {
1067 fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1);
1068 }
1069
1070 // Add node as fanout to fanin.
1071 auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()];
1072 fanouts.emplace_back(this, node_view->node_index(),
1073 node_view->regular_fanins_.size(),
1074 node_view->regular_fanins_.size());
1075 ++fanin_node_view->num_regular_fanouts_;
1076
1077 // Add fanin to node.
1078 node_view->regular_fanins_.emplace_back(this, fanin_node_view->node_index(),
1079 fanin_id.index(), fanouts.size() - 1);
1080 IncrementFaninCount(
1081 &node_view->fanins_count_,
1082 {&graph_->node(fanin_node_view->node_index()), fanin_id.index()});
1083 }
1084
UpdateRegularFaninInternal(MutableNodeView * node_view,const int i,const SafeTensorId & fanin_id)1085 inline void MutableGraphView::UpdateRegularFaninInternal(
1086 MutableNodeView* node_view, const int i, const SafeTensorId& fanin_id) {
1087 // Remove fanin.
1088 RemoveRegularFaninFanoutInternal(node_view, i);
1089
1090 MutableNodeView* fanin_node_view = GetNode(fanin_id.node());
1091 // Resize fanouts to include new output port index.
1092 int fanin_node_view_regular_fanouts_by_port_size =
1093 fanin_node_view->regular_fanouts_by_port_.size();
1094 if (fanin_node_view_regular_fanouts_by_port_size < fanin_id.index() + 1) {
1095 fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1);
1096 }
1097
1098 // Add node as fanout to fanin.
1099 auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()];
1100 fanouts.emplace_back(this, node_view->node_index(), i, i);
1101 ++fanin_node_view->num_regular_fanouts_;
1102
1103 // Replace fanin in node.
1104 node_view->regular_fanins_[i] =
1105 MutableFanoutView(this, fanin_node_view->node_index(), fanin_id.index(),
1106 fanouts.size() - 1);
1107 IncrementFaninCount(
1108 &node_view->fanins_count_,
1109 {&graph_->node(fanin_node_view->node_index()), fanin_id.index()});
1110 }
1111
RemoveControllingFaninFanoutInternal(MutableNodeView * node_view,int i)1112 inline void MutableGraphView::RemoveControllingFaninFanoutInternal(
1113 MutableNodeView* node_view, int i) {
1114 auto& control_to_remove = node_view->controlling_fanins_[i];
1115 if (control_to_remove.fanout_index_ != internal::kMissingIndex) {
1116 // Update internal state associated with node.
1117 node_view->fanins_count_.erase(
1118 {control_to_remove.node_view()->node(), Graph::kControlSlot});
1119 node_view->controlling_fanins_index_.erase(
1120 control_to_remove.node_view()->GetName());
1121
1122 // Remove controlled fanout from controlling fanin, via swapping last
1123 // controlled fanout in controlling fanin with controlled fanout to be
1124 // removed.
1125 auto* control_to_remove_view = control_to_remove.node_view();
1126 int control_to_remove_view_controlled_fanouts_size =
1127 control_to_remove_view->controlled_fanouts_.size();
1128 if (control_to_remove.fanout_index_ <
1129 control_to_remove_view_controlled_fanouts_size - 1) {
1130 auto& control_to_remove_view_last_control =
1131 control_to_remove_view->controlled_fanouts_.back();
1132 control_to_remove_view_last_control.node_view()
1133 ->controlling_fanins_[control_to_remove_view_last_control
1134 .fanin_index_]
1135 .fanout_index_ = control_to_remove.fanout_index_;
1136 std::swap(control_to_remove_view_last_control,
1137 control_to_remove_view
1138 ->controlled_fanouts_[control_to_remove.fanout_index_]);
1139 }
1140 control_to_remove_view->controlled_fanouts_.pop_back();
1141 }
1142 }
1143
RemoveControllingFaninInternal(MutableNodeView * node_view,const std::set<int> & indices_to_remove)1144 inline void MutableGraphView::RemoveControllingFaninInternal(
1145 MutableNodeView* node_view, const std::set<int>& indices_to_remove) {
1146 const int num_regular_fanins = node_view->NumRegularFanins();
1147 auto* mutable_input = node_view->node()->mutable_input();
1148 // Iterate in descending order so indices stay consistent.
1149 for (auto rit = indices_to_remove.rbegin(); rit != indices_to_remove.rend();
1150 ++rit) {
1151 const int control_index = *rit;
1152 RemoveControllingFaninFanoutInternal(node_view, control_index);
1153
1154 // Swap last controlling fanin in node with controlling fanin to be removed.
1155 int node_view_controlling_fanins_size =
1156 node_view->controlling_fanins_.size();
1157 if (control_index < node_view_controlling_fanins_size - 1) {
1158 auto& last_control = node_view->controlling_fanins_.back();
1159 auto* last_control_view = last_control.node_view();
1160 last_control_view->controlled_fanouts_[last_control.fanout_index_]
1161 .fanin_index_ = control_index;
1162 node_view->controlling_fanins_index_.find(last_control_view->GetName())
1163 ->second = control_index;
1164 mutable_input->SwapElements(
1165 num_regular_fanins + control_index,
1166 num_regular_fanins + node_view->NumControllingFanins() - 1);
1167 std::swap(last_control, node_view->controlling_fanins_[control_index]);
1168 }
1169 mutable_input->RemoveLast();
1170 node_view->controlling_fanins_.pop_back();
1171 }
1172 }
1173
AddControllingFaninInternal(MutableNodeView * node_view,absl::string_view fanin_node_name)1174 inline void MutableGraphView::AddControllingFaninInternal(
1175 MutableNodeView* node_view, absl::string_view fanin_node_name) {
1176 NodeDef* node = node_view->node();
1177 // Add controlling fanin to NodeDef.
1178 node->add_input(AsControlDependency(string(fanin_node_name)));
1179 MutableNodeView* fanin_node_view = GetNode(fanin_node_name);
1180 const int index = node_view->controlling_fanins_.size();
1181 fanin_node_view->controlled_fanouts_.emplace_back(
1182 this, node_view->node_index(), Graph::kControlSlot, index);
1183 node_view->controlling_fanins_.emplace_back(
1184 this, fanin_node_view->node_index(), Graph::kControlSlot,
1185 fanin_node_view->controlled_fanouts_.size() - 1);
1186 IncrementFaninCount(
1187 &node_view->fanins_count_,
1188 {&graph_->node(fanin_node_view->node_index()), Graph::kControlSlot});
1189 // Parse new fanin string for node name.
1190 TensorId tensor_id = ParseTensorName(node->input(node->input_size() - 1));
1191 node_view->controlling_fanins_index_.emplace(tensor_id.node(), index);
1192 }
1193
ApplyNodeUpdates()1194 void MutableGraphView::ApplyNodeUpdates() {
1195 for (auto& diff : mutation_.updated_nodes_) {
1196 if (internal::IsEmpty(&diff)) {
1197 continue;
1198 }
1199 MutableNodeView& node_view = nodes_[diff.node_index];
1200 diff.node_index = internal::kMissingIndex;
1201 // Clean up node view.
1202 node_view.update_index_ = internal::kMissingIndex;
1203
1204 NodeDef* node_def = node_view.node();
1205
1206 // Set updated fields and attributes of node.
1207 if (diff.update_op) {
1208 node_def->set_op(diff.op);
1209 }
1210 if (diff.update_device) {
1211 node_def->set_device(diff.device);
1212 }
1213 node_def->mutable_attr()->swap((*diff.processed_attrs));
1214
1215 // Updated fanins. Only one of `regular_inputs_to_remove_` or
1216 // `regular_inputs_to_add_` can be set.
1217 if (diff.num_regular_inputs_to_remove > 0) {
1218 // Truncate trailing regular fanins.
1219 const int first_index =
1220 node_view.NumRegularFanins() - diff.num_regular_inputs_to_remove;
1221 for (int i = first_index; i < node_view.NumRegularFanins(); ++i) {
1222 RemoveRegularFaninFanoutInternal(&node_view, i);
1223 }
1224 node_view.regular_fanins_.resize(first_index);
1225 node_def->mutable_input()->DeleteSubrange(
1226 node_view.NumRegularFanins(), diff.num_regular_inputs_to_remove);
1227 } else if (diff.num_regular_inputs_to_add > 0) {
1228 // Append regular fanins.
1229 node_def->mutable_input()->Reserve(node_def->mutable_input()->size() +
1230 diff.num_regular_inputs_to_add);
1231 int curr_index = node_view.NumRegularFanins();
1232 int curr_control_start = curr_index;
1233 for (const SafeTensorId& fanin : diff.regular_inputs_to_add) {
1234 AddRegularFaninInternal(&node_view, fanin);
1235 node_def->add_input(SafeTensorIdToString(fanin));
1236 node_def->mutable_input()->SwapElements(curr_index,
1237 node_def->input_size() - 1);
1238 if (curr_control_start == curr_index) {
1239 curr_control_start = node_def->input_size() - 1;
1240 }
1241 ++curr_index;
1242 }
1243 // Rotate shifted controlling fanins to match up with
1244 // `node_view.controlling_fanins_` as `num_regular_inputs_to_add_` may not
1245 // be a multiple of `num_regular_inputs_to_add_`. This is to prevent
1246 // rehashing controlling fanins in `node_view.controlling_fanins_index_`.
1247 if (node_view.NumControllingFanins() > 1 &&
1248 curr_control_start != node_view.NumRegularFanins()) {
1249 std::rotate(
1250 node_def->mutable_input()->begin() + node_view.NumRegularFanins(),
1251 node_def->mutable_input()->begin() + curr_control_start,
1252 node_def->mutable_input()->end());
1253 }
1254 }
1255
1256 for (const auto& update_fanin : diff.regular_inputs_to_update) {
1257 UpdateRegularFaninInternal(&node_view, update_fanin.first,
1258 update_fanin.second);
1259 node_def->set_input(update_fanin.first,
1260 SafeTensorIdToString(update_fanin.second));
1261 }
1262
1263 RemoveControllingFaninInternal(&node_view,
1264 diff.controlling_inputs_to_remove);
1265
1266 node_def->mutable_input()->Reserve(node_def->mutable_input()->size() +
1267 diff.controlling_inputs_to_add.size());
1268 for (const auto& control_to_add : diff.controlling_inputs_to_add) {
1269 AddControllingFaninInternal(&node_view, control_to_add);
1270 }
1271 }
1272 }
1273
SetNewNodesFanins(const std::vector<int> & new_node_indices)1274 void MutableGraphView::SetNewNodesFanins(
1275 const std::vector<int>& new_node_indices) {
1276 auto new_node = mutation_.new_nodes_.begin();
1277 for (const int new_node_index : new_node_indices) {
1278 MutableNodeView& new_node_view = nodes_[new_node_index];
1279 NodeDef* new_node_def = new_node_view.node();
1280 new_node_def->mutable_input()->Reserve(new_node->num_regular_fanins +
1281 new_node->controlling_fanins.size());
1282 for (const SafeTensorId& fanin : new_node->regular_fanins) {
1283 AddRegularFaninInternal(&new_node_view, fanin);
1284 new_node_def->add_input(SafeTensorIdToString(fanin));
1285 }
1286 for (const string& control_to_add : new_node->controlling_fanins) {
1287 AddControllingFaninInternal(&new_node_view, control_to_add);
1288 }
1289 ++new_node;
1290 }
1291 }
1292
RemoveAllFaninFanoutInternal(MutableNodeView * node_view)1293 inline void MutableGraphView::RemoveAllFaninFanoutInternal(
1294 MutableNodeView* node_view) {
1295 const int num_regular_fanins = node_view->NumRegularFanins();
1296 for (int i = 0; i < num_regular_fanins; ++i) {
1297 RemoveRegularFaninFanoutInternal(node_view, i);
1298 }
1299 std::vector<MutableFanoutView>().swap(node_view->regular_fanins_);
1300 const int num_controlling_fanins = node_view->NumControllingFanins();
1301 for (int i = 0; i < num_controlling_fanins; ++i) {
1302 RemoveControllingFaninFanoutInternal(node_view, i);
1303 }
1304 std::vector<MutableFanoutView>().swap(node_view->controlling_fanins_);
1305 }
1306
RemoveNodesInternal(const std::vector<RenamedOrOverwrittenNode> & renamed_nodes,const std::vector<bool> & overwritten_name_removed_nodes)1307 void MutableGraphView::RemoveNodesInternal(
1308 const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
1309 const std::vector<bool>& overwritten_name_removed_nodes) {
1310 // Get all nodes overwritten by renamed nodes and remove their fanins.
1311 std::vector<int> overwritten_nodes;
1312 overwritten_nodes.reserve(renamed_nodes.size());
1313 for (const auto& renamed : renamed_nodes) {
1314 if (renamed.overwritten_node_index_ != internal::kMissingIndex) {
1315 auto& node = nodes_[renamed.overwritten_node_index_];
1316 RemoveAllFaninFanoutInternal(&node);
1317 overwritten_nodes.emplace_back(renamed.overwritten_node_index_);
1318 }
1319 }
1320
1321 // Get all nodes explicitly marked for removal and remove their fanins.
1322 std::vector<int> node_indices_to_remove;
1323 node_indices_to_remove.reserve(mutation_.updated_nodes_.size() +
1324 overwritten_nodes.size());
1325 for (int node_index : mutation_.removed_nodes_) {
1326 auto& node = nodes_[node_index];
1327 RemoveAllFaninFanoutInternal(&node);
1328 node_indices_to_remove.push_back(node_index);
1329 if (!overwritten_name_removed_nodes[node_index]) {
1330 node_index_by_name_.erase(node.GetName());
1331 }
1332 }
1333 node_indices_to_remove.insert(node_indices_to_remove.end(),
1334 overwritten_nodes.begin(),
1335 overwritten_nodes.end());
1336 std::set<int> sorted_node_indices_to_remove(node_indices_to_remove.begin(),
1337 node_indices_to_remove.end());
1338
1339 // Iterate in descending order so indices stay consistent.
1340 for (auto rit = sorted_node_indices_to_remove.rbegin();
1341 rit != sorted_node_indices_to_remove.rend(); ++rit) {
1342 const int removed_node_index = *rit;
1343 MutableNodeView& last_node = nodes_.back();
1344 if (last_node.node_index_ > removed_node_index) {
1345 last_node.node_index_ = removed_node_index;
1346 for (auto& regular_fanin : last_node.regular_fanins_) {
1347 // Update fanouts of regular fanins with new index.
1348 regular_fanin.node_view()
1349 ->regular_fanouts_by_port_[regular_fanin.index()]
1350 [regular_fanin.fanout_index_]
1351 .node_index_ = removed_node_index;
1352 }
1353 for (auto& controlling_fanin : last_node.controlling_fanins_) {
1354 // Update fanouts of controlling fanins with new index.
1355 controlling_fanin.node_view()
1356 ->controlled_fanouts_[controlling_fanin.fanout_index_]
1357 .node_index_ = removed_node_index;
1358 }
1359 for (auto& regular_fanouts : last_node.regular_fanouts_by_port_) {
1360 for (auto& regular_fanout : regular_fanouts) {
1361 // Update fanins of regular fanouts.
1362 MutableNodeView* fanout_node_view = regular_fanout.node_view();
1363 fanout_node_view->regular_fanins_[regular_fanout.fanin_index_]
1364 .node_index_ = removed_node_index;
1365 }
1366 }
1367 for (auto& controlled_fanout : last_node.controlled_fanouts_) {
1368 // Update fanins of controlled fanouts.
1369 MutableNodeView* fanout_node_view = controlled_fanout.node_view();
1370 fanout_node_view->controlling_fanins_[controlled_fanout.fanin_index_]
1371 .node_index_ = removed_node_index;
1372 }
1373
1374 const int last_node_index = nodes_.size() - 1;
1375 std::swap(nodes_[last_node_index], nodes_[removed_node_index]);
1376 graph()->mutable_node()->SwapElements(last_node_index,
1377 removed_node_index);
1378 node_index_by_name_.find(nodes_[removed_node_index].GetName())->second =
1379 removed_node_index;
1380 }
1381 nodes_.pop_back();
1382 }
1383 if (!sorted_node_indices_to_remove.empty()) {
1384 const int current_size = graph()->node_size();
1385 const int num_to_remove = sorted_node_indices_to_remove.size();
1386 graph()->mutable_node()->DeleteSubrange(current_size - num_to_remove,
1387 num_to_remove);
1388 }
1389 }
1390
1391 namespace {
1392 constexpr int kTopologicalSortDone = -1;
1393
1394 const char kMutableGraphViewSortTopologicallyError[] =
1395 "MutableGraphView::SortTopologically error: ";
1396
1397 // TraversalState is an enum representing the state of a node when it is being
1398 // traversed via DFS.
1399 enum TraversalState : uint8_t { PENDING, PROCESSING, PROCESSED };
1400
1401 // RecursionStackState is an enum representing the recursion stack state
1402 // when using DFS iteratively. `ENTER` is the state representing entering into
1403 // a recursive call, while `EXIT` is the state representing exiting a
1404 // recursive call.
1405 enum RecursionStackState : bool { ENTER, EXIT };
1406
1407 // RecursionStackEntry is a helper struct representing an instance of a
1408 // recursive call in the iterative DFS simulating a recursive ordering.
1409 struct RecursionStackEntry {
RecursionStackEntrytensorflow::grappler::utils::__anond34f60600f11::RecursionStackEntry1410 RecursionStackEntry(int node_index, RecursionStackState recursion_state)
1411 : node_index(node_index), recursion_state(recursion_state) {}
1412
1413 const int node_index;
1414 const RecursionStackState recursion_state;
1415 };
1416
1417 // Edge is a helper struct representing an edge in the graph.
1418 struct Edge {
Edgetensorflow::grappler::utils::__anond34f60600f11::Edge1419 Edge(int from, int to) : from(from), to(to) {}
1420
1421 const int from;
1422 const int to;
1423 };
1424 } // namespace
1425
SortTopologically(bool ignore_cycles,absl::Span<const TopologicalDependency> extra_dependencies)1426 Status MutableGraphView::SortTopologically(
1427 bool ignore_cycles,
1428 absl::Span<const TopologicalDependency> extra_dependencies) {
1429 if (!mutation_.updated_nodes_.empty() || !mutation_.new_nodes_.empty()) {
1430 // Cannot sort when there is an active mutation due to indices possibly
1431 // being changed or invalidated.
1432 return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
1433 "active mutation exists.");
1434 }
1435
1436 const int num_nodes = nodes_.size();
1437
1438 // Group extra dependencies by `from` node.
1439 absl::flat_hash_map<int, std::vector<int>> extra_dependencies_by_parent;
1440 for (const auto& extra_dependency : extra_dependencies) {
1441 if (extra_dependency.graph_view_ != this ||
1442 extra_dependency.from_ == extra_dependency.to_ ||
1443 extra_dependency.from_ < 0 || extra_dependency.from_ >= num_nodes ||
1444 extra_dependency.to_ < 0 || extra_dependency.to_ >= num_nodes) {
1445 return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
1446 "invalid extra dependencies.");
1447 }
1448 extra_dependencies_by_parent[extra_dependency.from_].push_back(
1449 extra_dependency.to_);
1450 }
1451
1452 // Reversed colored post-order DFS traversal. This does not fail on cycles,
1453 // but there are no guarantees on ordering within a cycle.
1454 std::vector<TraversalState> traversal_state(num_nodes, PENDING);
1455 int curr_pos = num_nodes - 1;
1456 std::vector<int> order(num_nodes);
1457 std::vector<Edge> edges_in_cycle;
1458
1459 auto push_onto_stack = [this](
1460 const int curr_index, const int fanout_index,
1461 std::vector<RecursionStackEntry>* recursion_stack,
1462 std::vector<TraversalState>* traversal_state,
1463 std::vector<Edge>* edges_in_cycle) {
1464 // Ignore NextIteration -> Merge connections to break control flow cycles.
1465 if (IsNextIteration(graph_->node(curr_index)) &&
1466 IsMerge(graph_->node(fanout_index))) {
1467 return;
1468 }
1469 auto& fanout_traversal_state = (*traversal_state)[fanout_index];
1470 if (fanout_traversal_state == PROCESSING) {
1471 // Cycle detected.
1472 edges_in_cycle->push_back({curr_index, fanout_index});
1473 } else if (fanout_traversal_state == PENDING) {
1474 // Unvisited node, simply add to stack for future traversal.
1475 recursion_stack->push_back({fanout_index, ENTER});
1476 }
1477 };
1478
1479 auto process_fanouts = [this, &extra_dependencies_by_parent,
1480 &push_onto_stack](
1481 const int curr_index,
1482 std::vector<RecursionStackEntry>* recursion_stack,
1483 std::vector<TraversalState>* traversal_state,
1484 std::vector<Edge>* edges_in_cycle) {
1485 const auto& node_view = nodes_[curr_index];
1486 // Regular fanouts.
1487 for (const auto& regular_fanouts_port_i : node_view.GetRegularFanouts()) {
1488 for (const auto& regular_fanout : regular_fanouts_port_i) {
1489 push_onto_stack(curr_index, regular_fanout.node_index_, recursion_stack,
1490 traversal_state, edges_in_cycle);
1491 }
1492 }
1493 // Controlled fanouts.
1494 for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
1495 push_onto_stack(curr_index, controlled_fanout.node_index_,
1496 recursion_stack, traversal_state, edges_in_cycle);
1497 }
1498 // Extra dependencies.
1499 auto it = extra_dependencies_by_parent.find(curr_index);
1500 if (it != extra_dependencies_by_parent.end()) {
1501 for (const auto& extra_fanout : it->second) {
1502 push_onto_stack(curr_index, extra_fanout, recursion_stack,
1503 traversal_state, edges_in_cycle);
1504 }
1505 }
1506 };
1507
1508 auto reversed_postorder_dfs =
1509 [&process_fanouts](const MutableNodeView& root_node_view,
1510 std::vector<int>* order,
1511 std::vector<TraversalState>* traversal_state,
1512 int* curr_pos, std::vector<Edge>* edges_in_cycle) {
1513 std::vector<RecursionStackEntry> recursion_stack;
1514 // Add the root to stack to start the traversal.
1515 const int root_index = root_node_view.node_index_;
1516 auto& root_traversal_state = (*traversal_state)[root_index];
1517 if (root_traversal_state == PENDING) {
1518 recursion_stack.push_back({root_index, ENTER});
1519 }
1520 while (!recursion_stack.empty()) {
1521 auto curr_entry = recursion_stack.back();
1522 recursion_stack.pop_back();
1523 const int curr_index = curr_entry.node_index;
1524 auto& curr_traversal_state = (*traversal_state)[curr_index];
1525 if (curr_traversal_state == PROCESSED) {
1526 // Node already processed which can be ignored.
1527 continue;
1528 } else if (curr_entry.recursion_state == EXIT) {
1529 // Node from recursion stack where all fanouts were visited.
1530 // Instead of adding node index to a vector, simply set what its
1531 // index would be, so there will not be a need for inversion later
1532 // on. The value set is in decending order so the reversed
1533 // post-order is returned.
1534 (*order)[curr_index] = *curr_pos;
1535 curr_traversal_state = PROCESSED;
1536 --(*curr_pos);
1537 } else {
1538 // Process current node and fanouts.
1539 curr_traversal_state = PROCESSING;
1540 recursion_stack.push_back({curr_index, EXIT});
1541 process_fanouts(curr_index, &recursion_stack, traversal_state,
1542 edges_in_cycle);
1543 }
1544 }
1545 };
1546
1547 // Determine sources to start DFS (nodes with no inputs) and unique fanout
1548 // nodes.
1549 for (int i = num_nodes - 1; i >= 0; --i) {
1550 auto& node = nodes_[i];
1551 if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
1552 reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
1553 &edges_in_cycle);
1554 }
1555 }
1556
1557 if (!ignore_cycles && !edges_in_cycle.empty()) {
1558 std::vector<string> edges_formatted;
1559 edges_formatted.reserve(edges_in_cycle.size());
1560 for (const auto& edge : edges_in_cycle) {
1561 edges_formatted.push_back(
1562 absl::StrCat("'", graph_->node(edge.from).name(), "' -> '",
1563 graph_->node(edge.to).name(), "'"));
1564 }
1565 const string edges_str =
1566 absl::StrCat("{", absl::StrJoin(edges_formatted, ", "), "}");
1567 return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
1568 "detected edge(s) creating cycle(s) ",
1569 edges_str, ".");
1570 }
1571 if (curr_pos != kTopologicalSortDone) {
1572 // Not all nodes were processed.
1573 if (!ignore_cycles) {
1574 return errors::InvalidArgument(
1575 kMutableGraphViewSortTopologicallyError,
1576 "was not able to sort all nodes topologically.");
1577 }
1578 // Otherwise process all nodes regardless of cycles.
1579 for (const auto& node : nodes_) {
1580 reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
1581 &edges_in_cycle);
1582 }
1583 }
1584
1585 // Permute nodes by reversed post-order DFS.
1586 std::vector<MutableNodeView> permuted_nodes(num_nodes);
1587 for (int i = 0; i < num_nodes; ++i) {
1588 permuted_nodes[order[i]] = std::move(nodes_[i]);
1589 }
1590 nodes_.swap(permuted_nodes);
1591
1592 // Fix up indices of MutableNodeViews.
1593 for (MutableNodeView& node_view : nodes_) {
1594 const int prev_node_index = node_view.node_index_;
1595 if (prev_node_index != order[prev_node_index]) {
1596 const string& node_name = graph_->node(prev_node_index).name();
1597 node_view.node_index_ = order[prev_node_index];
1598 node_index_by_name_.find(node_name)->second = node_view.node_index_;
1599 }
1600 for (MutableFanoutView& regular_fanin : node_view.regular_fanins_) {
1601 regular_fanin.node_index_ = order[regular_fanin.node_index_];
1602 }
1603 for (MutableFanoutView& controlling_fanin : node_view.controlling_fanins_) {
1604 controlling_fanin.node_index_ = order[controlling_fanin.node_index_];
1605 }
1606 for (std::vector<MutableFaninView>& regular_fanouts_port_i :
1607 node_view.regular_fanouts_by_port_) {
1608 for (MutableFaninView& regular_fanout : regular_fanouts_port_i) {
1609 regular_fanout.node_index_ = order[regular_fanout.node_index_];
1610 }
1611 }
1612 for (MutableFaninView& controlled_fanout : node_view.controlled_fanouts_) {
1613 controlled_fanout.node_index_ = order[controlled_fanout.node_index_];
1614 }
1615 }
1616
1617 // Permute graph NodeDefs.
1618 PermuteNodesInPlace(graph_, &order, /*invert_permutation=*/false);
1619
1620 return Status::OK();
1621 }
1622
ValidateInternal(absl::flat_hash_map<absl::string_view,int> * node_names,std::vector<RenamedOrOverwrittenNode> * renamed_nodes,std::vector<int> * inplace_nodes,std::vector<int> * empty_diff_node_indices)1623 inline Status MutableGraphView::ValidateInternal(
1624 absl::flat_hash_map<absl::string_view, int>* node_names,
1625 std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
1626 std::vector<int>* inplace_nodes,
1627 std::vector<int>* empty_diff_node_indices) {
1628 // Get node names and partition updated_nodes_ by if they are renamed or not,
1629 // skipping empty MutableNodeViewDiff.
1630 TF_RETURN_IF_ERROR(GetNodeNamesAndPartitionUpdatedNodes(
1631 node_names, renamed_nodes, inplace_nodes, empty_diff_node_indices));
1632
1633 // Check existence of fanins and validity (i.e. no self loops).
1634 TF_RETURN_IF_ERROR(
1635 CheckNodeNamesAndFanins(*node_names, *renamed_nodes, *inplace_nodes));
1636
1637 // Check if nodes after mutation have kernels registered.
1638 TF_RETURN_IF_ERROR(CheckKernelRegisteredForNodes());
1639
1640 return Status::OK();
1641 }
1642
ApplyMutationInternal()1643 Status MutableGraphView::ApplyMutationInternal() {
1644 // Node name -> node index mapping. If a node index is -1, the associated node
1645 // with key node name exists. Otherwise the node index is the node's index in
1646 // the graph.
1647 absl::flat_hash_map<absl::string_view, int> node_names;
1648 // Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are
1649 // renamed (and possibly have other fields mutated).
1650 std::vector<RenamedOrOverwrittenNode> renamed_nodes;
1651 // Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are
1652 // not renamed but have fields mutated.
1653 std::vector<int> inplace_nodes;
1654 // Indices of nodes in graph where MutableNodeViewDiff are empty.
1655 // `update_index_` of nodes associated to empty MutableNodeViewDiff should be
1656 // cleared after validation success.
1657 std::vector<int> empty_diff_node_indices;
1658
1659 // Check if this mutation is valid before applying, and partition
1660 // updated_nodes_ into inplace mutated nodes and renamed nodes.
1661 TF_RETURN_IF_ERROR(ValidateInternal(
1662 &node_names, &renamed_nodes, &inplace_nodes, &empty_diff_node_indices));
1663
1664 // Clear `update_index_` of MutableNodeView with empty associated
1665 // MutableNodeViewDiff.
1666 for (const int empty_diff_node_index : empty_diff_node_indices) {
1667 nodes_[empty_diff_node_index].update_index_ = internal::kMissingIndex;
1668 }
1669
1670 // Node name and associated fanouts.
1671 absl::flat_hash_map<string, NodeViewFanouts> renamed_fanouts;
1672 // Removed nodes where name was overwritten by a renamed node.
1673 std::vector<bool> overwritten_name_removed_nodes(nodes_.size());
1674 // Fix renaming of existing nodes by swapping fanouts and rehashing names.
1675 // This will also overwrite removed or unmodified nodes.
1676 FixRenamedNodes(&renamed_nodes, &renamed_fanouts,
1677 &overwritten_name_removed_nodes);
1678
1679 // Indices of nodes in graph where new nodes were inserted/appended. These
1680 // will be corresponding to `new_nodes_` in order.
1681 std::vector<int> new_node_indices;
1682 // Add new nodes, overwriting removed or unmodified nodes.
1683 AddNewNodes(&renamed_fanouts, &new_node_indices);
1684
1685 // For abandoned fanouts, mark their respective fanins so the original node
1686 // associated will not have their fanouts removed and be left in an
1687 // inconsistent state.
1688 FixRenamedFanouts(renamed_fanouts);
1689
1690 // Apply mutations to updated nodes (renamed nodes are treated as inplace
1691 // nodes as they have already been renamed). Removed nodes are ignored.
1692 ApplyNodeUpdates();
1693
1694 // Set fanins of new nodes.
1695 SetNewNodesFanins(new_node_indices);
1696
1697 // Remove overwritten nodes and updated nodes set to be removed.
1698 RemoveNodesInternal(renamed_nodes, overwritten_name_removed_nodes);
1699
1700 mutation_.ResetInternal();
1701
1702 mutation_.mutation_counter_++;
1703
1704 return Status::OK();
1705 }
1706
1707 } // namespace utils
1708 } // namespace grappler
1709 } // namespace tensorflow
1710