1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/node_builder.h"
17
18 #include <unordered_map>
19 #include <vector>
20
21 #include "tensorflow/core/framework/full_type.pb.h"
22 #include "tensorflow/core/framework/full_type_util.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/statusor.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29
30 namespace tensorflow {
31
NodeOut(Node * n,int32_t i)32 NodeBuilder::NodeOut::NodeOut(Node* n, int32_t i) // NOLINT(runtime/explicit)
33 : node(n),
34 error(false),
35 name(node != nullptr ? node->name() : (error = true, "")),
36 index(i),
37 dt(SafeGetOutput(node, i, &error)) {}
38
NodeOut(OutputTensor t)39 NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {}
40
NodeOut(StringPiece n,int32_t i,DataType t)41 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32_t i, DataType t)
42 : node(nullptr), error(false), name(n), index(i), dt(t) {}
43
NodeOut()44 NodeBuilder::NodeOut::NodeOut()
45 : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
46
NodeBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)47 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
48 const OpRegistryInterface* op_registry,
49 const NodeDebugInfo* debug)
50 : def_builder_(name, op_name, op_registry, debug) {}
51
NodeBuilder(StringPiece name,const OpDef * op_def)52 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
53 : def_builder_(name, op_def) {}
54
NodeBuilder(const NodeDefBuilder & def_builder)55 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
56 : def_builder_(def_builder) {}
57
Input(Node * src_node,int src_index)58 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
59 inputs_.emplace_back(src_node, src_index);
60 DataType dt;
61 if (GetOutputType(src_node, src_index, &dt)) {
62 def_builder_.Input(src_node->name(), src_index, dt);
63 }
64 return *this;
65 }
66
Input(NodeOut src)67 NodeBuilder& NodeBuilder::Input(NodeOut src) {
68 if (src.error) {
69 AddIndexError(src.node, src.index);
70 } else {
71 inputs_.emplace_back(src.node, src.index);
72 def_builder_.Input(src.name, src.index, src.dt);
73 }
74 return *this;
75 }
76
Input(gtl::ArraySlice<NodeOut> src_list)77 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
78 std::vector<NodeDefBuilder::NodeOut> srcs;
79 srcs.reserve(src_list.size());
80 for (const auto& node_out : src_list) {
81 if (node_out.error) {
82 AddIndexError(node_out.node, node_out.index);
83 } else {
84 srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
85 inputs_.emplace_back(node_out.node, node_out.index);
86 }
87 }
88 def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
89 return *this;
90 }
91
ControlInput(Node * src_node)92 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
93 control_inputs_.emplace_back(src_node);
94 def_builder_.ControlInput(src_node->name());
95 return *this;
96 }
97
ControlInputs(gtl::ArraySlice<Node * > src_nodes)98 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
99 control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
100 src_nodes.end());
101 for (const Node* src_node : src_nodes) {
102 def_builder_.ControlInput(src_node->name());
103 }
104 return *this;
105 }
106
Device(StringPiece device_spec)107 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
108 def_builder_.Device(device_spec);
109 return *this;
110 }
111
AssignedDevice(StringPiece device)112 NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
113 assigned_device_ = string(device);
114 return *this;
115 }
116
XlaCluster(StringPiece xla_cluster)117 NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
118 def_builder_.Attr("_XlaCluster", xla_cluster);
119 return *this;
120 }
121
122 namespace {
123
run_type_constructor(Graph * graph,const NodeDef & node_def)124 StatusOr<FullTypeDef> run_type_constructor(Graph* graph,
125 const NodeDef& node_def) {
126 // TODO(mdan): Decouple this from graph building, or run again after.
127 const auto* op_registry = graph->op_registry();
128 const tensorflow::OpRegistrationData* op_reg_data;
129 TF_RETURN_IF_ERROR(op_registry->LookUp(node_def.op(), &op_reg_data));
130 if (op_reg_data->type_ctor == nullptr) {
131 // Default to the default unset type.
132 return FullTypeDef();
133 }
134
135 // TODO(mdan): Do we still need to save this info in the Graph object?
136 return full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def);
137 }
138
139 } // namespace
140
Finalize(Graph * graph,Node ** created_node,bool consume)141 Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
142 // In case of error, set *created_node to nullptr.
143 if (created_node != nullptr) *created_node = nullptr;
144 if (!errors_.empty()) {
145 return errors::InvalidArgument(absl::StrJoin(errors_, "\n"));
146 }
147
148 NodeDef node_def;
149 TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
150 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
151 TF_RETURN_IF_ERROR(
152 CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
153
154 const auto ret = run_type_constructor(graph, node_def);
155 TF_RETURN_IF_ERROR(ret.status());
156
157 Status status;
158 Node* node = graph->AddNode(std::move(node_def), &status);
159 TF_RETURN_IF_ERROR(status);
160
161 FullTypeDef ft = ret.ValueOrDie();
162 if (ft.type_id() != TFT_UNSET) {
163 graph->SetNodeType(node->name(), ft);
164 }
165
166 node->set_assigned_device_name(assigned_device_);
167
168 for (size_t i = 0; i < inputs_.size(); ++i) {
169 if (inputs_[i].node != nullptr) { // Skip back edges.
170 graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
171 }
172 }
173 for (Node* control_input : control_inputs_) {
174 graph->AddControlEdge(control_input, node);
175 }
176 if (created_node != nullptr) *created_node = node;
177 return Status::OK();
178 }
179
AddIndexError(const Node * node,int i)180 void NodeBuilder::AddIndexError(const Node* node, int i) {
181 if (node == nullptr) {
182 errors_.emplace_back(
183 strings::StrCat("Attempt to add nullptr Node to node with type ",
184 def_builder_.op_def().name()));
185 } else {
186 errors_.emplace_back(strings::StrCat(
187 "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
188 node->num_outputs(), ") to node with type ",
189 def_builder_.op_def().name(), ". Node: ", FormatNodeForError(*node)));
190 }
191 }
192
GetOutputType(const Node * node,int i,DataType * dt)193 bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
194 bool error;
195 *dt = SafeGetOutput(node, i, &error);
196 if (error) AddIndexError(node, i);
197 return !error;
198 }
199
200 } // namespace tensorflow
201