• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/mutable_graph_view.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/grappler/op_types.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 namespace grappler {
40 
41 namespace {
42 
IsTensorIdPortValid(const TensorId & tensor_id)43 bool IsTensorIdPortValid(const TensorId& tensor_id) {
44   return tensor_id.index() >= Graph::kControlSlot;
45 }
46 
IsTensorIdRegular(const TensorId & tensor_id)47 bool IsTensorIdRegular(const TensorId& tensor_id) {
48   return tensor_id.index() > Graph::kControlSlot;
49 }
50 
IsTensorIdControlling(const TensorId & tensor_id)51 bool IsTensorIdControlling(const TensorId& tensor_id) {
52   return tensor_id.index() == Graph::kControlSlot;
53 }
54 
IsOutputPortControlling(const MutableGraphView::OutputPort & port)55 bool IsOutputPortControlling(const MutableGraphView::OutputPort& port) {
56   return port.port_id == Graph::kControlSlot;
57 }
58 
59 // Determines if node is an Identity where it's first regular input is a Switch
60 // node.
IsIdentityConsumingSwitch(const MutableGraphView & graph,const NodeDef & node)61 bool IsIdentityConsumingSwitch(const MutableGraphView& graph,
62                                const NodeDef& node) {
63   if ((IsIdentity(node) || IsIdentityNSingleInput(node)) &&
64       node.input_size() > 0) {
65     TensorId tensor_id = ParseTensorName(node.input(0));
66     if (IsTensorIdControlling(tensor_id)) {
67       return false;
68     }
69 
70     NodeDef* input_node = graph.GetNode(tensor_id.node());
71     return IsSwitch(*input_node);
72   }
73   return false;
74 }
75 
76 // Determines if node input can be deduped by regular inputs when used as a
77 // control dependency. Specifically, if a node is an Identity that leads to a
78 // Switch node, when used as a control dependency, that control dependency
79 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,const NodeDef & control_node)80 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
81                                      const NodeDef& control_node) {
82   return !IsIdentityConsumingSwitch(graph, control_node);
83 }
84 
85 // Determines if node input can be deduped by regular inputs when used as a
86 // control dependency. Specifically, if a node is an Identity that leads to a
87 // Switch node, when used as a control dependency, that control dependency
88 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,absl::string_view control_node_name)89 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
90                                      absl::string_view control_node_name) {
91   NodeDef* control_node = graph.GetNode(control_node_name);
92   DCHECK(control_node != nullptr)
93       << "Didn't find a node for control dependency: " << control_node_name;
94   return CanDedupControlWithRegularInput(graph, *control_node);
95 }
96 
HasRegularFaninNode(const MutableGraphView & graph,const NodeDef & node,absl::string_view fanin_node_name)97 bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node,
98                          absl::string_view fanin_node_name) {
99   const int num_regular_fanins =
100       graph.NumFanins(node, /*include_controlling_nodes=*/false);
101   for (int i = 0; i < num_regular_fanins; ++i) {
102     if (ParseTensorName(node.input(i)).node() == fanin_node_name) {
103       return true;
104     }
105   }
106   return false;
107 }
108 
109 using FanoutsMap =
110     absl::flat_hash_map<MutableGraphView::OutputPort,
111                         absl::flat_hash_set<MutableGraphView::InputPort>>;
112 
SwapControlledFanoutInputs(const MutableGraphView & graph,const FanoutsMap::iterator & control_fanouts,absl::string_view to_node_name)113 void SwapControlledFanoutInputs(const MutableGraphView& graph,
114                                 const FanoutsMap::iterator& control_fanouts,
115                                 absl::string_view to_node_name) {
116   absl::string_view from_node_name(control_fanouts->first.node->name());
117   string control = TensorIdToString({to_node_name, Graph::kControlSlot});
118   for (const auto& control_fanout : control_fanouts->second) {
119     const int start = graph.NumFanins(*control_fanout.node,
120                                       /*include_controlling_nodes=*/false);
121     for (int i = start; i < control_fanout.node->input_size(); ++i) {
122       TensorId tensor_id = ParseTensorName(control_fanout.node->input(i));
123       if (tensor_id.node() == from_node_name) {
124         control_fanout.node->set_input(i, control);
125         break;
126       }
127     }
128   }
129 }
130 
SwapRegularFanoutInputs(FanoutsMap * fanouts,NodeDef * from_node,absl::string_view to_node_name,int max_port)131 void SwapRegularFanoutInputs(FanoutsMap* fanouts, NodeDef* from_node,
132                              absl::string_view to_node_name, int max_port) {
133   MutableGraphView::OutputPort port;
134   port.node = from_node;
135   for (int i = 0; i <= max_port; ++i) {
136     port.port_id = i;
137     auto it = fanouts->find(port);
138     if (it == fanouts->end()) {
139       continue;
140     }
141     string input = TensorIdToString({to_node_name, i});
142     for (const auto& fanout : it->second) {
143       fanout.node->set_input(fanout.port_id, input);
144     }
145   }
146 }
147 
148 using MaxOutputPortsMap = absl::flat_hash_map<const NodeDef*, int>;
149 
SwapFanoutInputs(const MutableGraphView & graph,FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)150 void SwapFanoutInputs(const MutableGraphView& graph, FanoutsMap* fanouts,
151                       MaxOutputPortsMap* max_output_ports, NodeDef* from_node,
152                       NodeDef* to_node) {
153   auto from_control_fanouts = fanouts->find({from_node, Graph::kControlSlot});
154   if (from_control_fanouts != fanouts->end()) {
155     SwapControlledFanoutInputs(graph, from_control_fanouts, to_node->name());
156   }
157   auto to_control_fanouts = fanouts->find({to_node, Graph::kControlSlot});
158   if (to_control_fanouts != fanouts->end()) {
159     SwapControlledFanoutInputs(graph, to_control_fanouts, from_node->name());
160   }
161   auto from_max_port = max_output_ports->find(from_node);
162   if (from_max_port != max_output_ports->end()) {
163     SwapRegularFanoutInputs(fanouts, from_node, to_node->name(),
164                             from_max_port->second);
165   }
166   auto to_max_port = max_output_ports->find(to_node);
167   if (to_max_port != max_output_ports->end()) {
168     SwapRegularFanoutInputs(fanouts, to_node, from_node->name(),
169                             to_max_port->second);
170   }
171 }
172 
SwapFanoutsMapValues(FanoutsMap * fanouts,const MutableGraphView::OutputPort & from_port,const FanoutsMap::iterator & from_fanouts,const MutableGraphView::OutputPort & to_port,const FanoutsMap::iterator & to_fanouts)173 void SwapFanoutsMapValues(FanoutsMap* fanouts,
174                           const MutableGraphView::OutputPort& from_port,
175                           const FanoutsMap::iterator& from_fanouts,
176                           const MutableGraphView::OutputPort& to_port,
177                           const FanoutsMap::iterator& to_fanouts) {
178   const bool from_exists = from_fanouts != fanouts->end();
179   const bool to_exists = to_fanouts != fanouts->end();
180 
181   if (from_exists && to_exists) {
182     std::swap(from_fanouts->second, to_fanouts->second);
183   } else if (from_exists) {
184     fanouts->emplace(to_port, std::move(from_fanouts->second));
185     fanouts->erase(from_port);
186   } else if (to_exists) {
187     fanouts->emplace(from_port, std::move(to_fanouts->second));
188     fanouts->erase(to_port);
189   }
190 }
191 
SwapRegularFanoutsAndMaxPortValues(FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)192 void SwapRegularFanoutsAndMaxPortValues(FanoutsMap* fanouts,
193                                         MaxOutputPortsMap* max_output_ports,
194                                         NodeDef* from_node, NodeDef* to_node) {
195   auto from_max_port = max_output_ports->find(from_node);
196   auto to_max_port = max_output_ports->find(to_node);
197   bool from_exists = from_max_port != max_output_ports->end();
198   bool to_exists = to_max_port != max_output_ports->end();
199 
200   auto forward_fanouts = [fanouts](NodeDef* from, NodeDef* to, int start,
201                                    int end) {
202     for (int i = start; i <= end; ++i) {
203       MutableGraphView::OutputPort from_port(from, i);
204       auto from_fanouts = fanouts->find(from_port);
205       if (from_fanouts != fanouts->end()) {
206         MutableGraphView::OutputPort to_port(to, i);
207         fanouts->emplace(to_port, std::move(from_fanouts->second));
208         fanouts->erase(from_port);
209       }
210     }
211   };
212 
213   if (from_exists && to_exists) {
214     const int from = from_max_port->second;
215     const int to = to_max_port->second;
216     const int shared = std::min(from, to);
217     for (int i = 0; i <= shared; ++i) {
218       MutableGraphView::OutputPort from_port(from_node, i);
219       auto from_fanouts = fanouts->find(from_port);
220       MutableGraphView::OutputPort to_port(to_node, i);
221       auto to_fanouts = fanouts->find(to_port);
222       SwapFanoutsMapValues(fanouts, from_port, from_fanouts, to_port,
223                            to_fanouts);
224     }
225     if (to > from) {
226       forward_fanouts(to_node, from_node, shared + 1, to);
227     } else if (from > to) {
228       forward_fanouts(from_node, to_node, shared + 1, from);
229     }
230 
231     std::swap(from_max_port->second, to_max_port->second);
232   } else if (from_exists) {
233     forward_fanouts(from_node, to_node, 0, from_max_port->second);
234 
235     max_output_ports->emplace(to_node, from_max_port->second);
236     max_output_ports->erase(from_node);
237   } else if (to_exists) {
238     forward_fanouts(to_node, from_node, 0, to_max_port->second);
239 
240     max_output_ports->emplace(from_node, to_max_port->second);
241     max_output_ports->erase(to_node);
242   }
243 }
244 
HasFanoutValue(const FanoutsMap & fanouts,const FanoutsMap::iterator & it)245 bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) {
246   return it != fanouts.end() && !it->second.empty();
247 }
248 
MutationError(absl::string_view function_name,absl::string_view params,absl::string_view msg)249 Status MutationError(absl::string_view function_name, absl::string_view params,
250                      absl::string_view msg) {
251   return errors::InvalidArgument(absl::Substitute(
252       "MutableGraphView::$0($1) error: $2.", function_name, params, msg));
253 }
254 
255 using ErrorHandler = std::function<Status(absl::string_view)>;
256 
UpdateFanoutsError(absl::string_view from_node_name,absl::string_view to_node_name)257 ErrorHandler UpdateFanoutsError(absl::string_view from_node_name,
258                                 absl::string_view to_node_name) {
259   return [from_node_name, to_node_name](absl::string_view msg) {
260     string params = absl::Substitute("from_node_name='$0', to_node_name='$1'",
261                                      from_node_name, to_node_name);
262     return MutationError("UpdateFanouts", params, msg);
263   };
264 }
265 
CheckFaninIsRegular(const TensorId & fanin,ErrorHandler handler)266 Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) {
267   if (!IsTensorIdRegular(fanin)) {
268     return handler(absl::Substitute("fanin '$0' must be a regular tensor id",
269                                     fanin.ToString()));
270   }
271   return Status::OK();
272 }
273 
CheckFaninIsValid(const TensorId & fanin,ErrorHandler handler)274 Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) {
275   if (!IsTensorIdPortValid(fanin)) {
276     return handler(absl::Substitute("fanin '$0' must be a valid tensor id",
277                                     fanin.ToString()));
278   }
279   return Status::OK();
280 }
281 
CheckAddingFaninToSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)282 Status CheckAddingFaninToSelf(absl::string_view node_name,
283                               const TensorId& fanin, ErrorHandler handler) {
284   if (node_name == fanin.node()) {
285     return handler(
286         absl::Substitute("can't add fanin '$0' to self", fanin.ToString()));
287   }
288   return Status::OK();
289 }
290 
CheckRemovingFaninFromSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)291 Status CheckRemovingFaninFromSelf(absl::string_view node_name,
292                                   const TensorId& fanin, ErrorHandler handler) {
293   if (node_name == fanin.node()) {
294     return handler(absl::Substitute("can't remove fanin '$0' from self",
295                                     fanin.ToString()));
296   }
297   return Status::OK();
298 }
299 
NodeMissingErrorMsg(absl::string_view node_name)300 string NodeMissingErrorMsg(absl::string_view node_name) {
301   return absl::Substitute("node '$0' was not found", node_name);
302 }
303 
CheckNodeExists(absl::string_view node_name,NodeDef * node,ErrorHandler handler)304 Status CheckNodeExists(absl::string_view node_name, NodeDef* node,
305                        ErrorHandler handler) {
306   if (node == nullptr) {
307     return handler(NodeMissingErrorMsg(node_name));
308   }
309   return Status::OK();
310 }
311 
CheckPortRange(int port,int min,int max,ErrorHandler handler)312 Status CheckPortRange(int port, int min, int max, ErrorHandler handler) {
313   if (port < min || port > max) {
314     if (max < min) {
315       return handler("no available ports as node has no regular fanins");
316     }
317     return handler(
318         absl::Substitute("port must be in range [$0, $1]", min, max));
319   }
320   return Status::OK();
321 }
322 
SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name)323 string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) {
324   return absl::Substitute(
325       "can't swap node name '$0' as it will become a Switch control dependency",
326       node_name);
327 }
328 
GeneratedNameForIdentityConsumingSwitch(const MutableGraphView::OutputPort & fanin)329 string GeneratedNameForIdentityConsumingSwitch(
330     const MutableGraphView::OutputPort& fanin) {
331   return AddPrefixToNodeName(
332       absl::StrCat(fanin.node->name(), "_", fanin.port_id),
333       kMutableGraphViewCtrl);
334 }
335 
336 }  // namespace
337 
AddAndDedupFanouts(NodeDef * node)338 void MutableGraphView::AddAndDedupFanouts(NodeDef* node) {
339   // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
340   // exist, and all regular fanins come before controlling fanins.
341   absl::flat_hash_set<absl::string_view> fanins;
342   absl::flat_hash_set<absl::string_view> controlling_fanins;
343   int max_input_port = -1;
344   int pos = 0;
345   const int last_idx = node->input_size() - 1;
346   int last_pos = last_idx;
347   while (pos <= last_pos) {
348     TensorId tensor_id = ParseTensorName(node->input(pos));
349     absl::string_view input_node_name = tensor_id.node();
350     bool is_control_input = IsTensorIdControlling(tensor_id);
351     bool can_dedup_control_with_regular_input =
352         CanDedupControlWithRegularInput(*this, input_node_name);
353     bool can_dedup_control =
354         is_control_input && (can_dedup_control_with_regular_input ||
355                              (!can_dedup_control_with_regular_input &&
356                               controlling_fanins.contains(input_node_name)));
357     if (!gtl::InsertIfNotPresent(&fanins, input_node_name) &&
358         can_dedup_control) {
359       node->mutable_input()->SwapElements(pos, last_pos);
360       --last_pos;
361     } else {
362       OutputPort output(nodes()[input_node_name], tensor_id.index());
363 
364       if (is_control_input) {
365         fanouts()[output].emplace(node, Graph::kControlSlot);
366       } else {
367         max_input_port = pos;
368         max_regular_output_port()[output.node] =
369             std::max(max_regular_output_port()[output.node], output.port_id);
370         fanouts()[output].emplace(node, pos);
371       }
372       ++pos;
373     }
374     if (is_control_input) {
375       controlling_fanins.insert(input_node_name);
376     }
377   }
378 
379   if (last_pos < last_idx) {
380     node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
381   }
382 
383   if (max_input_port > -1) {
384     max_regular_input_port()[node] = max_input_port;
385   }
386 }
387 
UpdateMaxRegularOutputPortForRemovedFanin(const OutputPort & fanin,const absl::flat_hash_set<InputPort> & fanin_fanouts)388 void MutableGraphView::UpdateMaxRegularOutputPortForRemovedFanin(
389     const OutputPort& fanin,
390     const absl::flat_hash_set<InputPort>& fanin_fanouts) {
391   int max_port = max_regular_output_port()[fanin.node];
392   if (!fanin_fanouts.empty() || max_port != fanin.port_id) {
393     return;
394   }
395   bool updated_max_port = false;
396   for (int i = fanin.port_id - 1; i >= 0; --i) {
397     OutputPort fanin_port(fanin.node, i);
398     if (!fanouts()[fanin_port].empty()) {
399       max_regular_output_port()[fanin.node] = i;
400       updated_max_port = true;
401       break;
402     }
403   }
404   if (!updated_max_port) {
405     max_regular_output_port().erase(fanin.node);
406   }
407 }
408 
UpdateMaxRegularOutputPortForAddedFanin(const OutputPort & fanin)409 void MutableGraphView::UpdateMaxRegularOutputPortForAddedFanin(
410     const OutputPort& fanin) {
411   if (max_regular_output_port()[fanin.node] < fanin.port_id) {
412     max_regular_output_port()[fanin.node] = fanin.port_id;
413   }
414 }
415 
416 const absl::flat_hash_set<MutableGraphView::InputPort>&
GetFanout(const GraphView::OutputPort & port) const417 MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
418   return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
419                                                 port.port_id));
420 }
421 
GetFanin(const GraphView::InputPort & port) const422 absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
423     const GraphView::InputPort& port) const {
424   return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
425                                               port.port_id));
426 }
427 
GetRegularFanin(const GraphView::InputPort & port) const428 const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
429     const GraphView::InputPort& port) const {
430   return GetRegularFanin(MutableGraphView::InputPort(
431       const_cast<NodeDef*>(port.node), port.port_id));
432 }
433 
AddNode(NodeDef && node)434 NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
435   auto* node_in_graph = graph()->add_node();
436   *node_in_graph = std::move(node);
437 
438   AddUniqueNodeOrDie(node_in_graph);
439 
440   AddAndDedupFanouts(node_in_graph);
441   return node_in_graph;
442 }
443 
AddSubgraph(GraphDef && subgraph)444 Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
445   // 1. Add all new functions and check that functions with the same name
446   // have identical definition.
447   const int function_size = subgraph.library().function_size();
448   if (function_size > 0) {
449     absl::flat_hash_map<absl::string_view, const FunctionDef*> graph_fdefs;
450     for (const FunctionDef& fdef : graph()->library().function()) {
451       graph_fdefs.emplace(fdef.signature().name(), &fdef);
452     }
453 
454     for (FunctionDef& fdef : *subgraph.mutable_library()->mutable_function()) {
455       const auto graph_fdef = graph_fdefs.find(fdef.signature().name());
456 
457       if (graph_fdef == graph_fdefs.end()) {
458         VLOG(3) << "Add new function definition: " << fdef.signature().name();
459         graph()->mutable_library()->add_function()->Swap(&fdef);
460       } else {
461         if (!FunctionDefsEqual(fdef, *graph_fdef->second)) {
462           return MutationError(
463               "AddSubgraph",
464               absl::Substitute("function_size=$0", function_size),
465               absl::StrCat(
466                   "Found different function definition with the same name: ",
467                   fdef.signature().name()));
468         }
469       }
470     }
471   }
472 
473   // 2. Add all nodes to the underlying graph.
474   int node_size_before = graph()->node_size();
475 
476   for (NodeDef& node : *subgraph.mutable_node()) {
477     auto* node_in_graph = graph()->add_node();
478     node_in_graph->Swap(&node);
479     TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
480   }
481 
482   // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
483   // fanins actually exists in the graph, and there is already TODO for that.
484 
485   for (int i = node_size_before; i < graph()->node_size(); ++i) {
486     NodeDef* node = graph()->mutable_node(i);
487     AddAndDedupFanouts(node);
488   }
489 
490   return Status::OK();
491 }
492 
UpdateNode(absl::string_view node_name,absl::string_view op,absl::string_view device,absl::Span<const std::pair<string,AttrValue>> attrs)493 Status MutableGraphView::UpdateNode(
494     absl::string_view node_name, absl::string_view op, absl::string_view device,
495     absl::Span<const std::pair<string, AttrValue>> attrs) {
496   auto error_status = [node_name, op, device, attrs](absl::string_view msg) {
497     std::vector<string> attr_strs;
498     attr_strs.reserve(attrs.size());
499     for (const auto& attr : attrs) {
500       string attr_str = absl::Substitute("('$0', $1)", attr.first,
501                                          attr.second.ShortDebugString());
502       attr_strs.push_back(attr_str);
503     }
504     string params =
505         absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}",
506                          node_name, op, device, absl::StrJoin(attr_strs, ", "));
507     return MutationError("UpdateNodeOp", params, msg);
508   };
509 
510   NodeDef* node = GetNode(node_name);
511   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
512 
513   MutableGraphView::OutputPort control_port(node, Graph::kControlSlot);
514   auto control_fanouts = GetFanout(control_port);
515   if (op == "Switch" && !control_fanouts.empty()) {
516     return error_status(
517         "can't change node op to Switch when node drives a control dependency "
518         "(alternatively, we could add the identity node needed, but it seems "
519         "like an unlikely event and probably a mistake)");
520   }
521 
522   if (node->device() != device) {
523     node->set_device(string(device));
524   }
525   node->mutable_attr()->clear();
526   for (const auto& attr : attrs) {
527     (*node->mutable_attr())[attr.first] = attr.second;
528   }
529 
530   if (node->op() == op) {
531     return Status::OK();
532   }
533 
534   node->set_op(string(op));
535 
536   if (CanDedupControlWithRegularInput(*this, *node)) {
537     for (const auto& control_fanout : control_fanouts) {
538       if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) {
539         RemoveControllingFaninInternal(control_fanout.node, node);
540       }
541     }
542   }
543 
544   return Status::OK();
545 }
546 
UpdateNodeName(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)547 Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name,
548                                         absl::string_view to_node_name,
549                                         bool update_fanouts) {
550   auto error_status = [from_node_name, to_node_name,
551                        update_fanouts](absl::string_view msg) {
552     string params = absl::Substitute(
553         "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
554         from_node_name, to_node_name, update_fanouts);
555     return MutationError("UpdateNodeName", params, msg);
556   };
557 
558   NodeDef* node = GetNode(from_node_name);
559   TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status));
560 
561   if (node->name() == to_node_name) {
562     return Status::OK();
563   }
564   if (HasNode(to_node_name)) {
565     return error_status(
566         "can't update node name because new node name is in use");
567   }
568   auto max_output_port = max_regular_output_port().find(node);
569   const bool has_max_output_port =
570       max_output_port != max_regular_output_port().end();
571   auto control_fanouts = fanouts().find({node, Graph::kControlSlot});
572 
573   if (update_fanouts) {
574     SwapControlledFanoutInputs(*this, control_fanouts, to_node_name);
575     if (has_max_output_port) {
576       SwapRegularFanoutInputs(&fanouts(), node, to_node_name,
577                               max_output_port->second);
578     }
579   } else if (has_max_output_port ||
580              HasFanoutValue(fanouts(), control_fanouts)) {
581     return error_status("can't update node name because node has fanouts");
582   }
583 
584   nodes().erase(node->name());
585   node->set_name(string(to_node_name));
586   nodes().emplace(node->name(), node);
587   return Status::OK();
588 }
589 
SwapNodeNames(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)590 Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name,
591                                        absl::string_view to_node_name,
592                                        bool update_fanouts) {
593   auto error_status = [from_node_name, to_node_name,
594                        update_fanouts](absl::string_view msg) {
595     string params = absl::Substitute(
596         "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
597         from_node_name, to_node_name, update_fanouts);
598     return MutationError("SwapNodeNames", params, msg);
599   };
600 
601   NodeDef* from_node = GetNode(from_node_name);
602   TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status));
603   if (from_node_name == to_node_name) {
604     return Status::OK();
605   }
606   NodeDef* to_node = GetNode(to_node_name);
607   TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status));
608 
609   auto swap_names = [this, from_node, to_node]() {
610     nodes().erase(from_node->name());
611     nodes().erase(to_node->name());
612     std::swap(*from_node->mutable_name(), *to_node->mutable_name());
613     nodes().emplace(from_node->name(), from_node);
614     nodes().emplace(to_node->name(), to_node);
615   };
616 
617   if (update_fanouts) {
618     SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node,
619                      to_node);
620     swap_names();
621     return Status::OK();
622   }
623 
624   bool from_is_switch = IsSwitch(*from_node);
625   MutableGraphView::OutputPort to_control(to_node, Graph::kControlSlot);
626   auto to_control_fanouts = fanouts().find(to_control);
627   if (from_is_switch && HasFanoutValue(fanouts(), to_control_fanouts)) {
628     return error_status(SwapNodeNamesSwitchControlErrorMsg(from_node_name));
629   }
630 
631   bool to_is_switch = IsSwitch(*to_node);
632   MutableGraphView::OutputPort from_control(from_node, Graph::kControlSlot);
633   auto from_control_fanouts = fanouts().find(from_control);
634   if (to_is_switch && HasFanoutValue(fanouts(), from_control_fanouts)) {
635     return error_status(SwapNodeNamesSwitchControlErrorMsg(to_node_name));
636   }
637 
638   // Swap node names.
639   swap_names();
640 
641   // Swap controlling fanouts.
642   //
643   // Note: To and from control fanout iterators are still valid as no mutations
644   // has been performed on fanouts().
645   SwapFanoutsMapValues(&fanouts(), from_control, from_control_fanouts,
646                        to_control, to_control_fanouts);
647 
648   // Swap regular fanouts.
649   SwapRegularFanoutsAndMaxPortValues(&fanouts(), &max_regular_output_port(),
650                                      from_node, to_node);
651 
652   // Update fanins to remove self loops.
653   auto update_fanins = [this](NodeDef* node, absl::string_view old_node_name) {
654     for (int i = 0; i < node->input_size(); ++i) {
655       TensorId tensor_id = ParseTensorName(node->input(i));
656       if (tensor_id.node() == node->name()) {
657         const int idx = tensor_id.index();
658         const int node_idx =
659             IsTensorIdControlling(tensor_id) ? Graph::kControlSlot : i;
660 
661         MutableGraphView::OutputPort from_fanin(node, idx);
662         absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin];
663         from_fanouts->erase({node, node_idx});
664         UpdateMaxRegularOutputPortForRemovedFanin(from_fanin, *from_fanouts);
665 
666         MutableGraphView::OutputPort to_fanin(nodes().at(old_node_name), idx);
667         fanouts()[to_fanin].insert({node, node_idx});
668         UpdateMaxRegularOutputPortForAddedFanin(to_fanin);
669         node->set_input(i, TensorIdToString({old_node_name, idx}));
670       }
671     }
672   };
673   update_fanins(from_node, to_node->name());
674   update_fanins(to_node, from_node->name());
675 
676   // Dedup control dependencies.
677   auto dedup_control_fanouts =
678       [this](NodeDef* node, const FanoutsMap::iterator& control_fanouts) {
679         if (CanDedupControlWithRegularInput(*this, *node) &&
680             control_fanouts != fanouts().end()) {
681           for (const auto& control_fanout : control_fanouts->second) {
682             if (HasRegularFaninNode(*this, *control_fanout.node,
683                                     node->name())) {
684               RemoveControllingFaninInternal(control_fanout.node, node);
685             }
686           }
687         }
688       };
689   auto dedup_switch_control = [this, dedup_control_fanouts](NodeDef* node) {
690     OutputPort port;
691     port.node = node;
692     const int max_port =
693         gtl::FindWithDefault(max_regular_output_port(), node, -1);
694     for (int i = 0; i <= max_port; ++i) {
695       port.port_id = i;
696       auto it = fanouts().find(port);
697       if (it == fanouts().end()) {
698         continue;
699       }
700       for (const auto& fanout : it->second) {
701         auto fanout_controls =
702             fanouts().find({fanout.node, Graph::kControlSlot});
703         dedup_control_fanouts(fanout.node, fanout_controls);
704       }
705     }
706   };
707 
708   if (!from_is_switch) {
709     if (to_is_switch) {
710       dedup_switch_control(from_node);
711     } else {
712       // Fetch iterator again as the original iterator might have been
713       // invalidated by container rehash triggered due to mutations.
714       auto from_control_fanouts = fanouts().find(from_control);
715       dedup_control_fanouts(from_node, from_control_fanouts);
716     }
717   }
718   if (!to_is_switch) {
719     if (from_is_switch) {
720       dedup_switch_control(to_node);
721     } else {
722       // Fetch iterator again as the original iterator might have been
723       // invalidated by container rehash triggered due to mutations.
724       auto to_control_fanouts = fanouts().find(to_control);
725       dedup_control_fanouts(to_node, to_control_fanouts);
726     }
727   }
728 
729   return Status::OK();
730 }
731 
UpdateFanouts(absl::string_view from_node_name,absl::string_view to_node_name)732 Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name,
733                                        absl::string_view to_node_name) {
734   NodeDef* from_node = GetNode(from_node_name);
735   TF_RETURN_IF_ERROR(
736       CheckNodeExists(from_node_name, from_node,
737                       UpdateFanoutsError(from_node_name, to_node_name)));
738   NodeDef* to_node = GetNode(to_node_name);
739   TF_RETURN_IF_ERROR(CheckNodeExists(
740       to_node_name, to_node, UpdateFanoutsError(from_node_name, to_node_name)));
741 
742   return UpdateFanoutsInternal(from_node, to_node);
743 }
744 
UpdateFanoutsInternal(NodeDef * from_node,NodeDef * to_node)745 Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node,
746                                                NodeDef* to_node) {
747   VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.",
748                               from_node->name(), to_node->name());
749   if (from_node == to_node) {
750     return Status::OK();
751   }
752 
753   // Update internal state with the new output_port->input_port edge.
754   const auto add_edge = [this](const OutputPort& output_port,
755                                const InputPort& input_port) {
756     fanouts()[output_port].insert(input_port);
757   };
758 
759   // Remove invalidated edge from the internal state.
760   const auto remove_edge = [this](const OutputPort& output_port,
761                                   const InputPort& input_port) {
762     fanouts()[output_port].erase(input_port);
763   };
764 
765   // For the control fanouts we do not know the input index in a NodeDef,
766   // so we have to traverse all control inputs.
767 
768   auto control_fanouts =
769       GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
770 
771   bool to_node_is_switch = IsSwitch(*to_node);
772   for (const InputPort& control_port : control_fanouts) {
773     // Node can't be control dependency of itself.
774     if (control_port.node == to_node) continue;
775 
776     // Can't add Switch node as a control dependency.
777     if (to_node_is_switch) {
778       // Trying to add a Switch as a control dependency, which if allowed will
779       // make the graph invalid.
780       return UpdateFanoutsError(from_node->name(), to_node->name())(
781           absl::Substitute("can't update fanouts to node '$0' as it will "
782                            "become a Switch control dependency",
783                            to_node->name()));
784     }
785 
786     NodeDef* node = control_port.node;
787     RemoveControllingFaninInternal(node, from_node);
788     AddFaninInternal(node, {to_node, Graph::kControlSlot});
789   }
790 
791   // First we update regular fanouts. For the regular fanouts
792   // `input_port:port_id` is the input index in NodeDef.
793 
794   auto regular_edges =
795       GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
796 
797   // Maximum index of the `from_node` output tensor that is still used as an
798   // input to some other node.
799   int keep_max_regular_output_port = -1;
800 
801   for (const Edge& edge : regular_edges) {
802     const OutputPort output_port = edge.src;
803     const InputPort input_port = edge.dst;
804 
805     // If the `to_node` reads from the `from_node`, skip this edge (see
806     // AddAndUpdateFanoutsWithoutSelfLoops test for an example).
807     if (input_port.node == to_node) {
808       keep_max_regular_output_port =
809           std::max(keep_max_regular_output_port, output_port.port_id);
810       continue;
811     }
812 
813     // Update input at destination node.
814     input_port.node->set_input(
815         input_port.port_id,
816         TensorIdToString({to_node->name(), output_port.port_id}));
817 
818     // Remove old edge between the `from_node` and the fanout node.
819     remove_edge(output_port, input_port);
820     // Add an edge between the `to_node` and new fanout node.
821     add_edge(OutputPort(to_node, output_port.port_id), input_port);
822     // Dedup control dependency.
823     if (CanDedupControlWithRegularInput(*this, *to_node)) {
824       RemoveControllingFaninInternal(input_port.node, to_node);
825     }
826   }
827 
828   // Because we update all regular fanouts of `from_node`, we can just copy
829   // the value `num_regular_outputs`.
830   max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
831 
832   // Check if all fanouts were updated to read from the `to_node`.
833   if (keep_max_regular_output_port >= 0) {
834     max_regular_output_port()[from_node] = keep_max_regular_output_port;
835   } else {
836     max_regular_output_port().erase(from_node);
837   }
838 
839   return Status::OK();
840 }
841 
AddFaninInternal(NodeDef * node,const OutputPort & fanin)842 bool MutableGraphView::AddFaninInternal(NodeDef* node,
843                                         const OutputPort& fanin) {
844   int num_regular_fanins =
845       NumFanins(*node, /*include_controlling_nodes=*/false);
846   bool input_is_control = IsOutputPortControlling(fanin);
847   bool can_dedup_control_with_regular_input =
848       CanDedupControlWithRegularInput(*this, *fanin.node);
849   // Don't add duplicate control dependencies.
850   if (input_is_control) {
851     const int start =
852         can_dedup_control_with_regular_input ? 0 : num_regular_fanins;
853     for (int i = start; i < node->input_size(); ++i) {
854       if (ParseTensorName(node->input(i)).node() == fanin.node->name()) {
855         return false;
856       }
857     }
858   }
859 
860   InputPort input;
861   input.node = node;
862   input.port_id = input_is_control ? Graph::kControlSlot : num_regular_fanins;
863 
864   node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
865   if (!input_is_control) {
866     const int last_node_input = node->input_size() - 1;
867     // If there are control dependencies in node, move newly inserted fanin to
868     // be before such control dependencies.
869     if (num_regular_fanins < last_node_input) {
870       node->mutable_input()->SwapElements(last_node_input, num_regular_fanins);
871     }
872   }
873 
874   fanouts()[fanin].insert(input);
875   if (max_regular_output_port()[fanin.node] < fanin.port_id) {
876     max_regular_output_port()[fanin.node] = fanin.port_id;
877   }
878 
879   // Update max input port and dedup control dependencies.
880   if (!input_is_control) {
881     max_regular_input_port()[node] = num_regular_fanins;
882     if (can_dedup_control_with_regular_input) {
883       RemoveControllingFaninInternal(node, fanin.node);
884     }
885   }
886 
887   return true;
888 }
889 
AddRegularFanin(absl::string_view node_name,const TensorId & fanin)890 Status MutableGraphView::AddRegularFanin(absl::string_view node_name,
891                                          const TensorId& fanin) {
892   auto error_status = [node_name, fanin](absl::string_view msg) {
893     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
894                                      fanin.ToString());
895     return MutationError("AddRegularFanin", params, msg);
896   };
897 
898   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
899   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
900   NodeDef* node = GetNode(node_name);
901   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
902   NodeDef* fanin_node = GetNode(fanin.node());
903   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
904 
905   AddFaninInternal(node, {fanin_node, fanin.index()});
906   return Status::OK();
907 }
908 
AddRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)909 Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name,
910                                                int port,
911                                                const TensorId& fanin) {
912   auto error_status = [node_name, port, fanin](absl::string_view msg) {
913     string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
914                                      node_name, port, fanin.ToString());
915     return MutationError("AddRegularFaninByPort", params, msg);
916   };
917 
918   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
919   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
920   NodeDef* node = GetNode(node_name);
921   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
922   const int num_regular_fanins =
923       NumFanins(*node, /*include_controlling_nodes=*/false);
924   TF_RETURN_IF_ERROR(
925       CheckPortRange(port, /*min=*/0, num_regular_fanins, error_status));
926   NodeDef* fanin_node = GetNode(fanin.node());
927   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
928 
929   const int last_node_input = node->input_size();
930   node->add_input(TensorIdToString(fanin));
931   node->mutable_input()->SwapElements(num_regular_fanins, last_node_input);
932   for (int i = num_regular_fanins - 1; i >= port; --i) {
933     TensorId tensor_id = ParseTensorName(node->input(i));
934     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
935     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
936     fanouts_set->erase({node, i});
937     fanouts_set->insert({node, i + 1});
938     node->mutable_input()->SwapElements(i, i + 1);
939   }
940 
941   OutputPort fanin_port(fanin_node, fanin.index());
942   fanouts()[fanin_port].insert({node, port});
943   UpdateMaxRegularOutputPortForAddedFanin(fanin_port);
944 
945   max_regular_input_port()[node] = num_regular_fanins;
946   if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
947     RemoveControllingFaninInternal(node, fanin_node);
948   }
949 
950   return Status::OK();
951 }
952 
GetControllingFaninToAdd(absl::string_view node_name,const OutputPort & fanin,string * error_msg)953 NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name,
954                                                     const OutputPort& fanin,
955                                                     string* error_msg) {
956   if (!IsSwitch(*fanin.node)) {
957     return fanin.node;
958   } else {
959     if (IsOutputPortControlling(fanin)) {
960       // Can't add a Switch node control dependency.
961       TensorId tensor_id(fanin.node->name(), fanin.port_id);
962       *error_msg = absl::Substitute(
963           "can't add fanin '$0' as it will become a Switch control dependency",
964           tensor_id.ToString());
965       return nullptr;
966     }
967     // We can't anchor control dependencies directly on the switch node: unlike
968     // other nodes only one of the outputs of the switch node will be generated
969     // when the switch node is executed, and we need to make sure the control
970     // dependency is only triggered when the corresponding output is triggered.
971     // We start by looking for an identity node connected to the output of the
972     // switch node, and use it to anchor the control dependency.
973     for (const auto& fanout : GetFanout(fanin)) {
974       if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
975         if (fanout.node->name() == node_name) {
976           *error_msg =
977               absl::Substitute("can't add found fanin '$0' to self",
978                                AsControlDependency(fanout.node->name()));
979           return nullptr;
980         }
981         return fanout.node;
982       }
983     }
984 
985     // No node found, check if node to be created is itself.
986     if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) {
987       *error_msg = absl::Substitute("can't add generated fanin '$0' to self",
988                                     AsControlDependency(string(node_name)));
989     }
990   }
991   return nullptr;
992 }
993 
GetOrCreateIdentityConsumingSwitch(const OutputPort & fanin)994 NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch(
995     const OutputPort& fanin) {
996   // We haven't found an existing node where we can anchor the control
997   // dependency: add a new identity node.
998   string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin);
999   NodeDef* identity_node = GetNode(identity_name);
1000   if (identity_node == nullptr) {
1001     NodeDef new_node;
1002     new_node.set_name(identity_name);
1003     new_node.set_op("Identity");
1004     new_node.set_device(fanin.node->device());
1005     (*new_node.mutable_attr())["T"].set_type(fanin.node->attr().at("T").type());
1006     new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
1007     identity_node = AddNode(std::move(new_node));
1008   }
1009   return identity_node;
1010 }
1011 
AddControllingFanin(absl::string_view node_name,const TensorId & fanin)1012 Status MutableGraphView::AddControllingFanin(absl::string_view node_name,
1013                                              const TensorId& fanin) {
1014   auto error_status = [node_name, fanin](absl::string_view msg) {
1015     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1016                                      fanin.ToString());
1017     return MutationError("AddControllingFanin", params, msg);
1018   };
1019 
1020   TF_RETURN_IF_ERROR(CheckFaninIsValid(fanin, error_status));
1021   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1022   NodeDef* node = GetNode(node_name);
1023   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1024   NodeDef* fanin_node = GetNode(fanin.node());
1025   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1026 
1027   OutputPort fanin_port(fanin_node, fanin.index());
1028 
1029   string error_msg = "";
1030   NodeDef* control_node = GetControllingFaninToAdd(
1031       node_name, {fanin_node, fanin.index()}, &error_msg);
1032   if (!error_msg.empty()) {
1033     return error_status(error_msg);
1034   }
1035   if (control_node == nullptr) {
1036     control_node = GetOrCreateIdentityConsumingSwitch(fanin_port);
1037   }
1038   AddFaninInternal(node, {control_node, Graph::kControlSlot});
1039 
1040   return Status::OK();
1041 }
1042 
RemoveRegularFaninInternal(NodeDef * node,const OutputPort & fanin)1043 bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node,
1044                                                   const OutputPort& fanin) {
1045   auto remove_input = [this, node](const OutputPort& fanin_port,
1046                                    int node_input_port, bool update_max_port) {
1047     InputPort input(node, node_input_port);
1048 
1049     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1050     fanouts_set->erase(input);
1051     if (update_max_port) {
1052       UpdateMaxRegularOutputPortForRemovedFanin(fanin_port, *fanouts_set);
1053     }
1054     return fanouts_set;
1055   };
1056 
1057   auto mutable_inputs = node->mutable_input();
1058   bool modified = false;
1059   const int num_regular_fanins =
1060       NumFanins(*node, /*include_controlling_nodes=*/false);
1061   int i;
1062   int curr_pos = 0;
1063   for (i = 0; i < num_regular_fanins; ++i) {
1064     TensorId tensor_id = ParseTensorName(node->input(i));
1065     if (tensor_id.node() == fanin.node->name() &&
1066         tensor_id.index() == fanin.port_id) {
1067       remove_input(fanin, i, /*update_max_port=*/true);
1068       modified = true;
1069     } else if (modified) {
1070       // Regular inputs will need to have their ports updated.
1071       OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1072       auto fanouts_set = remove_input(fanin_port, i, /*update_max_port=*/false);
1073       fanouts_set->insert({node, curr_pos});
1074       // Shift inputs to be retained.
1075       mutable_inputs->SwapElements(i, curr_pos);
1076       ++curr_pos;
1077     } else {
1078       // Skip inputs to be retained until first modification.
1079       ++curr_pos;
1080     }
1081   }
1082 
1083   if (modified) {
1084     const int last_regular_input_port = curr_pos - 1;
1085     if (last_regular_input_port < 0) {
1086       max_regular_input_port().erase(node);
1087     } else {
1088       max_regular_input_port()[node] = last_regular_input_port;
1089     }
1090     if (curr_pos < i) {
1091       // Remove fanins from node inputs.
1092       mutable_inputs->DeleteSubrange(curr_pos, i - curr_pos);
1093     }
1094   }
1095 
1096   return modified;
1097 }
1098 
RemoveRegularFanin(absl::string_view node_name,const TensorId & fanin)1099 Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name,
1100                                             const TensorId& fanin) {
1101   auto error_status = [node_name, fanin](absl::string_view msg) {
1102     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1103                                      fanin.ToString());
1104     return MutationError("RemoveRegularFanin", params, msg);
1105   };
1106 
1107   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1108   TF_RETURN_IF_ERROR(
1109       CheckRemovingFaninFromSelf(node_name, fanin, error_status));
1110   NodeDef* node = GetNode(node_name);
1111   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1112   NodeDef* fanin_node = GetNode(fanin.node());
1113   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1114 
1115   RemoveRegularFaninInternal(node, {fanin_node, fanin.index()});
1116   return Status::OK();
1117 }
1118 
RemoveRegularFaninByPort(absl::string_view node_name,int port)1119 Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name,
1120                                                   int port) {
1121   auto error_status = [node_name, port](absl::string_view msg) {
1122     string params =
1123         absl::Substitute("node_name='$0', port=$1", node_name, port);
1124     return MutationError("RemoveRegularFaninByPort", params, msg);
1125   };
1126 
1127   NodeDef* node = GetNode(node_name);
1128   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1129   const int last_regular_fanin_port =
1130       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1131   TF_RETURN_IF_ERROR(
1132       CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1133 
1134   TensorId tensor_id = ParseTensorName(node->input(port));
1135   OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1136   fanouts()[fanin_port].erase({node, port});
1137   auto mutable_inputs = node->mutable_input();
1138   for (int i = port + 1; i <= last_regular_fanin_port; ++i) {
1139     TensorId tensor_id = ParseTensorName(node->input(i));
1140     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1141     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1142     fanouts_set->erase({node, i});
1143     fanouts_set->insert({node, i - 1});
1144     mutable_inputs->SwapElements(i - 1, i);
1145   }
1146   const int last_node_input = node->input_size() - 1;
1147   if (last_regular_fanin_port < last_node_input) {
1148     mutable_inputs->SwapElements(last_regular_fanin_port, last_node_input);
1149   }
1150   mutable_inputs->RemoveLast();
1151 
1152   const int updated_last_regular_input_port = last_regular_fanin_port - 1;
1153   if (updated_last_regular_input_port < 0) {
1154     max_regular_input_port().erase(node);
1155   } else {
1156     max_regular_input_port()[node] = updated_last_regular_input_port;
1157   }
1158 
1159   return Status::OK();
1160 }
1161 
RemoveControllingFaninInternal(NodeDef * node,NodeDef * fanin_node)1162 bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node,
1163                                                       NodeDef* fanin_node) {
1164   for (int i = node->input_size() - 1; i >= 0; --i) {
1165     TensorId tensor_id = ParseTensorName(node->input(i));
1166     if (tensor_id.index() > Graph::kControlSlot) {
1167       break;
1168     }
1169     if (tensor_id.node() == fanin_node->name()) {
1170       fanouts()[{fanin_node, Graph::kControlSlot}].erase(
1171           {node, Graph::kControlSlot});
1172       node->mutable_input()->SwapElements(i, node->input_size() - 1);
1173       node->mutable_input()->RemoveLast();
1174       return true;
1175     }
1176   }
1177   return false;
1178 }
1179 
RemoveControllingFanin(absl::string_view node_name,absl::string_view fanin_node_name)1180 Status MutableGraphView::RemoveControllingFanin(
1181     absl::string_view node_name, absl::string_view fanin_node_name) {
1182   auto error_status = [node_name, fanin_node_name](absl::string_view msg) {
1183     string params = absl::Substitute("node_name='$0', fanin_node_name='$1'",
1184                                      node_name, fanin_node_name);
1185     return MutationError("RemoveControllingFanin", params, msg);
1186   };
1187 
1188   TF_RETURN_IF_ERROR(CheckRemovingFaninFromSelf(
1189       node_name, {fanin_node_name, Graph::kControlSlot}, error_status));
1190   NodeDef* node = GetNode(node_name);
1191   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1192   NodeDef* fanin_node = GetNode(fanin_node_name);
1193   TF_RETURN_IF_ERROR(
1194       CheckNodeExists(fanin_node_name, fanin_node, error_status));
1195 
1196   RemoveControllingFaninInternal(node, fanin_node);
1197   return Status::OK();
1198 }
1199 
RemoveAllFanins(absl::string_view node_name,bool keep_controlling_fanins)1200 Status MutableGraphView::RemoveAllFanins(absl::string_view node_name,
1201                                          bool keep_controlling_fanins) {
1202   NodeDef* node = GetNode(node_name);
1203   if (node == nullptr) {
1204     string params =
1205         absl::Substitute("node_name='$0', keep_controlling_fanins=$1",
1206                          node_name, keep_controlling_fanins);
1207     return MutationError("RemoveAllFanins", params,
1208                          NodeMissingErrorMsg(node_name));
1209   }
1210 
1211   if (node->input().empty()) {
1212     return Status::OK();
1213   }
1214 
1215   const int num_regular_fanins =
1216       NumFanins(*node, /*include_controlling_nodes=*/false);
1217   RemoveFaninsInternal(node, keep_controlling_fanins);
1218   if (keep_controlling_fanins) {
1219     if (num_regular_fanins == 0) {
1220       return Status::OK();
1221     } else if (num_regular_fanins < node->input_size()) {
1222       node->mutable_input()->DeleteSubrange(0, num_regular_fanins);
1223     } else {
1224       node->clear_input();
1225     }
1226   } else {
1227     node->clear_input();
1228   }
1229   return Status::OK();
1230 }
1231 
UpdateFanin(absl::string_view node_name,const TensorId & from_fanin,const TensorId & to_fanin)1232 Status MutableGraphView::UpdateFanin(absl::string_view node_name,
1233                                      const TensorId& from_fanin,
1234                                      const TensorId& to_fanin) {
1235   auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) {
1236     string params =
1237         absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'",
1238                          node_name, from_fanin.ToString(), to_fanin.ToString());
1239     return MutationError("UpdateFanin", params, msg);
1240   };
1241 
1242   TF_RETURN_IF_ERROR(CheckFaninIsValid(from_fanin, error_status));
1243   TF_RETURN_IF_ERROR(CheckFaninIsValid(to_fanin, error_status));
1244   NodeDef* node = GetNode(node_name);
1245   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1246   NodeDef* from_fanin_node = GetNode(from_fanin.node());
1247   TF_RETURN_IF_ERROR(
1248       CheckNodeExists(from_fanin.node(), from_fanin_node, error_status));
1249   NodeDef* to_fanin_node = GetNode(to_fanin.node());
1250   TF_RETURN_IF_ERROR(
1251       CheckNodeExists(to_fanin.node(), to_fanin_node, error_status));
1252 
1253   // When replacing a non control dependency fanin with a control dependency, or
1254   // vice versa, remove and add, so ports can be updated properly in fanout(s).
1255   bool to_fanin_is_control = IsTensorIdControlling(to_fanin);
1256   if (to_fanin_is_control && IsSwitch(*to_fanin_node)) {
1257     // Can't add Switch node as a control dependency.
1258     return error_status(
1259         absl::Substitute("can't update to fanin '$0' as it will become a "
1260                          "Switch control dependency",
1261                          to_fanin.ToString()));
1262   }
1263   if (node_name == from_fanin.node() || node_name == to_fanin.node()) {
1264     return error_status("can't update fanin to or from self");
1265   }
1266 
1267   if (from_fanin == to_fanin) {
1268     return Status::OK();
1269   }
1270 
1271   bool from_fanin_is_control = IsTensorIdControlling(from_fanin);
1272   if (from_fanin_is_control || to_fanin_is_control) {
1273     bool modified = false;
1274     if (from_fanin_is_control) {
1275       modified |= RemoveControllingFaninInternal(node, from_fanin_node);
1276     } else {
1277       modified |= RemoveRegularFaninInternal(
1278           node, {from_fanin_node, from_fanin.index()});
1279     }
1280     if (modified) {
1281       AddFaninInternal(node, {to_fanin_node, to_fanin.index()});
1282     }
1283     return Status::OK();
1284   }
1285 
1286   // In place mutation of regular fanins, requires no shifting of ports.
1287   string to_fanin_string = TensorIdToString(to_fanin);
1288   const int num_regular_fanins =
1289       NumFanins(*node, /*include_controlling_nodes=*/false);
1290   bool modified = false;
1291   absl::flat_hash_set<InputPort>* from_fanin_port_fanouts = nullptr;
1292   absl::flat_hash_set<InputPort>* to_fanin_port_fanouts = nullptr;
1293   for (int i = 0; i < num_regular_fanins; ++i) {
1294     if (ParseTensorName(node->input(i)) == from_fanin) {
1295       InputPort input(node, i);
1296       if (from_fanin_port_fanouts == nullptr) {
1297         OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1298         from_fanin_port_fanouts = &fanouts()[from_fanin_port];
1299       }
1300       from_fanin_port_fanouts->erase(input);
1301 
1302       if (to_fanin_port_fanouts == nullptr) {
1303         OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
1304         to_fanin_port_fanouts = &fanouts()[to_fanin_port];
1305       }
1306       to_fanin_port_fanouts->insert(input);
1307 
1308       node->set_input(i, to_fanin_string);
1309       modified = true;
1310     }
1311   }
1312 
1313   // Dedup control dependencies and update max regular output ports.
1314   if (modified) {
1315     UpdateMaxRegularOutputPortForRemovedFanin(
1316         {from_fanin_node, from_fanin.index()}, *from_fanin_port_fanouts);
1317     if (max_regular_output_port()[to_fanin_node] < to_fanin.index()) {
1318       max_regular_output_port()[to_fanin_node] = to_fanin.index();
1319     }
1320     if (CanDedupControlWithRegularInput(*this, *to_fanin_node)) {
1321       RemoveControllingFaninInternal(node, to_fanin_node);
1322     }
1323   }
1324 
1325   return Status::OK();
1326 }
1327 
UpdateRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)1328 Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name,
1329                                                   int port,
1330                                                   const TensorId& fanin) {
1331   auto error_status = [node_name, port, fanin](absl::string_view msg) {
1332     string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
1333                                      node_name, port, fanin.ToString());
1334     return MutationError("UpdateRegularFaninByPort", params, msg);
1335   };
1336 
1337   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1338   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1339   NodeDef* node = GetNode(node_name);
1340   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1341   const int last_regular_fanin_port =
1342       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1343   TF_RETURN_IF_ERROR(
1344       CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1345   NodeDef* fanin_node = GetNode(fanin.node());
1346   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1347 
1348   TensorId tensor_id = ParseTensorName(node->input(port));
1349   if (tensor_id == fanin) {
1350     return Status::OK();
1351   }
1352 
1353   InputPort input(node, port);
1354   OutputPort from_fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1355   absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin_port];
1356   from_fanouts->erase(input);
1357   UpdateMaxRegularOutputPortForRemovedFanin(from_fanin_port, *from_fanouts);
1358 
1359   OutputPort to_fanin_port(fanin_node, fanin.index());
1360   fanouts()[to_fanin_port].insert(input);
1361   UpdateMaxRegularOutputPortForAddedFanin(to_fanin_port);
1362 
1363   node->set_input(port, TensorIdToString(fanin));
1364 
1365   if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
1366     RemoveControllingFaninInternal(node, fanin_node);
1367   }
1368 
1369   return Status::OK();
1370 }
1371 
SwapRegularFaninsByPorts(absl::string_view node_name,int from_port,int to_port)1372 Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name,
1373                                                   int from_port, int to_port) {
1374   auto error_status = [node_name, from_port, to_port](absl::string_view msg) {
1375     string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2",
1376                                      node_name, from_port, to_port);
1377     return MutationError("SwapRegularFaninsByPorts", params, msg);
1378   };
1379 
1380   NodeDef* node = GetNode(node_name);
1381   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1382   const int last_regular_fanin_port =
1383       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1384   TF_RETURN_IF_ERROR(CheckPortRange(from_port, /*min=*/0,
1385                                     last_regular_fanin_port, error_status));
1386   TF_RETURN_IF_ERROR(CheckPortRange(to_port, /*min=*/0, last_regular_fanin_port,
1387                                     error_status));
1388 
1389   if (from_port == to_port) {
1390     return Status::OK();
1391   }
1392   TensorId from_fanin = ParseTensorName(node->input(from_port));
1393   TensorId to_fanin = ParseTensorName(node->input(to_port));
1394   if (from_fanin == to_fanin) {
1395     return Status::OK();
1396   }
1397 
1398   InputPort from_input(node, from_port);
1399   InputPort to_input(node, to_port);
1400   NodeDef* from_fanin_node = GetNode(from_fanin.node());
1401   absl::flat_hash_set<InputPort>* from_fanouts =
1402       &fanouts()[{from_fanin_node, from_fanin.index()}];
1403   from_fanouts->erase(from_input);
1404   from_fanouts->insert(to_input);
1405   NodeDef* to_fanin_node = GetNode(to_fanin.node());
1406   absl::flat_hash_set<InputPort>* to_fanouts =
1407       &fanouts()[{to_fanin_node, to_fanin.index()}];
1408   to_fanouts->erase(to_input);
1409   to_fanouts->insert(from_input);
1410 
1411   node->mutable_input()->SwapElements(from_port, to_port);
1412 
1413   return Status::OK();
1414 }
1415 
UpdateAllRegularFaninsToControlling(absl::string_view node_name)1416 Status MutableGraphView::UpdateAllRegularFaninsToControlling(
1417     absl::string_view node_name) {
1418   auto error_status = [node_name](absl::string_view msg) {
1419     string params = absl::Substitute("node_name='$0'", node_name);
1420     return MutationError("UpdateAllRegularFaninsToControlling", params, msg);
1421   };
1422 
1423   NodeDef* node = GetNode(node_name);
1424   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1425 
1426   const int num_regular_fanins =
1427       NumFanins(*node, /*include_controlling_nodes=*/false);
1428   std::vector<OutputPort> regular_fanins;
1429   regular_fanins.reserve(num_regular_fanins);
1430   std::vector<NodeDef*> controlling_fanins;
1431   controlling_fanins.reserve(num_regular_fanins);
1432 
1433   // Get all regular fanins and derive controlling fanins.
1434   for (int i = 0; i < num_regular_fanins; ++i) {
1435     TensorId tensor_id = ParseTensorName(node->input(i));
1436     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1437 
1438     string error_msg = "";
1439     NodeDef* control_node =
1440         GetControllingFaninToAdd(node_name, fanin_port, &error_msg);
1441     if (!error_msg.empty()) {
1442       return error_status(error_msg);
1443     }
1444 
1445     regular_fanins.push_back(fanin_port);
1446     controlling_fanins.push_back(control_node);
1447   }
1448 
1449   // Replace regular fanins with controlling fanins and dedup.
1450   int pos = 0;
1451   InputPort input_port(node, Graph::kControlSlot);
1452   absl::flat_hash_set<absl::string_view> controls;
1453   for (int i = 0; i < num_regular_fanins; ++i) {
1454     OutputPort fanin_port = regular_fanins[i];
1455     NodeDef* control = controlling_fanins[i];
1456     if (control == nullptr) {
1457       control = GetOrCreateIdentityConsumingSwitch(fanin_port);
1458     }
1459     fanouts()[fanin_port].erase({node, i});
1460     if (controls.contains(control->name())) {
1461       continue;
1462     }
1463     controls.insert(control->name());
1464     node->set_input(pos, AsControlDependency(control->name()));
1465     fanouts()[{control, Graph::kControlSlot}].insert(input_port);
1466     ++pos;
1467   }
1468 
1469   // Shift existing controlling fanins and dedup.
1470   for (int i = num_regular_fanins; i < node->input_size(); ++i) {
1471     TensorId tensor_id = ParseTensorName(node->input(i));
1472     if (controls.contains(tensor_id.node())) {
1473       continue;
1474     }
1475     controls.insert(tensor_id.node());
1476     node->mutable_input()->SwapElements(pos, i);
1477     ++pos;
1478   }
1479 
1480   // Remove duplicate controls and leftover regular fanins.
1481   node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos);
1482   max_regular_input_port().erase(node);
1483 
1484   return Status::OK();
1485 }
1486 
CheckNodesCanBeDeleted(const absl::flat_hash_set<string> & nodes_to_delete)1487 Status MutableGraphView::CheckNodesCanBeDeleted(
1488     const absl::flat_hash_set<string>& nodes_to_delete) {
1489   std::vector<string> missing_nodes;
1490   std::vector<string> nodes_with_fanouts;
1491   for (const string& node_name_to_delete : nodes_to_delete) {
1492     NodeDef* node = GetNode(node_name_to_delete);
1493     if (node == nullptr) {
1494       // Can't delete missing node.
1495       missing_nodes.push_back(node_name_to_delete);
1496       continue;
1497     }
1498     const int max_port = gtl::FindWithDefault(max_regular_output_port(), node,
1499                                               Graph::kControlSlot);
1500     for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1501       auto it = fanouts().find({node, i});
1502       bool has_retained_fanout = false;
1503       if (it != fanouts().end()) {
1504         for (const auto& fanout : it->second) {
1505           // Check if fanouts are of nodes to be deleted, and if so, they can be
1506           // ignored, as they will be removed also.
1507           if (!nodes_to_delete.contains(fanout.node->name())) {
1508             // Removing node will leave graph in an invalid state.
1509             has_retained_fanout = true;
1510             break;
1511           }
1512         }
1513       }
1514       if (has_retained_fanout) {
1515         nodes_with_fanouts.push_back(node_name_to_delete);
1516         break;
1517       }
1518     }
1519   }
1520 
1521   // Error message can get quite long, so we only show the first 5 node names.
1522   auto sort_and_sample = [](std::vector<string>* s) {
1523     constexpr int kMaxNodeNames = 5;
1524     std::sort(s->begin(), s->end());
1525     if (s->size() > kMaxNodeNames) {
1526       return absl::StrCat(
1527           absl::StrJoin(s->begin(), s->begin() + kMaxNodeNames, ", "), ", ...");
1528     }
1529     return absl::StrJoin(*s, ", ");
1530   };
1531 
1532   if (!missing_nodes.empty()) {
1533     VLOG(2) << absl::Substitute("Attempting to delete missing node(s) [$0].",
1534                                 sort_and_sample(&missing_nodes));
1535   }
1536   if (!nodes_with_fanouts.empty()) {
1537     std::vector<string> input_node_names(nodes_to_delete.begin(),
1538                                          nodes_to_delete.end());
1539     string params = absl::Substitute("nodes_to_delete={$0}",
1540                                      sort_and_sample(&input_node_names));
1541     string error_msg =
1542         absl::Substitute("can't delete node(s) with retained fanouts(s) [$0]",
1543                          sort_and_sample(&nodes_with_fanouts));
1544     return MutationError("DeleteNodes", params, error_msg);
1545   }
1546 
1547   return Status::OK();
1548 }
1549 
DeleteNodes(const absl::flat_hash_set<string> & nodes_to_delete)1550 Status MutableGraphView::DeleteNodes(
1551     const absl::flat_hash_set<string>& nodes_to_delete) {
1552   TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete));
1553 
1554   // Find nodes in internal state and delete.
1555   for (const string& node_name_to_delete : nodes_to_delete) {
1556     NodeDef* node = GetNode(node_name_to_delete);
1557     if (node != nullptr) {
1558       RemoveFaninsInternal(node, /*keep_controlling_fanins=*/false);
1559       RemoveFanoutsInternal(node);
1560     }
1561   }
1562   for (const string& node_name_to_delete : nodes_to_delete) {
1563     nodes().erase(node_name_to_delete);
1564   }
1565 
1566   // Find nodes in graph and delete by partitioning into nodes to retain and
1567   // nodes to delete based on input set of nodes to delete by name.
1568   // TODO(lyandy): Use a node name->idx hashmap if this is a performance
1569   // bottleneck.
1570   int pos = 0;
1571   const int last_idx = graph()->node_size() - 1;
1572   int last_pos = last_idx;
1573   while (pos <= last_pos) {
1574     if (nodes_to_delete.contains(graph()->node(pos).name())) {
1575       graph()->mutable_node()->SwapElements(pos, last_pos);
1576       --last_pos;
1577     } else {
1578       ++pos;
1579     }
1580   }
1581   if (last_pos < last_idx) {
1582     graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
1583   }
1584 
1585   return Status::OK();
1586 }
1587 
RemoveFaninsInternal(NodeDef * deleted_node,bool keep_controlling_fanins)1588 void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
1589                                             bool keep_controlling_fanins) {
1590   for (int i = 0; i < deleted_node->input_size(); ++i) {
1591     TensorId tensor_id = ParseTensorName(deleted_node->input(i));
1592     bool is_control = IsTensorIdControlling(tensor_id);
1593     if (keep_controlling_fanins && is_control) {
1594       break;
1595     }
1596     OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
1597 
1598     InputPort input;
1599     input.node = deleted_node;
1600     input.port_id = is_control ? Graph::kControlSlot : i;
1601 
1602     auto it = fanouts().find(fanin);
1603     if (it != fanouts().end()) {
1604       absl::flat_hash_set<InputPort>* fanouts_set = &it->second;
1605       fanouts_set->erase(input);
1606       UpdateMaxRegularOutputPortForRemovedFanin(fanin, *fanouts_set);
1607     }
1608   }
1609   max_regular_input_port().erase(deleted_node);
1610 }
1611 
RemoveFanoutsInternal(NodeDef * deleted_node)1612 void MutableGraphView::RemoveFanoutsInternal(NodeDef* deleted_node) {
1613   const int max_port =
1614       gtl::FindWithDefault(max_regular_output_port(), deleted_node, -1);
1615   for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1616     fanouts().erase({deleted_node, i});
1617   }
1618   max_regular_output_port().erase(deleted_node);
1619 }
1620 
1621 }  // end namespace grappler
1622 }  // end namespace tensorflow
1623