1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_
18
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/hash/hash.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/graph/tensor_id.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30
31 namespace tensorflow {
32 namespace grappler {
33 namespace utils {
34 namespace internal {
35
36 constexpr int kMissingSlot = -2;
37 constexpr int kMissingIndex = -1;
38 constexpr int kNodeNamePresent = -1;
39
40 // NodeIndexAndPortIndex is a helper class that represents fanins and fanouts
41 // of a node.
42 template <typename NodeViewT, typename GraphViewT>
43 class NodeIndexAndPortIndex {
44 public:
NodeIndexAndPortIndex()45 NodeIndexAndPortIndex()
46 : graph_view_(nullptr),
47 node_index_(kMissingIndex),
48 port_index_(kMissingSlot) {}
NodeIndexAndPortIndex(GraphViewT * graph_view,int node_index,int port_index)49 NodeIndexAndPortIndex(GraphViewT* graph_view, int node_index, int port_index)
50 : graph_view_(graph_view),
51 node_index_(node_index),
52 port_index_(port_index) {}
53
54 bool operator==(const NodeIndexAndPortIndex& other) const {
55 return port_index_ == other.port_index_ &&
56 node_index_ == other.node_index_ && graph_view_ == other.graph_view_;
57 }
58
59 template <typename Hash>
AbslHashValue(Hash h,const NodeIndexAndPortIndex & n)60 friend Hash AbslHashValue(Hash h, const NodeIndexAndPortIndex& n) {
61 return Hash::combine(std::move(h), n.node_index_, n.port_index_);
62 }
63
64 // Returns NodeView from `graph_view_` at `node_index_`.
node_view()65 NodeViewT* node_view() const {
66 if (graph_view_ == nullptr) {
67 return nullptr;
68 }
69 return graph_view_->GetNode(node_index_);
70 }
71
72 // Returns node index in graph.
node_index()73 int node_index() const { return node_index_; }
74
75 // Returns input/output port index.
index()76 int index() const { return port_index_; }
77
78 protected:
79 GraphViewT* graph_view_;
80 int node_index_;
81 int port_index_;
82 };
83
84 // NodeDefAndPortIndex is a helper class that represents fanins hashed with
85 // pointer stability using the fanin's NodeDef.
86 class NodeDefAndPortIndex {
87 public:
NodeDefAndPortIndex(const NodeDef * node_def,int port_index)88 NodeDefAndPortIndex(const NodeDef* node_def, int port_index)
89 : node_def_(node_def), port_index_(port_index) {}
90
91 bool operator==(const NodeDefAndPortIndex& other) const {
92 return node_def_ == other.node_def_ && port_index_ == other.port_index_;
93 }
94
95 template <typename Hash>
AbslHashValue(Hash h,const NodeDefAndPortIndex & n)96 friend Hash AbslHashValue(Hash h, const NodeDefAndPortIndex& n) {
97 return Hash::combine(std::move(h), n.node_def_, n.port_index_);
98 }
99
100 private:
101 const NodeDef* node_def_;
102 int port_index_;
103 };
104
105 // NodeViewInternal is a helper class to simplify graph traversal. It creates
106 // a view of a node and associated fanins and fanouts from the NodeDef
107 // protocol buffer.
108 //
109 // There are two public classes implementing NodeViewInternal:
110 //
111 // - NodeView: constructed from `const NodeDef` and doesn't allow mutating the
112 // underlying node.
113 // - MutableNodeView: constructed from `NodeDef` and allows mutating the
114 // underlying node.
115 //
116 // --------------------------- !!! WARNING !!! ---------------------------------
117 // Modifying the node outside of implementations of NodeViewInternal
118 // (i.e. modifying inputs of the NodeDef directly) may leave the NodeView
119 // in an inconsistent/invalid state.
120 // -----------------------------------------------------------------------------
121 //
122 template <typename FaninViewT, typename FanoutViewT, typename GraphViewT,
123 bool IsConst>
124 class NodeViewInternal {
125 private:
126 using NodeDefT =
127 typename std::conditional<IsConst, const NodeDef, NodeDef>::type;
128
129 public:
NodeViewInternal(GraphViewT * graph_view,int node_index)130 explicit NodeViewInternal(GraphViewT* graph_view, int node_index)
131 : graph_view_(graph_view),
132 node_index_(node_index),
133 attrs_(AttrSlice(graph_view->graph()->node(node_index))) {}
134
NodeViewInternal()135 NodeViewInternal()
136 : graph_view_(nullptr), node_index_(kMissingIndex), attrs_(AttrSlice()) {}
137
~NodeViewInternal()138 virtual ~NodeViewInternal() {}
139
140 NodeViewInternal(NodeViewInternal&&) = default;
141 NodeViewInternal& operator=(NodeViewInternal&&) = default;
142
143 bool operator==(const NodeViewInternal& other) const {
144 return node_index_ == other.node_index_ && graph_view_ == other.graph_view_;
145 }
146
147 template <typename Hash>
AbslHashValue(Hash h,const NodeViewInternal & n)148 friend Hash AbslHashValue(Hash h, const NodeViewInternal& n) {
149 return Hash::combine(std::move(h), n.node_index_);
150 }
151
152 // Returns NodeDef of view.
153 virtual NodeDefT* node() const = 0;
154
155 // Returns index of node in GraphDef/GraphView.
node_index()156 int node_index() const { return node_index_; }
157
158 // Returns the name of the node.
GetName()159 const string& GetName() const { return node()->name(); }
160
161 // Returns the op of the node.
GetOp()162 const string& GetOp() const { return node()->op(); }
163
164 // Returns the device set for the node.
GetDevice()165 const string& GetDevice() const { return node()->device(); }
166
167 // Returns all regular fanins, based on ordering in the node.
GetRegularFanins()168 const std::vector<FanoutViewT>& GetRegularFanins() const {
169 return regular_fanins_;
170 }
171
172 // Returns a regular fanin based on input index. If no such fanin exist, a
173 // missing fanin is returned, with no NodeView set and an index of -2.
GetRegularFanin(int i)174 const FanoutViewT& GetRegularFanin(int i) const {
175 int regular_fanins_size = regular_fanins_.size();
176 if (i < 0 || i >= regular_fanins_size) {
177 return GetMissingFanin();
178 }
179 return regular_fanins_[i];
180 }
181
182 // Returns all controlling fanins, based on ordering in the node.
GetControllingFanins()183 const std::vector<FanoutViewT>& GetControllingFanins() const {
184 return controlling_fanins_;
185 }
186
187 // Returns all regular fanouts.
GetRegularFanouts()188 const std::vector<std::vector<FaninViewT>>& GetRegularFanouts() const {
189 return regular_fanouts_by_port_;
190 }
191
192 // Returns a regular fanout(s) based on output index. If no such output index
193 // exists, no fanouts will be returned.
GetRegularFanout(int i)194 const std::vector<FaninViewT>& GetRegularFanout(int i) const {
195 int regular_fanouts_by_port_size = regular_fanouts_by_port_.size();
196 if (i < 0 || i >= regular_fanouts_by_port_size) {
197 return GetMissingFanout();
198 }
199 return regular_fanouts_by_port_[i];
200 }
201
202 // Returns all controlled fanouts.
GetControlledFanouts()203 const std::vector<FaninViewT>& GetControlledFanouts() const {
204 return controlled_fanouts_;
205 }
206
207 // Returns the number of regular fanins.
NumRegularFanins()208 int NumRegularFanins() const { return regular_fanins_.size(); }
209
210 // Returns the number of controlling fanins.
NumControllingFanins()211 int NumControllingFanins() const { return controlling_fanins_.size(); }
212
213 // Returns the number of regular fanouts.
NumRegularFanouts()214 int NumRegularFanouts() const { return num_regular_fanouts_; }
215
216 // Returns the number of controlled fanouts.
NumControlledFanouts()217 int NumControlledFanouts() const { return controlled_fanouts_.size(); }
218
219 // Checks if a fanin exists for the node.
220 virtual bool HasFanin(const FanoutViewT& fanin) const = 0;
221
222 // Checks if a fanout exists for the node.
223 virtual bool HasFanout(const FaninViewT& fanout) const = 0;
224
225 // Returns an attribute of the node by key. If no attribute for such key
226 // exists, a `nullptr` is returned.
GetAttr(absl::string_view attr_name)227 const AttrValue* GetAttr(absl::string_view attr_name) const {
228 return attrs_.Find(attr_name);
229 }
230
231 // Returns all attributes of the node.
GetAttrs()232 const AttrSlice& GetAttrs() const { return attrs_; }
233
234 // Returns the number of attributes in the node.
NumAttrs()235 int NumAttrs() const { return attrs_.size(); }
236
237 // Checks if an attribute exist in the node.
HasAttr(absl::string_view attr_name)238 bool HasAttr(absl::string_view attr_name) const {
239 return attrs_.Find(attr_name) != nullptr;
240 }
241
242 protected:
243 virtual inline const FanoutViewT& GetMissingFanin() const = 0;
244 virtual inline const std::vector<FaninViewT>& GetMissingFanout() const = 0;
245
246 std::vector<FanoutViewT> regular_fanins_;
247 std::vector<FanoutViewT> controlling_fanins_;
248 std::vector<std::vector<FaninViewT>> regular_fanouts_by_port_;
249 int num_regular_fanouts_ = 0;
250 std::vector<FaninViewT> controlled_fanouts_;
251
252 GraphViewT* graph_view_;
253 int node_index_;
254 AttrSlice attrs_;
255 };
256
257 // GraphViewInternal is a helper class to simplify graph traversal. It creates
258 // a view of the nodes and associated fanins and fanouts from the GraphDef
259 // protocol buffer.
260 //
261 // There are two public classes implementing GraphViewInternal:
262 //
263 // - GraphView: constructed from `const GraphDef` and doesn't allow mutating
264 // the underlying graph and its nodes.
265 // - MutableGraphView: constructed from `GraphDef` and allows mutating the
266 // underlying graph and its nodes.
267 //
268 // --------------------------- !!! WARNING !!! ---------------------------------
269 // Modifying the graph outside of implementations of GraphViewInternal
270 // (i.e. removing nodes from the GraphDef directly) may lead to
271 // segfaults! Guaranteed by absl::string_view!
272 // -----------------------------------------------------------------------------
273 //
274 template <typename NodeViewT, typename FaninViewT, typename FanoutViewT,
275 bool IsConst>
276 class GraphViewInternal {
277 private:
278 using GraphDefT =
279 typename std::conditional<IsConst, const GraphDef, GraphDef>::type;
280
281 public:
GraphViewInternal(GraphDefT * graph)282 explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
~GraphViewInternal()283 virtual ~GraphViewInternal() {}
284
285 bool operator==(const GraphViewInternal& other) const {
286 return graph_ == other.graph_;
287 }
288
graph()289 GraphDefT* graph() const { return graph_; }
290
291 // Finds node by index in the graph. If no such node exists in the graph, a
292 // `nullptr` is returned.
GetNode(int node_index)293 const NodeViewT* GetNode(int node_index) const {
294 int nodes_size = nodes_.size();
295 if (node_index < 0 || node_index >= nodes_size) {
296 return nullptr;
297 }
298 return &nodes_[node_index];
299 }
300
GetNode(int node_index)301 NodeViewT* GetNode(int node_index) {
302 int nodes_size = nodes_.size();
303 if (node_index < 0 || node_index >= nodes_size) {
304 return nullptr;
305 }
306 return &nodes_[node_index];
307 }
308
309 // Finds node by name. If no such node exists in the graph, a `nullptr` is
310 // returned.
GetNode(absl::string_view node_name)311 const NodeViewT* GetNode(absl::string_view node_name) const {
312 auto it = node_index_by_name_.find(node_name);
313 if (it == node_index_by_name_.end()) {
314 return nullptr;
315 }
316 return &nodes_[it->second];
317 }
318
GetNode(absl::string_view node_name)319 NodeViewT* GetNode(absl::string_view node_name) {
320 auto it = node_index_by_name_.find(node_name);
321 if (it == node_index_by_name_.end()) {
322 return nullptr;
323 }
324 return &nodes_[it->second];
325 }
326
327 // Returns all nodes (as NodeView) in the graph.
GetNodes()328 const std::vector<NodeViewT>& GetNodes() const { return nodes_; }
329
330 // Checks if a node by name exists in the graph.
HasNode(absl::string_view node_name)331 bool HasNode(absl::string_view node_name) const {
332 return node_index_by_name_.contains(node_name);
333 }
334
335 // Returns the number of nodes in the graph.
NumNodes()336 int NumNodes() const { return nodes_.size(); }
337
338 protected:
339 // Reset allocated node vector and node map in case of failure.
Reset()340 void Reset() {
341 std::vector<NodeViewT>().swap(nodes_);
342 absl::flat_hash_map<absl::string_view, int>().swap(node_index_by_name_);
343 }
344
345 // nodes_[i] is a view of graph_.{mutable_}node(i).
346 std::vector<NodeViewT> nodes_;
347 absl::flat_hash_map<absl::string_view, int> node_index_by_name_;
348 GraphDefT* graph_;
349 const FanoutViewT missing_fanin_;
350 const std::vector<FaninViewT> missing_fanout_;
351 };
352
EmptyTensorId()353 inline SafeTensorId EmptyTensorId() {
354 return SafeTensorId("", internal::kMissingSlot);
355 }
356
IsEmptyTensorId(const TensorId tensor_id)357 inline bool IsEmptyTensorId(const TensorId tensor_id) {
358 return tensor_id.node().empty() &&
359 tensor_id.index() == internal::kMissingSlot;
360 }
361
362 // NodeViewDiff is a helper struct holding changes to be made to an existing
363 // node in GraphViewT. This should not be initialized or be used directly.
364 template <typename GraphViewT>
365 struct NodeViewDiff {
NodeViewDiffNodeViewDiff366 explicit NodeViewDiff(GraphViewT* graph_view, int node_index)
367 : graph_view(graph_view), node_index(node_index) {}
368
369 GraphViewT* graph_view;
370 int node_index;
371 string name;
372 bool update_name = false;
373 string op;
374 bool update_op = false;
375 string device;
376 bool update_device = false;
377 // Fanins to append after existing regular fanins.
378 std::vector<SafeTensorId> regular_inputs_to_add;
379 // Number of fanins to be appended. This is used for a quick comparison with
380 // `regular_inputs_to_add` for if there will be any missing inputs in the
381 // updated node.
382 int num_regular_inputs_to_add = 0;
383 // Fanins to update inplace.
384 std::map<int, SafeTensorId> regular_inputs_to_update;
385 // Fanins from end of regular fanins to remove. This keeps track of existing
386 // regular fanins in the original node to remove.
387 std::vector<bool> regular_inputs_to_remove;
388 // Number of fanins marked for removal. This is used for a quick comparison
389 // with `regular_inputs_to_remove` for if there will be any missing inputs
390 // in the updated node.
391 int num_regular_inputs_to_remove = 0;
392 absl::flat_hash_set<string> controlling_inputs_to_add;
393 std::set<int> controlling_inputs_to_remove;
394 absl::flat_hash_map<string, AttrValue> attrs_to_add;
395 absl::flat_hash_set<string> attrs_to_remove;
396 // AttrValueMap constructor and destructor are very expensive, we will
397 // initialize it lazily only if needed.
398 absl::optional<AttrValueMap> processed_attrs;
399 };
400
401 // Updates node name. If `name` is the same as the name in the original node,
402 // the field will be cleared in the diff.
403 template <typename GraphViewT>
UpdateName(NodeViewDiff<GraphViewT> * diff,absl::string_view name)404 inline bool UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
405 if (diff->graph_view->GetNode(diff->node_index)->GetName() == name) {
406 diff->name.clear();
407 diff->update_name = false;
408 } else {
409 diff->name = string(name);
410 diff->update_name = true;
411 }
412 return true;
413 }
414
415 // Updates node op. If `op` is the same as the op in the original node, the
416 // field will be cleared in the diff.
417 template <typename GraphViewT>
UpdateOp(NodeViewDiff<GraphViewT> * diff,absl::string_view op)418 inline bool UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
419 if (diff->graph_view->GetNode(diff->node_index)->GetOp() == op) {
420 diff->op.clear();
421 diff->update_op = false;
422 } else {
423 diff->op = string(op);
424 diff->update_op = true;
425 }
426 return true;
427 }
428
429 // Updates node device. If `device` is the same as the device in the original
430 // node, the field will be cleared in the diff.
431 template <typename GraphViewT>
UpdateDevice(NodeViewDiff<GraphViewT> * diff,absl::string_view device)432 inline bool UpdateDevice(NodeViewDiff<GraphViewT>* diff,
433 absl::string_view device) {
434 if (diff->graph_view->GetNode(diff->node_index)->GetDevice() == device) {
435 diff->device.clear();
436 diff->update_device = false;
437 } else {
438 diff->device = string(device);
439 diff->update_device = true;
440 }
441 return true;
442 }
443
444 // Adds or updates value in vector `v` at index `i`. This will also resize the
445 // vector if index `i` is out of bounds, padding the vector with
446 // `default_value`. Returns true if a new value was appended or if an update
447 // occurred where an existing value was changed from `default_value`.
448 template <typename T, typename U>
AddOrUpdateAtIndex(std::vector<T> * v,int i,const U & value,const T & default_value)449 inline bool AddOrUpdateAtIndex(std::vector<T>* v, int i, const U& value,
450 const T& default_value) {
451 int v_size = v->size();
452 if (i > v_size) {
453 // Resize to include `value`, filling the newly introduced gap with
454 // `default_value` for later checks of validity (gaps in vector).
455 v->reserve(i + 1);
456 v->resize(i, default_value);
457 v->push_back({value});
458 } else if (i == v_size) {
459 // Vector is large enough, simply append `value` to the end.
460 v->push_back({value});
461 } else {
462 // Update existing value.
463 bool updated = (*v)[i] == default_value;
464 (*v)[i] = {value};
465 return updated;
466 }
467 return true;
468 }
469
470 // Checks if a node with name `node_name` will exist in the final mutated graph.
471 template <typename GraphViewT>
CheckNodeNameExists(absl::string_view node_name,const absl::flat_hash_map<absl::string_view,int> & updated_node_names,const GraphViewT * graph_view)472 inline bool CheckNodeNameExists(
473 absl::string_view node_name,
474 const absl::flat_hash_map<absl::string_view, int>& updated_node_names,
475 const GraphViewT* graph_view) {
476 auto it = updated_node_names.find(node_name);
477 if (it != updated_node_names.end()) {
478 return it->second == kNodeNamePresent;
479 }
480 return graph_view->HasNode(node_name);
481 }
482
483 // Adds or updates regular fanin at `index` of regular fanins. If `index` is
484 // less than the number of regular fanins in the original node, the fanin at
485 // `index` in the original node will be updated with `fanin` if the fanin
486 // differs. If `index` is greater than or equal to the number of regular fanins,
487 // `fanin` will be added beyond the end of regular fanins at `index`.
488 template <typename GraphViewT>
AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT> * diff,int index,const TensorId & fanin)489 inline bool AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
490 const TensorId& fanin) {
491 if (index < 0) {
492 // Not a valid index for regular fanins.
493 return false;
494 }
495 auto* node_view = diff->graph_view->GetNode(diff->node_index);
496 const int num_regular_fanins = node_view->NumRegularFanins();
497 if (index < num_regular_fanins) { // Updating existing fanins.
498 // Calculate (relative) index from end of regular fanins, from absolute
499 // index from beginning of regular fanins.
500 const int relative_removal_index = num_regular_fanins - index - 1;
501 // Check if at relative index fanin was already marked for removal.
502 int diff_regular_inputs_to_remove_size =
503 diff->regular_inputs_to_remove.size();
504 if (relative_removal_index < diff_regular_inputs_to_remove_size &&
505 diff->regular_inputs_to_remove[relative_removal_index]) {
506 // Unmark fanin for removal.
507 diff->regular_inputs_to_remove[relative_removal_index] = false;
508 --diff->num_regular_inputs_to_remove;
509 }
510 const auto& existing_fanin = node_view->GetRegularFanin(index);
511 if (existing_fanin.index() != fanin.index() ||
512 existing_fanin.node_view()->GetName() != fanin.node()) {
513 // Update fanin if it is different from original fanin in node.
514 gtl::InsertOrUpdate(&diff->regular_inputs_to_update, index,
515 SafeTensorId(fanin));
516 }
517 } else {
518 // Add fanin beyond current fanin range.
519 const int relative_add_index = index - num_regular_fanins;
520 if (AddOrUpdateAtIndex(&diff->regular_inputs_to_add, relative_add_index,
521 fanin, EmptyTensorId())) {
522 // New fanin was added.
523 ++diff->num_regular_inputs_to_add;
524 }
525 }
526 return true;
527 }
528
529 // Remove regular fanin at `index` of regular fanins. This can remove existing
530 // fanins and updated/added fanins via AddOrUpdateRegularFanins.
531 template <typename GraphViewT>
RemoveRegularFanin(NodeViewDiff<GraphViewT> * diff,int index)532 inline bool RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
533 if (index < 0) {
534 // Not a valid index for regular fanins.
535 return false;
536 }
537 auto* node_view = diff->graph_view->GetNode(diff->node_index);
538 const int num_regular_fanins = node_view->NumRegularFanins();
539 if (index < num_regular_fanins) { // Removing existing fanins.
540 // Remove updated fanin if it exists.
541 diff->regular_inputs_to_update.erase(index);
542 // Calculate (relative) index from end of regular fanins, from absolute
543 // index from beginning of regular fanins.
544 const int relative_removal_index = num_regular_fanins - index - 1;
545 if (AddOrUpdateAtIndex(&diff->regular_inputs_to_remove,
546 relative_removal_index,
547 /*value=*/true, /*default_value=*/false)) {
548 ++diff->num_regular_inputs_to_remove;
549 }
550 } else {
551 // Relative index from end of regular fanins.
552 const int relative_add_index = index - num_regular_fanins;
553 int diff_regular_inputs_to_add_size = diff->regular_inputs_to_add.size();
554 if (relative_add_index >= diff_regular_inputs_to_add_size ||
555 IsEmptyTensorId(diff->regular_inputs_to_add[relative_add_index])) {
556 // At relative index, appended regular fanin was already marked for
557 // removal.
558 return false;
559 }
560 // Remove added fanin.
561 diff->regular_inputs_to_add[relative_add_index] = EmptyTensorId();
562 --diff->num_regular_inputs_to_add;
563 }
564 return true;
565 }
566
567 // Adds controlling fanin. If the controlling fanin already exists in the
568 // original node, it will be dedupped. If the controlling fanin is marked for
569 // removal, this will reverse it.
570 template <typename GraphViewT>
AddControllingFanin(NodeViewDiff<GraphViewT> * diff,int control_index,absl::string_view fanin_node_name)571 inline bool AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
572 int control_index,
573 absl::string_view fanin_node_name) {
574 if (control_index == kMissingIndex) {
575 diff->controlling_inputs_to_add.emplace(fanin_node_name);
576 } else {
577 diff->controlling_inputs_to_remove.erase(control_index);
578 }
579 return true;
580 }
581
582 // Remove controlling fanin. If the controlling fanin does not exist in the
583 // original node and diff, nothing will happen. If the controlling fanin exists
584 // in the diff, it will be removed. Otherwise the controlling fanin will be
585 // marked for removal from the original node.
586 template <typename GraphViewT>
RemoveControllingFanin(NodeViewDiff<GraphViewT> * diff,int control_index,absl::string_view fanin_node_name)587 inline bool RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
588 int control_index,
589 absl::string_view fanin_node_name) {
590 if (control_index == kMissingIndex) {
591 diff->controlling_inputs_to_add.erase(fanin_node_name);
592 } else {
593 diff->controlling_inputs_to_remove.emplace(control_index);
594 }
595 return true;
596 }
597
598 // Adds or updates an attribute by name. If an attribute exist in the original
599 // node or diff (including those marked for removal), this will overwrite it.
600 template <typename GraphViewT>
AddOrUpdateAttribute(NodeViewDiff<GraphViewT> * diff,absl::string_view attr_name,const AttrValue & attr_value)601 inline bool AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff,
602 absl::string_view attr_name,
603 const AttrValue& attr_value) {
604 diff->attrs_to_add.empty() ? 0 : diff->attrs_to_remove.erase(attr_name);
605 gtl::InsertOrUpdate(&diff->attrs_to_add, string(attr_name), attr_value);
606 return true;
607 }
608
609 // Removes an attribute by name. If an attribute exist in the original node or
610 // diff, this will remove it.
611 template <typename GraphViewT>
RemoveAttribute(NodeViewDiff<GraphViewT> * diff,absl::string_view attr_name)612 inline bool RemoveAttribute(NodeViewDiff<GraphViewT>* diff,
613 absl::string_view attr_name) {
614 const size_t num_erased =
615 diff->attrs_to_add.empty() ? 0 : diff->attrs_to_add.erase(attr_name);
616 auto* node_view = diff->graph_view->GetNode(diff->node_index);
617 if (node_view->HasAttr(attr_name)) {
618 diff->attrs_to_remove.emplace(attr_name);
619 return true;
620 }
621 return num_erased > 0;
622 }
623
624 // Removes trailing values in vector `v` for values equal to `value`.
625 template <typename T>
ResizeByTrimmingEndForValue(std::vector<T> * v,const T & value)626 inline void ResizeByTrimmingEndForValue(std::vector<T>* v, const T& value) {
627 int curr_index = v->size();
628 const int last_index = v->size() - 1;
629 for (int i = last_index; i >= 0; --i) {
630 if ((*v)[i] == value) {
631 curr_index = i;
632 } else {
633 break;
634 }
635 }
636 if (curr_index <= last_index) {
637 v->resize(curr_index);
638 }
639 }
640
641 // Checks if any changes are set in the diff.
642 template <typename GraphViewT>
IsEmpty(NodeViewDiff<GraphViewT> * diff)643 inline bool IsEmpty(NodeViewDiff<GraphViewT>* diff) {
644 ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false);
645 ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId());
646 return !diff->update_name && !diff->update_op && !diff->update_device &&
647 diff->regular_inputs_to_add.empty() &&
648 diff->regular_inputs_to_update.empty() &&
649 diff->regular_inputs_to_remove.empty() &&
650 diff->controlling_inputs_to_add.empty() &&
651 diff->controlling_inputs_to_remove.empty() &&
652 diff->attrs_to_add.empty() && diff->attrs_to_remove.empty();
653 }
654
655 // Resets and clears existing diff.
656 template <typename GraphViewT>
Reset(NodeViewDiff<GraphViewT> * diff)657 inline void Reset(NodeViewDiff<GraphViewT>* diff) {
658 diff->name.clear();
659 diff->update_name = false;
660 diff->op.clear();
661 diff->update_op = false;
662 diff->device.clear();
663 diff->update_device = false;
664 std::vector<SafeTensorId>().swap(diff->regular_inputs_to_add);
665 diff->num_regular_inputs_to_add = false;
666 std::map<int, SafeTensorId>().swap(diff->regular_inputs_to_update);
667 std::vector<bool>().swap(diff->regular_inputs_to_remove);
668 diff->num_regular_inputs_to_remove = 0;
669 absl::flat_hash_set<string>().swap(diff->controlling_inputs_to_add);
670 std::set<int>().swap(diff->controlling_inputs_to_remove);
671 absl::flat_hash_map<string, AttrValue>().swap(diff->attrs_to_add);
672 absl::flat_hash_set<string>().swap(diff->attrs_to_remove);
673 }
674
675 // Checks if changes to node will result in a valid node.
676 template <typename GraphViewT>
IsWellFormed(NodeViewDiff<GraphViewT> * diff,const absl::flat_hash_map<absl::string_view,int> & updated_node_names)677 inline bool IsWellFormed(
678 NodeViewDiff<GraphViewT>* diff,
679 const absl::flat_hash_map<absl::string_view, int>& updated_node_names) {
680 ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false);
681 ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId());
682 int diff_regular_inputs_to_add_size = diff->regular_inputs_to_add.size();
683 if (diff_regular_inputs_to_add_size != diff->num_regular_inputs_to_add) {
684 // Missing regular fanins in between appended fanins.
685 return false;
686 } else if (diff->num_regular_inputs_to_add > 0 &&
687 !diff->regular_inputs_to_remove.empty()) {
688 // Appending new fanins while removing existing fanins, resulting in missing
689 // regular fanins in between.
690 return false;
691 } else if (static_cast<int>(diff->regular_inputs_to_remove.size()) !=
692 diff->num_regular_inputs_to_remove) {
693 // Regular fanins exist in between removed fanins.
694 return false;
695 }
696 auto* node_view = diff->graph_view->GetNode(diff->node_index);
697 const string& node_name =
698 diff->update_name ? diff->name : node_view->GetName();
699 auto invalid_node_name = [&](absl::string_view fanin_node_name) -> bool {
700 return fanin_node_name == node_name ||
701 !CheckNodeNameExists(fanin_node_name, updated_node_names,
702 diff->graph_view);
703 };
704
705 // Check if nodes of all updated and new fanins exist (from name) and if such
706 // fanins do not introduce self loops. Note, this will not check for if
707 // unmodified fanins exist.
708 if (diff->update_name) {
709 // If name of node was changed in node, check all fanins. Updated fanins are
710 // checked for existence and self loops. Unmodified fanins are checked for
711 // self loops.
712 // `regular_inputs_to_update`, `controlling_inputs_to_remove` are sorted,
713 // so iterators from these maps/sets can be incremented alongside iteration
714 // and be used for comparisons.
715 const int last_index =
716 node_view->NumRegularFanins() - diff->num_regular_inputs_to_remove - 1;
717 auto regular_to_update_it = diff->regular_inputs_to_update.begin();
718 for (int i = 0; i <= last_index; ++i) {
719 if (regular_to_update_it != diff->regular_inputs_to_update.end() &&
720 regular_to_update_it->first < i) {
721 ++regular_to_update_it;
722 }
723 if (regular_to_update_it != diff->regular_inputs_to_update.end() &&
724 regular_to_update_it->first == i) {
725 if (invalid_node_name(regular_to_update_it->second.node())) {
726 return false;
727 }
728 } else {
729 const string& regular_name =
730 node_view->GetRegularFanin(i).node_view()->GetName();
731 if (regular_name == node_name) {
732 return false;
733 }
734 }
735 }
736
737 auto& controls = node_view->GetControllingFanins();
738 const int num_controls = controls.size();
739 auto control_to_remove_it = diff->controlling_inputs_to_remove.begin();
740 for (int i = 0; i < num_controls; ++i) {
741 if (control_to_remove_it != diff->controlling_inputs_to_remove.end() &&
742 *control_to_remove_it < i) {
743 ++control_to_remove_it;
744 }
745 if (control_to_remove_it != diff->controlling_inputs_to_remove.end() &&
746 *control_to_remove_it == i) {
747 // Control dependency marked for removal, can be ignored.
748 continue;
749 } else if (controls[i].node_view()->GetName() == node_name) {
750 return false;
751 }
752 }
753 } else {
754 // Name of node was not changed, check only updated fanins under the
755 // assumption prior fanins were valid.
756 for (const auto& updated : diff->regular_inputs_to_update) {
757 const string& fanin_name = updated.second.node();
758 if (invalid_node_name(fanin_name)) {
759 return false;
760 }
761 }
762 }
763 // Check appended regular fanins.
764 for (const auto& regular : diff->regular_inputs_to_add) {
765 if (invalid_node_name(regular.node())) {
766 return false;
767 }
768 }
769 // Check new controlling fanins.
770 for (const auto& control : diff->controlling_inputs_to_add) {
771 if (invalid_node_name(control)) {
772 return false;
773 }
774 }
775
776 return true;
777 }
778
779 // NewNode is a helper struct holding a new node to be added to a GraphViewT.
780 // This should not be initialized or be used directly.
781 template <typename GraphViewT>
782 struct NewNode {
NewNodeNewNode783 explicit NewNode(GraphViewT* graph_view, NodeDef&& node)
784 : graph_view(graph_view), node(std::move(node)) {}
785
786 GraphViewT* graph_view;
787 NodeDef node;
788 std::vector<SafeTensorId> regular_fanins;
789 int num_regular_fanins = 0;
790 absl::flat_hash_set<string> controlling_fanins;
791 };
792
793 // Updates new node name.
794 template <typename GraphViewT>
UpdateName(NewNode<GraphViewT> * new_node,absl::string_view name)795 inline void UpdateName(NewNode<GraphViewT>* new_node, absl::string_view name) {
796 if (name.empty()) {
797 new_node->node.clear_name();
798 } else {
799 new_node->node.set_name(string(name));
800 }
801 }
802
803 // Updates new node op.
804 template <typename GraphViewT>
UpdateOp(NewNode<GraphViewT> * new_node,absl::string_view op)805 inline void UpdateOp(NewNode<GraphViewT>* new_node, absl::string_view op) {
806 if (op.empty()) {
807 new_node->node.clear_op();
808 } else {
809 new_node->node.set_op(string(op));
810 }
811 }
812
813 // Updates new node device.
814 template <typename GraphViewT>
UpdateDevice(NewNode<GraphViewT> * new_node,absl::string_view device)815 inline void UpdateDevice(NewNode<GraphViewT>* new_node,
816 absl::string_view device) {
817 if (device.empty()) {
818 new_node->node.clear_device();
819 } else {
820 new_node->node.set_device(string(device));
821 }
822 }
823
824 // Adds or updates regular fanin at `index` of regular fanins in the new node.
825 // If another fanin already exists at `index`, it will be replaced with `fanin`.
826 template <typename GraphViewT>
AddOrUpdateRegularFanin(NewNode<GraphViewT> * new_node,int index,const TensorId & fanin)827 inline void AddOrUpdateRegularFanin(NewNode<GraphViewT>* new_node, int index,
828 const TensorId& fanin) {
829 if (index < 0) {
830 // Not a valid index for regular fanins.
831 return;
832 } else if (AddOrUpdateAtIndex(&new_node->regular_fanins, index, fanin,
833 EmptyTensorId())) {
834 ++new_node->num_regular_fanins;
835 }
836 }
837
838 // Remove regular fanin at `index` of regular fanins in the new node. This can
839 // remove existing fanins and updated/added fanins via AddOrUpdateRegularFanins.
840 template <typename GraphViewT>
RemoveRegularFanin(NewNode<GraphViewT> * new_node,int index)841 inline void RemoveRegularFanin(NewNode<GraphViewT>* new_node, int index) {
842 int new_node_regular_fanins_size = new_node->regular_fanins.size();
843 if (index < 0 || index >= new_node_regular_fanins_size ||
844 IsEmptyTensorId(new_node->regular_fanins[index])) {
845 return;
846 }
847 new_node->regular_fanins[index] = EmptyTensorId();
848 --new_node->num_regular_fanins;
849 }
850
851 // Adds controlling fanin to new node.
852 template <typename GraphViewT>
AddControllingFanin(NewNode<GraphViewT> * new_node,absl::string_view fanin_node_name)853 inline void AddControllingFanin(NewNode<GraphViewT>* new_node,
854 absl::string_view fanin_node_name) {
855 new_node->controlling_fanins.emplace(fanin_node_name);
856 }
857
858 // Removes controlling fanin to new node.
859 template <typename GraphViewT>
RemoveControllingFanin(NewNode<GraphViewT> * new_node,absl::string_view fanin_node_name)860 inline void RemoveControllingFanin(NewNode<GraphViewT>* new_node,
861 absl::string_view fanin_node_name) {
862 new_node->controlling_fanins.erase(fanin_node_name);
863 }
864
865 // Adds or updates an attribute by name to a new node.
866 template <typename GraphViewT>
AddOrUpdateAttribute(NewNode<GraphViewT> * new_node,absl::string_view attr_name,const AttrValue & attr_value)867 inline void AddOrUpdateAttribute(NewNode<GraphViewT>* new_node,
868 absl::string_view attr_name,
869 const AttrValue& attr_value) {
870 gtl::InsertOrUpdate(new_node->node.mutable_attr(), string(attr_name),
871 attr_value);
872 }
873
874 // Removes an attribute by name to a new node.
875 template <typename GraphViewT>
RemoveAttribute(NewNode<GraphViewT> * new_node,absl::string_view attr_name)876 inline void RemoveAttribute(NewNode<GraphViewT>* new_node,
877 absl::string_view attr_name) {
878 new_node->node.mutable_attr()->erase(string(attr_name));
879 }
880
881 // Checks if current state of new node is a valid node.
882 template <typename GraphViewT>
IsWellFormed(NewNode<GraphViewT> * new_node,const absl::flat_hash_map<absl::string_view,int> & updated_node_names)883 inline bool IsWellFormed(
884 NewNode<GraphViewT>* new_node,
885 const absl::flat_hash_map<absl::string_view, int>& updated_node_names) {
886 ResizeByTrimmingEndForValue(&new_node->regular_fanins, EmptyTensorId());
887 int new_node_regular_fanins_size = new_node->regular_fanins.size();
888 if (new_node_regular_fanins_size != new_node->num_regular_fanins) {
889 return false;
890 }
891
892 const string& node_name = new_node->node.name();
893 auto invalid_node_name = [new_node, updated_node_names,
894 node_name](absl::string_view fanin_node_name) {
895 return fanin_node_name == node_name ||
896 !CheckNodeNameExists(fanin_node_name, updated_node_names,
897 new_node->graph_view);
898 };
899 // Check if nodes of all fanins exist (from name) and if fanins do not
900 // introduce self loops.
901 for (const auto& regular : new_node->regular_fanins) {
902 if (invalid_node_name(regular.node())) {
903 return false;
904 }
905 }
906 for (const auto& control : new_node->controlling_fanins) {
907 if (invalid_node_name(control)) {
908 return false;
909 }
910 }
911
912 return true;
913 }
914
915 } // namespace internal
916 } // namespace utils
917 } // namespace grappler
918 } // namespace tensorflow
919
920 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_
921