1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ 18 19 #include <map> 20 #include <memory> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/framework/op_def.pb.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/protobuf/meta_graph.pb.h" 29 30 namespace tensorflow { 31 namespace grappler { 32 namespace graph_analyzer { 33 34 class GenNode; 35 36 // To find nodes by name. 37 using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>; 38 39 // One node in the graph, in the form convenient for traversal and generation of 40 // subgraphs. It refers to the original NodeDef protobuf for most information 41 // and adds the extra enrichment. 42 // 43 // The graph building is 2-stage: first match a GenNode with each NodeDef and 44 // collect them into a map that finds them by name, then process the map, 45 // deep-parse the underlying NodeDefs and connect the GenNodes together. 46 class GenNode { 47 public: 48 // Will keep the pointer, so the underlying object must not be deleted while 49 // GenNode is alive. 50 explicit GenNode(const NodeDef* node); 51 52 // Access wrappers. name()53 const string& name() const { return node_->name(); } opcode()54 const string& opcode() const { return node_->op(); } node_def()55 const NodeDef* node_def() const { return node_; } 56 57 // Parse the inputs of this node and update the map accordingly, creating the 58 // links (i.e. edges, connections between nodes) in itself and in the nodes 59 // it's linked to (the map itself is unchanged, only the nodes in it are 60 // updated). 61 Status ParseInputs(const GenNodeMap* map); 62 63 // Does the full 2-stage build of the graph. The map should be initially 64 // empty. The map keeps pointers to the nodes in source, so the source must 65 // not be destroyed before the map. 66 static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map); 67 68 // The enrichment that constitutes the point of this class. 69 70 // Representation of a connection on a node. 71 class Port { 72 public: 73 // A port may be inbound or outbound. 74 // Negative ids (canonically -1) mean a control port. Port(bool inbound,int32_t id)75 Port(bool inbound, int32_t id) : value_(id << 1) { 76 if (inbound) { 77 value_ |= 1; 78 } 79 } 80 Port(const Port&) = default; 81 Port& operator=(const Port&) = default; 82 IsInbound()83 bool IsInbound() const { return (value_ & 0x1); } 84 IsControl()85 bool IsControl() const { return (value_ < 0); } 86 Id()87 int32_t Id() const { 88 // Arithmetic shift preserves the sign. 89 return (value_ >> 1); 90 } 91 92 // Integer type used to represent the encoded port value. 93 using IntPort = int32_t; 94 95 // Returns the encoded form of this port, so that it can be used 96 // as various map indexes. Encoded()97 IntPort Encoded() const { return value_; } 98 Decode(IntPort encoded)99 static Port Decode(IntPort encoded) { return Port(encoded); } 100 101 bool operator==(const Port& other) const { return value_ == other.value_; } 102 bool operator<(const Port& other) const { return value_ < other.value_; } 103 104 struct Hasher { operatorHasher105 size_t operator()(const Port& port) const noexcept { 106 return hasher(port.Encoded()); 107 } 108 std::hash<int32_t> hasher; 109 }; 110 111 // Convenient for printing. I've really wanted it to be implicit but 112 // ClangTidy insists on making it explicit. 113 explicit operator string() const; 114 115 private: Port(IntPort value)116 explicit Port(IntPort value) : value_(value) {} 117 118 IntPort value_; 119 }; 120 121 struct LinkTarget { 122 GenNode* node; // Node where this link points. 123 Port port; // Port on the remote side of this link. 124 LinkTargetLinkTarget125 LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {} 126 }; 127 // All the links that are connected to the same port of this node 128 // are collected in one vector. A link is an edge of the graph that connects 129 // 2 nodes. Each of the connected nodes has its own perspective on the link, 130 // seeing its local port, remote port and the remote node. The direction of 131 // the link is encoded in the ports, one port is always incoming and another 132 // one outgoing. 133 using LinkTargetVector = std::vector<LinkTarget>; 134 // Both inputs and outputs are stored in the same map. 135 using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>; 136 137 // Access to the link map. links()138 const LinkMap& links() const { return links_; } 139 140 // Check whether the port is an input (including the controls) with multiple 141 // connections. Such inputs get handled in a special way when building the 142 // subgraphs, in an "all or nothing" fashion. 143 bool IsMultiInput(Port port) const; 144 145 // When building the subgraphs, must include either all non-control inputs of 146 // this node into the subgraph or none of them. This happens when at least one 147 // of the inputs is a multi-input (or if the opcode is commutative, thus 148 // treating all the inputs as one multi-input). AllInputsOrNone()149 bool AllInputsOrNone() const { return all_inputs_or_none_; } 150 151 private: 152 const NodeDef* node_; 153 // Becomes valid only after ParseInputs(). 154 const OpDef* op_; 155 156 // The opcode has a complicated structure of input args, with multi-input args 157 // that are not commutative. This means that to make sense, the subgraphs that 158 // include this node must also include either all its inputs or none of them. 159 bool all_inputs_or_none_ = false; 160 161 LinkMap links_; 162 }; 163 164 } // end namespace graph_analyzer 165 } // end namespace grappler 166 } // end namespace tensorflow 167 168 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ 169