1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/mutable_graph_view.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/grappler/op_types.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace tensorflow {
39 namespace grappler {
40
41 namespace {
42
IsTensorIdPortValid(const TensorId & tensor_id)43 bool IsTensorIdPortValid(const TensorId& tensor_id) {
44 return tensor_id.index() >= Graph::kControlSlot;
45 }
46
IsTensorIdRegular(const TensorId & tensor_id)47 bool IsTensorIdRegular(const TensorId& tensor_id) {
48 return tensor_id.index() > Graph::kControlSlot;
49 }
50
IsTensorIdControlling(const TensorId & tensor_id)51 bool IsTensorIdControlling(const TensorId& tensor_id) {
52 return tensor_id.index() == Graph::kControlSlot;
53 }
54
IsOutputPortControlling(const MutableGraphView::OutputPort & port)55 bool IsOutputPortControlling(const MutableGraphView::OutputPort& port) {
56 return port.port_id == Graph::kControlSlot;
57 }
58
59 // Determines if node is an Identity where it's first regular input is a Switch
60 // node.
IsIdentityConsumingSwitch(const MutableGraphView & graph,const NodeDef & node)61 bool IsIdentityConsumingSwitch(const MutableGraphView& graph,
62 const NodeDef& node) {
63 if ((IsIdentity(node) || IsIdentityNSingleInput(node)) &&
64 node.input_size() > 0) {
65 TensorId tensor_id = ParseTensorName(node.input(0));
66 if (IsTensorIdControlling(tensor_id)) {
67 return false;
68 }
69
70 NodeDef* input_node = graph.GetNode(tensor_id.node());
71 return IsSwitch(*input_node);
72 }
73 return false;
74 }
75
76 // Determines if node input can be deduped by regular inputs when used as a
77 // control dependency. Specifically, if a node is an Identity that leads to a
78 // Switch node, when used as a control dependency, that control dependency
79 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,const NodeDef & control_node)80 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
81 const NodeDef& control_node) {
82 return !IsIdentityConsumingSwitch(graph, control_node);
83 }
84
85 // Determines if node input can be deduped by regular inputs when used as a
86 // control dependency. Specifically, if a node is an Identity that leads to a
87 // Switch node, when used as a control dependency, that control dependency
88 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,absl::string_view control_node_name)89 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
90 absl::string_view control_node_name) {
91 NodeDef* control_node = graph.GetNode(control_node_name);
92 DCHECK(control_node != nullptr)
93 << "Didn't find a node for control dependency: " << control_node_name;
94 return CanDedupControlWithRegularInput(graph, *control_node);
95 }
96
HasRegularFaninNode(const MutableGraphView & graph,const NodeDef & node,absl::string_view fanin_node_name)97 bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node,
98 absl::string_view fanin_node_name) {
99 const int num_regular_fanins =
100 graph.NumFanins(node, /*include_controlling_nodes=*/false);
101 for (int i = 0; i < num_regular_fanins; ++i) {
102 if (ParseTensorName(node.input(i)).node() == fanin_node_name) {
103 return true;
104 }
105 }
106 return false;
107 }
108
109 using FanoutsMap =
110 absl::flat_hash_map<MutableGraphView::OutputPort,
111 absl::flat_hash_set<MutableGraphView::InputPort>>;
112
SwapControlledFanoutInputs(const MutableGraphView & graph,const FanoutsMap::iterator & control_fanouts,absl::string_view to_node_name)113 void SwapControlledFanoutInputs(const MutableGraphView& graph,
114 const FanoutsMap::iterator& control_fanouts,
115 absl::string_view to_node_name) {
116 absl::string_view from_node_name(control_fanouts->first.node->name());
117 string control = TensorIdToString({to_node_name, Graph::kControlSlot});
118 for (const auto& control_fanout : control_fanouts->second) {
119 const int start = graph.NumFanins(*control_fanout.node,
120 /*include_controlling_nodes=*/false);
121 for (int i = start; i < control_fanout.node->input_size(); ++i) {
122 TensorId tensor_id = ParseTensorName(control_fanout.node->input(i));
123 if (tensor_id.node() == from_node_name) {
124 control_fanout.node->set_input(i, control);
125 break;
126 }
127 }
128 }
129 }
130
SwapRegularFanoutInputs(FanoutsMap * fanouts,NodeDef * from_node,absl::string_view to_node_name,int max_port)131 void SwapRegularFanoutInputs(FanoutsMap* fanouts, NodeDef* from_node,
132 absl::string_view to_node_name, int max_port) {
133 MutableGraphView::OutputPort port;
134 port.node = from_node;
135 for (int i = 0; i <= max_port; ++i) {
136 port.port_id = i;
137 auto it = fanouts->find(port);
138 if (it == fanouts->end()) {
139 continue;
140 }
141 string input = TensorIdToString({to_node_name, i});
142 for (const auto& fanout : it->second) {
143 fanout.node->set_input(fanout.port_id, input);
144 }
145 }
146 }
147
148 using MaxOutputPortsMap = absl::flat_hash_map<const NodeDef*, int>;
149
SwapFanoutInputs(const MutableGraphView & graph,FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)150 void SwapFanoutInputs(const MutableGraphView& graph, FanoutsMap* fanouts,
151 MaxOutputPortsMap* max_output_ports, NodeDef* from_node,
152 NodeDef* to_node) {
153 auto from_control_fanouts = fanouts->find({from_node, Graph::kControlSlot});
154 if (from_control_fanouts != fanouts->end()) {
155 SwapControlledFanoutInputs(graph, from_control_fanouts, to_node->name());
156 }
157 auto to_control_fanouts = fanouts->find({to_node, Graph::kControlSlot});
158 if (to_control_fanouts != fanouts->end()) {
159 SwapControlledFanoutInputs(graph, to_control_fanouts, from_node->name());
160 }
161 auto from_max_port = max_output_ports->find(from_node);
162 if (from_max_port != max_output_ports->end()) {
163 SwapRegularFanoutInputs(fanouts, from_node, to_node->name(),
164 from_max_port->second);
165 }
166 auto to_max_port = max_output_ports->find(to_node);
167 if (to_max_port != max_output_ports->end()) {
168 SwapRegularFanoutInputs(fanouts, to_node, from_node->name(),
169 to_max_port->second);
170 }
171 }
172
SwapFanoutsMapValues(FanoutsMap * fanouts,const MutableGraphView::OutputPort & from_port,const FanoutsMap::iterator & from_fanouts,const MutableGraphView::OutputPort & to_port,const FanoutsMap::iterator & to_fanouts)173 void SwapFanoutsMapValues(FanoutsMap* fanouts,
174 const MutableGraphView::OutputPort& from_port,
175 const FanoutsMap::iterator& from_fanouts,
176 const MutableGraphView::OutputPort& to_port,
177 const FanoutsMap::iterator& to_fanouts) {
178 const bool from_exists = from_fanouts != fanouts->end();
179 const bool to_exists = to_fanouts != fanouts->end();
180
181 if (from_exists && to_exists) {
182 std::swap(from_fanouts->second, to_fanouts->second);
183 } else if (from_exists) {
184 fanouts->emplace(to_port, std::move(from_fanouts->second));
185 fanouts->erase(from_port);
186 } else if (to_exists) {
187 fanouts->emplace(from_port, std::move(to_fanouts->second));
188 fanouts->erase(to_port);
189 }
190 }
191
SwapRegularFanoutsAndMaxPortValues(FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)192 void SwapRegularFanoutsAndMaxPortValues(FanoutsMap* fanouts,
193 MaxOutputPortsMap* max_output_ports,
194 NodeDef* from_node, NodeDef* to_node) {
195 auto from_max_port = max_output_ports->find(from_node);
196 auto to_max_port = max_output_ports->find(to_node);
197 bool from_exists = from_max_port != max_output_ports->end();
198 bool to_exists = to_max_port != max_output_ports->end();
199
200 auto forward_fanouts = [fanouts](NodeDef* from, NodeDef* to, int start,
201 int end) {
202 for (int i = start; i <= end; ++i) {
203 MutableGraphView::OutputPort from_port(from, i);
204 auto from_fanouts = fanouts->find(from_port);
205 if (from_fanouts != fanouts->end()) {
206 MutableGraphView::OutputPort to_port(to, i);
207 fanouts->emplace(to_port, std::move(from_fanouts->second));
208 fanouts->erase(from_port);
209 }
210 }
211 };
212
213 if (from_exists && to_exists) {
214 const int from = from_max_port->second;
215 const int to = to_max_port->second;
216 const int shared = std::min(from, to);
217 for (int i = 0; i <= shared; ++i) {
218 MutableGraphView::OutputPort from_port(from_node, i);
219 auto from_fanouts = fanouts->find(from_port);
220 MutableGraphView::OutputPort to_port(to_node, i);
221 auto to_fanouts = fanouts->find(to_port);
222 SwapFanoutsMapValues(fanouts, from_port, from_fanouts, to_port,
223 to_fanouts);
224 }
225 if (to > from) {
226 forward_fanouts(to_node, from_node, shared + 1, to);
227 } else if (from > to) {
228 forward_fanouts(from_node, to_node, shared + 1, from);
229 }
230
231 std::swap(from_max_port->second, to_max_port->second);
232 } else if (from_exists) {
233 forward_fanouts(from_node, to_node, 0, from_max_port->second);
234
235 max_output_ports->emplace(to_node, from_max_port->second);
236 max_output_ports->erase(from_node);
237 } else if (to_exists) {
238 forward_fanouts(to_node, from_node, 0, to_max_port->second);
239
240 max_output_ports->emplace(from_node, to_max_port->second);
241 max_output_ports->erase(to_node);
242 }
243 }
244
HasFanoutValue(const FanoutsMap & fanouts,const FanoutsMap::iterator & it)245 bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) {
246 return it != fanouts.end() && !it->second.empty();
247 }
248
MutationError(absl::string_view function_name,absl::string_view params,absl::string_view msg)249 Status MutationError(absl::string_view function_name, absl::string_view params,
250 absl::string_view msg) {
251 return errors::InvalidArgument(absl::Substitute(
252 "MutableGraphView::$0($1) error: $2.", function_name, params, msg));
253 }
254
255 using ErrorHandler = std::function<Status(absl::string_view)>;
256
UpdateFanoutsError(absl::string_view from_node_name,absl::string_view to_node_name)257 ErrorHandler UpdateFanoutsError(absl::string_view from_node_name,
258 absl::string_view to_node_name) {
259 return [from_node_name, to_node_name](absl::string_view msg) {
260 string params = absl::Substitute("from_node_name='$0', to_node_name='$1'",
261 from_node_name, to_node_name);
262 return MutationError("UpdateFanouts", params, msg);
263 };
264 }
265
CheckFaninIsRegular(const TensorId & fanin,ErrorHandler handler)266 Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) {
267 if (!IsTensorIdRegular(fanin)) {
268 return handler(absl::Substitute("fanin '$0' must be a regular tensor id",
269 fanin.ToString()));
270 }
271 return Status::OK();
272 }
273
CheckFaninIsValid(const TensorId & fanin,ErrorHandler handler)274 Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) {
275 if (!IsTensorIdPortValid(fanin)) {
276 return handler(absl::Substitute("fanin '$0' must be a valid tensor id",
277 fanin.ToString()));
278 }
279 return Status::OK();
280 }
281
CheckAddingFaninToSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)282 Status CheckAddingFaninToSelf(absl::string_view node_name,
283 const TensorId& fanin, ErrorHandler handler) {
284 if (node_name == fanin.node()) {
285 return handler(
286 absl::Substitute("can't add fanin '$0' to self", fanin.ToString()));
287 }
288 return Status::OK();
289 }
290
CheckRemovingFaninFromSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)291 Status CheckRemovingFaninFromSelf(absl::string_view node_name,
292 const TensorId& fanin, ErrorHandler handler) {
293 if (node_name == fanin.node()) {
294 return handler(absl::Substitute("can't remove fanin '$0' from self",
295 fanin.ToString()));
296 }
297 return Status::OK();
298 }
299
NodeMissingErrorMsg(absl::string_view node_name)300 string NodeMissingErrorMsg(absl::string_view node_name) {
301 return absl::Substitute("node '$0' was not found", node_name);
302 }
303
CheckNodeExists(absl::string_view node_name,NodeDef * node,ErrorHandler handler)304 Status CheckNodeExists(absl::string_view node_name, NodeDef* node,
305 ErrorHandler handler) {
306 if (node == nullptr) {
307 return handler(NodeMissingErrorMsg(node_name));
308 }
309 return Status::OK();
310 }
311
CheckPortRange(int port,int min,int max,ErrorHandler handler)312 Status CheckPortRange(int port, int min, int max, ErrorHandler handler) {
313 if (port < min || port > max) {
314 if (max < min) {
315 return handler("no available ports as node has no regular fanins");
316 }
317 return handler(
318 absl::Substitute("port must be in range [$0, $1]", min, max));
319 }
320 return Status::OK();
321 }
322
SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name)323 string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) {
324 return absl::Substitute(
325 "can't swap node name '$0' as it will become a Switch control dependency",
326 node_name);
327 }
328
GeneratedNameForIdentityConsumingSwitch(const MutableGraphView::OutputPort & fanin)329 string GeneratedNameForIdentityConsumingSwitch(
330 const MutableGraphView::OutputPort& fanin) {
331 return AddPrefixToNodeName(
332 absl::StrCat(fanin.node->name(), "_", fanin.port_id),
333 kMutableGraphViewCtrl);
334 }
335
336 } // namespace
337
AddAndDedupFanouts(NodeDef * node)338 void MutableGraphView::AddAndDedupFanouts(NodeDef* node) {
339 // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
340 // exist, and all regular fanins come before controlling fanins.
341 absl::flat_hash_set<absl::string_view> fanins;
342 absl::flat_hash_set<absl::string_view> controlling_fanins;
343 int max_input_port = -1;
344 int pos = 0;
345 const int last_idx = node->input_size() - 1;
346 int last_pos = last_idx;
347 while (pos <= last_pos) {
348 TensorId tensor_id = ParseTensorName(node->input(pos));
349 absl::string_view input_node_name = tensor_id.node();
350 bool is_control_input = IsTensorIdControlling(tensor_id);
351 bool can_dedup_control_with_regular_input =
352 CanDedupControlWithRegularInput(*this, input_node_name);
353 bool can_dedup_control =
354 is_control_input && (can_dedup_control_with_regular_input ||
355 (!can_dedup_control_with_regular_input &&
356 controlling_fanins.contains(input_node_name)));
357 if (!gtl::InsertIfNotPresent(&fanins, input_node_name) &&
358 can_dedup_control) {
359 node->mutable_input()->SwapElements(pos, last_pos);
360 --last_pos;
361 } else {
362 OutputPort output(nodes()[input_node_name], tensor_id.index());
363
364 if (is_control_input) {
365 fanouts()[output].emplace(node, Graph::kControlSlot);
366 } else {
367 max_input_port = pos;
368 max_regular_output_port()[output.node] =
369 std::max(max_regular_output_port()[output.node], output.port_id);
370 fanouts()[output].emplace(node, pos);
371 }
372 ++pos;
373 }
374 if (is_control_input) {
375 controlling_fanins.insert(input_node_name);
376 }
377 }
378
379 if (last_pos < last_idx) {
380 node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
381 }
382
383 if (max_input_port > -1) {
384 max_regular_input_port()[node] = max_input_port;
385 }
386 }
387
UpdateMaxRegularOutputPortForRemovedFanin(const OutputPort & fanin,const absl::flat_hash_set<InputPort> & fanin_fanouts)388 void MutableGraphView::UpdateMaxRegularOutputPortForRemovedFanin(
389 const OutputPort& fanin,
390 const absl::flat_hash_set<InputPort>& fanin_fanouts) {
391 int max_port = max_regular_output_port()[fanin.node];
392 if (!fanin_fanouts.empty() || max_port != fanin.port_id) {
393 return;
394 }
395 bool updated_max_port = false;
396 for (int i = fanin.port_id - 1; i >= 0; --i) {
397 OutputPort fanin_port(fanin.node, i);
398 if (!fanouts()[fanin_port].empty()) {
399 max_regular_output_port()[fanin.node] = i;
400 updated_max_port = true;
401 break;
402 }
403 }
404 if (!updated_max_port) {
405 max_regular_output_port().erase(fanin.node);
406 }
407 }
408
UpdateMaxRegularOutputPortForAddedFanin(const OutputPort & fanin)409 void MutableGraphView::UpdateMaxRegularOutputPortForAddedFanin(
410 const OutputPort& fanin) {
411 if (max_regular_output_port()[fanin.node] < fanin.port_id) {
412 max_regular_output_port()[fanin.node] = fanin.port_id;
413 }
414 }
415
416 const absl::flat_hash_set<MutableGraphView::InputPort>&
GetFanout(const GraphView::OutputPort & port) const417 MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
418 return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
419 port.port_id));
420 }
421
GetFanin(const GraphView::InputPort & port) const422 absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
423 const GraphView::InputPort& port) const {
424 return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
425 port.port_id));
426 }
427
GetRegularFanin(const GraphView::InputPort & port) const428 const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
429 const GraphView::InputPort& port) const {
430 return GetRegularFanin(MutableGraphView::InputPort(
431 const_cast<NodeDef*>(port.node), port.port_id));
432 }
433
AddNode(NodeDef && node)434 NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
435 auto* node_in_graph = graph()->add_node();
436 *node_in_graph = std::move(node);
437
438 AddUniqueNodeOrDie(node_in_graph);
439
440 AddAndDedupFanouts(node_in_graph);
441 return node_in_graph;
442 }
443
AddSubgraph(GraphDef && subgraph)444 Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
445 // 1. Add all new functions and check that functions with the same name
446 // have identical definition.
447 const int function_size = subgraph.library().function_size();
448 if (function_size > 0) {
449 absl::flat_hash_map<absl::string_view, const FunctionDef*> graph_fdefs;
450 for (const FunctionDef& fdef : graph()->library().function()) {
451 graph_fdefs.emplace(fdef.signature().name(), &fdef);
452 }
453
454 for (FunctionDef& fdef : *subgraph.mutable_library()->mutable_function()) {
455 const auto graph_fdef = graph_fdefs.find(fdef.signature().name());
456
457 if (graph_fdef == graph_fdefs.end()) {
458 VLOG(3) << "Add new function definition: " << fdef.signature().name();
459 graph()->mutable_library()->add_function()->Swap(&fdef);
460 } else {
461 if (!FunctionDefsEqual(fdef, *graph_fdef->second)) {
462 return MutationError(
463 "AddSubgraph",
464 absl::Substitute("function_size=$0", function_size),
465 absl::StrCat(
466 "Found different function definition with the same name: ",
467 fdef.signature().name()));
468 }
469 }
470 }
471 }
472
473 // 2. Add all nodes to the underlying graph.
474 int node_size_before = graph()->node_size();
475
476 for (NodeDef& node : *subgraph.mutable_node()) {
477 auto* node_in_graph = graph()->add_node();
478 node_in_graph->Swap(&node);
479 TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
480 }
481
482 // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
483 // fanins actually exists in the graph, and there is already TODO for that.
484
485 for (int i = node_size_before; i < graph()->node_size(); ++i) {
486 NodeDef* node = graph()->mutable_node(i);
487 AddAndDedupFanouts(node);
488 }
489
490 return Status::OK();
491 }
492
UpdateNode(absl::string_view node_name,absl::string_view op,absl::string_view device,absl::Span<const std::pair<string,AttrValue>> attrs)493 Status MutableGraphView::UpdateNode(
494 absl::string_view node_name, absl::string_view op, absl::string_view device,
495 absl::Span<const std::pair<string, AttrValue>> attrs) {
496 auto error_status = [node_name, op, device, attrs](absl::string_view msg) {
497 std::vector<string> attr_strs;
498 attr_strs.reserve(attrs.size());
499 for (const auto& attr : attrs) {
500 string attr_str = absl::Substitute("('$0', $1)", attr.first,
501 attr.second.ShortDebugString());
502 attr_strs.push_back(attr_str);
503 }
504 string params =
505 absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}",
506 node_name, op, device, absl::StrJoin(attr_strs, ", "));
507 return MutationError("UpdateNodeOp", params, msg);
508 };
509
510 NodeDef* node = GetNode(node_name);
511 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
512
513 MutableGraphView::OutputPort control_port(node, Graph::kControlSlot);
514 auto control_fanouts = GetFanout(control_port);
515 if (op == "Switch" && !control_fanouts.empty()) {
516 return error_status(
517 "can't change node op to Switch when node drives a control dependency "
518 "(alternatively, we could add the identity node needed, but it seems "
519 "like an unlikely event and probably a mistake)");
520 }
521
522 if (node->device() != device) {
523 node->set_device(string(device));
524 }
525 node->mutable_attr()->clear();
526 for (const auto& attr : attrs) {
527 (*node->mutable_attr())[attr.first] = attr.second;
528 }
529
530 if (node->op() == op) {
531 return Status::OK();
532 }
533
534 node->set_op(string(op));
535
536 if (CanDedupControlWithRegularInput(*this, *node)) {
537 for (const auto& control_fanout : control_fanouts) {
538 if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) {
539 RemoveControllingFaninInternal(control_fanout.node, node);
540 }
541 }
542 }
543
544 return Status::OK();
545 }
546
UpdateNodeName(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)547 Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name,
548 absl::string_view to_node_name,
549 bool update_fanouts) {
550 auto error_status = [from_node_name, to_node_name,
551 update_fanouts](absl::string_view msg) {
552 string params = absl::Substitute(
553 "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
554 from_node_name, to_node_name, update_fanouts);
555 return MutationError("UpdateNodeName", params, msg);
556 };
557
558 NodeDef* node = GetNode(from_node_name);
559 TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status));
560
561 if (node->name() == to_node_name) {
562 return Status::OK();
563 }
564 if (HasNode(to_node_name)) {
565 return error_status(
566 "can't update node name because new node name is in use");
567 }
568 auto max_output_port = max_regular_output_port().find(node);
569 const bool has_max_output_port =
570 max_output_port != max_regular_output_port().end();
571 auto control_fanouts = fanouts().find({node, Graph::kControlSlot});
572
573 if (update_fanouts) {
574 SwapControlledFanoutInputs(*this, control_fanouts, to_node_name);
575 if (has_max_output_port) {
576 SwapRegularFanoutInputs(&fanouts(), node, to_node_name,
577 max_output_port->second);
578 }
579 } else if (has_max_output_port ||
580 HasFanoutValue(fanouts(), control_fanouts)) {
581 return error_status("can't update node name because node has fanouts");
582 }
583
584 nodes().erase(node->name());
585 node->set_name(string(to_node_name));
586 nodes().emplace(node->name(), node);
587 return Status::OK();
588 }
589
SwapNodeNames(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)590 Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name,
591 absl::string_view to_node_name,
592 bool update_fanouts) {
593 auto error_status = [from_node_name, to_node_name,
594 update_fanouts](absl::string_view msg) {
595 string params = absl::Substitute(
596 "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
597 from_node_name, to_node_name, update_fanouts);
598 return MutationError("SwapNodeNames", params, msg);
599 };
600
601 NodeDef* from_node = GetNode(from_node_name);
602 TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status));
603 if (from_node_name == to_node_name) {
604 return Status::OK();
605 }
606 NodeDef* to_node = GetNode(to_node_name);
607 TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status));
608
609 auto swap_names = [this, from_node, to_node]() {
610 nodes().erase(from_node->name());
611 nodes().erase(to_node->name());
612 std::swap(*from_node->mutable_name(), *to_node->mutable_name());
613 nodes().emplace(from_node->name(), from_node);
614 nodes().emplace(to_node->name(), to_node);
615 };
616
617 if (update_fanouts) {
618 SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node,
619 to_node);
620 swap_names();
621 return Status::OK();
622 }
623
624 bool from_is_switch = IsSwitch(*from_node);
625 MutableGraphView::OutputPort to_control(to_node, Graph::kControlSlot);
626 auto to_control_fanouts = fanouts().find(to_control);
627 if (from_is_switch && HasFanoutValue(fanouts(), to_control_fanouts)) {
628 return error_status(SwapNodeNamesSwitchControlErrorMsg(from_node_name));
629 }
630
631 bool to_is_switch = IsSwitch(*to_node);
632 MutableGraphView::OutputPort from_control(from_node, Graph::kControlSlot);
633 auto from_control_fanouts = fanouts().find(from_control);
634 if (to_is_switch && HasFanoutValue(fanouts(), from_control_fanouts)) {
635 return error_status(SwapNodeNamesSwitchControlErrorMsg(to_node_name));
636 }
637
638 // Swap node names.
639 swap_names();
640
641 // Swap controlling fanouts.
642 //
643 // Note: To and from control fanout iterators are still valid as no mutations
644 // has been performed on fanouts().
645 SwapFanoutsMapValues(&fanouts(), from_control, from_control_fanouts,
646 to_control, to_control_fanouts);
647
648 // Swap regular fanouts.
649 SwapRegularFanoutsAndMaxPortValues(&fanouts(), &max_regular_output_port(),
650 from_node, to_node);
651
652 // Update fanins to remove self loops.
653 auto update_fanins = [this](NodeDef* node, absl::string_view old_node_name) {
654 for (int i = 0; i < node->input_size(); ++i) {
655 TensorId tensor_id = ParseTensorName(node->input(i));
656 if (tensor_id.node() == node->name()) {
657 const int idx = tensor_id.index();
658 const int node_idx =
659 IsTensorIdControlling(tensor_id) ? Graph::kControlSlot : i;
660
661 MutableGraphView::OutputPort from_fanin(node, idx);
662 absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin];
663 from_fanouts->erase({node, node_idx});
664 UpdateMaxRegularOutputPortForRemovedFanin(from_fanin, *from_fanouts);
665
666 MutableGraphView::OutputPort to_fanin(nodes().at(old_node_name), idx);
667 fanouts()[to_fanin].insert({node, node_idx});
668 UpdateMaxRegularOutputPortForAddedFanin(to_fanin);
669 node->set_input(i, TensorIdToString({old_node_name, idx}));
670 }
671 }
672 };
673 update_fanins(from_node, to_node->name());
674 update_fanins(to_node, from_node->name());
675
676 // Dedup control dependencies.
677 auto dedup_control_fanouts =
678 [this](NodeDef* node, const FanoutsMap::iterator& control_fanouts) {
679 if (CanDedupControlWithRegularInput(*this, *node) &&
680 control_fanouts != fanouts().end()) {
681 for (const auto& control_fanout : control_fanouts->second) {
682 if (HasRegularFaninNode(*this, *control_fanout.node,
683 node->name())) {
684 RemoveControllingFaninInternal(control_fanout.node, node);
685 }
686 }
687 }
688 };
689 auto dedup_switch_control = [this, dedup_control_fanouts](NodeDef* node) {
690 OutputPort port;
691 port.node = node;
692 const int max_port =
693 gtl::FindWithDefault(max_regular_output_port(), node, -1);
694 for (int i = 0; i <= max_port; ++i) {
695 port.port_id = i;
696 auto it = fanouts().find(port);
697 if (it == fanouts().end()) {
698 continue;
699 }
700 for (const auto& fanout : it->second) {
701 auto fanout_controls =
702 fanouts().find({fanout.node, Graph::kControlSlot});
703 dedup_control_fanouts(fanout.node, fanout_controls);
704 }
705 }
706 };
707
708 if (!from_is_switch) {
709 if (to_is_switch) {
710 dedup_switch_control(from_node);
711 } else {
712 // Fetch iterator again as the original iterator might have been
713 // invalidated by container rehash triggered due to mutations.
714 auto from_control_fanouts = fanouts().find(from_control);
715 dedup_control_fanouts(from_node, from_control_fanouts);
716 }
717 }
718 if (!to_is_switch) {
719 if (from_is_switch) {
720 dedup_switch_control(to_node);
721 } else {
722 // Fetch iterator again as the original iterator might have been
723 // invalidated by container rehash triggered due to mutations.
724 auto to_control_fanouts = fanouts().find(to_control);
725 dedup_control_fanouts(to_node, to_control_fanouts);
726 }
727 }
728
729 return Status::OK();
730 }
731
UpdateFanouts(absl::string_view from_node_name,absl::string_view to_node_name)732 Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name,
733 absl::string_view to_node_name) {
734 NodeDef* from_node = GetNode(from_node_name);
735 TF_RETURN_IF_ERROR(
736 CheckNodeExists(from_node_name, from_node,
737 UpdateFanoutsError(from_node_name, to_node_name)));
738 NodeDef* to_node = GetNode(to_node_name);
739 TF_RETURN_IF_ERROR(CheckNodeExists(
740 to_node_name, to_node, UpdateFanoutsError(from_node_name, to_node_name)));
741
742 return UpdateFanoutsInternal(from_node, to_node);
743 }
744
UpdateFanoutsInternal(NodeDef * from_node,NodeDef * to_node)745 Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node,
746 NodeDef* to_node) {
747 VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.",
748 from_node->name(), to_node->name());
749 if (from_node == to_node) {
750 return Status::OK();
751 }
752
753 // Update internal state with the new output_port->input_port edge.
754 const auto add_edge = [this](const OutputPort& output_port,
755 const InputPort& input_port) {
756 fanouts()[output_port].insert(input_port);
757 };
758
759 // Remove invalidated edge from the internal state.
760 const auto remove_edge = [this](const OutputPort& output_port,
761 const InputPort& input_port) {
762 fanouts()[output_port].erase(input_port);
763 };
764
765 // For the control fanouts we do not know the input index in a NodeDef,
766 // so we have to traverse all control inputs.
767
768 auto control_fanouts =
769 GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
770
771 bool to_node_is_switch = IsSwitch(*to_node);
772 for (const InputPort& control_port : control_fanouts) {
773 // Node can't be control dependency of itself.
774 if (control_port.node == to_node) continue;
775
776 // Can't add Switch node as a control dependency.
777 if (to_node_is_switch) {
778 // Trying to add a Switch as a control dependency, which if allowed will
779 // make the graph invalid.
780 return UpdateFanoutsError(from_node->name(), to_node->name())(
781 absl::Substitute("can't update fanouts to node '$0' as it will "
782 "become a Switch control dependency",
783 to_node->name()));
784 }
785
786 NodeDef* node = control_port.node;
787 RemoveControllingFaninInternal(node, from_node);
788 AddFaninInternal(node, {to_node, Graph::kControlSlot});
789 }
790
791 // First we update regular fanouts. For the regular fanouts
792 // `input_port:port_id` is the input index in NodeDef.
793
794 auto regular_edges =
795 GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
796
797 // Maximum index of the `from_node` output tensor that is still used as an
798 // input to some other node.
799 int keep_max_regular_output_port = -1;
800
801 for (const Edge& edge : regular_edges) {
802 const OutputPort output_port = edge.src;
803 const InputPort input_port = edge.dst;
804
805 // If the `to_node` reads from the `from_node`, skip this edge (see
806 // AddAndUpdateFanoutsWithoutSelfLoops test for an example).
807 if (input_port.node == to_node) {
808 keep_max_regular_output_port =
809 std::max(keep_max_regular_output_port, output_port.port_id);
810 continue;
811 }
812
813 // Update input at destination node.
814 input_port.node->set_input(
815 input_port.port_id,
816 TensorIdToString({to_node->name(), output_port.port_id}));
817
818 // Remove old edge between the `from_node` and the fanout node.
819 remove_edge(output_port, input_port);
820 // Add an edge between the `to_node` and new fanout node.
821 add_edge(OutputPort(to_node, output_port.port_id), input_port);
822 // Dedup control dependency.
823 if (CanDedupControlWithRegularInput(*this, *to_node)) {
824 RemoveControllingFaninInternal(input_port.node, to_node);
825 }
826 }
827
828 // Because we update all regular fanouts of `from_node`, we can just copy
829 // the value `num_regular_outputs`.
830 max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
831
832 // Check if all fanouts were updated to read from the `to_node`.
833 if (keep_max_regular_output_port >= 0) {
834 max_regular_output_port()[from_node] = keep_max_regular_output_port;
835 } else {
836 max_regular_output_port().erase(from_node);
837 }
838
839 return Status::OK();
840 }
841
AddFaninInternal(NodeDef * node,const OutputPort & fanin)842 bool MutableGraphView::AddFaninInternal(NodeDef* node,
843 const OutputPort& fanin) {
844 int num_regular_fanins =
845 NumFanins(*node, /*include_controlling_nodes=*/false);
846 bool input_is_control = IsOutputPortControlling(fanin);
847 bool can_dedup_control_with_regular_input =
848 CanDedupControlWithRegularInput(*this, *fanin.node);
849 // Don't add duplicate control dependencies.
850 if (input_is_control) {
851 const int start =
852 can_dedup_control_with_regular_input ? 0 : num_regular_fanins;
853 for (int i = start; i < node->input_size(); ++i) {
854 if (ParseTensorName(node->input(i)).node() == fanin.node->name()) {
855 return false;
856 }
857 }
858 }
859
860 InputPort input;
861 input.node = node;
862 input.port_id = input_is_control ? Graph::kControlSlot : num_regular_fanins;
863
864 node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
865 if (!input_is_control) {
866 const int last_node_input = node->input_size() - 1;
867 // If there are control dependencies in node, move newly inserted fanin to
868 // be before such control dependencies.
869 if (num_regular_fanins < last_node_input) {
870 node->mutable_input()->SwapElements(last_node_input, num_regular_fanins);
871 }
872 }
873
874 fanouts()[fanin].insert(input);
875 if (max_regular_output_port()[fanin.node] < fanin.port_id) {
876 max_regular_output_port()[fanin.node] = fanin.port_id;
877 }
878
879 // Update max input port and dedup control dependencies.
880 if (!input_is_control) {
881 max_regular_input_port()[node] = num_regular_fanins;
882 if (can_dedup_control_with_regular_input) {
883 RemoveControllingFaninInternal(node, fanin.node);
884 }
885 }
886
887 return true;
888 }
889
AddRegularFanin(absl::string_view node_name,const TensorId & fanin)890 Status MutableGraphView::AddRegularFanin(absl::string_view node_name,
891 const TensorId& fanin) {
892 auto error_status = [node_name, fanin](absl::string_view msg) {
893 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
894 fanin.ToString());
895 return MutationError("AddRegularFanin", params, msg);
896 };
897
898 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
899 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
900 NodeDef* node = GetNode(node_name);
901 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
902 NodeDef* fanin_node = GetNode(fanin.node());
903 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
904
905 AddFaninInternal(node, {fanin_node, fanin.index()});
906 return Status::OK();
907 }
908
AddRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)909 Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name,
910 int port,
911 const TensorId& fanin) {
912 auto error_status = [node_name, port, fanin](absl::string_view msg) {
913 string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
914 node_name, port, fanin.ToString());
915 return MutationError("AddRegularFaninByPort", params, msg);
916 };
917
918 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
919 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
920 NodeDef* node = GetNode(node_name);
921 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
922 const int num_regular_fanins =
923 NumFanins(*node, /*include_controlling_nodes=*/false);
924 TF_RETURN_IF_ERROR(
925 CheckPortRange(port, /*min=*/0, num_regular_fanins, error_status));
926 NodeDef* fanin_node = GetNode(fanin.node());
927 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
928
929 const int last_node_input = node->input_size();
930 node->add_input(TensorIdToString(fanin));
931 node->mutable_input()->SwapElements(num_regular_fanins, last_node_input);
932 for (int i = num_regular_fanins - 1; i >= port; --i) {
933 TensorId tensor_id = ParseTensorName(node->input(i));
934 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
935 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
936 fanouts_set->erase({node, i});
937 fanouts_set->insert({node, i + 1});
938 node->mutable_input()->SwapElements(i, i + 1);
939 }
940
941 OutputPort fanin_port(fanin_node, fanin.index());
942 fanouts()[fanin_port].insert({node, port});
943 UpdateMaxRegularOutputPortForAddedFanin(fanin_port);
944
945 max_regular_input_port()[node] = num_regular_fanins;
946 if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
947 RemoveControllingFaninInternal(node, fanin_node);
948 }
949
950 return Status::OK();
951 }
952
GetControllingFaninToAdd(absl::string_view node_name,const OutputPort & fanin,string * error_msg)953 NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name,
954 const OutputPort& fanin,
955 string* error_msg) {
956 if (!IsSwitch(*fanin.node)) {
957 return fanin.node;
958 } else {
959 if (IsOutputPortControlling(fanin)) {
960 // Can't add a Switch node control dependency.
961 TensorId tensor_id(fanin.node->name(), fanin.port_id);
962 *error_msg = absl::Substitute(
963 "can't add fanin '$0' as it will become a Switch control dependency",
964 tensor_id.ToString());
965 return nullptr;
966 }
967 // We can't anchor control dependencies directly on the switch node: unlike
968 // other nodes only one of the outputs of the switch node will be generated
969 // when the switch node is executed, and we need to make sure the control
970 // dependency is only triggered when the corresponding output is triggered.
971 // We start by looking for an identity node connected to the output of the
972 // switch node, and use it to anchor the control dependency.
973 for (const auto& fanout : GetFanout(fanin)) {
974 if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
975 if (fanout.node->name() == node_name) {
976 *error_msg =
977 absl::Substitute("can't add found fanin '$0' to self",
978 AsControlDependency(fanout.node->name()));
979 return nullptr;
980 }
981 return fanout.node;
982 }
983 }
984
985 // No node found, check if node to be created is itself.
986 if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) {
987 *error_msg = absl::Substitute("can't add generated fanin '$0' to self",
988 AsControlDependency(string(node_name)));
989 }
990 }
991 return nullptr;
992 }
993
GetOrCreateIdentityConsumingSwitch(const OutputPort & fanin)994 NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch(
995 const OutputPort& fanin) {
996 // We haven't found an existing node where we can anchor the control
997 // dependency: add a new identity node.
998 string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin);
999 NodeDef* identity_node = GetNode(identity_name);
1000 if (identity_node == nullptr) {
1001 NodeDef new_node;
1002 new_node.set_name(identity_name);
1003 new_node.set_op("Identity");
1004 new_node.set_device(fanin.node->device());
1005 (*new_node.mutable_attr())["T"].set_type(fanin.node->attr().at("T").type());
1006 new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
1007 identity_node = AddNode(std::move(new_node));
1008 }
1009 return identity_node;
1010 }
1011
AddControllingFanin(absl::string_view node_name,const TensorId & fanin)1012 Status MutableGraphView::AddControllingFanin(absl::string_view node_name,
1013 const TensorId& fanin) {
1014 auto error_status = [node_name, fanin](absl::string_view msg) {
1015 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1016 fanin.ToString());
1017 return MutationError("AddControllingFanin", params, msg);
1018 };
1019
1020 TF_RETURN_IF_ERROR(CheckFaninIsValid(fanin, error_status));
1021 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1022 NodeDef* node = GetNode(node_name);
1023 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1024 NodeDef* fanin_node = GetNode(fanin.node());
1025 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1026
1027 OutputPort fanin_port(fanin_node, fanin.index());
1028
1029 string error_msg = "";
1030 NodeDef* control_node = GetControllingFaninToAdd(
1031 node_name, {fanin_node, fanin.index()}, &error_msg);
1032 if (!error_msg.empty()) {
1033 return error_status(error_msg);
1034 }
1035 if (control_node == nullptr) {
1036 control_node = GetOrCreateIdentityConsumingSwitch(fanin_port);
1037 }
1038 AddFaninInternal(node, {control_node, Graph::kControlSlot});
1039
1040 return Status::OK();
1041 }
1042
RemoveRegularFaninInternal(NodeDef * node,const OutputPort & fanin)1043 bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node,
1044 const OutputPort& fanin) {
1045 auto remove_input = [this, node](const OutputPort& fanin_port,
1046 int node_input_port, bool update_max_port) {
1047 InputPort input(node, node_input_port);
1048
1049 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1050 fanouts_set->erase(input);
1051 if (update_max_port) {
1052 UpdateMaxRegularOutputPortForRemovedFanin(fanin_port, *fanouts_set);
1053 }
1054 return fanouts_set;
1055 };
1056
1057 auto mutable_inputs = node->mutable_input();
1058 bool modified = false;
1059 const int num_regular_fanins =
1060 NumFanins(*node, /*include_controlling_nodes=*/false);
1061 int i;
1062 int curr_pos = 0;
1063 for (i = 0; i < num_regular_fanins; ++i) {
1064 TensorId tensor_id = ParseTensorName(node->input(i));
1065 if (tensor_id.node() == fanin.node->name() &&
1066 tensor_id.index() == fanin.port_id) {
1067 remove_input(fanin, i, /*update_max_port=*/true);
1068 modified = true;
1069 } else if (modified) {
1070 // Regular inputs will need to have their ports updated.
1071 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1072 auto fanouts_set = remove_input(fanin_port, i, /*update_max_port=*/false);
1073 fanouts_set->insert({node, curr_pos});
1074 // Shift inputs to be retained.
1075 mutable_inputs->SwapElements(i, curr_pos);
1076 ++curr_pos;
1077 } else {
1078 // Skip inputs to be retained until first modification.
1079 ++curr_pos;
1080 }
1081 }
1082
1083 if (modified) {
1084 const int last_regular_input_port = curr_pos - 1;
1085 if (last_regular_input_port < 0) {
1086 max_regular_input_port().erase(node);
1087 } else {
1088 max_regular_input_port()[node] = last_regular_input_port;
1089 }
1090 if (curr_pos < i) {
1091 // Remove fanins from node inputs.
1092 mutable_inputs->DeleteSubrange(curr_pos, i - curr_pos);
1093 }
1094 }
1095
1096 return modified;
1097 }
1098
RemoveRegularFanin(absl::string_view node_name,const TensorId & fanin)1099 Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name,
1100 const TensorId& fanin) {
1101 auto error_status = [node_name, fanin](absl::string_view msg) {
1102 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1103 fanin.ToString());
1104 return MutationError("RemoveRegularFanin", params, msg);
1105 };
1106
1107 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1108 TF_RETURN_IF_ERROR(
1109 CheckRemovingFaninFromSelf(node_name, fanin, error_status));
1110 NodeDef* node = GetNode(node_name);
1111 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1112 NodeDef* fanin_node = GetNode(fanin.node());
1113 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1114
1115 RemoveRegularFaninInternal(node, {fanin_node, fanin.index()});
1116 return Status::OK();
1117 }
1118
RemoveRegularFaninByPort(absl::string_view node_name,int port)1119 Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name,
1120 int port) {
1121 auto error_status = [node_name, port](absl::string_view msg) {
1122 string params =
1123 absl::Substitute("node_name='$0', port=$1", node_name, port);
1124 return MutationError("RemoveRegularFaninByPort", params, msg);
1125 };
1126
1127 NodeDef* node = GetNode(node_name);
1128 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1129 const int last_regular_fanin_port =
1130 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1131 TF_RETURN_IF_ERROR(
1132 CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1133
1134 TensorId tensor_id = ParseTensorName(node->input(port));
1135 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1136 fanouts()[fanin_port].erase({node, port});
1137 auto mutable_inputs = node->mutable_input();
1138 for (int i = port + 1; i <= last_regular_fanin_port; ++i) {
1139 TensorId tensor_id = ParseTensorName(node->input(i));
1140 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1141 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1142 fanouts_set->erase({node, i});
1143 fanouts_set->insert({node, i - 1});
1144 mutable_inputs->SwapElements(i - 1, i);
1145 }
1146 const int last_node_input = node->input_size() - 1;
1147 if (last_regular_fanin_port < last_node_input) {
1148 mutable_inputs->SwapElements(last_regular_fanin_port, last_node_input);
1149 }
1150 mutable_inputs->RemoveLast();
1151
1152 const int updated_last_regular_input_port = last_regular_fanin_port - 1;
1153 if (updated_last_regular_input_port < 0) {
1154 max_regular_input_port().erase(node);
1155 } else {
1156 max_regular_input_port()[node] = updated_last_regular_input_port;
1157 }
1158
1159 return Status::OK();
1160 }
1161
RemoveControllingFaninInternal(NodeDef * node,NodeDef * fanin_node)1162 bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node,
1163 NodeDef* fanin_node) {
1164 for (int i = node->input_size() - 1; i >= 0; --i) {
1165 TensorId tensor_id = ParseTensorName(node->input(i));
1166 if (tensor_id.index() > Graph::kControlSlot) {
1167 break;
1168 }
1169 if (tensor_id.node() == fanin_node->name()) {
1170 fanouts()[{fanin_node, Graph::kControlSlot}].erase(
1171 {node, Graph::kControlSlot});
1172 node->mutable_input()->SwapElements(i, node->input_size() - 1);
1173 node->mutable_input()->RemoveLast();
1174 return true;
1175 }
1176 }
1177 return false;
1178 }
1179
RemoveControllingFanin(absl::string_view node_name,absl::string_view fanin_node_name)1180 Status MutableGraphView::RemoveControllingFanin(
1181 absl::string_view node_name, absl::string_view fanin_node_name) {
1182 auto error_status = [node_name, fanin_node_name](absl::string_view msg) {
1183 string params = absl::Substitute("node_name='$0', fanin_node_name='$1'",
1184 node_name, fanin_node_name);
1185 return MutationError("RemoveControllingFanin", params, msg);
1186 };
1187
1188 TF_RETURN_IF_ERROR(CheckRemovingFaninFromSelf(
1189 node_name, {fanin_node_name, Graph::kControlSlot}, error_status));
1190 NodeDef* node = GetNode(node_name);
1191 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1192 NodeDef* fanin_node = GetNode(fanin_node_name);
1193 TF_RETURN_IF_ERROR(
1194 CheckNodeExists(fanin_node_name, fanin_node, error_status));
1195
1196 RemoveControllingFaninInternal(node, fanin_node);
1197 return Status::OK();
1198 }
1199
RemoveAllFanins(absl::string_view node_name,bool keep_controlling_fanins)1200 Status MutableGraphView::RemoveAllFanins(absl::string_view node_name,
1201 bool keep_controlling_fanins) {
1202 NodeDef* node = GetNode(node_name);
1203 if (node == nullptr) {
1204 string params =
1205 absl::Substitute("node_name='$0', keep_controlling_fanins=$1",
1206 node_name, keep_controlling_fanins);
1207 return MutationError("RemoveAllFanins", params,
1208 NodeMissingErrorMsg(node_name));
1209 }
1210
1211 if (node->input().empty()) {
1212 return Status::OK();
1213 }
1214
1215 const int num_regular_fanins =
1216 NumFanins(*node, /*include_controlling_nodes=*/false);
1217 RemoveFaninsInternal(node, keep_controlling_fanins);
1218 if (keep_controlling_fanins) {
1219 if (num_regular_fanins == 0) {
1220 return Status::OK();
1221 } else if (num_regular_fanins < node->input_size()) {
1222 node->mutable_input()->DeleteSubrange(0, num_regular_fanins);
1223 } else {
1224 node->clear_input();
1225 }
1226 } else {
1227 node->clear_input();
1228 }
1229 return Status::OK();
1230 }
1231
UpdateFanin(absl::string_view node_name,const TensorId & from_fanin,const TensorId & to_fanin)1232 Status MutableGraphView::UpdateFanin(absl::string_view node_name,
1233 const TensorId& from_fanin,
1234 const TensorId& to_fanin) {
1235 auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) {
1236 string params =
1237 absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'",
1238 node_name, from_fanin.ToString(), to_fanin.ToString());
1239 return MutationError("UpdateFanin", params, msg);
1240 };
1241
1242 TF_RETURN_IF_ERROR(CheckFaninIsValid(from_fanin, error_status));
1243 TF_RETURN_IF_ERROR(CheckFaninIsValid(to_fanin, error_status));
1244 NodeDef* node = GetNode(node_name);
1245 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1246 NodeDef* from_fanin_node = GetNode(from_fanin.node());
1247 TF_RETURN_IF_ERROR(
1248 CheckNodeExists(from_fanin.node(), from_fanin_node, error_status));
1249 NodeDef* to_fanin_node = GetNode(to_fanin.node());
1250 TF_RETURN_IF_ERROR(
1251 CheckNodeExists(to_fanin.node(), to_fanin_node, error_status));
1252
1253 // When replacing a non control dependency fanin with a control dependency, or
1254 // vice versa, remove and add, so ports can be updated properly in fanout(s).
1255 bool to_fanin_is_control = IsTensorIdControlling(to_fanin);
1256 if (to_fanin_is_control && IsSwitch(*to_fanin_node)) {
1257 // Can't add Switch node as a control dependency.
1258 return error_status(
1259 absl::Substitute("can't update to fanin '$0' as it will become a "
1260 "Switch control dependency",
1261 to_fanin.ToString()));
1262 }
1263 if (node_name == from_fanin.node() || node_name == to_fanin.node()) {
1264 return error_status("can't update fanin to or from self");
1265 }
1266
1267 if (from_fanin == to_fanin) {
1268 return Status::OK();
1269 }
1270
1271 bool from_fanin_is_control = IsTensorIdControlling(from_fanin);
1272 if (from_fanin_is_control || to_fanin_is_control) {
1273 bool modified = false;
1274 if (from_fanin_is_control) {
1275 modified |= RemoveControllingFaninInternal(node, from_fanin_node);
1276 } else {
1277 modified |= RemoveRegularFaninInternal(
1278 node, {from_fanin_node, from_fanin.index()});
1279 }
1280 if (modified) {
1281 AddFaninInternal(node, {to_fanin_node, to_fanin.index()});
1282 }
1283 return Status::OK();
1284 }
1285
1286 // In place mutation of regular fanins, requires no shifting of ports.
1287 string to_fanin_string = TensorIdToString(to_fanin);
1288 const int num_regular_fanins =
1289 NumFanins(*node, /*include_controlling_nodes=*/false);
1290 bool modified = false;
1291 absl::flat_hash_set<InputPort>* from_fanin_port_fanouts = nullptr;
1292 absl::flat_hash_set<InputPort>* to_fanin_port_fanouts = nullptr;
1293 for (int i = 0; i < num_regular_fanins; ++i) {
1294 if (ParseTensorName(node->input(i)) == from_fanin) {
1295 InputPort input(node, i);
1296 if (from_fanin_port_fanouts == nullptr) {
1297 OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1298 from_fanin_port_fanouts = &fanouts()[from_fanin_port];
1299 }
1300 from_fanin_port_fanouts->erase(input);
1301
1302 if (to_fanin_port_fanouts == nullptr) {
1303 OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
1304 to_fanin_port_fanouts = &fanouts()[to_fanin_port];
1305 }
1306 to_fanin_port_fanouts->insert(input);
1307
1308 node->set_input(i, to_fanin_string);
1309 modified = true;
1310 }
1311 }
1312
1313 // Dedup control dependencies and update max regular output ports.
1314 if (modified) {
1315 UpdateMaxRegularOutputPortForRemovedFanin(
1316 {from_fanin_node, from_fanin.index()}, *from_fanin_port_fanouts);
1317 if (max_regular_output_port()[to_fanin_node] < to_fanin.index()) {
1318 max_regular_output_port()[to_fanin_node] = to_fanin.index();
1319 }
1320 if (CanDedupControlWithRegularInput(*this, *to_fanin_node)) {
1321 RemoveControllingFaninInternal(node, to_fanin_node);
1322 }
1323 }
1324
1325 return Status::OK();
1326 }
1327
UpdateRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)1328 Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name,
1329 int port,
1330 const TensorId& fanin) {
1331 auto error_status = [node_name, port, fanin](absl::string_view msg) {
1332 string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
1333 node_name, port, fanin.ToString());
1334 return MutationError("UpdateRegularFaninByPort", params, msg);
1335 };
1336
1337 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1338 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1339 NodeDef* node = GetNode(node_name);
1340 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1341 const int last_regular_fanin_port =
1342 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1343 TF_RETURN_IF_ERROR(
1344 CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1345 NodeDef* fanin_node = GetNode(fanin.node());
1346 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1347
1348 TensorId tensor_id = ParseTensorName(node->input(port));
1349 if (tensor_id == fanin) {
1350 return Status::OK();
1351 }
1352
1353 InputPort input(node, port);
1354 OutputPort from_fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1355 absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin_port];
1356 from_fanouts->erase(input);
1357 UpdateMaxRegularOutputPortForRemovedFanin(from_fanin_port, *from_fanouts);
1358
1359 OutputPort to_fanin_port(fanin_node, fanin.index());
1360 fanouts()[to_fanin_port].insert(input);
1361 UpdateMaxRegularOutputPortForAddedFanin(to_fanin_port);
1362
1363 node->set_input(port, TensorIdToString(fanin));
1364
1365 if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
1366 RemoveControllingFaninInternal(node, fanin_node);
1367 }
1368
1369 return Status::OK();
1370 }
1371
SwapRegularFaninsByPorts(absl::string_view node_name,int from_port,int to_port)1372 Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name,
1373 int from_port, int to_port) {
1374 auto error_status = [node_name, from_port, to_port](absl::string_view msg) {
1375 string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2",
1376 node_name, from_port, to_port);
1377 return MutationError("SwapRegularFaninsByPorts", params, msg);
1378 };
1379
1380 NodeDef* node = GetNode(node_name);
1381 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1382 const int last_regular_fanin_port =
1383 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1384 TF_RETURN_IF_ERROR(CheckPortRange(from_port, /*min=*/0,
1385 last_regular_fanin_port, error_status));
1386 TF_RETURN_IF_ERROR(CheckPortRange(to_port, /*min=*/0, last_regular_fanin_port,
1387 error_status));
1388
1389 if (from_port == to_port) {
1390 return Status::OK();
1391 }
1392 TensorId from_fanin = ParseTensorName(node->input(from_port));
1393 TensorId to_fanin = ParseTensorName(node->input(to_port));
1394 if (from_fanin == to_fanin) {
1395 return Status::OK();
1396 }
1397
1398 InputPort from_input(node, from_port);
1399 InputPort to_input(node, to_port);
1400 NodeDef* from_fanin_node = GetNode(from_fanin.node());
1401 absl::flat_hash_set<InputPort>* from_fanouts =
1402 &fanouts()[{from_fanin_node, from_fanin.index()}];
1403 from_fanouts->erase(from_input);
1404 from_fanouts->insert(to_input);
1405 NodeDef* to_fanin_node = GetNode(to_fanin.node());
1406 absl::flat_hash_set<InputPort>* to_fanouts =
1407 &fanouts()[{to_fanin_node, to_fanin.index()}];
1408 to_fanouts->erase(to_input);
1409 to_fanouts->insert(from_input);
1410
1411 node->mutable_input()->SwapElements(from_port, to_port);
1412
1413 return Status::OK();
1414 }
1415
UpdateAllRegularFaninsToControlling(absl::string_view node_name)1416 Status MutableGraphView::UpdateAllRegularFaninsToControlling(
1417 absl::string_view node_name) {
1418 auto error_status = [node_name](absl::string_view msg) {
1419 string params = absl::Substitute("node_name='$0'", node_name);
1420 return MutationError("UpdateAllRegularFaninsToControlling", params, msg);
1421 };
1422
1423 NodeDef* node = GetNode(node_name);
1424 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1425
1426 const int num_regular_fanins =
1427 NumFanins(*node, /*include_controlling_nodes=*/false);
1428 std::vector<OutputPort> regular_fanins;
1429 regular_fanins.reserve(num_regular_fanins);
1430 std::vector<NodeDef*> controlling_fanins;
1431 controlling_fanins.reserve(num_regular_fanins);
1432
1433 // Get all regular fanins and derive controlling fanins.
1434 for (int i = 0; i < num_regular_fanins; ++i) {
1435 TensorId tensor_id = ParseTensorName(node->input(i));
1436 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1437
1438 string error_msg = "";
1439 NodeDef* control_node =
1440 GetControllingFaninToAdd(node_name, fanin_port, &error_msg);
1441 if (!error_msg.empty()) {
1442 return error_status(error_msg);
1443 }
1444
1445 regular_fanins.push_back(fanin_port);
1446 controlling_fanins.push_back(control_node);
1447 }
1448
1449 // Replace regular fanins with controlling fanins and dedup.
1450 int pos = 0;
1451 InputPort input_port(node, Graph::kControlSlot);
1452 absl::flat_hash_set<absl::string_view> controls;
1453 for (int i = 0; i < num_regular_fanins; ++i) {
1454 OutputPort fanin_port = regular_fanins[i];
1455 NodeDef* control = controlling_fanins[i];
1456 if (control == nullptr) {
1457 control = GetOrCreateIdentityConsumingSwitch(fanin_port);
1458 }
1459 fanouts()[fanin_port].erase({node, i});
1460 if (controls.contains(control->name())) {
1461 continue;
1462 }
1463 controls.insert(control->name());
1464 node->set_input(pos, AsControlDependency(control->name()));
1465 fanouts()[{control, Graph::kControlSlot}].insert(input_port);
1466 ++pos;
1467 }
1468
1469 // Shift existing controlling fanins and dedup.
1470 for (int i = num_regular_fanins; i < node->input_size(); ++i) {
1471 TensorId tensor_id = ParseTensorName(node->input(i));
1472 if (controls.contains(tensor_id.node())) {
1473 continue;
1474 }
1475 controls.insert(tensor_id.node());
1476 node->mutable_input()->SwapElements(pos, i);
1477 ++pos;
1478 }
1479
1480 // Remove duplicate controls and leftover regular fanins.
1481 node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos);
1482 max_regular_input_port().erase(node);
1483
1484 return Status::OK();
1485 }
1486
CheckNodesCanBeDeleted(const absl::flat_hash_set<string> & nodes_to_delete)1487 Status MutableGraphView::CheckNodesCanBeDeleted(
1488 const absl::flat_hash_set<string>& nodes_to_delete) {
1489 std::vector<string> missing_nodes;
1490 std::vector<string> nodes_with_fanouts;
1491 for (const string& node_name_to_delete : nodes_to_delete) {
1492 NodeDef* node = GetNode(node_name_to_delete);
1493 if (node == nullptr) {
1494 // Can't delete missing node.
1495 missing_nodes.push_back(node_name_to_delete);
1496 continue;
1497 }
1498 const int max_port = gtl::FindWithDefault(max_regular_output_port(), node,
1499 Graph::kControlSlot);
1500 for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1501 auto it = fanouts().find({node, i});
1502 bool has_retained_fanout = false;
1503 if (it != fanouts().end()) {
1504 for (const auto& fanout : it->second) {
1505 // Check if fanouts are of nodes to be deleted, and if so, they can be
1506 // ignored, as they will be removed also.
1507 if (!nodes_to_delete.contains(fanout.node->name())) {
1508 // Removing node will leave graph in an invalid state.
1509 has_retained_fanout = true;
1510 break;
1511 }
1512 }
1513 }
1514 if (has_retained_fanout) {
1515 nodes_with_fanouts.push_back(node_name_to_delete);
1516 break;
1517 }
1518 }
1519 }
1520
1521 // Error message can get quite long, so we only show the first 5 node names.
1522 auto sort_and_sample = [](std::vector<string>* s) {
1523 constexpr int kMaxNodeNames = 5;
1524 std::sort(s->begin(), s->end());
1525 if (s->size() > kMaxNodeNames) {
1526 return absl::StrCat(
1527 absl::StrJoin(s->begin(), s->begin() + kMaxNodeNames, ", "), ", ...");
1528 }
1529 return absl::StrJoin(*s, ", ");
1530 };
1531
1532 if (!missing_nodes.empty()) {
1533 VLOG(2) << absl::Substitute("Attempting to delete missing node(s) [$0].",
1534 sort_and_sample(&missing_nodes));
1535 }
1536 if (!nodes_with_fanouts.empty()) {
1537 std::vector<string> input_node_names(nodes_to_delete.begin(),
1538 nodes_to_delete.end());
1539 string params = absl::Substitute("nodes_to_delete={$0}",
1540 sort_and_sample(&input_node_names));
1541 string error_msg =
1542 absl::Substitute("can't delete node(s) with retained fanouts(s) [$0]",
1543 sort_and_sample(&nodes_with_fanouts));
1544 return MutationError("DeleteNodes", params, error_msg);
1545 }
1546
1547 return Status::OK();
1548 }
1549
DeleteNodes(const absl::flat_hash_set<string> & nodes_to_delete)1550 Status MutableGraphView::DeleteNodes(
1551 const absl::flat_hash_set<string>& nodes_to_delete) {
1552 TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete));
1553
1554 // Find nodes in internal state and delete.
1555 for (const string& node_name_to_delete : nodes_to_delete) {
1556 NodeDef* node = GetNode(node_name_to_delete);
1557 if (node != nullptr) {
1558 RemoveFaninsInternal(node, /*keep_controlling_fanins=*/false);
1559 RemoveFanoutsInternal(node);
1560 }
1561 }
1562 for (const string& node_name_to_delete : nodes_to_delete) {
1563 nodes().erase(node_name_to_delete);
1564 }
1565
1566 // Find nodes in graph and delete by partitioning into nodes to retain and
1567 // nodes to delete based on input set of nodes to delete by name.
1568 // TODO(lyandy): Use a node name->idx hashmap if this is a performance
1569 // bottleneck.
1570 int pos = 0;
1571 const int last_idx = graph()->node_size() - 1;
1572 int last_pos = last_idx;
1573 while (pos <= last_pos) {
1574 if (nodes_to_delete.contains(graph()->node(pos).name())) {
1575 graph()->mutable_node()->SwapElements(pos, last_pos);
1576 --last_pos;
1577 } else {
1578 ++pos;
1579 }
1580 }
1581 if (last_pos < last_idx) {
1582 graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
1583 }
1584
1585 return Status::OK();
1586 }
1587
RemoveFaninsInternal(NodeDef * deleted_node,bool keep_controlling_fanins)1588 void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
1589 bool keep_controlling_fanins) {
1590 for (int i = 0; i < deleted_node->input_size(); ++i) {
1591 TensorId tensor_id = ParseTensorName(deleted_node->input(i));
1592 bool is_control = IsTensorIdControlling(tensor_id);
1593 if (keep_controlling_fanins && is_control) {
1594 break;
1595 }
1596 OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
1597
1598 InputPort input;
1599 input.node = deleted_node;
1600 input.port_id = is_control ? Graph::kControlSlot : i;
1601
1602 auto it = fanouts().find(fanin);
1603 if (it != fanouts().end()) {
1604 absl::flat_hash_set<InputPort>* fanouts_set = &it->second;
1605 fanouts_set->erase(input);
1606 UpdateMaxRegularOutputPortForRemovedFanin(fanin, *fanouts_set);
1607 }
1608 }
1609 max_regular_input_port().erase(deleted_node);
1610 }
1611
RemoveFanoutsInternal(NodeDef * deleted_node)1612 void MutableGraphView::RemoveFanoutsInternal(NodeDef* deleted_node) {
1613 const int max_port =
1614 gtl::FindWithDefault(max_regular_output_port(), deleted_node, -1);
1615 for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1616 fanouts().erase({deleted_node, i});
1617 }
1618 max_regular_output_port().erase(deleted_node);
1619 }
1620
1621 } // end namespace grappler
1622 } // end namespace tensorflow
1623