1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
17 #include "absl/memory/memory.h"
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
21 #include "tensorflow/core/grappler/op_types.h"
22 #include "tensorflow/core/grappler/utils.h"
23
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_analyzer {
27
GenNode(const NodeDef * node)28 GenNode::GenNode(const NodeDef* node) : node_(node), op_(nullptr) {}
29
BuildGraphInMap(const GraphDef & source,GenNodeMap * map)30 Status GenNode::BuildGraphInMap(const GraphDef& source, GenNodeMap* map) {
31 for (const auto& n : source.node()) {
32 const string& name = n.name();
33 if (map->find(name) != map->end()) {
34 // This error code looks more meaningful than ALREADY_EXISTS.
35 return Status(error::INVALID_ARGUMENT,
36 "Duplicate node name '" + name + "'.");
37 }
38 (*map)[name] = absl::make_unique<GenNode>(&n);
39 }
40 // Now parse the links.
41 for (const auto& mapit : *map) {
42 Status st = mapit.second->ParseInputs(map);
43 if (!st.ok()) {
44 return st;
45 }
46 }
47 return Status::OK();
48 }
49
ParseInputs(const GenNodeMap * map)50 Status GenNode::ParseInputs(const GenNodeMap* map) {
51 all_inputs_or_none_ = false;
52 Status st = OpRegistry::Global()->LookUpOpDef(opcode(), &op_);
53 if (!st.ok()) {
54 return Status(
55 error::INVALID_ARGUMENT,
56 absl::StrFormat("Node '%s' contains an undefined operation '%s': %s",
57 name(), opcode(), st.error_message()));
58 }
59
60 int n_inputs = node_->input_size();
61
62 int n_named_inputs = op_->input_arg_size();
63
64 int n_multi_inputs = 0;
65 for (const auto& inarg : op_->input_arg()) {
66 if (!inarg.number_attr().empty() || !inarg.type_list_attr().empty()) {
67 ++n_multi_inputs;
68 }
69 }
70 bool is_commutative = grappler::IsCommutative(*node_);
71
72 if (n_multi_inputs > 1 || (n_multi_inputs > 0 && n_named_inputs > 1)) {
73 // Can't handle more than one multi-input at a time.
74 // And can't handle the commutativeness of only some arguments
75 // rather than all of them.
76 is_commutative = false;
77 }
78
79 if (is_commutative) {
80 // If truly commutative, can treat all the inputs as one multi-input.
81 // It's possible to just treat the commutative nodes as AllInputsOrNone
82 // but (1) this way is a bit more efficient and (2) I want to preserve this
83 // more efficient code path that does all-or-none by a single input and
84 // perhaps extend its use in the future.
85 n_named_inputs = 1;
86 all_inputs_or_none_ = false;
87 } else if (n_multi_inputs > 0) {
88 all_inputs_or_none_ = true;
89 }
90
91 for (int i = 0; i < n_inputs; ++i) {
92 int other_position;
93 string other_name = ParseNodeName(node_->input(i), &other_position);
94 auto other_it = map->find(other_name);
95 if (other_it == map->end()) {
96 return Status(
97 error::INVALID_ARGUMENT,
98 absl::StrFormat(
99 "Node '%s' input %d refers to a non-existing node '%s'.", name(),
100 i, other_name));
101 }
102 GenNode* other_node = other_it->second.get();
103
104 int this_position = other_position < 0 ? -1 : (is_commutative ? 0 : i);
105
106 if (this_position >= 0 && n_multi_inputs == 0 &&
107 this_position >= n_named_inputs) {
108 return Status(
109 error::INVALID_ARGUMENT,
110 absl::StrFormat(
111 "Node '%s' has a non-control input from '%s' at index %d but its "
112 "operation '%s' defines only %d inputs.",
113 name(), other_name, this_position, op_->name(), n_named_inputs));
114 }
115
116 Port this_port(/*inbound=*/true, this_position);
117 Port other_port(/*inbound=*/false, other_position);
118
119 links_[this_port].emplace_back(LinkTarget(other_node, other_port));
120 other_node->links_[other_port].emplace_back(LinkTarget(this, this_port));
121 }
122 return Status::OK();
123 }
124
IsMultiInput(Port port) const125 bool GenNode::IsMultiInput(Port port) const {
126 if (!port.IsInbound()) {
127 return false;
128 }
129 auto it = links_.find(port);
130 if (it == links_.end()) {
131 return false; // Shouldn't happen.
132 }
133 return (it->second.size() > 1);
134 }
135
operator string() const136 GenNode::Port::operator string() const {
137 string result = this->IsInbound() ? "i" : "o";
138 if (this->IsControl()) {
139 result.append("C");
140 } else {
141 result.append(absl::StrFormat("%d", this->Id()));
142 }
143 return result;
144 }
145
146 } // end namespace graph_analyzer
147 } // end namespace grappler
148 } // end namespace tensorflow
149