1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
17 #include "absl/memory/memory.h"
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
21 #include "tensorflow/core/grappler/op_types.h"
22 #include "tensorflow/core/grappler/utils.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_analyzer {
27 
GenNode(const NodeDef * node)28 GenNode::GenNode(const NodeDef* node) : node_(node), op_(nullptr) {}
29 
BuildGraphInMap(const GraphDef & source,GenNodeMap * map)30 Status GenNode::BuildGraphInMap(const GraphDef& source, GenNodeMap* map) {
31   for (const auto& n : source.node()) {
32     const string& name = n.name();
33     if (map->find(name) != map->end()) {
34       // This error code looks more meaningful than ALREADY_EXISTS.
35       return Status(error::INVALID_ARGUMENT,
36                     "Duplicate node name '" + name + "'.");
37     }
38     (*map)[name] = std::make_unique<GenNode>(&n);
39   }
40   // Now parse the links.
41   for (const auto& mapit : *map) {
42     Status st = mapit.second->ParseInputs(map);
43     if (!st.ok()) {
44       return st;
45     }
46   }
47   return OkStatus();
48 }
49 
ParseInputs(const GenNodeMap * map)50 Status GenNode::ParseInputs(const GenNodeMap* map) {
51   all_inputs_or_none_ = false;
52   Status st = OpRegistry::Global()->LookUpOpDef(opcode(), &op_);
53   if (!st.ok()) {
54     return Status(
55         error::INVALID_ARGUMENT,
56         absl::StrFormat("Node '%s' contains an undefined operation '%s': %s",
57                         name(), opcode(), st.error_message()));
58   }
59 
60   int n_inputs = node_->input_size();
61 
62   int n_named_inputs = op_->input_arg_size();
63 
64   int n_multi_inputs = 0;
65   for (const auto& inarg : op_->input_arg()) {
66     if (!inarg.number_attr().empty() || !inarg.type_list_attr().empty()) {
67       ++n_multi_inputs;
68     }
69   }
70   bool is_commutative = grappler::IsCommutative(*node_);
71 
72   if (n_multi_inputs > 1 || (n_multi_inputs > 0 && n_named_inputs > 1)) {
73     // Can't handle more than one multi-input at a time.
74     // And can't handle the commutativeness of only some arguments
75     // rather than all of them.
76     is_commutative = false;
77   }
78 
79   if (is_commutative) {
80     // If truly commutative, can treat all the inputs as one multi-input.
81     // It's possible to just treat the commutative nodes as AllInputsOrNone
82     // but (1) this way is a bit more efficient and (2) I want to preserve this
83     // more efficient code path that does all-or-none by a single input and
84     // perhaps extend its use in the future.
85     n_named_inputs = 1;
86     all_inputs_or_none_ = false;
87   } else if (n_multi_inputs > 0) {
88     all_inputs_or_none_ = true;
89   }
90 
91   for (int i = 0; i < n_inputs; ++i) {
92     int other_position;
93     string other_name = ParseNodeName(node_->input(i), &other_position);
94     auto other_it = map->find(other_name);
95     if (other_it == map->end()) {
96       return Status(
97           error::INVALID_ARGUMENT,
98           absl::StrFormat(
99               "Node '%s' input %d refers to a non-existing node '%s'.", name(),
100               i, other_name));
101     }
102     GenNode* other_node = other_it->second.get();
103 
104     int this_position = other_position < 0 ? -1 : (is_commutative ? 0 : i);
105 
106     if (this_position >= 0 && n_multi_inputs == 0 &&
107         this_position >= n_named_inputs) {
108       return Status(
109           error::INVALID_ARGUMENT,
110           absl::StrFormat(
111               "Node '%s' has a non-control input from '%s' at index %d but its "
112               "operation '%s' defines only %d inputs.",
113               name(), other_name, this_position, op_->name(), n_named_inputs));
114     }
115 
116     Port this_port(/*inbound=*/true, this_position);
117     Port other_port(/*inbound=*/false, other_position);
118 
119     links_[this_port].emplace_back(LinkTarget(other_node, other_port));
120     other_node->links_[other_port].emplace_back(LinkTarget(this, this_port));
121   }
122   return OkStatus();
123 }
124 
IsMultiInput(Port port) const125 bool GenNode::IsMultiInput(Port port) const {
126   if (!port.IsInbound()) {
127     return false;
128   }
129   auto it = links_.find(port);
130   if (it == links_.end()) {
131     return false;  // Shouldn't happen.
132   }
133   return (it->second.size() > 1);
134 }
135 
operator string() const136 GenNode::Port::operator string() const {
137   string result = this->IsInbound() ? "i" : "o";
138   if (this->IsControl()) {
139     result.append("C");
140   } else {
141     result.append(absl::StrFormat("%d", this->Id()));
142   }
143   return result;
144 }
145 
146 }  // end namespace graph_analyzer
147 }  // end namespace grappler
148 }  // end namespace tensorflow
149