1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/mutable_graph_view.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/grappler/op_types.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace tensorflow {
39 namespace grappler {
40
41 namespace {
42
IsTensorIdPortValid(const TensorId & tensor_id)43 bool IsTensorIdPortValid(const TensorId& tensor_id) {
44 return tensor_id.index() >= Graph::kControlSlot;
45 }
46
IsTensorIdRegular(const TensorId & tensor_id)47 bool IsTensorIdRegular(const TensorId& tensor_id) {
48 return tensor_id.index() > Graph::kControlSlot;
49 }
50
IsTensorIdControlling(const TensorId & tensor_id)51 bool IsTensorIdControlling(const TensorId& tensor_id) {
52 return tensor_id.index() == Graph::kControlSlot;
53 }
54
IsOutputPortControlling(const MutableGraphView::OutputPort & port)55 bool IsOutputPortControlling(const MutableGraphView::OutputPort& port) {
56 return port.port_id == Graph::kControlSlot;
57 }
58
59 // Determines if node is an Identity where it's first regular input is a Switch
60 // node.
IsIdentityConsumingSwitch(const MutableGraphView & graph,const NodeDef & node)61 bool IsIdentityConsumingSwitch(const MutableGraphView& graph,
62 const NodeDef& node) {
63 if ((IsIdentity(node) || IsIdentityNSingleInput(node)) &&
64 node.input_size() > 0) {
65 TensorId tensor_id = ParseTensorName(node.input(0));
66 if (IsTensorIdControlling(tensor_id)) {
67 return false;
68 }
69
70 NodeDef* input_node = graph.GetNode(tensor_id.node());
71 return IsSwitch(*input_node);
72 }
73 return false;
74 }
75
76 // Determines if node input can be deduped by regular inputs when used as a
77 // control dependency. Specifically, if a node is an Identity that leads to a
78 // Switch node, when used as a control dependency, that control dependency
79 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,const NodeDef & control_node)80 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
81 const NodeDef& control_node) {
82 return !IsIdentityConsumingSwitch(graph, control_node);
83 }
84
85 // Determines if node input can be deduped by regular inputs when used as a
86 // control dependency. Specifically, if a node is an Identity that leads to a
87 // Switch node, when used as a control dependency, that control dependency
88 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,absl::string_view control_node_name)89 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
90 absl::string_view control_node_name) {
91 NodeDef* control_node = graph.GetNode(control_node_name);
92 if (control_node == nullptr) {
93 return false;
94 }
95 return CanDedupControlWithRegularInput(graph, *control_node);
96 }
97
HasRegularFaninNode(const MutableGraphView & graph,const NodeDef & node,absl::string_view fanin_node_name)98 bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node,
99 absl::string_view fanin_node_name) {
100 const int num_regular_fanins =
101 graph.NumFanins(node, /*include_controlling_nodes=*/false);
102 for (int i = 0; i < num_regular_fanins; ++i) {
103 if (ParseTensorName(node.input(i)).node() == fanin_node_name) {
104 return true;
105 }
106 }
107 return false;
108 }
109
110 using FanoutsMap =
111 absl::flat_hash_map<MutableGraphView::OutputPort,
112 absl::flat_hash_set<MutableGraphView::InputPort>>;
113
SwapControlledFanoutInputs(const MutableGraphView & graph,const FanoutsMap::iterator & control_fanouts,absl::string_view to_node_name)114 void SwapControlledFanoutInputs(const MutableGraphView& graph,
115 const FanoutsMap::iterator& control_fanouts,
116 absl::string_view to_node_name) {
117 absl::string_view from_node_name(control_fanouts->first.node->name());
118 string control = TensorIdToString({to_node_name, Graph::kControlSlot});
119 for (const auto& control_fanout : control_fanouts->second) {
120 const int start = graph.NumFanins(*control_fanout.node,
121 /*include_controlling_nodes=*/false);
122 for (int i = start; i < control_fanout.node->input_size(); ++i) {
123 TensorId tensor_id = ParseTensorName(control_fanout.node->input(i));
124 if (tensor_id.node() == from_node_name) {
125 control_fanout.node->set_input(i, control);
126 break;
127 }
128 }
129 }
130 }
131
SwapRegularFanoutInputs(FanoutsMap * fanouts,NodeDef * from_node,absl::string_view to_node_name,int max_port)132 void SwapRegularFanoutInputs(FanoutsMap* fanouts, NodeDef* from_node,
133 absl::string_view to_node_name, int max_port) {
134 MutableGraphView::OutputPort port;
135 port.node = from_node;
136 for (int i = 0; i <= max_port; ++i) {
137 port.port_id = i;
138 auto it = fanouts->find(port);
139 if (it == fanouts->end()) {
140 continue;
141 }
142 string input = TensorIdToString({to_node_name, i});
143 for (const auto& fanout : it->second) {
144 fanout.node->set_input(fanout.port_id, input);
145 }
146 }
147 }
148
149 using MaxOutputPortsMap = absl::flat_hash_map<const NodeDef*, int>;
150
SwapFanoutInputs(const MutableGraphView & graph,FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)151 void SwapFanoutInputs(const MutableGraphView& graph, FanoutsMap* fanouts,
152 MaxOutputPortsMap* max_output_ports, NodeDef* from_node,
153 NodeDef* to_node) {
154 auto from_control_fanouts = fanouts->find({from_node, Graph::kControlSlot});
155 if (from_control_fanouts != fanouts->end()) {
156 SwapControlledFanoutInputs(graph, from_control_fanouts, to_node->name());
157 }
158 auto to_control_fanouts = fanouts->find({to_node, Graph::kControlSlot});
159 if (to_control_fanouts != fanouts->end()) {
160 SwapControlledFanoutInputs(graph, to_control_fanouts, from_node->name());
161 }
162 auto from_max_port = max_output_ports->find(from_node);
163 if (from_max_port != max_output_ports->end()) {
164 SwapRegularFanoutInputs(fanouts, from_node, to_node->name(),
165 from_max_port->second);
166 }
167 auto to_max_port = max_output_ports->find(to_node);
168 if (to_max_port != max_output_ports->end()) {
169 SwapRegularFanoutInputs(fanouts, to_node, from_node->name(),
170 to_max_port->second);
171 }
172 }
173
SwapFanoutsMapValues(FanoutsMap * fanouts,const MutableGraphView::OutputPort & from_port,const FanoutsMap::iterator & from_fanouts,const MutableGraphView::OutputPort & to_port,const FanoutsMap::iterator & to_fanouts)174 void SwapFanoutsMapValues(FanoutsMap* fanouts,
175 const MutableGraphView::OutputPort& from_port,
176 const FanoutsMap::iterator& from_fanouts,
177 const MutableGraphView::OutputPort& to_port,
178 const FanoutsMap::iterator& to_fanouts) {
179 const bool from_exists = from_fanouts != fanouts->end();
180 const bool to_exists = to_fanouts != fanouts->end();
181
182 if (from_exists && to_exists) {
183 std::swap(from_fanouts->second, to_fanouts->second);
184 } else if (from_exists) {
185 fanouts->emplace(to_port, std::move(from_fanouts->second));
186 fanouts->erase(from_port);
187 } else if (to_exists) {
188 fanouts->emplace(from_port, std::move(to_fanouts->second));
189 fanouts->erase(to_port);
190 }
191 }
192
SwapRegularFanoutsAndMaxPortValues(FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)193 void SwapRegularFanoutsAndMaxPortValues(FanoutsMap* fanouts,
194 MaxOutputPortsMap* max_output_ports,
195 NodeDef* from_node, NodeDef* to_node) {
196 auto from_max_port = max_output_ports->find(from_node);
197 auto to_max_port = max_output_ports->find(to_node);
198 bool from_exists = from_max_port != max_output_ports->end();
199 bool to_exists = to_max_port != max_output_ports->end();
200
201 auto forward_fanouts = [fanouts](NodeDef* from, NodeDef* to, int start,
202 int end) {
203 for (int i = start; i <= end; ++i) {
204 MutableGraphView::OutputPort from_port(from, i);
205 auto from_fanouts = fanouts->find(from_port);
206 if (from_fanouts != fanouts->end()) {
207 MutableGraphView::OutputPort to_port(to, i);
208 fanouts->emplace(to_port, std::move(from_fanouts->second));
209 fanouts->erase(from_port);
210 }
211 }
212 };
213
214 if (from_exists && to_exists) {
215 const int from = from_max_port->second;
216 const int to = to_max_port->second;
217 const int shared = std::min(from, to);
218 for (int i = 0; i <= shared; ++i) {
219 MutableGraphView::OutputPort from_port(from_node, i);
220 auto from_fanouts = fanouts->find(from_port);
221 MutableGraphView::OutputPort to_port(to_node, i);
222 auto to_fanouts = fanouts->find(to_port);
223 SwapFanoutsMapValues(fanouts, from_port, from_fanouts, to_port,
224 to_fanouts);
225 }
226 if (to > from) {
227 forward_fanouts(to_node, from_node, shared + 1, to);
228 } else if (from > to) {
229 forward_fanouts(from_node, to_node, shared + 1, from);
230 }
231
232 std::swap(from_max_port->second, to_max_port->second);
233 } else if (from_exists) {
234 forward_fanouts(from_node, to_node, 0, from_max_port->second);
235
236 max_output_ports->emplace(to_node, from_max_port->second);
237 max_output_ports->erase(from_node);
238 } else if (to_exists) {
239 forward_fanouts(to_node, from_node, 0, to_max_port->second);
240
241 max_output_ports->emplace(from_node, to_max_port->second);
242 max_output_ports->erase(to_node);
243 }
244 }
245
HasFanoutValue(const FanoutsMap & fanouts,const FanoutsMap::iterator & it)246 bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) {
247 return it != fanouts.end() && !it->second.empty();
248 }
249
MutationError(absl::string_view function_name,absl::string_view params,absl::string_view msg)250 Status MutationError(absl::string_view function_name, absl::string_view params,
251 absl::string_view msg) {
252 return errors::InvalidArgument(absl::Substitute(
253 "MutableGraphView::$0($1) error: $2.", function_name, params, msg));
254 }
255
256 using ErrorHandler = std::function<Status(absl::string_view)>;
257
UpdateFanoutsError(absl::string_view from_node_name,absl::string_view to_node_name)258 ErrorHandler UpdateFanoutsError(absl::string_view from_node_name,
259 absl::string_view to_node_name) {
260 return [from_node_name, to_node_name](absl::string_view msg) {
261 string params = absl::Substitute("from_node_name='$0', to_node_name='$1'",
262 from_node_name, to_node_name);
263 return MutationError("UpdateFanouts", params, msg);
264 };
265 }
266
CheckFaninIsRegular(const TensorId & fanin,ErrorHandler handler)267 Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) {
268 if (!IsTensorIdRegular(fanin)) {
269 return handler(absl::Substitute("fanin '$0' must be a regular tensor id",
270 fanin.ToString()));
271 }
272 return Status::OK();
273 }
274
CheckFaninIsValid(const TensorId & fanin,ErrorHandler handler)275 Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) {
276 if (!IsTensorIdPortValid(fanin)) {
277 return handler(absl::Substitute("fanin '$0' must be a valid tensor id",
278 fanin.ToString()));
279 }
280 return Status::OK();
281 }
282
CheckAddingFaninToSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)283 Status CheckAddingFaninToSelf(absl::string_view node_name,
284 const TensorId& fanin, ErrorHandler handler) {
285 if (node_name == fanin.node()) {
286 return handler(
287 absl::Substitute("can't add fanin '$0' to self", fanin.ToString()));
288 }
289 return Status::OK();
290 }
291
CheckRemovingFaninFromSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)292 Status CheckRemovingFaninFromSelf(absl::string_view node_name,
293 const TensorId& fanin, ErrorHandler handler) {
294 if (node_name == fanin.node()) {
295 return handler(absl::Substitute("can't remove fanin '$0' from self",
296 fanin.ToString()));
297 }
298 return Status::OK();
299 }
300
NodeMissingErrorMsg(absl::string_view node_name)301 string NodeMissingErrorMsg(absl::string_view node_name) {
302 return absl::Substitute("node '$0' was not found", node_name);
303 }
304
CheckNodeExists(absl::string_view node_name,NodeDef * node,ErrorHandler handler)305 Status CheckNodeExists(absl::string_view node_name, NodeDef* node,
306 ErrorHandler handler) {
307 if (node == nullptr) {
308 return handler(NodeMissingErrorMsg(node_name));
309 }
310 return Status::OK();
311 }
312
CheckPortRange(int port,int min,int max,ErrorHandler handler)313 Status CheckPortRange(int port, int min, int max, ErrorHandler handler) {
314 if (port < min || port > max) {
315 if (max < min) {
316 return handler("no available ports as node has no regular fanins");
317 }
318 return handler(
319 absl::Substitute("port must be in range [$0, $1]", min, max));
320 }
321 return Status::OK();
322 }
323
SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name)324 string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) {
325 return absl::Substitute(
326 "can't swap node name '$0' as it will become a Switch control dependency",
327 node_name);
328 }
329
GeneratedNameForIdentityConsumingSwitch(const MutableGraphView::OutputPort & fanin)330 string GeneratedNameForIdentityConsumingSwitch(
331 const MutableGraphView::OutputPort& fanin) {
332 return AddPrefixToNodeName(
333 absl::StrCat(fanin.node->name(), "_", fanin.port_id),
334 kMutableGraphViewCtrl);
335 }
336
337 } // namespace
338
AddAndDedupFanouts(NodeDef * node)339 void MutableGraphView::AddAndDedupFanouts(NodeDef* node) {
340 // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
341 // exist, and all regular fanins come before controlling fanins.
342 absl::flat_hash_set<absl::string_view> fanins;
343 absl::flat_hash_set<absl::string_view> controlling_fanins;
344 int max_input_port = -1;
345 int pos = 0;
346 const int last_idx = node->input_size() - 1;
347 int last_pos = last_idx;
348 while (pos <= last_pos) {
349 TensorId tensor_id = ParseTensorName(node->input(pos));
350 absl::string_view input_node_name = tensor_id.node();
351 bool is_control_input = IsTensorIdControlling(tensor_id);
352 bool can_dedup_control_with_regular_input =
353 CanDedupControlWithRegularInput(*this, input_node_name);
354 bool can_dedup_control =
355 is_control_input && (can_dedup_control_with_regular_input ||
356 controlling_fanins.contains(input_node_name));
357 if (!gtl::InsertIfNotPresent(&fanins, input_node_name) &&
358 can_dedup_control) {
359 node->mutable_input()->SwapElements(pos, last_pos);
360 --last_pos;
361 } else {
362 OutputPort output(nodes()[input_node_name], tensor_id.index());
363
364 if (is_control_input) {
365 fanouts()[output].emplace(node, Graph::kControlSlot);
366 } else {
367 max_input_port = pos;
368 max_regular_output_port()[output.node] =
369 std::max(max_regular_output_port()[output.node], output.port_id);
370 fanouts()[output].emplace(node, pos);
371 }
372 ++pos;
373 }
374 if (is_control_input) {
375 controlling_fanins.insert(input_node_name);
376 }
377 }
378
379 if (last_pos < last_idx) {
380 node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
381 }
382
383 if (max_input_port > -1) {
384 max_regular_input_port()[node] = max_input_port;
385 }
386 }
387
UpdateMaxRegularOutputPortForRemovedFanin(const OutputPort & fanin,const absl::flat_hash_set<InputPort> & fanin_fanouts)388 void MutableGraphView::UpdateMaxRegularOutputPortForRemovedFanin(
389 const OutputPort& fanin,
390 const absl::flat_hash_set<InputPort>& fanin_fanouts) {
391 int max_port = max_regular_output_port()[fanin.node];
392 if (!fanin_fanouts.empty() || max_port != fanin.port_id) {
393 return;
394 }
395 bool updated_max_port = false;
396 for (int i = fanin.port_id - 1; i >= 0; --i) {
397 OutputPort fanin_port(fanin.node, i);
398 if (!fanouts()[fanin_port].empty()) {
399 max_regular_output_port()[fanin.node] = i;
400 updated_max_port = true;
401 break;
402 }
403 }
404 if (!updated_max_port) {
405 max_regular_output_port().erase(fanin.node);
406 }
407 }
408
UpdateMaxRegularOutputPortForAddedFanin(const OutputPort & fanin)409 void MutableGraphView::UpdateMaxRegularOutputPortForAddedFanin(
410 const OutputPort& fanin) {
411 if (max_regular_output_port()[fanin.node] < fanin.port_id) {
412 max_regular_output_port()[fanin.node] = fanin.port_id;
413 }
414 }
415
416 const absl::flat_hash_set<MutableGraphView::InputPort>&
GetFanout(const GraphView::OutputPort & port) const417 MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
418 return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
419 port.port_id));
420 }
421
GetFanin(const GraphView::InputPort & port) const422 absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
423 const GraphView::InputPort& port) const {
424 return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
425 port.port_id));
426 }
427
GetRegularFanin(const GraphView::InputPort & port) const428 const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
429 const GraphView::InputPort& port) const {
430 return GetRegularFanin(MutableGraphView::InputPort(
431 const_cast<NodeDef*>(port.node), port.port_id));
432 }
433
AddNode(NodeDef && node)434 NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
435 auto* node_in_graph = graph()->add_node();
436 *node_in_graph = std::move(node);
437
438 AddUniqueNodeOrDie(node_in_graph);
439
440 AddAndDedupFanouts(node_in_graph);
441 return node_in_graph;
442 }
443
AddSubgraph(GraphDef && subgraph)444 Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
445 // 1. Add all new functions and check that functions with the same name
446 // have identical definition.
447 const int function_size = subgraph.library().function_size();
448 if (function_size > 0) {
449 absl::flat_hash_map<absl::string_view, const FunctionDef*> graph_fdefs;
450 for (const FunctionDef& fdef : graph()->library().function()) {
451 graph_fdefs.emplace(fdef.signature().name(), &fdef);
452 }
453
454 for (FunctionDef& fdef : *subgraph.mutable_library()->mutable_function()) {
455 const auto graph_fdef = graph_fdefs.find(fdef.signature().name());
456
457 if (graph_fdef == graph_fdefs.end()) {
458 VLOG(3) << "Add new function definition: " << fdef.signature().name();
459 graph()->mutable_library()->add_function()->Swap(&fdef);
460 } else {
461 if (!FunctionDefsEqual(fdef, *graph_fdef->second)) {
462 return MutationError(
463 "AddSubgraph",
464 absl::Substitute("function_size=$0", function_size),
465 absl::StrCat(
466 "Found different function definition with the same name: ",
467 fdef.signature().name()));
468 }
469 }
470 }
471 }
472
473 // 2. Add all nodes to the underlying graph.
474 int node_size_before = graph()->node_size();
475
476 for (NodeDef& node : *subgraph.mutable_node()) {
477 auto* node_in_graph = graph()->add_node();
478 node_in_graph->Swap(&node);
479 TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
480 }
481
482 // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
483 // fanins actually exists in the graph, and there is already TODO for that.
484
485 for (int i = node_size_before; i < graph()->node_size(); ++i) {
486 NodeDef* node = graph()->mutable_node(i);
487 AddAndDedupFanouts(node);
488 }
489
490 return Status::OK();
491 }
492
UpdateNode(absl::string_view node_name,absl::string_view op,absl::string_view device,absl::Span<const std::pair<string,AttrValue>> attrs)493 Status MutableGraphView::UpdateNode(
494 absl::string_view node_name, absl::string_view op, absl::string_view device,
495 absl::Span<const std::pair<string, AttrValue>> attrs) {
496 auto error_status = [node_name, op, device, attrs](absl::string_view msg) {
497 std::vector<string> attr_strs;
498 attr_strs.reserve(attrs.size());
499 for (const auto& attr : attrs) {
500 string attr_str = absl::Substitute("('$0', $1)", attr.first,
501 attr.second.ShortDebugString());
502 attr_strs.push_back(attr_str);
503 }
504 string params =
505 absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}",
506 node_name, op, device, absl::StrJoin(attr_strs, ", "));
507 return MutationError("UpdateNodeOp", params, msg);
508 };
509
510 NodeDef* node = GetNode(node_name);
511 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
512
513 MutableGraphView::OutputPort control_port(node, Graph::kControlSlot);
514 auto control_fanouts = GetFanout(control_port);
515 if (op == "Switch" && !control_fanouts.empty()) {
516 return error_status(
517 "can't change node op to Switch when node drives a control dependency "
518 "(alternatively, we could add the identity node needed, but it seems "
519 "like an unlikely event and probably a mistake)");
520 }
521
522 if (node->device() != device) {
523 node->set_device(string(device));
524 }
525 node->mutable_attr()->clear();
526 for (const auto& attr : attrs) {
527 (*node->mutable_attr())[attr.first] = attr.second;
528 }
529
530 if (node->op() == op) {
531 return Status::OK();
532 }
533
534 node->set_op(string(op));
535
536 if (CanDedupControlWithRegularInput(*this, *node)) {
537 for (const auto& control_fanout : control_fanouts) {
538 if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) {
539 RemoveControllingFaninInternal(control_fanout.node, node);
540 }
541 }
542 }
543
544 return Status::OK();
545 }
546
UpdateNodeName(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)547 Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name,
548 absl::string_view to_node_name,
549 bool update_fanouts) {
550 auto error_status = [from_node_name, to_node_name,
551 update_fanouts](absl::string_view msg) {
552 string params = absl::Substitute(
553 "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
554 from_node_name, to_node_name, update_fanouts);
555 return MutationError("UpdateNodeName", params, msg);
556 };
557
558 NodeDef* node = GetNode(from_node_name);
559 TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status));
560
561 if (node->name() == to_node_name) {
562 return Status::OK();
563 }
564 if (HasNode(to_node_name)) {
565 return error_status(
566 "can't update node name because new node name is in use");
567 }
568 auto max_output_port = max_regular_output_port().find(node);
569 const bool has_max_output_port =
570 max_output_port != max_regular_output_port().end();
571 auto control_fanouts = fanouts().find({node, Graph::kControlSlot});
572
573 if (update_fanouts) {
574 SwapControlledFanoutInputs(*this, control_fanouts, to_node_name);
575 if (has_max_output_port) {
576 SwapRegularFanoutInputs(&fanouts(), node, to_node_name,
577 max_output_port->second);
578 }
579 } else if (has_max_output_port ||
580 HasFanoutValue(fanouts(), control_fanouts)) {
581 return error_status("can't update node name because node has fanouts");
582 }
583
584 nodes().erase(node->name());
585 node->set_name(string(to_node_name));
586 nodes().emplace(node->name(), node);
587 return Status::OK();
588 }
589
SwapNodeNames(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)590 Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name,
591 absl::string_view to_node_name,
592 bool update_fanouts) {
593 auto error_status = [from_node_name, to_node_name,
594 update_fanouts](absl::string_view msg) {
595 string params = absl::Substitute(
596 "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
597 from_node_name, to_node_name, update_fanouts);
598 return MutationError("SwapNodeNames", params, msg);
599 };
600
601 NodeDef* from_node = GetNode(from_node_name);
602 TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status));
603 if (from_node_name == to_node_name) {
604 return Status::OK();
605 }
606 NodeDef* to_node = GetNode(to_node_name);
607 TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status));
608
609 auto swap_names = [this, from_node, to_node]() {
610 nodes().erase(from_node->name());
611 nodes().erase(to_node->name());
612 std::swap(*from_node->mutable_name(), *to_node->mutable_name());
613 nodes().emplace(from_node->name(), from_node);
614 nodes().emplace(to_node->name(), to_node);
615 };
616
617 if (update_fanouts) {
618 SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node,
619 to_node);
620 swap_names();
621 return Status::OK();
622 }
623
624 bool from_is_switch = IsSwitch(*from_node);
625 MutableGraphView::OutputPort to_control(to_node, Graph::kControlSlot);
626 auto to_control_fanouts = fanouts().find(to_control);
627 if (from_is_switch && HasFanoutValue(fanouts(), to_control_fanouts)) {
628 return error_status(SwapNodeNamesSwitchControlErrorMsg(from_node_name));
629 }
630
631 bool to_is_switch = IsSwitch(*to_node);
632 MutableGraphView::OutputPort from_control(from_node, Graph::kControlSlot);
633 auto from_control_fanouts = fanouts().find(from_control);
634 if (to_is_switch && HasFanoutValue(fanouts(), from_control_fanouts)) {
635 return error_status(SwapNodeNamesSwitchControlErrorMsg(to_node_name));
636 }
637
638 // Swap node names.
639 swap_names();
640
641 // Swap controlling fanouts.
642 //
643 // Note: To and from control fanout iterators are still valid as no mutations
644 // has been performed on fanouts().
645 SwapFanoutsMapValues(&fanouts(), from_control, from_control_fanouts,
646 to_control, to_control_fanouts);
647
648 // Swap regular fanouts.
649 SwapRegularFanoutsAndMaxPortValues(&fanouts(), &max_regular_output_port(),
650 from_node, to_node);
651
652 // Update fanins to remove self loops.
653 auto update_fanins = [this](NodeDef* node, absl::string_view old_node_name) {
654 for (int i = 0; i < node->input_size(); ++i) {
655 TensorId tensor_id = ParseTensorName(node->input(i));
656 if (tensor_id.node() == node->name()) {
657 const int idx = tensor_id.index();
658 const int node_idx =
659 IsTensorIdControlling(tensor_id) ? Graph::kControlSlot : i;
660
661 MutableGraphView::OutputPort from_fanin(node, idx);
662 absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin];
663 from_fanouts->erase({node, node_idx});
664 UpdateMaxRegularOutputPortForRemovedFanin(from_fanin, *from_fanouts);
665
666 MutableGraphView::OutputPort to_fanin(nodes().at(old_node_name), idx);
667 fanouts()[to_fanin].insert({node, node_idx});
668 UpdateMaxRegularOutputPortForAddedFanin(to_fanin);
669 node->set_input(i, TensorIdToString({old_node_name, idx}));
670 }
671 }
672 };
673 update_fanins(from_node, to_node->name());
674 update_fanins(to_node, from_node->name());
675
676 // Dedup control dependencies.
677 auto dedup_control_fanouts =
678 [this](NodeDef* node, const FanoutsMap::iterator& control_fanouts) {
679 if (CanDedupControlWithRegularInput(*this, *node) &&
680 control_fanouts != fanouts().end()) {
681 for (auto it = control_fanouts->second.begin();
682 it != control_fanouts->second.end();) {
683 // Advance `it` before invalidation from removal.
684 const auto& control_fanout = *it++;
685 if (HasRegularFaninNode(*this, *control_fanout.node,
686 node->name())) {
687 RemoveControllingFaninInternal(control_fanout.node, node);
688 }
689 }
690 }
691 };
692 auto dedup_switch_control = [this, dedup_control_fanouts](NodeDef* node) {
693 OutputPort port;
694 port.node = node;
695 const int max_port =
696 gtl::FindWithDefault(max_regular_output_port(), node, -1);
697 for (int i = 0; i <= max_port; ++i) {
698 port.port_id = i;
699 auto it = fanouts().find(port);
700 if (it == fanouts().end()) {
701 continue;
702 }
703 for (const auto& fanout : it->second) {
704 auto fanout_controls =
705 fanouts().find({fanout.node, Graph::kControlSlot});
706 dedup_control_fanouts(fanout.node, fanout_controls);
707 }
708 }
709 };
710
711 if (!from_is_switch) {
712 if (to_is_switch) {
713 dedup_switch_control(from_node);
714 } else {
715 // Fetch iterator again as the original iterator might have been
716 // invalidated by container rehash triggered due to mutations.
717 auto from_control_fanouts = fanouts().find(from_control);
718 dedup_control_fanouts(from_node, from_control_fanouts);
719 }
720 }
721 if (!to_is_switch) {
722 if (from_is_switch) {
723 dedup_switch_control(to_node);
724 } else {
725 // Fetch iterator again as the original iterator might have been
726 // invalidated by container rehash triggered due to mutations.
727 auto to_control_fanouts = fanouts().find(to_control);
728 dedup_control_fanouts(to_node, to_control_fanouts);
729 }
730 }
731
732 return Status::OK();
733 }
734
UpdateFanouts(absl::string_view from_node_name,absl::string_view to_node_name)735 Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name,
736 absl::string_view to_node_name) {
737 NodeDef* from_node = GetNode(from_node_name);
738 TF_RETURN_IF_ERROR(
739 CheckNodeExists(from_node_name, from_node,
740 UpdateFanoutsError(from_node_name, to_node_name)));
741 NodeDef* to_node = GetNode(to_node_name);
742 TF_RETURN_IF_ERROR(CheckNodeExists(
743 to_node_name, to_node, UpdateFanoutsError(from_node_name, to_node_name)));
744
745 return UpdateFanoutsInternal(from_node, to_node);
746 }
747
UpdateFanoutsInternal(NodeDef * from_node,NodeDef * to_node)748 Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node,
749 NodeDef* to_node) {
750 VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.",
751 from_node->name(), to_node->name());
752 if (from_node == to_node) {
753 return Status::OK();
754 }
755
756 // Update internal state with the new output_port->input_port edge.
757 const auto add_edge = [this](const OutputPort& output_port,
758 const InputPort& input_port) {
759 fanouts()[output_port].insert(input_port);
760 };
761
762 // Remove invalidated edge from the internal state.
763 const auto remove_edge = [this](const OutputPort& output_port,
764 const InputPort& input_port) {
765 fanouts()[output_port].erase(input_port);
766 };
767
768 // For the control fanouts we do not know the input index in a NodeDef,
769 // so we have to traverse all control inputs.
770
771 auto control_fanouts =
772 GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
773
774 bool to_node_is_switch = IsSwitch(*to_node);
775 for (const InputPort& control_port : control_fanouts) {
776 // Node can't be control dependency of itself.
777 if (control_port.node == to_node) continue;
778
779 // Can't add Switch node as a control dependency.
780 if (to_node_is_switch) {
781 // Trying to add a Switch as a control dependency, which if allowed will
782 // make the graph invalid.
783 return UpdateFanoutsError(from_node->name(), to_node->name())(
784 absl::Substitute("can't update fanouts to node '$0' as it will "
785 "become a Switch control dependency",
786 to_node->name()));
787 }
788
789 NodeDef* node = control_port.node;
790 RemoveControllingFaninInternal(node, from_node);
791 AddFaninInternal(node, {to_node, Graph::kControlSlot});
792 }
793
794 // First we update regular fanouts. For the regular fanouts
795 // `input_port:port_id` is the input index in NodeDef.
796
797 auto regular_edges =
798 GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
799
800 // Maximum index of the `from_node` output tensor that is still used as an
801 // input to some other node.
802 int keep_max_regular_output_port = -1;
803
804 for (const Edge& edge : regular_edges) {
805 const OutputPort output_port = edge.src;
806 const InputPort input_port = edge.dst;
807
808 // If the `to_node` reads from the `from_node`, skip this edge (see
809 // AddAndUpdateFanoutsWithoutSelfLoops test for an example).
810 if (input_port.node == to_node) {
811 keep_max_regular_output_port =
812 std::max(keep_max_regular_output_port, output_port.port_id);
813 continue;
814 }
815
816 // Update input at destination node.
817 input_port.node->set_input(
818 input_port.port_id,
819 TensorIdToString({to_node->name(), output_port.port_id}));
820
821 // Remove old edge between the `from_node` and the fanout node.
822 remove_edge(output_port, input_port);
823 // Add an edge between the `to_node` and new fanout node.
824 add_edge(OutputPort(to_node, output_port.port_id), input_port);
825 // Dedup control dependency.
826 if (CanDedupControlWithRegularInput(*this, *to_node)) {
827 RemoveControllingFaninInternal(input_port.node, to_node);
828 }
829 }
830
831 // Because we update all regular fanouts of `from_node`, we can just copy
832 // the value `num_regular_outputs`.
833 max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
834
835 // Check if all fanouts were updated to read from the `to_node`.
836 if (keep_max_regular_output_port >= 0) {
837 max_regular_output_port()[from_node] = keep_max_regular_output_port;
838 } else {
839 max_regular_output_port().erase(from_node);
840 }
841
842 return Status::OK();
843 }
844
AddFaninInternal(NodeDef * node,const OutputPort & fanin)845 bool MutableGraphView::AddFaninInternal(NodeDef* node,
846 const OutputPort& fanin) {
847 int num_regular_fanins =
848 NumFanins(*node, /*include_controlling_nodes=*/false);
849 bool input_is_control = IsOutputPortControlling(fanin);
850 bool can_dedup_control_with_regular_input =
851 CanDedupControlWithRegularInput(*this, *fanin.node);
852 // Don't add duplicate control dependencies.
853 if (input_is_control) {
854 const int start =
855 can_dedup_control_with_regular_input ? 0 : num_regular_fanins;
856 for (int i = start; i < node->input_size(); ++i) {
857 if (ParseTensorName(node->input(i)).node() == fanin.node->name()) {
858 return false;
859 }
860 }
861 }
862
863 InputPort input;
864 input.node = node;
865 input.port_id = input_is_control ? Graph::kControlSlot : num_regular_fanins;
866
867 node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
868 if (!input_is_control) {
869 const int last_node_input = node->input_size() - 1;
870 // If there are control dependencies in node, move newly inserted fanin to
871 // be before such control dependencies.
872 if (num_regular_fanins < last_node_input) {
873 node->mutable_input()->SwapElements(last_node_input, num_regular_fanins);
874 }
875 }
876
877 fanouts()[fanin].insert(input);
878 if (max_regular_output_port()[fanin.node] < fanin.port_id) {
879 max_regular_output_port()[fanin.node] = fanin.port_id;
880 }
881
882 // Update max input port and dedup control dependencies.
883 if (!input_is_control) {
884 max_regular_input_port()[node] = num_regular_fanins;
885 if (can_dedup_control_with_regular_input) {
886 RemoveControllingFaninInternal(node, fanin.node);
887 }
888 }
889
890 return true;
891 }
892
AddRegularFanin(absl::string_view node_name,const TensorId & fanin)893 Status MutableGraphView::AddRegularFanin(absl::string_view node_name,
894 const TensorId& fanin) {
895 auto error_status = [node_name, fanin](absl::string_view msg) {
896 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
897 fanin.ToString());
898 return MutationError("AddRegularFanin", params, msg);
899 };
900
901 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
902 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
903 NodeDef* node = GetNode(node_name);
904 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
905 NodeDef* fanin_node = GetNode(fanin.node());
906 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
907
908 AddFaninInternal(node, {fanin_node, fanin.index()});
909 return Status::OK();
910 }
911
AddRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)912 Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name,
913 int port,
914 const TensorId& fanin) {
915 auto error_status = [node_name, port, fanin](absl::string_view msg) {
916 string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
917 node_name, port, fanin.ToString());
918 return MutationError("AddRegularFaninByPort", params, msg);
919 };
920
921 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
922 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
923 NodeDef* node = GetNode(node_name);
924 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
925 const int num_regular_fanins =
926 NumFanins(*node, /*include_controlling_nodes=*/false);
927 TF_RETURN_IF_ERROR(
928 CheckPortRange(port, /*min=*/0, num_regular_fanins, error_status));
929 NodeDef* fanin_node = GetNode(fanin.node());
930 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
931
932 const int last_node_input = node->input_size();
933 node->add_input(TensorIdToString(fanin));
934 node->mutable_input()->SwapElements(num_regular_fanins, last_node_input);
935 for (int i = num_regular_fanins - 1; i >= port; --i) {
936 TensorId tensor_id = ParseTensorName(node->input(i));
937 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
938 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
939 fanouts_set->erase({node, i});
940 fanouts_set->insert({node, i + 1});
941 node->mutable_input()->SwapElements(i, i + 1);
942 }
943
944 OutputPort fanin_port(fanin_node, fanin.index());
945 fanouts()[fanin_port].insert({node, port});
946 UpdateMaxRegularOutputPortForAddedFanin(fanin_port);
947
948 max_regular_input_port()[node] = num_regular_fanins;
949 if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
950 RemoveControllingFaninInternal(node, fanin_node);
951 }
952
953 return Status::OK();
954 }
955
GetControllingFaninToAdd(absl::string_view node_name,const OutputPort & fanin,string * error_msg)956 NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name,
957 const OutputPort& fanin,
958 string* error_msg) {
959 if (!IsSwitch(*fanin.node)) {
960 return fanin.node;
961 } else {
962 if (IsOutputPortControlling(fanin)) {
963 // Can't add a Switch node control dependency.
964 TensorId tensor_id(fanin.node->name(), fanin.port_id);
965 *error_msg = absl::Substitute(
966 "can't add fanin '$0' as it will become a Switch control dependency",
967 tensor_id.ToString());
968 return nullptr;
969 }
970 // We can't anchor control dependencies directly on the switch node: unlike
971 // other nodes only one of the outputs of the switch node will be generated
972 // when the switch node is executed, and we need to make sure the control
973 // dependency is only triggered when the corresponding output is triggered.
974 // We start by looking for an identity node connected to the output of the
975 // switch node, and use it to anchor the control dependency.
976 for (const auto& fanout : GetFanout(fanin)) {
977 if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
978 if (fanout.node->name() == node_name) {
979 *error_msg =
980 absl::Substitute("can't add found fanin '$0' to self",
981 AsControlDependency(fanout.node->name()));
982 return nullptr;
983 }
984 return fanout.node;
985 }
986 }
987
988 // No node found, check if node to be created is itself.
989 if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) {
990 *error_msg = absl::Substitute("can't add generated fanin '$0' to self",
991 AsControlDependency(string(node_name)));
992 }
993 }
994 return nullptr;
995 }
996
GetOrCreateIdentityConsumingSwitch(const OutputPort & fanin)997 NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch(
998 const OutputPort& fanin) {
999 // We haven't found an existing node where we can anchor the control
1000 // dependency: add a new identity node.
1001 string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin);
1002 NodeDef* identity_node = GetNode(identity_name);
1003 if (identity_node == nullptr) {
1004 NodeDef new_node;
1005 new_node.set_name(identity_name);
1006 new_node.set_op("Identity");
1007 new_node.set_device(fanin.node->device());
1008 (*new_node.mutable_attr())["T"].set_type(fanin.node->attr().at("T").type());
1009 new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
1010 identity_node = AddNode(std::move(new_node));
1011 }
1012 return identity_node;
1013 }
1014
AddControllingFanin(absl::string_view node_name,const TensorId & fanin)1015 Status MutableGraphView::AddControllingFanin(absl::string_view node_name,
1016 const TensorId& fanin) {
1017 auto error_status = [node_name, fanin](absl::string_view msg) {
1018 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1019 fanin.ToString());
1020 return MutationError("AddControllingFanin", params, msg);
1021 };
1022
1023 TF_RETURN_IF_ERROR(CheckFaninIsValid(fanin, error_status));
1024 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1025 NodeDef* node = GetNode(node_name);
1026 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1027 NodeDef* fanin_node = GetNode(fanin.node());
1028 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1029
1030 OutputPort fanin_port(fanin_node, fanin.index());
1031
1032 string error_msg = "";
1033 NodeDef* control_node = GetControllingFaninToAdd(
1034 node_name, {fanin_node, fanin.index()}, &error_msg);
1035 if (!error_msg.empty()) {
1036 return error_status(error_msg);
1037 }
1038 if (control_node == nullptr) {
1039 control_node = GetOrCreateIdentityConsumingSwitch(fanin_port);
1040 }
1041 AddFaninInternal(node, {control_node, Graph::kControlSlot});
1042
1043 return Status::OK();
1044 }
1045
RemoveRegularFaninInternal(NodeDef * node,const OutputPort & fanin)1046 bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node,
1047 const OutputPort& fanin) {
1048 auto remove_input = [this, node](const OutputPort& fanin_port,
1049 int node_input_port, bool update_max_port) {
1050 InputPort input(node, node_input_port);
1051
1052 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1053 fanouts_set->erase(input);
1054 if (update_max_port) {
1055 UpdateMaxRegularOutputPortForRemovedFanin(fanin_port, *fanouts_set);
1056 }
1057 return fanouts_set;
1058 };
1059
1060 auto mutable_inputs = node->mutable_input();
1061 bool modified = false;
1062 const int num_regular_fanins =
1063 NumFanins(*node, /*include_controlling_nodes=*/false);
1064 int i;
1065 int curr_pos = 0;
1066 for (i = 0; i < num_regular_fanins; ++i) {
1067 TensorId tensor_id = ParseTensorName(node->input(i));
1068 if (tensor_id.node() == fanin.node->name() &&
1069 tensor_id.index() == fanin.port_id) {
1070 remove_input(fanin, i, /*update_max_port=*/true);
1071 modified = true;
1072 } else if (modified) {
1073 // Regular inputs will need to have their ports updated.
1074 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1075 auto fanouts_set = remove_input(fanin_port, i, /*update_max_port=*/false);
1076 fanouts_set->insert({node, curr_pos});
1077 // Shift inputs to be retained.
1078 mutable_inputs->SwapElements(i, curr_pos);
1079 ++curr_pos;
1080 } else {
1081 // Skip inputs to be retained until first modification.
1082 ++curr_pos;
1083 }
1084 }
1085
1086 if (modified) {
1087 const int last_regular_input_port = curr_pos - 1;
1088 if (last_regular_input_port < 0) {
1089 max_regular_input_port().erase(node);
1090 } else {
1091 max_regular_input_port()[node] = last_regular_input_port;
1092 }
1093 if (curr_pos < i) {
1094 // Remove fanins from node inputs.
1095 mutable_inputs->DeleteSubrange(curr_pos, i - curr_pos);
1096 }
1097 }
1098
1099 return modified;
1100 }
1101
RemoveRegularFanin(absl::string_view node_name,const TensorId & fanin)1102 Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name,
1103 const TensorId& fanin) {
1104 auto error_status = [node_name, fanin](absl::string_view msg) {
1105 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1106 fanin.ToString());
1107 return MutationError("RemoveRegularFanin", params, msg);
1108 };
1109
1110 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1111 TF_RETURN_IF_ERROR(
1112 CheckRemovingFaninFromSelf(node_name, fanin, error_status));
1113 NodeDef* node = GetNode(node_name);
1114 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1115 NodeDef* fanin_node = GetNode(fanin.node());
1116 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1117
1118 RemoveRegularFaninInternal(node, {fanin_node, fanin.index()});
1119 return Status::OK();
1120 }
1121
RemoveRegularFaninByPort(absl::string_view node_name,int port)1122 Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name,
1123 int port) {
1124 auto error_status = [node_name, port](absl::string_view msg) {
1125 string params =
1126 absl::Substitute("node_name='$0', port=$1", node_name, port);
1127 return MutationError("RemoveRegularFaninByPort", params, msg);
1128 };
1129
1130 NodeDef* node = GetNode(node_name);
1131 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1132 const int last_regular_fanin_port =
1133 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1134 TF_RETURN_IF_ERROR(
1135 CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1136
1137 TensorId tensor_id = ParseTensorName(node->input(port));
1138 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1139 fanouts()[fanin_port].erase({node, port});
1140 auto mutable_inputs = node->mutable_input();
1141 for (int i = port + 1; i <= last_regular_fanin_port; ++i) {
1142 TensorId tensor_id = ParseTensorName(node->input(i));
1143 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1144 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1145 fanouts_set->erase({node, i});
1146 fanouts_set->insert({node, i - 1});
1147 mutable_inputs->SwapElements(i - 1, i);
1148 }
1149 const int last_node_input = node->input_size() - 1;
1150 if (last_regular_fanin_port < last_node_input) {
1151 mutable_inputs->SwapElements(last_regular_fanin_port, last_node_input);
1152 }
1153 mutable_inputs->RemoveLast();
1154
1155 const int updated_last_regular_input_port = last_regular_fanin_port - 1;
1156 if (updated_last_regular_input_port < 0) {
1157 max_regular_input_port().erase(node);
1158 } else {
1159 max_regular_input_port()[node] = updated_last_regular_input_port;
1160 }
1161
1162 return Status::OK();
1163 }
1164
RemoveControllingFaninInternal(NodeDef * node,NodeDef * fanin_node)1165 bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node,
1166 NodeDef* fanin_node) {
1167 for (int i = node->input_size() - 1; i >= 0; --i) {
1168 TensorId tensor_id = ParseTensorName(node->input(i));
1169 if (tensor_id.index() > Graph::kControlSlot) {
1170 break;
1171 }
1172 if (tensor_id.node() == fanin_node->name()) {
1173 fanouts()[{fanin_node, Graph::kControlSlot}].erase(
1174 {node, Graph::kControlSlot});
1175 node->mutable_input()->SwapElements(i, node->input_size() - 1);
1176 node->mutable_input()->RemoveLast();
1177 return true;
1178 }
1179 }
1180 return false;
1181 }
1182
RemoveControllingFanin(absl::string_view node_name,absl::string_view fanin_node_name)1183 Status MutableGraphView::RemoveControllingFanin(
1184 absl::string_view node_name, absl::string_view fanin_node_name) {
1185 auto error_status = [node_name, fanin_node_name](absl::string_view msg) {
1186 string params = absl::Substitute("node_name='$0', fanin_node_name='$1'",
1187 node_name, fanin_node_name);
1188 return MutationError("RemoveControllingFanin", params, msg);
1189 };
1190
1191 TF_RETURN_IF_ERROR(CheckRemovingFaninFromSelf(
1192 node_name, {fanin_node_name, Graph::kControlSlot}, error_status));
1193 NodeDef* node = GetNode(node_name);
1194 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1195 NodeDef* fanin_node = GetNode(fanin_node_name);
1196 TF_RETURN_IF_ERROR(
1197 CheckNodeExists(fanin_node_name, fanin_node, error_status));
1198
1199 RemoveControllingFaninInternal(node, fanin_node);
1200 return Status::OK();
1201 }
1202
RemoveAllFanins(absl::string_view node_name,bool keep_controlling_fanins)1203 Status MutableGraphView::RemoveAllFanins(absl::string_view node_name,
1204 bool keep_controlling_fanins) {
1205 NodeDef* node = GetNode(node_name);
1206 if (node == nullptr) {
1207 string params =
1208 absl::Substitute("node_name='$0', keep_controlling_fanins=$1",
1209 node_name, keep_controlling_fanins);
1210 return MutationError("RemoveAllFanins", params,
1211 NodeMissingErrorMsg(node_name));
1212 }
1213
1214 if (node->input().empty()) {
1215 return Status::OK();
1216 }
1217
1218 const int num_regular_fanins =
1219 NumFanins(*node, /*include_controlling_nodes=*/false);
1220 RemoveFaninsInternal(node, keep_controlling_fanins);
1221 if (keep_controlling_fanins) {
1222 if (num_regular_fanins == 0) {
1223 return Status::OK();
1224 } else if (num_regular_fanins < node->input_size()) {
1225 node->mutable_input()->DeleteSubrange(0, num_regular_fanins);
1226 } else {
1227 node->clear_input();
1228 }
1229 } else {
1230 node->clear_input();
1231 }
1232 return Status::OK();
1233 }
1234
UpdateFanin(absl::string_view node_name,const TensorId & from_fanin,const TensorId & to_fanin)1235 Status MutableGraphView::UpdateFanin(absl::string_view node_name,
1236 const TensorId& from_fanin,
1237 const TensorId& to_fanin) {
1238 auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) {
1239 string params =
1240 absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'",
1241 node_name, from_fanin.ToString(), to_fanin.ToString());
1242 return MutationError("UpdateFanin", params, msg);
1243 };
1244
1245 TF_RETURN_IF_ERROR(CheckFaninIsValid(from_fanin, error_status));
1246 TF_RETURN_IF_ERROR(CheckFaninIsValid(to_fanin, error_status));
1247 NodeDef* node = GetNode(node_name);
1248 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1249 NodeDef* from_fanin_node = GetNode(from_fanin.node());
1250 TF_RETURN_IF_ERROR(
1251 CheckNodeExists(from_fanin.node(), from_fanin_node, error_status));
1252 NodeDef* to_fanin_node = GetNode(to_fanin.node());
1253 TF_RETURN_IF_ERROR(
1254 CheckNodeExists(to_fanin.node(), to_fanin_node, error_status));
1255
1256 // When replacing a non control dependency fanin with a control dependency, or
1257 // vice versa, remove and add, so ports can be updated properly in fanout(s).
1258 bool to_fanin_is_control = IsTensorIdControlling(to_fanin);
1259 if (to_fanin_is_control && IsSwitch(*to_fanin_node)) {
1260 // Can't add Switch node as a control dependency.
1261 return error_status(
1262 absl::Substitute("can't update to fanin '$0' as it will become a "
1263 "Switch control dependency",
1264 to_fanin.ToString()));
1265 }
1266 if (node_name == from_fanin.node() || node_name == to_fanin.node()) {
1267 return error_status("can't update fanin to or from self");
1268 }
1269
1270 if (from_fanin == to_fanin) {
1271 return Status::OK();
1272 }
1273
1274 bool from_fanin_is_control = IsTensorIdControlling(from_fanin);
1275 if (from_fanin_is_control || to_fanin_is_control) {
1276 bool modified = false;
1277 if (from_fanin_is_control) {
1278 modified |= RemoveControllingFaninInternal(node, from_fanin_node);
1279 } else {
1280 modified |= RemoveRegularFaninInternal(
1281 node, {from_fanin_node, from_fanin.index()});
1282 }
1283 if (modified) {
1284 AddFaninInternal(node, {to_fanin_node, to_fanin.index()});
1285 }
1286 return Status::OK();
1287 }
1288
1289 // In place mutation of regular fanins, requires no shifting of ports.
1290 string to_fanin_string = TensorIdToString(to_fanin);
1291 const int num_regular_fanins =
1292 NumFanins(*node, /*include_controlling_nodes=*/false);
1293 bool modified = false;
1294 absl::flat_hash_set<InputPort>* from_fanin_port_fanouts = nullptr;
1295 absl::flat_hash_set<InputPort>* to_fanin_port_fanouts = nullptr;
1296 for (int i = 0; i < num_regular_fanins; ++i) {
1297 if (ParseTensorName(node->input(i)) == from_fanin) {
1298 InputPort input(node, i);
1299 if (from_fanin_port_fanouts == nullptr) {
1300 OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1301 from_fanin_port_fanouts = &fanouts()[from_fanin_port];
1302 }
1303 from_fanin_port_fanouts->erase(input);
1304
1305 if (to_fanin_port_fanouts == nullptr) {
1306 OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
1307 to_fanin_port_fanouts = &fanouts()[to_fanin_port];
1308 }
1309 to_fanin_port_fanouts->insert(input);
1310
1311 node->set_input(i, to_fanin_string);
1312 modified = true;
1313 }
1314 }
1315
1316 // Dedup control dependencies and update max regular output ports.
1317 if (modified) {
1318 UpdateMaxRegularOutputPortForRemovedFanin(
1319 {from_fanin_node, from_fanin.index()}, *from_fanin_port_fanouts);
1320 if (max_regular_output_port()[to_fanin_node] < to_fanin.index()) {
1321 max_regular_output_port()[to_fanin_node] = to_fanin.index();
1322 }
1323 if (CanDedupControlWithRegularInput(*this, *to_fanin_node)) {
1324 RemoveControllingFaninInternal(node, to_fanin_node);
1325 }
1326 }
1327
1328 return Status::OK();
1329 }
1330
UpdateRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)1331 Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name,
1332 int port,
1333 const TensorId& fanin) {
1334 auto error_status = [node_name, port, fanin](absl::string_view msg) {
1335 string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
1336 node_name, port, fanin.ToString());
1337 return MutationError("UpdateRegularFaninByPort", params, msg);
1338 };
1339
1340 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1341 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1342 NodeDef* node = GetNode(node_name);
1343 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1344 const int last_regular_fanin_port =
1345 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1346 TF_RETURN_IF_ERROR(
1347 CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1348 NodeDef* fanin_node = GetNode(fanin.node());
1349 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1350
1351 TensorId tensor_id = ParseTensorName(node->input(port));
1352 if (tensor_id == fanin) {
1353 return Status::OK();
1354 }
1355
1356 InputPort input(node, port);
1357 OutputPort from_fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1358 absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin_port];
1359 from_fanouts->erase(input);
1360 UpdateMaxRegularOutputPortForRemovedFanin(from_fanin_port, *from_fanouts);
1361
1362 OutputPort to_fanin_port(fanin_node, fanin.index());
1363 fanouts()[to_fanin_port].insert(input);
1364 UpdateMaxRegularOutputPortForAddedFanin(to_fanin_port);
1365
1366 node->set_input(port, TensorIdToString(fanin));
1367
1368 if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
1369 RemoveControllingFaninInternal(node, fanin_node);
1370 }
1371
1372 return Status::OK();
1373 }
1374
SwapRegularFaninsByPorts(absl::string_view node_name,int from_port,int to_port)1375 Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name,
1376 int from_port, int to_port) {
1377 auto error_status = [node_name, from_port, to_port](absl::string_view msg) {
1378 string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2",
1379 node_name, from_port, to_port);
1380 return MutationError("SwapRegularFaninsByPorts", params, msg);
1381 };
1382
1383 NodeDef* node = GetNode(node_name);
1384 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1385 const int last_regular_fanin_port =
1386 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1387 TF_RETURN_IF_ERROR(CheckPortRange(from_port, /*min=*/0,
1388 last_regular_fanin_port, error_status));
1389 TF_RETURN_IF_ERROR(CheckPortRange(to_port, /*min=*/0, last_regular_fanin_port,
1390 error_status));
1391
1392 if (from_port == to_port) {
1393 return Status::OK();
1394 }
1395 TensorId from_fanin = ParseTensorName(node->input(from_port));
1396 TensorId to_fanin = ParseTensorName(node->input(to_port));
1397 if (from_fanin == to_fanin) {
1398 return Status::OK();
1399 }
1400
1401 InputPort from_input(node, from_port);
1402 InputPort to_input(node, to_port);
1403 NodeDef* from_fanin_node = GetNode(from_fanin.node());
1404 absl::flat_hash_set<InputPort>* from_fanouts =
1405 &fanouts()[{from_fanin_node, from_fanin.index()}];
1406 from_fanouts->erase(from_input);
1407 from_fanouts->insert(to_input);
1408 NodeDef* to_fanin_node = GetNode(to_fanin.node());
1409 absl::flat_hash_set<InputPort>* to_fanouts =
1410 &fanouts()[{to_fanin_node, to_fanin.index()}];
1411 to_fanouts->erase(to_input);
1412 to_fanouts->insert(from_input);
1413
1414 node->mutable_input()->SwapElements(from_port, to_port);
1415
1416 return Status::OK();
1417 }
1418
UpdateAllRegularFaninsToControlling(absl::string_view node_name)1419 Status MutableGraphView::UpdateAllRegularFaninsToControlling(
1420 absl::string_view node_name) {
1421 auto error_status = [node_name](absl::string_view msg) {
1422 string params = absl::Substitute("node_name='$0'", node_name);
1423 return MutationError("UpdateAllRegularFaninsToControlling", params, msg);
1424 };
1425
1426 NodeDef* node = GetNode(node_name);
1427 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1428
1429 const int num_regular_fanins =
1430 NumFanins(*node, /*include_controlling_nodes=*/false);
1431 std::vector<OutputPort> regular_fanins;
1432 regular_fanins.reserve(num_regular_fanins);
1433 std::vector<NodeDef*> controlling_fanins;
1434 controlling_fanins.reserve(num_regular_fanins);
1435
1436 // Get all regular fanins and derive controlling fanins.
1437 for (int i = 0; i < num_regular_fanins; ++i) {
1438 TensorId tensor_id = ParseTensorName(node->input(i));
1439 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1440
1441 string error_msg = "";
1442 NodeDef* control_node =
1443 GetControllingFaninToAdd(node_name, fanin_port, &error_msg);
1444 if (!error_msg.empty()) {
1445 return error_status(error_msg);
1446 }
1447
1448 regular_fanins.push_back(fanin_port);
1449 controlling_fanins.push_back(control_node);
1450 }
1451
1452 // Replace regular fanins with controlling fanins and dedup.
1453 int pos = 0;
1454 InputPort input_port(node, Graph::kControlSlot);
1455 absl::flat_hash_set<absl::string_view> controls;
1456 for (int i = 0; i < num_regular_fanins; ++i) {
1457 OutputPort fanin_port = regular_fanins[i];
1458 NodeDef* control = controlling_fanins[i];
1459 if (control == nullptr) {
1460 control = GetOrCreateIdentityConsumingSwitch(fanin_port);
1461 }
1462 fanouts()[fanin_port].erase({node, i});
1463 if (controls.contains(control->name())) {
1464 continue;
1465 }
1466 controls.insert(control->name());
1467 node->set_input(pos, AsControlDependency(control->name()));
1468 fanouts()[{control, Graph::kControlSlot}].insert(input_port);
1469 ++pos;
1470 }
1471
1472 // Shift existing controlling fanins and dedup.
1473 for (int i = num_regular_fanins; i < node->input_size(); ++i) {
1474 TensorId tensor_id = ParseTensorName(node->input(i));
1475 if (controls.contains(tensor_id.node())) {
1476 continue;
1477 }
1478 controls.insert(tensor_id.node());
1479 node->mutable_input()->SwapElements(pos, i);
1480 ++pos;
1481 }
1482
1483 // Remove duplicate controls and leftover regular fanins.
1484 node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos);
1485 max_regular_input_port().erase(node);
1486
1487 return Status::OK();
1488 }
1489
CheckNodesCanBeDeleted(const absl::flat_hash_set<string> & nodes_to_delete)1490 Status MutableGraphView::CheckNodesCanBeDeleted(
1491 const absl::flat_hash_set<string>& nodes_to_delete) {
1492 std::vector<string> missing_nodes;
1493 std::vector<string> nodes_with_fanouts;
1494 for (const string& node_name_to_delete : nodes_to_delete) {
1495 NodeDef* node = GetNode(node_name_to_delete);
1496 if (node == nullptr) {
1497 // Can't delete missing node.
1498 missing_nodes.push_back(node_name_to_delete);
1499 continue;
1500 }
1501 const int max_port = gtl::FindWithDefault(max_regular_output_port(), node,
1502 Graph::kControlSlot);
1503 for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1504 auto it = fanouts().find({node, i});
1505 bool has_retained_fanout = false;
1506 if (it != fanouts().end()) {
1507 for (const auto& fanout : it->second) {
1508 // Check if fanouts are of nodes to be deleted, and if so, they can be
1509 // ignored, as they will be removed also.
1510 if (!nodes_to_delete.contains(fanout.node->name())) {
1511 // Removing node will leave graph in an invalid state.
1512 has_retained_fanout = true;
1513 break;
1514 }
1515 }
1516 }
1517 if (has_retained_fanout) {
1518 nodes_with_fanouts.push_back(node_name_to_delete);
1519 break;
1520 }
1521 }
1522 }
1523
1524 // Error message can get quite long, so we only show the first 5 node names.
1525 auto sort_and_sample = [](std::vector<string>* s) {
1526 constexpr int kMaxNodeNames = 5;
1527 std::sort(s->begin(), s->end());
1528 if (s->size() > kMaxNodeNames) {
1529 return absl::StrCat(
1530 absl::StrJoin(s->begin(), s->begin() + kMaxNodeNames, ", "), ", ...");
1531 }
1532 return absl::StrJoin(*s, ", ");
1533 };
1534
1535 if (!missing_nodes.empty()) {
1536 VLOG(2) << absl::Substitute("Attempting to delete missing node(s) [$0].",
1537 sort_and_sample(&missing_nodes));
1538 }
1539 if (!nodes_with_fanouts.empty()) {
1540 std::vector<string> input_node_names(nodes_to_delete.begin(),
1541 nodes_to_delete.end());
1542 string params = absl::Substitute("nodes_to_delete={$0}",
1543 sort_and_sample(&input_node_names));
1544 string error_msg =
1545 absl::Substitute("can't delete node(s) with retained fanouts(s) [$0]",
1546 sort_and_sample(&nodes_with_fanouts));
1547 return MutationError("DeleteNodes", params, error_msg);
1548 }
1549
1550 return Status::OK();
1551 }
1552
DeleteNodes(const absl::flat_hash_set<string> & nodes_to_delete)1553 Status MutableGraphView::DeleteNodes(
1554 const absl::flat_hash_set<string>& nodes_to_delete) {
1555 TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete));
1556
1557 // Find nodes in internal state and delete.
1558 for (const string& node_name_to_delete : nodes_to_delete) {
1559 NodeDef* node = GetNode(node_name_to_delete);
1560 if (node != nullptr) {
1561 RemoveFaninsInternal(node, /*keep_controlling_fanins=*/false);
1562 RemoveFanoutsInternal(node);
1563 }
1564 }
1565 for (const string& node_name_to_delete : nodes_to_delete) {
1566 nodes().erase(node_name_to_delete);
1567 }
1568
1569 // Find nodes in graph and delete by partitioning into nodes to retain and
1570 // nodes to delete based on input set of nodes to delete by name.
1571 // TODO(lyandy): Use a node name->idx hashmap if this is a performance
1572 // bottleneck.
1573 int pos = 0;
1574 const int last_idx = graph()->node_size() - 1;
1575 int last_pos = last_idx;
1576 while (pos <= last_pos) {
1577 if (nodes_to_delete.contains(graph()->node(pos).name())) {
1578 graph()->mutable_node()->SwapElements(pos, last_pos);
1579 --last_pos;
1580 } else {
1581 ++pos;
1582 }
1583 }
1584 if (last_pos < last_idx) {
1585 graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
1586 }
1587
1588 return Status::OK();
1589 }
1590
RemoveFaninsInternal(NodeDef * deleted_node,bool keep_controlling_fanins)1591 void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
1592 bool keep_controlling_fanins) {
1593 for (int i = 0; i < deleted_node->input_size(); ++i) {
1594 TensorId tensor_id = ParseTensorName(deleted_node->input(i));
1595 bool is_control = IsTensorIdControlling(tensor_id);
1596 if (keep_controlling_fanins && is_control) {
1597 break;
1598 }
1599 OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
1600
1601 InputPort input;
1602 input.node = deleted_node;
1603 input.port_id = is_control ? Graph::kControlSlot : i;
1604
1605 auto it = fanouts().find(fanin);
1606 if (it != fanouts().end()) {
1607 absl::flat_hash_set<InputPort>* fanouts_set = &it->second;
1608 fanouts_set->erase(input);
1609 UpdateMaxRegularOutputPortForRemovedFanin(fanin, *fanouts_set);
1610 }
1611 }
1612 max_regular_input_port().erase(deleted_node);
1613 }
1614
RemoveFanoutsInternal(NodeDef * deleted_node)1615 void MutableGraphView::RemoveFanoutsInternal(NodeDef* deleted_node) {
1616 const int max_port =
1617 gtl::FindWithDefault(max_regular_output_port(), deleted_node, -1);
1618 for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1619 fanouts().erase({deleted_node, i});
1620 }
1621 max_regular_output_port().erase(deleted_node);
1622 }
1623
1624 } // end namespace grappler
1625 } // end namespace tensorflow
1626