1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ 18 19 #include <vector> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/core/framework/attr_value.pb.h" 26 #include "tensorflow/core/framework/graph.pb.h" 27 #include "tensorflow/core/framework/node_def.pb.h" 28 #include "tensorflow/core/framework/node_def_util.h" 29 #include "tensorflow/core/graph/tensor_id.h" 30 #include "tensorflow/core/grappler/utils/graph_view_internal.h" 31 #include "tensorflow/core/lib/core/status.h" 32 33 namespace tensorflow { 34 namespace grappler { 35 namespace utils { 36 37 class NodeView; 38 39 class GraphView; 40 41 // FaninView is a helper class to represent fanouts of a node. This holds a 42 // pointer to GraphView, the index of the node being represented from GraphView, 43 // and the input index (hence is labeled as Fanin). 44 class FaninView : public internal::NodeIndexAndPortIndex<NodeView, GraphView> { 45 public: FaninView()46 FaninView() : NodeIndexAndPortIndex() {} 47 FaninView(GraphView * graph_view,int node_index,int port_index)48 FaninView(GraphView* graph_view, int node_index, int port_index) 49 : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} 50 51 FaninView(NodeView* node_view, int index); 52 53 private: 54 friend class NodeView; 55 friend class GraphView; 56 }; 57 58 // FanoutView is a helper class to represent fanins of a node. This holds a 59 // pointer to GraphView, the index of the node being represented from GraphView, 60 // and the output index (hence is labeled as Fanout). 61 class FanoutView : public internal::NodeIndexAndPortIndex<NodeView, GraphView> { 62 public: FanoutView()63 FanoutView() : NodeIndexAndPortIndex() {} 64 FanoutView(GraphView * graph_view,int node_index,int port_index)65 FanoutView(GraphView* graph_view, int node_index, int port_index) 66 : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} 67 68 FanoutView(NodeView* node_view, int index); 69 70 private: 71 friend class NodeView; 72 friend class GraphView; 73 }; 74 75 // Immutable NodeView that keeps the constness of the NodeDef. This allows for 76 // lookups of fanins and fanouts, and traversals of the graph, but no mutations. 77 // No dedupping of fanins will be performed on the node to preserve it's 78 // constness. 79 class NodeView : public internal::NodeViewInternal<FaninView, FanoutView, 80 GraphView, true> { 81 public: NodeView(GraphView * graph_view,int node_index)82 explicit NodeView(GraphView* graph_view, int node_index) 83 : NodeViewInternal(graph_view, node_index) {} 84 NodeView()85 NodeView() : NodeViewInternal() {} 86 87 ~NodeView() override = default; 88 89 NodeView(NodeView&&) = default; 90 NodeView& operator=(NodeView&&) = default; 91 92 const NodeDef* node() const override; 93 94 // Checks if a fanin exists for the node. 95 bool HasFanin(const FanoutView& fanin) const override; 96 97 // Checks if a fanout exists for the node. 98 bool HasFanout(const FaninView& fanout) const override; 99 100 private: 101 inline const FanoutView& GetMissingFanin() const override; 102 103 inline const std::vector<FaninView>& GetMissingFanout() const override; 104 105 absl::flat_hash_set<internal::NodeDefAndPortIndex> fanins_set_; 106 107 friend class FaninView; 108 friend class FanoutView; 109 friend class GraphView; 110 }; 111 112 // Immutable GraphView that keeps the constness of the GraphDef. This allows 113 // for lookups and traversals of the graph, but no mutations. 114 class GraphView : public internal::GraphViewInternal<NodeView, FaninView, 115 FanoutView, true> { 116 public: 117 explicit GraphView(const GraphDef* graph, Status* status); 118 ~GraphView() override = default; 119 120 private: 121 bool AddUniqueNodeInternal(const NodeDef* node); 122 123 Status CheckAndAddFaninsInternal(NodeView* node_view); 124 125 friend class NodeView; 126 }; 127 128 class MutableNodeView; 129 130 class MutableGraphView; 131 132 class Mutation; 133 134 // MutableFaninView is a helper class to represent fanouts of a node. This holds 135 // a pointer to MutableGraphView, the index of the node from MutableGraphView 136 // being mutated, and the input index (hence is labeled as Fanin). 137 class MutableFaninView 138 : public internal::NodeIndexAndPortIndex<MutableNodeView, 139 MutableGraphView> { 140 public: MutableFaninView()141 MutableFaninView() : NodeIndexAndPortIndex() {} 142 MutableFaninView(MutableGraphView * graph_view,int node_index,int port_index)143 MutableFaninView(MutableGraphView* graph_view, int node_index, int port_index) 144 : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} 145 MutableFaninView(MutableGraphView * graph_view,int node_index,int port_index,int fanin_index)146 explicit MutableFaninView(MutableGraphView* graph_view, int node_index, 147 int port_index, int fanin_index) 148 : NodeIndexAndPortIndex(graph_view, node_index, port_index), 149 fanin_index_(fanin_index) { 150 // TODO(lyandy): Remove once constructor is not public. 151 DCHECK(port_index < 0 || port_index == fanin_index); 152 } 153 154 MutableFaninView(MutableNodeView* node_view, int index); 155 156 private: 157 // Index of associated fanin in fanout's underlying MutableNodeView. For 158 // regular fanouts, this will be the same as port_index (index of the 159 // associated fanin in MutableNodeView::regular_fanins_). For controlled 160 // fanouts, this will be the index of the associated fanin in 161 // MutableNodeView::controlling_fanins_. 162 int fanin_index_ = internal::kMissingIndex; 163 164 friend class MutableNodeView; 165 friend class MutableGraphView; 166 friend class Mutation; 167 }; 168 169 // MutableFanoutView is a helper class to represent fanins of a node. This holds 170 // a pointer to MutableGraphView, the index of the node from MutableGraphView 171 // being mutated, and the output index (hence is labeled as Fanout). 172 class MutableFanoutView 173 : public internal::NodeIndexAndPortIndex<MutableNodeView, 174 MutableGraphView> { 175 public: MutableFanoutView()176 MutableFanoutView() : NodeIndexAndPortIndex() {} 177 MutableFanoutView(MutableGraphView * graph_view,int node_index,int port_index)178 MutableFanoutView(MutableGraphView* graph_view, int node_index, 179 int port_index) 180 : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} 181 MutableFanoutView(MutableGraphView * graph_view,int node_index,int port_index,int fanout_index)182 explicit MutableFanoutView(MutableGraphView* graph_view, int node_index, 183 int port_index, int fanout_index) 184 : NodeIndexAndPortIndex(graph_view, node_index, port_index), 185 fanout_index_(fanout_index) {} 186 187 MutableFanoutView(MutableNodeView* node_view, int index); 188 189 private: 190 // Index of associated fanout in fanin's underlying MutableNodeView. For 191 // regular fanins, this will be the index of the associated fanout in 192 // MutableNodeView::regular_fanouts_by_port_[port_index]. For controlled 193 // fanins, this will be the index of the associated fanout in 194 // MutableNodeView::controlled_fanouts_. 195 int fanout_index_ = internal::kMissingIndex; 196 197 friend class MutableNodeView; 198 friend class MutableGraphView; 199 friend class Mutation; 200 }; 201 202 // Mutable NodeView that holds a mutable NodeDef. This allows for lookups of 203 // fanins and fanouts, and traversals of the graph. Control dependencies will be 204 // dedupped among other control dependencies on initialization via 205 // MutableGraphView. Mutations should be handled via MutableGraphView and not 206 // directly on the mutable NodeDef. 207 class MutableNodeView 208 : public internal::NodeViewInternal<MutableFaninView, MutableFanoutView, 209 MutableGraphView, false> { 210 public: MutableNodeView(MutableGraphView * graph_view,int node_index)211 explicit MutableNodeView(MutableGraphView* graph_view, int node_index) 212 : NodeViewInternal(graph_view, node_index) {} 213 MutableNodeView()214 MutableNodeView() : NodeViewInternal() {} 215 216 ~MutableNodeView() override = default; 217 218 MutableNodeView(MutableNodeView&&) = default; 219 MutableNodeView& operator=(MutableNodeView&&) = default; 220 221 NodeDef* node() const override; 222 223 // Checks if a fanin exists for the node. 224 bool HasFanin(const MutableFanoutView& fanin) const override; 225 226 // Checks if a fanout exists for the node. 227 bool HasFanout(const MutableFaninView& fanout) const override; 228 229 private: 230 inline const MutableFanoutView& GetMissingFanin() const override; 231 232 inline const std::vector<MutableFaninView>& GetMissingFanout() const override; 233 234 absl::flat_hash_map<internal::NodeDefAndPortIndex, int> fanins_count_; 235 absl::flat_hash_map<absl::string_view, int> controlling_fanins_index_; 236 // Index of associated MutableNodeViewDiff in Mutation::updated_nodes_. 237 // If this is -1, there exists no MutableNodeViewDiff for this node. 238 int update_index_ = internal::kMissingIndex; 239 240 friend class MutableFaninView; 241 friend class MutableFanoutView; 242 friend class MutableGraphView; 243 friend class Mutation; 244 }; 245 246 class MutationNewNode { 247 public: MutationNewNode()248 MutationNewNode() {} 249 250 private: MutationNewNode(Mutation * mutation,int mutation_counter,int index)251 explicit MutationNewNode(Mutation* mutation, int mutation_counter, int index) 252 : mutation_(mutation), 253 mutation_counter_(mutation_counter), 254 index_(index) {} 255 256 Mutation* mutation_ = nullptr; 257 int mutation_counter_ = internal::kMissingSlot; 258 int index_ = internal::kMissingIndex; 259 260 friend class Mutation; 261 }; 262 263 // Mutation is a helper class that allows rewrites of MutableGraphView. This 264 // should not be initialized or be used directly. 265 // Note, if a node is renamed to another node, or a new node is created with the 266 // same name as an existing node, the node with the same name originally in the 267 // graph will be overwritten. 268 class Mutation { 269 public: 270 // Create a new node to be added to the graph. If the node's fanins are not 271 // well formed (self loops, control dependencies between regular fanins), the 272 // `status` will be set. 273 MutationNewNode AddNode(NodeDef&& node, Status* status); 274 275 // Remove an existing node in the graph. 276 void RemoveNode(MutableNodeView* node); 277 278 // Update the name of an existing node. 279 void UpdateNodeName(MutableNodeView* node, absl::string_view name); 280 281 // Update the name of a new node. 282 void UpdateNodeName(const MutationNewNode& node, absl::string_view name); 283 284 // Update the op of an existing node. 285 void UpdateNodeOp(MutableNodeView* node, absl::string_view op); 286 287 // Update the op of a new node. 288 void UpdateNodeOp(const MutationNewNode& node, absl::string_view op); 289 290 // Update the device of an existing node. 291 void UpdateNodeDevice(MutableNodeView* node, absl::string_view device); 292 293 // Update the device of a new node. 294 void UpdateNodeDevice(const MutationNewNode& node, absl::string_view device); 295 296 // Add or replace regular fanin `fanin` at `index` for an existing node. 297 void AddOrUpdateRegularFanin(MutableNodeView* node, int index, 298 const TensorId& fanin); 299 300 // Add or replace regular fanin `fanin` at `index` for a new node. 301 void AddOrUpdateRegularFanin(const MutationNewNode& node, int index, 302 const TensorId& fanin); 303 304 // Remove regular fanin at `index` for an existing node. 305 void RemoveRegularFanin(MutableNodeView* node, int index); 306 307 // Remove regular fanin at `index` for a new node. 308 void RemoveRegularFanin(const MutationNewNode& node, int index); 309 310 // Add controlling fanin `fanin_node_name` for an existing node. 311 void AddControllingFanin(MutableNodeView* node, 312 absl::string_view fanin_node_name); 313 314 // Add controlling fanin `fanin_node_name` for a new node. 315 void AddControllingFanin(const MutationNewNode& node, 316 absl::string_view fanin_node_name); 317 318 // Remove controlling fanin `fanin_node_name` for an existing node. 319 void RemoveControllingFanin(MutableNodeView* node, 320 absl::string_view fanin_node_name); 321 322 // Remove controlling fanin `fanin_node_name` for a new node. 323 void RemoveControllingFanin(const MutationNewNode& node, 324 absl::string_view fanin_node_name); 325 326 // Add or replace attribute `attr_name` with `attr_value` for an existing 327 // node. 328 void AddOrUpdateNodeAttr(MutableNodeView* node, absl::string_view attr_name, 329 const AttrValue& attr_value); 330 331 // Add or replace attribute `attr_name` with `attr_value` for a new node. 332 void AddOrUpdateNodeAttr(const MutationNewNode& node, 333 absl::string_view attr_name, 334 const AttrValue& attr_value); 335 336 // Remove attribute `attr_name` for an existing node. 337 void RemoveNodeAttr(MutableNodeView* node, absl::string_view attr_name); 338 339 // Remove attribute `attr_name` for a new node. 340 void RemoveNodeAttr(const MutationNewNode& node, absl::string_view attr_name); 341 342 // Reset and clear mutation. 343 void Reset(); 344 345 // Applies the Mutation to the graph. If the mutation is valid, the graph will 346 // be modified. Otherwise an error status will be returned and the graph will 347 // not be modified. 348 Status Apply(); 349 350 private: 351 explicit Mutation(MutableGraphView* graph_view); 352 353 void ResetInternal(); 354 355 using MutableNodeViewDiff = internal::NodeViewDiff<MutableGraphView>; 356 357 // Adds a mutation to the `node`. Mutation function `mutate_fn` must return 358 // `true` if it actually does any mutations. If it returns `false` mutation 359 // will be ignored. 360 void AddMutation(MutableNodeView* node, 361 std::function<bool(MutableNodeViewDiff*)> mutate_fn); 362 363 MutableGraphView* graph_view_ = nullptr; 364 int mutation_counter_ = 0; 365 std::vector<MutableNodeViewDiff> updated_nodes_; 366 absl::flat_hash_set<int> removed_nodes_; 367 368 using MutationNewNodeHolder = internal::NewNode<MutableGraphView>; 369 std::vector<MutationNewNodeHolder> new_nodes_; 370 371 friend class MutableGraphView; 372 }; 373 374 // Mutable GraphView that holds a mutable GraphDef. This allows for lookups and 375 // traversals of the graph. Control dependencies will be dedupped among other 376 // control dependencies on initialization. Mutations should be handled using 377 // this API instead of directly on the GraphDef/NodeDef. 378 // Note, after a mutation, pointers of MutableNodeView's from MutableGraphView 379 // may be invalidated. 380 class MutableGraphView 381 : public internal::GraphViewInternal<MutableNodeView, MutableFaninView, 382 MutableFanoutView, false> { 383 public: 384 explicit MutableGraphView(GraphDef* graph, Status* status); 385 ~MutableGraphView() override = default; 386 387 // Returns a Mutation (builder) that can be used to modify MutableGraphView. 388 Mutation* GetMutationBuilder(); 389 390 // Helper class representing an extra dependency for topological sorting. 391 class TopologicalDependency { 392 public: TopologicalDependency(const MutableNodeView * from_node,const MutableNodeView * to_node)393 TopologicalDependency(const MutableNodeView* from_node, 394 const MutableNodeView* to_node) { 395 if (from_node->graph_view_ == to_node->graph_view_) { 396 graph_view_ = from_node->graph_view_; 397 from_ = from_node->node_index_; 398 to_ = to_node->node_index_; 399 } 400 } 401 402 private: 403 MutableGraphView* graph_view_ = nullptr; 404 int from_ = internal::kMissingIndex; 405 int to_ = internal::kMissingIndex; 406 407 friend class MutableGraphView; 408 }; 409 410 // Sorts graph topologically in-place. If `ignore_cycles` is set, a 411 // topological like sorting will be performed when there are cycles. Otherwise 412 // if a cycle is detected or if the graph cannot be sorted, an error will be 413 // returned. 414 Status SortTopologically( 415 bool ignore_cycles, 416 absl::Span<const TopologicalDependency> extra_dependencies); 417 418 private: 419 bool AddUniqueNodeInternal(NodeDef* node); 420 421 Status CheckFaninsInternal(std::vector<std::vector<TensorId>>* fanins); 422 423 void AddFaninsInternal(std::vector<std::vector<TensorId>>* fanins); 424 425 // RenamedOrOverwrittenNode holds a index to Mutation::updated_nodes_ for a 426 // renamed node, alongside a potential overwritten node index in the actual 427 // graph. If the renamed node is not overwriting any existing nodes, 428 // `overwritten_node_index_` will be set to `internal::kMissingIndex`. 429 class RenamedOrOverwrittenNode { 430 public: RenamedOrOverwrittenNode(int renamed_update_index,int overwritten_node_index)431 RenamedOrOverwrittenNode(int renamed_update_index, 432 int overwritten_node_index) 433 : renamed_update_index_(renamed_update_index), 434 overwritten_node_index_(overwritten_node_index) {} 435 436 private: 437 int renamed_update_index_; 438 int overwritten_node_index_; 439 440 friend class MutableGraphView; 441 }; 442 443 Status GetNodeNamesAndPartitionUpdatedNodes( 444 absl::flat_hash_map<absl::string_view, int>* node_names, 445 std::vector<RenamedOrOverwrittenNode>* renamed_nodes, 446 std::vector<int>* inplace_nodes, 447 std::vector<int>* empty_diff_node_indices); 448 449 Status RemovedOrMissingNodeFanoutsWellFormed( 450 const absl::flat_hash_map<absl::string_view, int>& node_names, 451 const std::vector<RenamedOrOverwrittenNode>& renamed_nodes); 452 453 Status CheckNodeNamesAndFanins( 454 const absl::flat_hash_map<absl::string_view, int>& node_names, 455 const std::vector<RenamedOrOverwrittenNode>& renamed_nodes, 456 const std::vector<int>& inplace_nodes); 457 458 Status CheckKernelRegisteredForNodes(); 459 460 // Helper class to move fanouts around. 461 class NodeViewFanouts { 462 public: NodeViewFanouts(std::vector<std::vector<MutableFaninView>> && regular_fanouts_by_port,int num_regular_fanouts,std::vector<MutableFaninView> controlled_fanouts)463 NodeViewFanouts( 464 std::vector<std::vector<MutableFaninView>>&& regular_fanouts_by_port, 465 int num_regular_fanouts, 466 std::vector<MutableFaninView> controlled_fanouts) 467 : regular_fanouts_by_port_(std::move(regular_fanouts_by_port)), 468 num_regular_fanouts_(num_regular_fanouts), 469 controlled_fanouts_(std::move(controlled_fanouts)) {} 470 471 private: 472 std::vector<std::vector<MutableFaninView>> regular_fanouts_by_port_; 473 int num_regular_fanouts_ = 0; 474 std::vector<MutableFaninView> controlled_fanouts_; 475 476 friend class MutableGraphView; 477 }; 478 479 template <typename T> 480 void ReplaceNodeFanouts(MutableNodeView* node, T* fanouts); 481 482 void FixRenamedNodes( 483 std::vector<RenamedOrOverwrittenNode>* renamed_nodes, 484 absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts, 485 std::vector<bool>* overwritten_name_removed_nodes); 486 487 void AddNewNodes( 488 absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts, 489 std::vector<int>* new_node_indices); 490 491 void FixRenamedFanouts( 492 const absl::flat_hash_map<string, NodeViewFanouts>& renamed_fanouts); 493 494 inline void RemoveRegularFaninFanoutInternal(MutableNodeView* node_view, 495 int i); 496 497 inline void AddRegularFaninInternal(MutableNodeView* node_view, 498 const SafeTensorId& fanin_id); 499 500 inline void UpdateRegularFaninInternal(MutableNodeView* node_view, 501 const int i, 502 const SafeTensorId& fanin_id); 503 504 inline void RemoveControllingFaninFanoutInternal(MutableNodeView* node_view, 505 int i); 506 507 inline void RemoveControllingFaninInternal( 508 MutableNodeView* node_view, const std::set<int>& indices_to_remove); 509 510 inline void AddControllingFaninInternal(MutableNodeView* node_view, 511 absl::string_view fanin_node_name); 512 513 void ApplyNodeUpdates(); 514 515 void SetNewNodesFanins(const std::vector<int>& new_node_indices); 516 517 inline void RemoveAllFaninFanoutInternal(MutableNodeView* node_view); 518 519 void RemoveNodesInternal( 520 const std::vector<RenamedOrOverwrittenNode>& renamed_nodes, 521 const std::vector<bool>& overwritten_name_removed_nodes); 522 523 inline Status ValidateInternal( 524 absl::flat_hash_map<absl::string_view, int>* node_names, 525 std::vector<RenamedOrOverwrittenNode>* renamed_nodes, 526 std::vector<int>* inplace_nodes, 527 std::vector<int>* empty_diff_node_indices); 528 529 Status ApplyMutationInternal(); 530 531 Mutation mutation_; 532 533 friend class MutableNodeView; 534 friend class Mutation; 535 }; 536 537 } // namespace utils 538 } // namespace grappler 539 } // namespace tensorflow 540 541 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ 542