1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_builder.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24
25 namespace tensorflow {
26
NodeOut(StringPiece n,int i,DataType dt)27 NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
28 : node(n), index(i), data_type(dt) {}
29
NodeOut()30 NodeDefBuilder::NodeOut::NodeOut() {
31 // uninitialized, call Reset() before use.
32 }
33
Reset(StringPiece n,int i,DataType dt)34 void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
35 node = string(n);
36 index = i;
37 data_type = dt;
38 }
39
NodeDefBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)40 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
41 const OpRegistryInterface* op_registry,
42 const NodeDebugInfo* debug) {
43 node_def_.set_name(string(name));
44 const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_);
45 if (status.ok()) {
46 Initialize();
47 } else {
48 errors_.push_back(status.error_message());
49 inputs_specified_ = 0;
50 }
51 if (debug != nullptr) MergeDebugInfo(*debug, &node_def_);
52 }
53
NodeDefBuilder(StringPiece name,StringPiece op_name,const NodeDebugInfo & debug)54 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
55 const NodeDebugInfo& debug)
56 : NodeDefBuilder(name, op_name) {
57 MergeDebugInfo(debug, &node_def_);
58 }
59
NodeDefBuilder(StringPiece name,const OpDef * op_def)60 NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
61 : op_def_(op_def) {
62 node_def_.set_name(string(name));
63 Initialize();
64 }
65
Initialize()66 void NodeDefBuilder::Initialize() {
67 inputs_specified_ = 0;
68 node_def_.set_op(op_def_->name());
69 }
70
NextArgDef()71 const OpDef::ArgDef* NodeDefBuilder::NextArgDef() {
72 if (!NextArgAvailable()) return nullptr;
73 return &op_def_->input_arg(inputs_specified_++);
74 }
75
NextArgAvailable()76 bool NodeDefBuilder::NextArgAvailable() {
77 if (op_def_ == nullptr) {
78 return false;
79 } else if (inputs_specified_ >= op_def_->input_arg_size()) {
80 errors_.push_back(strings::StrCat("More Input() calls than the ",
81 op_def_->input_arg_size(),
82 " input_args"));
83 return false;
84 }
85 return true;
86 }
87
Input(FakeInputFunctor fake_input)88 NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
89 if (NextArgAvailable()) {
90 Status status = fake_input(*op_def_, inputs_specified_, node_def_, this);
91 if (!status.ok()) errors_.push_back(status.error_message());
92 }
93 return *this;
94 }
95
Input(StringPiece src_node,int src_index,DataType dt)96 NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
97 DataType dt) {
98 const OpDef::ArgDef* arg = NextArgDef();
99 if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
100 return *this;
101 }
102
Input(const NodeOut & src)103 NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
104 Input(src.node, src.index, src.data_type);
105 return *this;
106 }
107
108 // For inputs that take a list of tensors.
Input(gtl::ArraySlice<NodeOut> src_list)109 NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
110 const OpDef::ArgDef* arg = NextArgDef();
111 if (arg != nullptr) ListInput(arg, src_list);
112 return *this;
113 }
114
SingleInput(const OpDef::ArgDef * input_arg,StringPiece src_node,int src_index,DataType dt)115 void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
116 StringPiece src_node, int src_index,
117 DataType dt) {
118 AddInput(src_node, src_index);
119
120 if (!input_arg->number_attr().empty() ||
121 !input_arg->type_list_attr().empty()) {
122 errors_.push_back(strings::StrCat("Single tensor passed to '",
123 input_arg->name(), "', expected list"));
124 return;
125 }
126
127 if (input_arg->type() != DT_INVALID) {
128 const DataType expected = MaybeAddRef(input_arg, input_arg->type());
129 VerifyInputType(input_arg, expected, dt);
130 } else {
131 VerifyInputRef(input_arg, dt);
132 Attr(input_arg->type_attr(), BaseType(dt));
133 }
134 }
135
ListInput(const OpDef::ArgDef * input_arg,gtl::ArraySlice<NodeOut> src_list)136 void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
137 gtl::ArraySlice<NodeOut> src_list) {
138 for (const auto& node_out : src_list) {
139 AddInput(node_out.node, node_out.index);
140 }
141
142 if (!input_arg->number_attr().empty()) {
143 Attr(input_arg->number_attr(), static_cast<int64>(src_list.size()));
144 if (input_arg->type() != DT_INVALID) {
145 const DataType expected = MaybeAddRef(input_arg, input_arg->type());
146 for (const auto& node_out : src_list) {
147 VerifyInputType(input_arg, expected, node_out.data_type);
148 }
149 } else if (!src_list.empty()) {
150 const DataType base = BaseType(src_list[0].data_type);
151 Attr(input_arg->type_attr(), base);
152 const DataType expected = MaybeAddRef(input_arg, base);
153 for (const auto& node_out : src_list) {
154 VerifyInputType(input_arg, expected, node_out.data_type);
155 }
156 }
157 } else if (!input_arg->type_list_attr().empty()) {
158 DataTypeVector type_vec;
159 type_vec.reserve(src_list.size());
160 for (const auto& node_out : src_list) {
161 const DataType dt = node_out.data_type;
162 VerifyInputRef(input_arg, dt);
163 type_vec.push_back(BaseType(dt));
164 }
165 Attr(input_arg->type_list_attr(), type_vec);
166 } else {
167 errors_.push_back(strings::StrCat("List provided to input '",
168 input_arg->name(),
169 "' when single Tensor expected"));
170 }
171 }
172
AddInput(StringPiece src_node,int src_index)173 void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
174 if (src_node.empty()) {
175 errors_.push_back("Empty input node name");
176 } else if (src_node[0] == '^') {
177 errors_.push_back(
178 strings::StrCat("Non-control input starting with ^: ", src_node));
179 } else if (src_index > 0) {
180 node_def_.add_input(strings::StrCat(src_node, ":", src_index));
181 } else {
182 node_def_.add_input(string(src_node));
183 }
184 }
185
VerifyInputType(const OpDef::ArgDef * input_arg,DataType expected,DataType dt)186 void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg,
187 DataType expected, DataType dt) {
188 if (!TypesCompatible(expected, dt)) {
189 errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
190 DataTypeString(dt), " expected ",
191 DataTypeString(expected)));
192 }
193 }
194
VerifyInputRef(const OpDef::ArgDef * input_arg,DataType dt)195 void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
196 DataType dt) {
197 if (input_arg->is_ref() && !IsRefType(dt)) {
198 errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
199 DataTypeString(dt),
200 " expected ref type"));
201 }
202 }
203
ControlInput(StringPiece src_node)204 NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
205 control_inputs_.emplace_back(src_node);
206 return *this;
207 }
208
Device(StringPiece device_spec)209 NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
210 node_def_.set_device(string(device_spec));
211 return *this;
212 }
213
Finalize(NodeDef * node_def) const214 Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
215 const std::vector<string>* errors_ptr = &errors_;
216 std::vector<string> errors_storage;
217 if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
218 // Since this is a const method, to add an error, we have to make
219 // a copy of the existing errors.
220 errors_storage = errors_;
221 errors_storage.push_back(
222 strings::StrCat(inputs_specified_, " inputs specified of ",
223 op_def_->input_arg_size(), " inputs in Op"));
224 errors_ptr = &errors_storage;
225 }
226
227 if (!errors_ptr->empty()) {
228 if (errors_ptr->size() == 1) {
229 if (op_def_ == nullptr) {
230 return errors::InvalidArgument((*errors_ptr)[0],
231 " while building NodeDef '",
232 node_def_.name(), "'");
233 }
234 return errors::InvalidArgument(
235 (*errors_ptr)[0], " while building NodeDef '", node_def_.name(),
236 "' using ", SummarizeOpDef(*op_def_));
237 } else {
238 return errors::InvalidArgument(
239 errors_ptr->size(), " errors while building NodeDef '",
240 node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
241 str_util::Join(*errors_ptr, "\n"));
242 }
243 } else {
244 NodeDef node_def_backup;
245 if (node_def == nullptr) node_def = &node_def_backup;
246 *node_def = node_def_;
247
248 // Add control inputs after the regular inputs.
249 for (const auto& control_input : control_inputs_) {
250 node_def->add_input(strings::StrCat("^", control_input));
251 }
252
253 // Add default values for unspecified attrs.
254 AddDefaultsToNodeDef(*op_def_, node_def);
255
256 return Status::OK();
257 }
258 }
259
Attr(StringPiece name,const AttrValue & value)260 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
261 if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
262 if (!AreAttrValuesEqual(*found, value)) {
263 errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
264 "' ", SummarizeAttrValue(*found),
265 " vs. ", SummarizeAttrValue(value)));
266 }
267 } else {
268 AddNodeAttr(name, value, &node_def_);
269 }
270 return *this;
271 }
272
273 #define ATTR(T) \
274 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
275 AttrValue attr_value; \
276 SetAttrValue(value, &attr_value); \
277 return Attr(name, attr_value); \
278 }
279 ATTR(StringPiece)
280 ATTR(const char*)
281 ATTR(int32)
282 ATTR(int64)
283 ATTR(float)
284 ATTR(double)
285 ATTR(bool)
286 ATTR(DataType)
287 ATTR(const PartialTensorShape&)
288 ATTR(const Tensor&)
289 ATTR(const TensorProto&)
290 ATTR(const NameAttrList&)
291 ATTR(gtl::ArraySlice<StringPiece>)
292 ATTR(gtl::ArraySlice<const char*>)
293 ATTR(gtl::ArraySlice<string>)
294 ATTR(gtl::ArraySlice<int32>)
295 ATTR(gtl::ArraySlice<int64>)
296 ATTR(gtl::ArraySlice<float>)
297 ATTR(gtl::ArraySlice<bool>)
298 ATTR(const std::vector<bool>&)
299 ATTR(gtl::ArraySlice<DataType>)
300 ATTR(gtl::ArraySlice<TensorShape>)
301 ATTR(gtl::ArraySlice<PartialTensorShape>)
302 ATTR(gtl::ArraySlice<TensorShapeProto>)
303 ATTR(gtl::ArraySlice<Tensor>)
304 ATTR(gtl::ArraySlice<NameAttrList>)
305 #undef ATTR
306
307 } // namespace tensorflow
308