• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/mutable_graph_view.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/grappler/op_types.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 namespace grappler {
40 
41 namespace {
42 
IsTensorIdPortValid(const TensorId & tensor_id)43 bool IsTensorIdPortValid(const TensorId& tensor_id) {
44   return tensor_id.index() >= Graph::kControlSlot;
45 }
46 
IsTensorIdRegular(const TensorId & tensor_id)47 bool IsTensorIdRegular(const TensorId& tensor_id) {
48   return tensor_id.index() > Graph::kControlSlot;
49 }
50 
IsTensorIdControlling(const TensorId & tensor_id)51 bool IsTensorIdControlling(const TensorId& tensor_id) {
52   return tensor_id.index() == Graph::kControlSlot;
53 }
54 
IsOutputPortControlling(const MutableGraphView::OutputPort & port)55 bool IsOutputPortControlling(const MutableGraphView::OutputPort& port) {
56   return port.port_id == Graph::kControlSlot;
57 }
58 
59 // Determines if node is an Identity where it's first regular input is a Switch
60 // node.
IsIdentityConsumingSwitch(const MutableGraphView & graph,const NodeDef & node)61 bool IsIdentityConsumingSwitch(const MutableGraphView& graph,
62                                const NodeDef& node) {
63   if ((IsIdentity(node) || IsIdentityNSingleInput(node)) &&
64       node.input_size() > 0) {
65     TensorId tensor_id = ParseTensorName(node.input(0));
66     if (IsTensorIdControlling(tensor_id)) {
67       return false;
68     }
69 
70     NodeDef* input_node = graph.GetNode(tensor_id.node());
71     return IsSwitch(*input_node);
72   }
73   return false;
74 }
75 
76 // Determines if node input can be deduped by regular inputs when used as a
77 // control dependency. Specifically, if a node is an Identity that leads to a
78 // Switch node, when used as a control dependency, that control dependency
79 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,const NodeDef & control_node)80 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
81                                      const NodeDef& control_node) {
82   return !IsIdentityConsumingSwitch(graph, control_node);
83 }
84 
85 // Determines if node input can be deduped by regular inputs when used as a
86 // control dependency. Specifically, if a node is an Identity that leads to a
87 // Switch node, when used as a control dependency, that control dependency
88 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,absl::string_view control_node_name)89 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
90                                      absl::string_view control_node_name) {
91   NodeDef* control_node = graph.GetNode(control_node_name);
92   if (control_node == nullptr) {
93     return false;
94   }
95   return CanDedupControlWithRegularInput(graph, *control_node);
96 }
97 
HasRegularFaninNode(const MutableGraphView & graph,const NodeDef & node,absl::string_view fanin_node_name)98 bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node,
99                          absl::string_view fanin_node_name) {
100   const int num_regular_fanins =
101       graph.NumFanins(node, /*include_controlling_nodes=*/false);
102   for (int i = 0; i < num_regular_fanins; ++i) {
103     if (ParseTensorName(node.input(i)).node() == fanin_node_name) {
104       return true;
105     }
106   }
107   return false;
108 }
109 
110 using FanoutsMap =
111     absl::flat_hash_map<MutableGraphView::OutputPort,
112                         absl::flat_hash_set<MutableGraphView::InputPort>>;
113 
SwapControlledFanoutInputs(const MutableGraphView & graph,const FanoutsMap::iterator & control_fanouts,absl::string_view to_node_name)114 void SwapControlledFanoutInputs(const MutableGraphView& graph,
115                                 const FanoutsMap::iterator& control_fanouts,
116                                 absl::string_view to_node_name) {
117   absl::string_view from_node_name(control_fanouts->first.node->name());
118   string control = TensorIdToString({to_node_name, Graph::kControlSlot});
119   for (const auto& control_fanout : control_fanouts->second) {
120     const int start = graph.NumFanins(*control_fanout.node,
121                                       /*include_controlling_nodes=*/false);
122     for (int i = start; i < control_fanout.node->input_size(); ++i) {
123       TensorId tensor_id = ParseTensorName(control_fanout.node->input(i));
124       if (tensor_id.node() == from_node_name) {
125         control_fanout.node->set_input(i, control);
126         break;
127       }
128     }
129   }
130 }
131 
SwapRegularFanoutInputs(FanoutsMap * fanouts,NodeDef * from_node,absl::string_view to_node_name,int max_port)132 void SwapRegularFanoutInputs(FanoutsMap* fanouts, NodeDef* from_node,
133                              absl::string_view to_node_name, int max_port) {
134   MutableGraphView::OutputPort port;
135   port.node = from_node;
136   for (int i = 0; i <= max_port; ++i) {
137     port.port_id = i;
138     auto it = fanouts->find(port);
139     if (it == fanouts->end()) {
140       continue;
141     }
142     string input = TensorIdToString({to_node_name, i});
143     for (const auto& fanout : it->second) {
144       fanout.node->set_input(fanout.port_id, input);
145     }
146   }
147 }
148 
149 using MaxOutputPortsMap = absl::flat_hash_map<const NodeDef*, int>;
150 
SwapFanoutInputs(const MutableGraphView & graph,FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)151 void SwapFanoutInputs(const MutableGraphView& graph, FanoutsMap* fanouts,
152                       MaxOutputPortsMap* max_output_ports, NodeDef* from_node,
153                       NodeDef* to_node) {
154   auto from_control_fanouts = fanouts->find({from_node, Graph::kControlSlot});
155   if (from_control_fanouts != fanouts->end()) {
156     SwapControlledFanoutInputs(graph, from_control_fanouts, to_node->name());
157   }
158   auto to_control_fanouts = fanouts->find({to_node, Graph::kControlSlot});
159   if (to_control_fanouts != fanouts->end()) {
160     SwapControlledFanoutInputs(graph, to_control_fanouts, from_node->name());
161   }
162   auto from_max_port = max_output_ports->find(from_node);
163   if (from_max_port != max_output_ports->end()) {
164     SwapRegularFanoutInputs(fanouts, from_node, to_node->name(),
165                             from_max_port->second);
166   }
167   auto to_max_port = max_output_ports->find(to_node);
168   if (to_max_port != max_output_ports->end()) {
169     SwapRegularFanoutInputs(fanouts, to_node, from_node->name(),
170                             to_max_port->second);
171   }
172 }
173 
SwapFanoutsMapValues(FanoutsMap * fanouts,const MutableGraphView::OutputPort & from_port,const FanoutsMap::iterator & from_fanouts,const MutableGraphView::OutputPort & to_port,const FanoutsMap::iterator & to_fanouts)174 void SwapFanoutsMapValues(FanoutsMap* fanouts,
175                           const MutableGraphView::OutputPort& from_port,
176                           const FanoutsMap::iterator& from_fanouts,
177                           const MutableGraphView::OutputPort& to_port,
178                           const FanoutsMap::iterator& to_fanouts) {
179   const bool from_exists = from_fanouts != fanouts->end();
180   const bool to_exists = to_fanouts != fanouts->end();
181 
182   if (from_exists && to_exists) {
183     std::swap(from_fanouts->second, to_fanouts->second);
184   } else if (from_exists) {
185     fanouts->emplace(to_port, std::move(from_fanouts->second));
186     fanouts->erase(from_port);
187   } else if (to_exists) {
188     fanouts->emplace(from_port, std::move(to_fanouts->second));
189     fanouts->erase(to_port);
190   }
191 }
192 
SwapRegularFanoutsAndMaxPortValues(FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)193 void SwapRegularFanoutsAndMaxPortValues(FanoutsMap* fanouts,
194                                         MaxOutputPortsMap* max_output_ports,
195                                         NodeDef* from_node, NodeDef* to_node) {
196   auto from_max_port = max_output_ports->find(from_node);
197   auto to_max_port = max_output_ports->find(to_node);
198   bool from_exists = from_max_port != max_output_ports->end();
199   bool to_exists = to_max_port != max_output_ports->end();
200 
201   auto forward_fanouts = [fanouts](NodeDef* from, NodeDef* to, int start,
202                                    int end) {
203     for (int i = start; i <= end; ++i) {
204       MutableGraphView::OutputPort from_port(from, i);
205       auto from_fanouts = fanouts->find(from_port);
206       if (from_fanouts != fanouts->end()) {
207         MutableGraphView::OutputPort to_port(to, i);
208         fanouts->emplace(to_port, std::move(from_fanouts->second));
209         fanouts->erase(from_port);
210       }
211     }
212   };
213 
214   if (from_exists && to_exists) {
215     const int from = from_max_port->second;
216     const int to = to_max_port->second;
217     const int shared = std::min(from, to);
218     for (int i = 0; i <= shared; ++i) {
219       MutableGraphView::OutputPort from_port(from_node, i);
220       auto from_fanouts = fanouts->find(from_port);
221       MutableGraphView::OutputPort to_port(to_node, i);
222       auto to_fanouts = fanouts->find(to_port);
223       SwapFanoutsMapValues(fanouts, from_port, from_fanouts, to_port,
224                            to_fanouts);
225     }
226     if (to > from) {
227       forward_fanouts(to_node, from_node, shared + 1, to);
228     } else if (from > to) {
229       forward_fanouts(from_node, to_node, shared + 1, from);
230     }
231 
232     std::swap(from_max_port->second, to_max_port->second);
233   } else if (from_exists) {
234     forward_fanouts(from_node, to_node, 0, from_max_port->second);
235 
236     max_output_ports->emplace(to_node, from_max_port->second);
237     max_output_ports->erase(from_node);
238   } else if (to_exists) {
239     forward_fanouts(to_node, from_node, 0, to_max_port->second);
240 
241     max_output_ports->emplace(from_node, to_max_port->second);
242     max_output_ports->erase(to_node);
243   }
244 }
245 
HasFanoutValue(const FanoutsMap & fanouts,const FanoutsMap::iterator & it)246 bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) {
247   return it != fanouts.end() && !it->second.empty();
248 }
249 
MutationError(absl::string_view function_name,absl::string_view params,absl::string_view msg)250 Status MutationError(absl::string_view function_name, absl::string_view params,
251                      absl::string_view msg) {
252   return errors::InvalidArgument(absl::Substitute(
253       "MutableGraphView::$0($1) error: $2.", function_name, params, msg));
254 }
255 
256 using ErrorHandler = std::function<Status(absl::string_view)>;
257 
UpdateFanoutsError(absl::string_view from_node_name,absl::string_view to_node_name)258 ErrorHandler UpdateFanoutsError(absl::string_view from_node_name,
259                                 absl::string_view to_node_name) {
260   return [from_node_name, to_node_name](absl::string_view msg) {
261     string params = absl::Substitute("from_node_name='$0', to_node_name='$1'",
262                                      from_node_name, to_node_name);
263     return MutationError("UpdateFanouts", params, msg);
264   };
265 }
266 
CheckFaninIsRegular(const TensorId & fanin,ErrorHandler handler)267 Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) {
268   if (!IsTensorIdRegular(fanin)) {
269     return handler(absl::Substitute("fanin '$0' must be a regular tensor id",
270                                     fanin.ToString()));
271   }
272   return Status::OK();
273 }
274 
CheckFaninIsValid(const TensorId & fanin,ErrorHandler handler)275 Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) {
276   if (!IsTensorIdPortValid(fanin)) {
277     return handler(absl::Substitute("fanin '$0' must be a valid tensor id",
278                                     fanin.ToString()));
279   }
280   return Status::OK();
281 }
282 
CheckAddingFaninToSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)283 Status CheckAddingFaninToSelf(absl::string_view node_name,
284                               const TensorId& fanin, ErrorHandler handler) {
285   if (node_name == fanin.node()) {
286     return handler(
287         absl::Substitute("can't add fanin '$0' to self", fanin.ToString()));
288   }
289   return Status::OK();
290 }
291 
CheckRemovingFaninFromSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)292 Status CheckRemovingFaninFromSelf(absl::string_view node_name,
293                                   const TensorId& fanin, ErrorHandler handler) {
294   if (node_name == fanin.node()) {
295     return handler(absl::Substitute("can't remove fanin '$0' from self",
296                                     fanin.ToString()));
297   }
298   return Status::OK();
299 }
300 
NodeMissingErrorMsg(absl::string_view node_name)301 string NodeMissingErrorMsg(absl::string_view node_name) {
302   return absl::Substitute("node '$0' was not found", node_name);
303 }
304 
CheckNodeExists(absl::string_view node_name,NodeDef * node,ErrorHandler handler)305 Status CheckNodeExists(absl::string_view node_name, NodeDef* node,
306                        ErrorHandler handler) {
307   if (node == nullptr) {
308     return handler(NodeMissingErrorMsg(node_name));
309   }
310   return Status::OK();
311 }
312 
CheckPortRange(int port,int min,int max,ErrorHandler handler)313 Status CheckPortRange(int port, int min, int max, ErrorHandler handler) {
314   if (port < min || port > max) {
315     if (max < min) {
316       return handler("no available ports as node has no regular fanins");
317     }
318     return handler(
319         absl::Substitute("port must be in range [$0, $1]", min, max));
320   }
321   return Status::OK();
322 }
323 
SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name)324 string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) {
325   return absl::Substitute(
326       "can't swap node name '$0' as it will become a Switch control dependency",
327       node_name);
328 }
329 
GeneratedNameForIdentityConsumingSwitch(const MutableGraphView::OutputPort & fanin)330 string GeneratedNameForIdentityConsumingSwitch(
331     const MutableGraphView::OutputPort& fanin) {
332   return AddPrefixToNodeName(
333       absl::StrCat(fanin.node->name(), "_", fanin.port_id),
334       kMutableGraphViewCtrl);
335 }
336 
337 }  // namespace
338 
AddAndDedupFanouts(NodeDef * node)339 void MutableGraphView::AddAndDedupFanouts(NodeDef* node) {
340   // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
341   // exist, and all regular fanins come before controlling fanins.
342   absl::flat_hash_set<absl::string_view> fanins;
343   absl::flat_hash_set<absl::string_view> controlling_fanins;
344   int max_input_port = -1;
345   int pos = 0;
346   const int last_idx = node->input_size() - 1;
347   int last_pos = last_idx;
348   while (pos <= last_pos) {
349     TensorId tensor_id = ParseTensorName(node->input(pos));
350     absl::string_view input_node_name = tensor_id.node();
351     bool is_control_input = IsTensorIdControlling(tensor_id);
352     bool can_dedup_control_with_regular_input =
353         CanDedupControlWithRegularInput(*this, input_node_name);
354     bool can_dedup_control =
355         is_control_input && (can_dedup_control_with_regular_input ||
356                              controlling_fanins.contains(input_node_name));
357     if (!gtl::InsertIfNotPresent(&fanins, input_node_name) &&
358         can_dedup_control) {
359       node->mutable_input()->SwapElements(pos, last_pos);
360       --last_pos;
361     } else {
362       OutputPort output(nodes()[input_node_name], tensor_id.index());
363 
364       if (is_control_input) {
365         fanouts()[output].emplace(node, Graph::kControlSlot);
366       } else {
367         max_input_port = pos;
368         max_regular_output_port()[output.node] =
369             std::max(max_regular_output_port()[output.node], output.port_id);
370         fanouts()[output].emplace(node, pos);
371       }
372       ++pos;
373     }
374     if (is_control_input) {
375       controlling_fanins.insert(input_node_name);
376     }
377   }
378 
379   if (last_pos < last_idx) {
380     node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
381   }
382 
383   if (max_input_port > -1) {
384     max_regular_input_port()[node] = max_input_port;
385   }
386 }
387 
UpdateMaxRegularOutputPortForRemovedFanin(const OutputPort & fanin,const absl::flat_hash_set<InputPort> & fanin_fanouts)388 void MutableGraphView::UpdateMaxRegularOutputPortForRemovedFanin(
389     const OutputPort& fanin,
390     const absl::flat_hash_set<InputPort>& fanin_fanouts) {
391   int max_port = max_regular_output_port()[fanin.node];
392   if (!fanin_fanouts.empty() || max_port != fanin.port_id) {
393     return;
394   }
395   bool updated_max_port = false;
396   for (int i = fanin.port_id - 1; i >= 0; --i) {
397     OutputPort fanin_port(fanin.node, i);
398     if (!fanouts()[fanin_port].empty()) {
399       max_regular_output_port()[fanin.node] = i;
400       updated_max_port = true;
401       break;
402     }
403   }
404   if (!updated_max_port) {
405     max_regular_output_port().erase(fanin.node);
406   }
407 }
408 
UpdateMaxRegularOutputPortForAddedFanin(const OutputPort & fanin)409 void MutableGraphView::UpdateMaxRegularOutputPortForAddedFanin(
410     const OutputPort& fanin) {
411   if (max_regular_output_port()[fanin.node] < fanin.port_id) {
412     max_regular_output_port()[fanin.node] = fanin.port_id;
413   }
414 }
415 
416 const absl::flat_hash_set<MutableGraphView::InputPort>&
GetFanout(const GraphView::OutputPort & port) const417 MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
418   return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
419                                                 port.port_id));
420 }
421 
GetFanin(const GraphView::InputPort & port) const422 absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
423     const GraphView::InputPort& port) const {
424   return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
425                                               port.port_id));
426 }
427 
GetRegularFanin(const GraphView::InputPort & port) const428 const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
429     const GraphView::InputPort& port) const {
430   return GetRegularFanin(MutableGraphView::InputPort(
431       const_cast<NodeDef*>(port.node), port.port_id));
432 }
433 
AddNode(NodeDef && node)434 NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
435   auto* node_in_graph = graph()->add_node();
436   *node_in_graph = std::move(node);
437 
438   AddUniqueNodeOrDie(node_in_graph);
439 
440   AddAndDedupFanouts(node_in_graph);
441   return node_in_graph;
442 }
443 
AddSubgraph(GraphDef && subgraph)444 Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
445   // 1. Add all new functions and check that functions with the same name
446   // have identical definition.
447   const int function_size = subgraph.library().function_size();
448   if (function_size > 0) {
449     absl::flat_hash_map<absl::string_view, const FunctionDef*> graph_fdefs;
450     for (const FunctionDef& fdef : graph()->library().function()) {
451       graph_fdefs.emplace(fdef.signature().name(), &fdef);
452     }
453 
454     for (FunctionDef& fdef : *subgraph.mutable_library()->mutable_function()) {
455       const auto graph_fdef = graph_fdefs.find(fdef.signature().name());
456 
457       if (graph_fdef == graph_fdefs.end()) {
458         VLOG(3) << "Add new function definition: " << fdef.signature().name();
459         graph()->mutable_library()->add_function()->Swap(&fdef);
460       } else {
461         if (!FunctionDefsEqual(fdef, *graph_fdef->second)) {
462           return MutationError(
463               "AddSubgraph",
464               absl::Substitute("function_size=$0", function_size),
465               absl::StrCat(
466                   "Found different function definition with the same name: ",
467                   fdef.signature().name()));
468         }
469       }
470     }
471   }
472 
473   // 2. Add all nodes to the underlying graph.
474   int node_size_before = graph()->node_size();
475 
476   for (NodeDef& node : *subgraph.mutable_node()) {
477     auto* node_in_graph = graph()->add_node();
478     node_in_graph->Swap(&node);
479     TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
480   }
481 
482   // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
483   // fanins actually exists in the graph, and there is already TODO for that.
484 
485   for (int i = node_size_before; i < graph()->node_size(); ++i) {
486     NodeDef* node = graph()->mutable_node(i);
487     AddAndDedupFanouts(node);
488   }
489 
490   return Status::OK();
491 }
492 
UpdateNode(absl::string_view node_name,absl::string_view op,absl::string_view device,absl::Span<const std::pair<string,AttrValue>> attrs)493 Status MutableGraphView::UpdateNode(
494     absl::string_view node_name, absl::string_view op, absl::string_view device,
495     absl::Span<const std::pair<string, AttrValue>> attrs) {
496   auto error_status = [node_name, op, device, attrs](absl::string_view msg) {
497     std::vector<string> attr_strs;
498     attr_strs.reserve(attrs.size());
499     for (const auto& attr : attrs) {
500       string attr_str = absl::Substitute("('$0', $1)", attr.first,
501                                          attr.second.ShortDebugString());
502       attr_strs.push_back(attr_str);
503     }
504     string params =
505         absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}",
506                          node_name, op, device, absl::StrJoin(attr_strs, ", "));
507     return MutationError("UpdateNodeOp", params, msg);
508   };
509 
510   NodeDef* node = GetNode(node_name);
511   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
512 
513   MutableGraphView::OutputPort control_port(node, Graph::kControlSlot);
514   auto control_fanouts = GetFanout(control_port);
515   if (op == "Switch" && !control_fanouts.empty()) {
516     return error_status(
517         "can't change node op to Switch when node drives a control dependency "
518         "(alternatively, we could add the identity node needed, but it seems "
519         "like an unlikely event and probably a mistake)");
520   }
521 
522   if (node->device() != device) {
523     node->set_device(string(device));
524   }
525   node->mutable_attr()->clear();
526   for (const auto& attr : attrs) {
527     (*node->mutable_attr())[attr.first] = attr.second;
528   }
529 
530   if (node->op() == op) {
531     return Status::OK();
532   }
533 
534   node->set_op(string(op));
535 
536   if (CanDedupControlWithRegularInput(*this, *node)) {
537     for (const auto& control_fanout : control_fanouts) {
538       if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) {
539         RemoveControllingFaninInternal(control_fanout.node, node);
540       }
541     }
542   }
543 
544   return Status::OK();
545 }
546 
UpdateNodeName(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)547 Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name,
548                                         absl::string_view to_node_name,
549                                         bool update_fanouts) {
550   auto error_status = [from_node_name, to_node_name,
551                        update_fanouts](absl::string_view msg) {
552     string params = absl::Substitute(
553         "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
554         from_node_name, to_node_name, update_fanouts);
555     return MutationError("UpdateNodeName", params, msg);
556   };
557 
558   NodeDef* node = GetNode(from_node_name);
559   TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status));
560 
561   if (node->name() == to_node_name) {
562     return Status::OK();
563   }
564   if (HasNode(to_node_name)) {
565     return error_status(
566         "can't update node name because new node name is in use");
567   }
568   auto max_output_port = max_regular_output_port().find(node);
569   const bool has_max_output_port =
570       max_output_port != max_regular_output_port().end();
571   auto control_fanouts = fanouts().find({node, Graph::kControlSlot});
572 
573   if (update_fanouts) {
574     SwapControlledFanoutInputs(*this, control_fanouts, to_node_name);
575     if (has_max_output_port) {
576       SwapRegularFanoutInputs(&fanouts(), node, to_node_name,
577                               max_output_port->second);
578     }
579   } else if (has_max_output_port ||
580              HasFanoutValue(fanouts(), control_fanouts)) {
581     return error_status("can't update node name because node has fanouts");
582   }
583 
584   nodes().erase(node->name());
585   node->set_name(string(to_node_name));
586   nodes().emplace(node->name(), node);
587   return Status::OK();
588 }
589 
SwapNodeNames(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)590 Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name,
591                                        absl::string_view to_node_name,
592                                        bool update_fanouts) {
593   auto error_status = [from_node_name, to_node_name,
594                        update_fanouts](absl::string_view msg) {
595     string params = absl::Substitute(
596         "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
597         from_node_name, to_node_name, update_fanouts);
598     return MutationError("SwapNodeNames", params, msg);
599   };
600 
601   NodeDef* from_node = GetNode(from_node_name);
602   TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status));
603   if (from_node_name == to_node_name) {
604     return Status::OK();
605   }
606   NodeDef* to_node = GetNode(to_node_name);
607   TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status));
608 
609   auto swap_names = [this, from_node, to_node]() {
610     nodes().erase(from_node->name());
611     nodes().erase(to_node->name());
612     std::swap(*from_node->mutable_name(), *to_node->mutable_name());
613     nodes().emplace(from_node->name(), from_node);
614     nodes().emplace(to_node->name(), to_node);
615   };
616 
617   if (update_fanouts) {
618     SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node,
619                      to_node);
620     swap_names();
621     return Status::OK();
622   }
623 
624   bool from_is_switch = IsSwitch(*from_node);
625   MutableGraphView::OutputPort to_control(to_node, Graph::kControlSlot);
626   auto to_control_fanouts = fanouts().find(to_control);
627   if (from_is_switch && HasFanoutValue(fanouts(), to_control_fanouts)) {
628     return error_status(SwapNodeNamesSwitchControlErrorMsg(from_node_name));
629   }
630 
631   bool to_is_switch = IsSwitch(*to_node);
632   MutableGraphView::OutputPort from_control(from_node, Graph::kControlSlot);
633   auto from_control_fanouts = fanouts().find(from_control);
634   if (to_is_switch && HasFanoutValue(fanouts(), from_control_fanouts)) {
635     return error_status(SwapNodeNamesSwitchControlErrorMsg(to_node_name));
636   }
637 
638   // Swap node names.
639   swap_names();
640 
641   // Swap controlling fanouts.
642   //
643   // Note: To and from control fanout iterators are still valid as no mutations
644   // has been performed on fanouts().
645   SwapFanoutsMapValues(&fanouts(), from_control, from_control_fanouts,
646                        to_control, to_control_fanouts);
647 
648   // Swap regular fanouts.
649   SwapRegularFanoutsAndMaxPortValues(&fanouts(), &max_regular_output_port(),
650                                      from_node, to_node);
651 
652   // Update fanins to remove self loops.
653   auto update_fanins = [this](NodeDef* node, absl::string_view old_node_name) {
654     for (int i = 0; i < node->input_size(); ++i) {
655       TensorId tensor_id = ParseTensorName(node->input(i));
656       if (tensor_id.node() == node->name()) {
657         const int idx = tensor_id.index();
658         const int node_idx =
659             IsTensorIdControlling(tensor_id) ? Graph::kControlSlot : i;
660 
661         MutableGraphView::OutputPort from_fanin(node, idx);
662         absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin];
663         from_fanouts->erase({node, node_idx});
664         UpdateMaxRegularOutputPortForRemovedFanin(from_fanin, *from_fanouts);
665 
666         MutableGraphView::OutputPort to_fanin(nodes().at(old_node_name), idx);
667         fanouts()[to_fanin].insert({node, node_idx});
668         UpdateMaxRegularOutputPortForAddedFanin(to_fanin);
669         node->set_input(i, TensorIdToString({old_node_name, idx}));
670       }
671     }
672   };
673   update_fanins(from_node, to_node->name());
674   update_fanins(to_node, from_node->name());
675 
676   // Dedup control dependencies.
677   auto dedup_control_fanouts =
678       [this](NodeDef* node, const FanoutsMap::iterator& control_fanouts) {
679         if (CanDedupControlWithRegularInput(*this, *node) &&
680             control_fanouts != fanouts().end()) {
681           for (auto it = control_fanouts->second.begin();
682                it != control_fanouts->second.end();) {
683             // Advance `it` before invalidation from removal.
684             const auto& control_fanout = *it++;
685             if (HasRegularFaninNode(*this, *control_fanout.node,
686                                     node->name())) {
687               RemoveControllingFaninInternal(control_fanout.node, node);
688             }
689           }
690         }
691       };
692   auto dedup_switch_control = [this, dedup_control_fanouts](NodeDef* node) {
693     OutputPort port;
694     port.node = node;
695     const int max_port =
696         gtl::FindWithDefault(max_regular_output_port(), node, -1);
697     for (int i = 0; i <= max_port; ++i) {
698       port.port_id = i;
699       auto it = fanouts().find(port);
700       if (it == fanouts().end()) {
701         continue;
702       }
703       for (const auto& fanout : it->second) {
704         auto fanout_controls =
705             fanouts().find({fanout.node, Graph::kControlSlot});
706         dedup_control_fanouts(fanout.node, fanout_controls);
707       }
708     }
709   };
710 
711   if (!from_is_switch) {
712     if (to_is_switch) {
713       dedup_switch_control(from_node);
714     } else {
715       // Fetch iterator again as the original iterator might have been
716       // invalidated by container rehash triggered due to mutations.
717       auto from_control_fanouts = fanouts().find(from_control);
718       dedup_control_fanouts(from_node, from_control_fanouts);
719     }
720   }
721   if (!to_is_switch) {
722     if (from_is_switch) {
723       dedup_switch_control(to_node);
724     } else {
725       // Fetch iterator again as the original iterator might have been
726       // invalidated by container rehash triggered due to mutations.
727       auto to_control_fanouts = fanouts().find(to_control);
728       dedup_control_fanouts(to_node, to_control_fanouts);
729     }
730   }
731 
732   return Status::OK();
733 }
734 
UpdateFanouts(absl::string_view from_node_name,absl::string_view to_node_name)735 Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name,
736                                        absl::string_view to_node_name) {
737   NodeDef* from_node = GetNode(from_node_name);
738   TF_RETURN_IF_ERROR(
739       CheckNodeExists(from_node_name, from_node,
740                       UpdateFanoutsError(from_node_name, to_node_name)));
741   NodeDef* to_node = GetNode(to_node_name);
742   TF_RETURN_IF_ERROR(CheckNodeExists(
743       to_node_name, to_node, UpdateFanoutsError(from_node_name, to_node_name)));
744 
745   return UpdateFanoutsInternal(from_node, to_node);
746 }
747 
UpdateFanoutsInternal(NodeDef * from_node,NodeDef * to_node)748 Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node,
749                                                NodeDef* to_node) {
750   VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.",
751                               from_node->name(), to_node->name());
752   if (from_node == to_node) {
753     return Status::OK();
754   }
755 
756   // Update internal state with the new output_port->input_port edge.
757   const auto add_edge = [this](const OutputPort& output_port,
758                                const InputPort& input_port) {
759     fanouts()[output_port].insert(input_port);
760   };
761 
762   // Remove invalidated edge from the internal state.
763   const auto remove_edge = [this](const OutputPort& output_port,
764                                   const InputPort& input_port) {
765     fanouts()[output_port].erase(input_port);
766   };
767 
768   // For the control fanouts we do not know the input index in a NodeDef,
769   // so we have to traverse all control inputs.
770 
771   auto control_fanouts =
772       GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
773 
774   bool to_node_is_switch = IsSwitch(*to_node);
775   for (const InputPort& control_port : control_fanouts) {
776     // Node can't be control dependency of itself.
777     if (control_port.node == to_node) continue;
778 
779     // Can't add Switch node as a control dependency.
780     if (to_node_is_switch) {
781       // Trying to add a Switch as a control dependency, which if allowed will
782       // make the graph invalid.
783       return UpdateFanoutsError(from_node->name(), to_node->name())(
784           absl::Substitute("can't update fanouts to node '$0' as it will "
785                            "become a Switch control dependency",
786                            to_node->name()));
787     }
788 
789     NodeDef* node = control_port.node;
790     RemoveControllingFaninInternal(node, from_node);
791     AddFaninInternal(node, {to_node, Graph::kControlSlot});
792   }
793 
794   // First we update regular fanouts. For the regular fanouts
795   // `input_port:port_id` is the input index in NodeDef.
796 
797   auto regular_edges =
798       GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
799 
800   // Maximum index of the `from_node` output tensor that is still used as an
801   // input to some other node.
802   int keep_max_regular_output_port = -1;
803 
804   for (const Edge& edge : regular_edges) {
805     const OutputPort output_port = edge.src;
806     const InputPort input_port = edge.dst;
807 
808     // If the `to_node` reads from the `from_node`, skip this edge (see
809     // AddAndUpdateFanoutsWithoutSelfLoops test for an example).
810     if (input_port.node == to_node) {
811       keep_max_regular_output_port =
812           std::max(keep_max_regular_output_port, output_port.port_id);
813       continue;
814     }
815 
816     // Update input at destination node.
817     input_port.node->set_input(
818         input_port.port_id,
819         TensorIdToString({to_node->name(), output_port.port_id}));
820 
821     // Remove old edge between the `from_node` and the fanout node.
822     remove_edge(output_port, input_port);
823     // Add an edge between the `to_node` and new fanout node.
824     add_edge(OutputPort(to_node, output_port.port_id), input_port);
825     // Dedup control dependency.
826     if (CanDedupControlWithRegularInput(*this, *to_node)) {
827       RemoveControllingFaninInternal(input_port.node, to_node);
828     }
829   }
830 
831   // Because we update all regular fanouts of `from_node`, we can just copy
832   // the value `num_regular_outputs`.
833   max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
834 
835   // Check if all fanouts were updated to read from the `to_node`.
836   if (keep_max_regular_output_port >= 0) {
837     max_regular_output_port()[from_node] = keep_max_regular_output_port;
838   } else {
839     max_regular_output_port().erase(from_node);
840   }
841 
842   return Status::OK();
843 }
844 
AddFaninInternal(NodeDef * node,const OutputPort & fanin)845 bool MutableGraphView::AddFaninInternal(NodeDef* node,
846                                         const OutputPort& fanin) {
847   int num_regular_fanins =
848       NumFanins(*node, /*include_controlling_nodes=*/false);
849   bool input_is_control = IsOutputPortControlling(fanin);
850   bool can_dedup_control_with_regular_input =
851       CanDedupControlWithRegularInput(*this, *fanin.node);
852   // Don't add duplicate control dependencies.
853   if (input_is_control) {
854     const int start =
855         can_dedup_control_with_regular_input ? 0 : num_regular_fanins;
856     for (int i = start; i < node->input_size(); ++i) {
857       if (ParseTensorName(node->input(i)).node() == fanin.node->name()) {
858         return false;
859       }
860     }
861   }
862 
863   InputPort input;
864   input.node = node;
865   input.port_id = input_is_control ? Graph::kControlSlot : num_regular_fanins;
866 
867   node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
868   if (!input_is_control) {
869     const int last_node_input = node->input_size() - 1;
870     // If there are control dependencies in node, move newly inserted fanin to
871     // be before such control dependencies.
872     if (num_regular_fanins < last_node_input) {
873       node->mutable_input()->SwapElements(last_node_input, num_regular_fanins);
874     }
875   }
876 
877   fanouts()[fanin].insert(input);
878   if (max_regular_output_port()[fanin.node] < fanin.port_id) {
879     max_regular_output_port()[fanin.node] = fanin.port_id;
880   }
881 
882   // Update max input port and dedup control dependencies.
883   if (!input_is_control) {
884     max_regular_input_port()[node] = num_regular_fanins;
885     if (can_dedup_control_with_regular_input) {
886       RemoveControllingFaninInternal(node, fanin.node);
887     }
888   }
889 
890   return true;
891 }
892 
AddRegularFanin(absl::string_view node_name,const TensorId & fanin)893 Status MutableGraphView::AddRegularFanin(absl::string_view node_name,
894                                          const TensorId& fanin) {
895   auto error_status = [node_name, fanin](absl::string_view msg) {
896     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
897                                      fanin.ToString());
898     return MutationError("AddRegularFanin", params, msg);
899   };
900 
901   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
902   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
903   NodeDef* node = GetNode(node_name);
904   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
905   NodeDef* fanin_node = GetNode(fanin.node());
906   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
907 
908   AddFaninInternal(node, {fanin_node, fanin.index()});
909   return Status::OK();
910 }
911 
AddRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)912 Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name,
913                                                int port,
914                                                const TensorId& fanin) {
915   auto error_status = [node_name, port, fanin](absl::string_view msg) {
916     string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
917                                      node_name, port, fanin.ToString());
918     return MutationError("AddRegularFaninByPort", params, msg);
919   };
920 
921   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
922   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
923   NodeDef* node = GetNode(node_name);
924   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
925   const int num_regular_fanins =
926       NumFanins(*node, /*include_controlling_nodes=*/false);
927   TF_RETURN_IF_ERROR(
928       CheckPortRange(port, /*min=*/0, num_regular_fanins, error_status));
929   NodeDef* fanin_node = GetNode(fanin.node());
930   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
931 
932   const int last_node_input = node->input_size();
933   node->add_input(TensorIdToString(fanin));
934   node->mutable_input()->SwapElements(num_regular_fanins, last_node_input);
935   for (int i = num_regular_fanins - 1; i >= port; --i) {
936     TensorId tensor_id = ParseTensorName(node->input(i));
937     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
938     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
939     fanouts_set->erase({node, i});
940     fanouts_set->insert({node, i + 1});
941     node->mutable_input()->SwapElements(i, i + 1);
942   }
943 
944   OutputPort fanin_port(fanin_node, fanin.index());
945   fanouts()[fanin_port].insert({node, port});
946   UpdateMaxRegularOutputPortForAddedFanin(fanin_port);
947 
948   max_regular_input_port()[node] = num_regular_fanins;
949   if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
950     RemoveControllingFaninInternal(node, fanin_node);
951   }
952 
953   return Status::OK();
954 }
955 
GetControllingFaninToAdd(absl::string_view node_name,const OutputPort & fanin,string * error_msg)956 NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name,
957                                                     const OutputPort& fanin,
958                                                     string* error_msg) {
959   if (!IsSwitch(*fanin.node)) {
960     return fanin.node;
961   } else {
962     if (IsOutputPortControlling(fanin)) {
963       // Can't add a Switch node control dependency.
964       TensorId tensor_id(fanin.node->name(), fanin.port_id);
965       *error_msg = absl::Substitute(
966           "can't add fanin '$0' as it will become a Switch control dependency",
967           tensor_id.ToString());
968       return nullptr;
969     }
970     // We can't anchor control dependencies directly on the switch node: unlike
971     // other nodes only one of the outputs of the switch node will be generated
972     // when the switch node is executed, and we need to make sure the control
973     // dependency is only triggered when the corresponding output is triggered.
974     // We start by looking for an identity node connected to the output of the
975     // switch node, and use it to anchor the control dependency.
976     for (const auto& fanout : GetFanout(fanin)) {
977       if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
978         if (fanout.node->name() == node_name) {
979           *error_msg =
980               absl::Substitute("can't add found fanin '$0' to self",
981                                AsControlDependency(fanout.node->name()));
982           return nullptr;
983         }
984         return fanout.node;
985       }
986     }
987 
988     // No node found, check if node to be created is itself.
989     if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) {
990       *error_msg = absl::Substitute("can't add generated fanin '$0' to self",
991                                     AsControlDependency(string(node_name)));
992     }
993   }
994   return nullptr;
995 }
996 
GetOrCreateIdentityConsumingSwitch(const OutputPort & fanin)997 NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch(
998     const OutputPort& fanin) {
999   // We haven't found an existing node where we can anchor the control
1000   // dependency: add a new identity node.
1001   string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin);
1002   NodeDef* identity_node = GetNode(identity_name);
1003   if (identity_node == nullptr) {
1004     NodeDef new_node;
1005     new_node.set_name(identity_name);
1006     new_node.set_op("Identity");
1007     new_node.set_device(fanin.node->device());
1008     (*new_node.mutable_attr())["T"].set_type(fanin.node->attr().at("T").type());
1009     new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
1010     identity_node = AddNode(std::move(new_node));
1011   }
1012   return identity_node;
1013 }
1014 
AddControllingFanin(absl::string_view node_name,const TensorId & fanin)1015 Status MutableGraphView::AddControllingFanin(absl::string_view node_name,
1016                                              const TensorId& fanin) {
1017   auto error_status = [node_name, fanin](absl::string_view msg) {
1018     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1019                                      fanin.ToString());
1020     return MutationError("AddControllingFanin", params, msg);
1021   };
1022 
1023   TF_RETURN_IF_ERROR(CheckFaninIsValid(fanin, error_status));
1024   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1025   NodeDef* node = GetNode(node_name);
1026   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1027   NodeDef* fanin_node = GetNode(fanin.node());
1028   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1029 
1030   OutputPort fanin_port(fanin_node, fanin.index());
1031 
1032   string error_msg = "";
1033   NodeDef* control_node = GetControllingFaninToAdd(
1034       node_name, {fanin_node, fanin.index()}, &error_msg);
1035   if (!error_msg.empty()) {
1036     return error_status(error_msg);
1037   }
1038   if (control_node == nullptr) {
1039     control_node = GetOrCreateIdentityConsumingSwitch(fanin_port);
1040   }
1041   AddFaninInternal(node, {control_node, Graph::kControlSlot});
1042 
1043   return Status::OK();
1044 }
1045 
RemoveRegularFaninInternal(NodeDef * node,const OutputPort & fanin)1046 bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node,
1047                                                   const OutputPort& fanin) {
1048   auto remove_input = [this, node](const OutputPort& fanin_port,
1049                                    int node_input_port, bool update_max_port) {
1050     InputPort input(node, node_input_port);
1051 
1052     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1053     fanouts_set->erase(input);
1054     if (update_max_port) {
1055       UpdateMaxRegularOutputPortForRemovedFanin(fanin_port, *fanouts_set);
1056     }
1057     return fanouts_set;
1058   };
1059 
1060   auto mutable_inputs = node->mutable_input();
1061   bool modified = false;
1062   const int num_regular_fanins =
1063       NumFanins(*node, /*include_controlling_nodes=*/false);
1064   int i;
1065   int curr_pos = 0;
1066   for (i = 0; i < num_regular_fanins; ++i) {
1067     TensorId tensor_id = ParseTensorName(node->input(i));
1068     if (tensor_id.node() == fanin.node->name() &&
1069         tensor_id.index() == fanin.port_id) {
1070       remove_input(fanin, i, /*update_max_port=*/true);
1071       modified = true;
1072     } else if (modified) {
1073       // Regular inputs will need to have their ports updated.
1074       OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1075       auto fanouts_set = remove_input(fanin_port, i, /*update_max_port=*/false);
1076       fanouts_set->insert({node, curr_pos});
1077       // Shift inputs to be retained.
1078       mutable_inputs->SwapElements(i, curr_pos);
1079       ++curr_pos;
1080     } else {
1081       // Skip inputs to be retained until first modification.
1082       ++curr_pos;
1083     }
1084   }
1085 
1086   if (modified) {
1087     const int last_regular_input_port = curr_pos - 1;
1088     if (last_regular_input_port < 0) {
1089       max_regular_input_port().erase(node);
1090     } else {
1091       max_regular_input_port()[node] = last_regular_input_port;
1092     }
1093     if (curr_pos < i) {
1094       // Remove fanins from node inputs.
1095       mutable_inputs->DeleteSubrange(curr_pos, i - curr_pos);
1096     }
1097   }
1098 
1099   return modified;
1100 }
1101 
RemoveRegularFanin(absl::string_view node_name,const TensorId & fanin)1102 Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name,
1103                                             const TensorId& fanin) {
1104   auto error_status = [node_name, fanin](absl::string_view msg) {
1105     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1106                                      fanin.ToString());
1107     return MutationError("RemoveRegularFanin", params, msg);
1108   };
1109 
1110   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1111   TF_RETURN_IF_ERROR(
1112       CheckRemovingFaninFromSelf(node_name, fanin, error_status));
1113   NodeDef* node = GetNode(node_name);
1114   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1115   NodeDef* fanin_node = GetNode(fanin.node());
1116   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1117 
1118   RemoveRegularFaninInternal(node, {fanin_node, fanin.index()});
1119   return Status::OK();
1120 }
1121 
RemoveRegularFaninByPort(absl::string_view node_name,int port)1122 Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name,
1123                                                   int port) {
1124   auto error_status = [node_name, port](absl::string_view msg) {
1125     string params =
1126         absl::Substitute("node_name='$0', port=$1", node_name, port);
1127     return MutationError("RemoveRegularFaninByPort", params, msg);
1128   };
1129 
1130   NodeDef* node = GetNode(node_name);
1131   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1132   const int last_regular_fanin_port =
1133       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1134   TF_RETURN_IF_ERROR(
1135       CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1136 
1137   TensorId tensor_id = ParseTensorName(node->input(port));
1138   OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1139   fanouts()[fanin_port].erase({node, port});
1140   auto mutable_inputs = node->mutable_input();
1141   for (int i = port + 1; i <= last_regular_fanin_port; ++i) {
1142     TensorId tensor_id = ParseTensorName(node->input(i));
1143     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1144     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1145     fanouts_set->erase({node, i});
1146     fanouts_set->insert({node, i - 1});
1147     mutable_inputs->SwapElements(i - 1, i);
1148   }
1149   const int last_node_input = node->input_size() - 1;
1150   if (last_regular_fanin_port < last_node_input) {
1151     mutable_inputs->SwapElements(last_regular_fanin_port, last_node_input);
1152   }
1153   mutable_inputs->RemoveLast();
1154 
1155   const int updated_last_regular_input_port = last_regular_fanin_port - 1;
1156   if (updated_last_regular_input_port < 0) {
1157     max_regular_input_port().erase(node);
1158   } else {
1159     max_regular_input_port()[node] = updated_last_regular_input_port;
1160   }
1161 
1162   return Status::OK();
1163 }
1164 
RemoveControllingFaninInternal(NodeDef * node,NodeDef * fanin_node)1165 bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node,
1166                                                       NodeDef* fanin_node) {
1167   for (int i = node->input_size() - 1; i >= 0; --i) {
1168     TensorId tensor_id = ParseTensorName(node->input(i));
1169     if (tensor_id.index() > Graph::kControlSlot) {
1170       break;
1171     }
1172     if (tensor_id.node() == fanin_node->name()) {
1173       fanouts()[{fanin_node, Graph::kControlSlot}].erase(
1174           {node, Graph::kControlSlot});
1175       node->mutable_input()->SwapElements(i, node->input_size() - 1);
1176       node->mutable_input()->RemoveLast();
1177       return true;
1178     }
1179   }
1180   return false;
1181 }
1182 
RemoveControllingFanin(absl::string_view node_name,absl::string_view fanin_node_name)1183 Status MutableGraphView::RemoveControllingFanin(
1184     absl::string_view node_name, absl::string_view fanin_node_name) {
1185   auto error_status = [node_name, fanin_node_name](absl::string_view msg) {
1186     string params = absl::Substitute("node_name='$0', fanin_node_name='$1'",
1187                                      node_name, fanin_node_name);
1188     return MutationError("RemoveControllingFanin", params, msg);
1189   };
1190 
1191   TF_RETURN_IF_ERROR(CheckRemovingFaninFromSelf(
1192       node_name, {fanin_node_name, Graph::kControlSlot}, error_status));
1193   NodeDef* node = GetNode(node_name);
1194   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1195   NodeDef* fanin_node = GetNode(fanin_node_name);
1196   TF_RETURN_IF_ERROR(
1197       CheckNodeExists(fanin_node_name, fanin_node, error_status));
1198 
1199   RemoveControllingFaninInternal(node, fanin_node);
1200   return Status::OK();
1201 }
1202 
RemoveAllFanins(absl::string_view node_name,bool keep_controlling_fanins)1203 Status MutableGraphView::RemoveAllFanins(absl::string_view node_name,
1204                                          bool keep_controlling_fanins) {
1205   NodeDef* node = GetNode(node_name);
1206   if (node == nullptr) {
1207     string params =
1208         absl::Substitute("node_name='$0', keep_controlling_fanins=$1",
1209                          node_name, keep_controlling_fanins);
1210     return MutationError("RemoveAllFanins", params,
1211                          NodeMissingErrorMsg(node_name));
1212   }
1213 
1214   if (node->input().empty()) {
1215     return Status::OK();
1216   }
1217 
1218   const int num_regular_fanins =
1219       NumFanins(*node, /*include_controlling_nodes=*/false);
1220   RemoveFaninsInternal(node, keep_controlling_fanins);
1221   if (keep_controlling_fanins) {
1222     if (num_regular_fanins == 0) {
1223       return Status::OK();
1224     } else if (num_regular_fanins < node->input_size()) {
1225       node->mutable_input()->DeleteSubrange(0, num_regular_fanins);
1226     } else {
1227       node->clear_input();
1228     }
1229   } else {
1230     node->clear_input();
1231   }
1232   return Status::OK();
1233 }
1234 
UpdateFanin(absl::string_view node_name,const TensorId & from_fanin,const TensorId & to_fanin)1235 Status MutableGraphView::UpdateFanin(absl::string_view node_name,
1236                                      const TensorId& from_fanin,
1237                                      const TensorId& to_fanin) {
1238   auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) {
1239     string params =
1240         absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'",
1241                          node_name, from_fanin.ToString(), to_fanin.ToString());
1242     return MutationError("UpdateFanin", params, msg);
1243   };
1244 
1245   TF_RETURN_IF_ERROR(CheckFaninIsValid(from_fanin, error_status));
1246   TF_RETURN_IF_ERROR(CheckFaninIsValid(to_fanin, error_status));
1247   NodeDef* node = GetNode(node_name);
1248   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1249   NodeDef* from_fanin_node = GetNode(from_fanin.node());
1250   TF_RETURN_IF_ERROR(
1251       CheckNodeExists(from_fanin.node(), from_fanin_node, error_status));
1252   NodeDef* to_fanin_node = GetNode(to_fanin.node());
1253   TF_RETURN_IF_ERROR(
1254       CheckNodeExists(to_fanin.node(), to_fanin_node, error_status));
1255 
1256   // When replacing a non control dependency fanin with a control dependency, or
1257   // vice versa, remove and add, so ports can be updated properly in fanout(s).
1258   bool to_fanin_is_control = IsTensorIdControlling(to_fanin);
1259   if (to_fanin_is_control && IsSwitch(*to_fanin_node)) {
1260     // Can't add Switch node as a control dependency.
1261     return error_status(
1262         absl::Substitute("can't update to fanin '$0' as it will become a "
1263                          "Switch control dependency",
1264                          to_fanin.ToString()));
1265   }
1266   if (node_name == from_fanin.node() || node_name == to_fanin.node()) {
1267     return error_status("can't update fanin to or from self");
1268   }
1269 
1270   if (from_fanin == to_fanin) {
1271     return Status::OK();
1272   }
1273 
1274   bool from_fanin_is_control = IsTensorIdControlling(from_fanin);
1275   if (from_fanin_is_control || to_fanin_is_control) {
1276     bool modified = false;
1277     if (from_fanin_is_control) {
1278       modified |= RemoveControllingFaninInternal(node, from_fanin_node);
1279     } else {
1280       modified |= RemoveRegularFaninInternal(
1281           node, {from_fanin_node, from_fanin.index()});
1282     }
1283     if (modified) {
1284       AddFaninInternal(node, {to_fanin_node, to_fanin.index()});
1285     }
1286     return Status::OK();
1287   }
1288 
1289   // In place mutation of regular fanins, requires no shifting of ports.
1290   string to_fanin_string = TensorIdToString(to_fanin);
1291   const int num_regular_fanins =
1292       NumFanins(*node, /*include_controlling_nodes=*/false);
1293   bool modified = false;
1294   absl::flat_hash_set<InputPort>* from_fanin_port_fanouts = nullptr;
1295   absl::flat_hash_set<InputPort>* to_fanin_port_fanouts = nullptr;
1296   for (int i = 0; i < num_regular_fanins; ++i) {
1297     if (ParseTensorName(node->input(i)) == from_fanin) {
1298       InputPort input(node, i);
1299       if (from_fanin_port_fanouts == nullptr) {
1300         OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1301         from_fanin_port_fanouts = &fanouts()[from_fanin_port];
1302       }
1303       from_fanin_port_fanouts->erase(input);
1304 
1305       if (to_fanin_port_fanouts == nullptr) {
1306         OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
1307         to_fanin_port_fanouts = &fanouts()[to_fanin_port];
1308       }
1309       to_fanin_port_fanouts->insert(input);
1310 
1311       node->set_input(i, to_fanin_string);
1312       modified = true;
1313     }
1314   }
1315 
1316   // Dedup control dependencies and update max regular output ports.
1317   if (modified) {
1318     UpdateMaxRegularOutputPortForRemovedFanin(
1319         {from_fanin_node, from_fanin.index()}, *from_fanin_port_fanouts);
1320     if (max_regular_output_port()[to_fanin_node] < to_fanin.index()) {
1321       max_regular_output_port()[to_fanin_node] = to_fanin.index();
1322     }
1323     if (CanDedupControlWithRegularInput(*this, *to_fanin_node)) {
1324       RemoveControllingFaninInternal(node, to_fanin_node);
1325     }
1326   }
1327 
1328   return Status::OK();
1329 }
1330 
UpdateRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)1331 Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name,
1332                                                   int port,
1333                                                   const TensorId& fanin) {
1334   auto error_status = [node_name, port, fanin](absl::string_view msg) {
1335     string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
1336                                      node_name, port, fanin.ToString());
1337     return MutationError("UpdateRegularFaninByPort", params, msg);
1338   };
1339 
1340   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1341   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1342   NodeDef* node = GetNode(node_name);
1343   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1344   const int last_regular_fanin_port =
1345       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1346   TF_RETURN_IF_ERROR(
1347       CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1348   NodeDef* fanin_node = GetNode(fanin.node());
1349   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1350 
1351   TensorId tensor_id = ParseTensorName(node->input(port));
1352   if (tensor_id == fanin) {
1353     return Status::OK();
1354   }
1355 
1356   InputPort input(node, port);
1357   OutputPort from_fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1358   absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin_port];
1359   from_fanouts->erase(input);
1360   UpdateMaxRegularOutputPortForRemovedFanin(from_fanin_port, *from_fanouts);
1361 
1362   OutputPort to_fanin_port(fanin_node, fanin.index());
1363   fanouts()[to_fanin_port].insert(input);
1364   UpdateMaxRegularOutputPortForAddedFanin(to_fanin_port);
1365 
1366   node->set_input(port, TensorIdToString(fanin));
1367 
1368   if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
1369     RemoveControllingFaninInternal(node, fanin_node);
1370   }
1371 
1372   return Status::OK();
1373 }
1374 
SwapRegularFaninsByPorts(absl::string_view node_name,int from_port,int to_port)1375 Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name,
1376                                                   int from_port, int to_port) {
1377   auto error_status = [node_name, from_port, to_port](absl::string_view msg) {
1378     string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2",
1379                                      node_name, from_port, to_port);
1380     return MutationError("SwapRegularFaninsByPorts", params, msg);
1381   };
1382 
1383   NodeDef* node = GetNode(node_name);
1384   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1385   const int last_regular_fanin_port =
1386       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1387   TF_RETURN_IF_ERROR(CheckPortRange(from_port, /*min=*/0,
1388                                     last_regular_fanin_port, error_status));
1389   TF_RETURN_IF_ERROR(CheckPortRange(to_port, /*min=*/0, last_regular_fanin_port,
1390                                     error_status));
1391 
1392   if (from_port == to_port) {
1393     return Status::OK();
1394   }
1395   TensorId from_fanin = ParseTensorName(node->input(from_port));
1396   TensorId to_fanin = ParseTensorName(node->input(to_port));
1397   if (from_fanin == to_fanin) {
1398     return Status::OK();
1399   }
1400 
1401   InputPort from_input(node, from_port);
1402   InputPort to_input(node, to_port);
1403   NodeDef* from_fanin_node = GetNode(from_fanin.node());
1404   absl::flat_hash_set<InputPort>* from_fanouts =
1405       &fanouts()[{from_fanin_node, from_fanin.index()}];
1406   from_fanouts->erase(from_input);
1407   from_fanouts->insert(to_input);
1408   NodeDef* to_fanin_node = GetNode(to_fanin.node());
1409   absl::flat_hash_set<InputPort>* to_fanouts =
1410       &fanouts()[{to_fanin_node, to_fanin.index()}];
1411   to_fanouts->erase(to_input);
1412   to_fanouts->insert(from_input);
1413 
1414   node->mutable_input()->SwapElements(from_port, to_port);
1415 
1416   return Status::OK();
1417 }
1418 
UpdateAllRegularFaninsToControlling(absl::string_view node_name)1419 Status MutableGraphView::UpdateAllRegularFaninsToControlling(
1420     absl::string_view node_name) {
1421   auto error_status = [node_name](absl::string_view msg) {
1422     string params = absl::Substitute("node_name='$0'", node_name);
1423     return MutationError("UpdateAllRegularFaninsToControlling", params, msg);
1424   };
1425 
1426   NodeDef* node = GetNode(node_name);
1427   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1428 
1429   const int num_regular_fanins =
1430       NumFanins(*node, /*include_controlling_nodes=*/false);
1431   std::vector<OutputPort> regular_fanins;
1432   regular_fanins.reserve(num_regular_fanins);
1433   std::vector<NodeDef*> controlling_fanins;
1434   controlling_fanins.reserve(num_regular_fanins);
1435 
1436   // Get all regular fanins and derive controlling fanins.
1437   for (int i = 0; i < num_regular_fanins; ++i) {
1438     TensorId tensor_id = ParseTensorName(node->input(i));
1439     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1440 
1441     string error_msg = "";
1442     NodeDef* control_node =
1443         GetControllingFaninToAdd(node_name, fanin_port, &error_msg);
1444     if (!error_msg.empty()) {
1445       return error_status(error_msg);
1446     }
1447 
1448     regular_fanins.push_back(fanin_port);
1449     controlling_fanins.push_back(control_node);
1450   }
1451 
1452   // Replace regular fanins with controlling fanins and dedup.
1453   int pos = 0;
1454   InputPort input_port(node, Graph::kControlSlot);
1455   absl::flat_hash_set<absl::string_view> controls;
1456   for (int i = 0; i < num_regular_fanins; ++i) {
1457     OutputPort fanin_port = regular_fanins[i];
1458     NodeDef* control = controlling_fanins[i];
1459     if (control == nullptr) {
1460       control = GetOrCreateIdentityConsumingSwitch(fanin_port);
1461     }
1462     fanouts()[fanin_port].erase({node, i});
1463     if (controls.contains(control->name())) {
1464       continue;
1465     }
1466     controls.insert(control->name());
1467     node->set_input(pos, AsControlDependency(control->name()));
1468     fanouts()[{control, Graph::kControlSlot}].insert(input_port);
1469     ++pos;
1470   }
1471 
1472   // Shift existing controlling fanins and dedup.
1473   for (int i = num_regular_fanins; i < node->input_size(); ++i) {
1474     TensorId tensor_id = ParseTensorName(node->input(i));
1475     if (controls.contains(tensor_id.node())) {
1476       continue;
1477     }
1478     controls.insert(tensor_id.node());
1479     node->mutable_input()->SwapElements(pos, i);
1480     ++pos;
1481   }
1482 
1483   // Remove duplicate controls and leftover regular fanins.
1484   node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos);
1485   max_regular_input_port().erase(node);
1486 
1487   return Status::OK();
1488 }
1489 
CheckNodesCanBeDeleted(const absl::flat_hash_set<string> & nodes_to_delete)1490 Status MutableGraphView::CheckNodesCanBeDeleted(
1491     const absl::flat_hash_set<string>& nodes_to_delete) {
1492   std::vector<string> missing_nodes;
1493   std::vector<string> nodes_with_fanouts;
1494   for (const string& node_name_to_delete : nodes_to_delete) {
1495     NodeDef* node = GetNode(node_name_to_delete);
1496     if (node == nullptr) {
1497       // Can't delete missing node.
1498       missing_nodes.push_back(node_name_to_delete);
1499       continue;
1500     }
1501     const int max_port = gtl::FindWithDefault(max_regular_output_port(), node,
1502                                               Graph::kControlSlot);
1503     for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1504       auto it = fanouts().find({node, i});
1505       bool has_retained_fanout = false;
1506       if (it != fanouts().end()) {
1507         for (const auto& fanout : it->second) {
1508           // Check if fanouts are of nodes to be deleted, and if so, they can be
1509           // ignored, as they will be removed also.
1510           if (!nodes_to_delete.contains(fanout.node->name())) {
1511             // Removing node will leave graph in an invalid state.
1512             has_retained_fanout = true;
1513             break;
1514           }
1515         }
1516       }
1517       if (has_retained_fanout) {
1518         nodes_with_fanouts.push_back(node_name_to_delete);
1519         break;
1520       }
1521     }
1522   }
1523 
1524   // Error message can get quite long, so we only show the first 5 node names.
1525   auto sort_and_sample = [](std::vector<string>* s) {
1526     constexpr int kMaxNodeNames = 5;
1527     std::sort(s->begin(), s->end());
1528     if (s->size() > kMaxNodeNames) {
1529       return absl::StrCat(
1530           absl::StrJoin(s->begin(), s->begin() + kMaxNodeNames, ", "), ", ...");
1531     }
1532     return absl::StrJoin(*s, ", ");
1533   };
1534 
1535   if (!missing_nodes.empty()) {
1536     VLOG(2) << absl::Substitute("Attempting to delete missing node(s) [$0].",
1537                                 sort_and_sample(&missing_nodes));
1538   }
1539   if (!nodes_with_fanouts.empty()) {
1540     std::vector<string> input_node_names(nodes_to_delete.begin(),
1541                                          nodes_to_delete.end());
1542     string params = absl::Substitute("nodes_to_delete={$0}",
1543                                      sort_and_sample(&input_node_names));
1544     string error_msg =
1545         absl::Substitute("can't delete node(s) with retained fanouts(s) [$0]",
1546                          sort_and_sample(&nodes_with_fanouts));
1547     return MutationError("DeleteNodes", params, error_msg);
1548   }
1549 
1550   return Status::OK();
1551 }
1552 
DeleteNodes(const absl::flat_hash_set<string> & nodes_to_delete)1553 Status MutableGraphView::DeleteNodes(
1554     const absl::flat_hash_set<string>& nodes_to_delete) {
1555   TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete));
1556 
1557   // Find nodes in internal state and delete.
1558   for (const string& node_name_to_delete : nodes_to_delete) {
1559     NodeDef* node = GetNode(node_name_to_delete);
1560     if (node != nullptr) {
1561       RemoveFaninsInternal(node, /*keep_controlling_fanins=*/false);
1562       RemoveFanoutsInternal(node);
1563     }
1564   }
1565   for (const string& node_name_to_delete : nodes_to_delete) {
1566     nodes().erase(node_name_to_delete);
1567   }
1568 
1569   // Find nodes in graph and delete by partitioning into nodes to retain and
1570   // nodes to delete based on input set of nodes to delete by name.
1571   // TODO(lyandy): Use a node name->idx hashmap if this is a performance
1572   // bottleneck.
1573   int pos = 0;
1574   const int last_idx = graph()->node_size() - 1;
1575   int last_pos = last_idx;
1576   while (pos <= last_pos) {
1577     if (nodes_to_delete.contains(graph()->node(pos).name())) {
1578       graph()->mutable_node()->SwapElements(pos, last_pos);
1579       --last_pos;
1580     } else {
1581       ++pos;
1582     }
1583   }
1584   if (last_pos < last_idx) {
1585     graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
1586   }
1587 
1588   return Status::OK();
1589 }
1590 
RemoveFaninsInternal(NodeDef * deleted_node,bool keep_controlling_fanins)1591 void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
1592                                             bool keep_controlling_fanins) {
1593   for (int i = 0; i < deleted_node->input_size(); ++i) {
1594     TensorId tensor_id = ParseTensorName(deleted_node->input(i));
1595     bool is_control = IsTensorIdControlling(tensor_id);
1596     if (keep_controlling_fanins && is_control) {
1597       break;
1598     }
1599     OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
1600 
1601     InputPort input;
1602     input.node = deleted_node;
1603     input.port_id = is_control ? Graph::kControlSlot : i;
1604 
1605     auto it = fanouts().find(fanin);
1606     if (it != fanouts().end()) {
1607       absl::flat_hash_set<InputPort>* fanouts_set = &it->second;
1608       fanouts_set->erase(input);
1609       UpdateMaxRegularOutputPortForRemovedFanin(fanin, *fanouts_set);
1610     }
1611   }
1612   max_regular_input_port().erase(deleted_node);
1613 }
1614 
RemoveFanoutsInternal(NodeDef * deleted_node)1615 void MutableGraphView::RemoveFanoutsInternal(NodeDef* deleted_node) {
1616   const int max_port =
1617       gtl::FindWithDefault(max_regular_output_port(), deleted_node, -1);
1618   for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1619     fanouts().erase({deleted_node, i});
1620   }
1621   max_regular_output_port().erase(deleted_node);
1622 }
1623 
1624 }  // end namespace grappler
1625 }  // end namespace tensorflow
1626