1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_builder.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24
25 namespace tensorflow {
26
NodeOut(StringPiece n,int i,DataType dt)27 NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
28 : node(n), index(i), data_type(dt) {}
29
NodeOut()30 NodeDefBuilder::NodeOut::NodeOut() {
31 // uninitialized, call Reset() before use.
32 }
33
Reset(StringPiece n,int i,DataType dt)34 void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
35 node = string(n);
36 index = i;
37 data_type = dt;
38 }
39
NodeDefBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)40 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
41 const OpRegistryInterface* op_registry,
42 const NodeDebugInfo* debug) {
43 node_def_.set_name(string(name));
44 const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_);
45 if (status.ok()) {
46 Initialize();
47 } else {
48 errors_.push_back(status.error_message());
49 inputs_specified_ = 0;
50 }
51 if (debug != nullptr) MergeDebugInfo(*debug, &node_def_);
52 }
53
NodeDefBuilder(StringPiece name,StringPiece op_name,const NodeDebugInfo & debug)54 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
55 const NodeDebugInfo& debug)
56 : NodeDefBuilder(name, op_name) {
57 MergeDebugInfo(debug, &node_def_);
58 }
59
NodeDefBuilder(StringPiece name,const OpDef * op_def)60 NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
61 : op_def_(op_def) {
62 node_def_.set_name(string(name));
63 Initialize();
64 }
65
Initialize()66 void NodeDefBuilder::Initialize() {
67 inputs_specified_ = 0;
68 node_def_.set_op(op_def_->name());
69 }
70
NextArgDef()71 const OpDef::ArgDef* NodeDefBuilder::NextArgDef() {
72 if (!NextArgAvailable()) return nullptr;
73 return &op_def_->input_arg(inputs_specified_++);
74 }
75
NextArgAvailable()76 bool NodeDefBuilder::NextArgAvailable() {
77 if (op_def_ == nullptr) {
78 return false;
79 } else if (inputs_specified_ >= op_def_->input_arg_size()) {
80 errors_.push_back(strings::StrCat("More Input() calls than the ",
81 op_def_->input_arg_size(),
82 " input_args"));
83 return false;
84 }
85 return true;
86 }
87
Input(FakeInputFunctor fake_input)88 NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
89 if (NextArgAvailable()) {
90 Status status = fake_input(*op_def_, inputs_specified_, node_def_, this);
91 if (!status.ok()) errors_.push_back(status.error_message());
92 }
93 return *this;
94 }
95
Input(StringPiece src_node,int src_index,DataType dt)96 NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
97 DataType dt) {
98 const OpDef::ArgDef* arg = NextArgDef();
99 if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
100 return *this;
101 }
102
Input(const NodeOut & src)103 NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
104 Input(src.node, src.index, src.data_type);
105 return *this;
106 }
107
108 // For inputs that take a list of tensors.
Input(gtl::ArraySlice<NodeOut> src_list)109 NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
110 const OpDef::ArgDef* arg = NextArgDef();
111 if (arg != nullptr) ListInput(arg, src_list);
112 return *this;
113 }
114
SingleInput(const OpDef::ArgDef * input_arg,StringPiece src_node,int src_index,DataType dt)115 void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
116 StringPiece src_node, int src_index,
117 DataType dt) {
118 AddInput(src_node, src_index);
119
120 if (!input_arg->number_attr().empty() ||
121 !input_arg->type_list_attr().empty()) {
122 errors_.push_back(strings::StrCat("Single tensor passed to '",
123 input_arg->name(), "', expected list"));
124 return;
125 }
126
127 if (input_arg->type() != DT_INVALID) {
128 const DataType expected = MaybeAddRef(input_arg, input_arg->type());
129 VerifyInputType(input_arg, expected, dt);
130 } else {
131 VerifyInputRef(input_arg, dt);
132 Attr(input_arg->type_attr(), BaseType(dt));
133 }
134 }
135
ListInput(const OpDef::ArgDef * input_arg,gtl::ArraySlice<NodeOut> src_list)136 void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
137 gtl::ArraySlice<NodeOut> src_list) {
138 for (const auto& node_out : src_list) {
139 AddInput(node_out.node, node_out.index);
140 }
141
142 if (!input_arg->number_attr().empty()) {
143 Attr(input_arg->number_attr(), static_cast<int64>(src_list.size()));
144 if (input_arg->type() != DT_INVALID) {
145 const DataType expected = MaybeAddRef(input_arg, input_arg->type());
146 for (const auto& node_out : src_list) {
147 VerifyInputType(input_arg, expected, node_out.data_type);
148 }
149 } else if (!src_list.empty()) {
150 const DataType base = BaseType(src_list[0].data_type);
151 Attr(input_arg->type_attr(), base);
152 const DataType expected = MaybeAddRef(input_arg, base);
153 for (const auto& node_out : src_list) {
154 VerifyInputType(input_arg, expected, node_out.data_type);
155 }
156 }
157 } else if (!input_arg->type_list_attr().empty()) {
158 DataTypeVector type_vec;
159 type_vec.reserve(src_list.size());
160 for (const auto& node_out : src_list) {
161 const DataType dt = node_out.data_type;
162 VerifyInputRef(input_arg, dt);
163 type_vec.push_back(BaseType(dt));
164 }
165 Attr(input_arg->type_list_attr(), type_vec);
166 } else {
167 errors_.push_back(strings::StrCat("List provided to input '",
168 input_arg->name(),
169 "' when single Tensor expected"));
170 }
171 }
172
AddInput(StringPiece src_node,int src_index)173 void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
174 if (src_node.empty()) {
175 errors_.push_back("Empty input node name");
176 } else if (src_node[0] == '^') {
177 errors_.push_back(
178 strings::StrCat("Non-control input starting with ^: ", src_node));
179 } else if (src_index > 0) {
180 node_def_.add_input(strings::StrCat(src_node, ":", src_index));
181 } else {
182 node_def_.add_input(string(src_node));
183 }
184 }
185
VerifyInputType(const OpDef::ArgDef * input_arg,DataType expected,DataType dt)186 void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg,
187 DataType expected, DataType dt) {
188 if (!TypesCompatible(expected, dt)) {
189 errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
190 DataTypeString(dt), " expected ",
191 DataTypeString(expected)));
192 }
193 }
194
VerifyInputRef(const OpDef::ArgDef * input_arg,DataType dt)195 void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
196 DataType dt) {
197 if (input_arg->is_ref() && !IsRefType(dt)) {
198 errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
199 DataTypeString(dt),
200 " expected ref type"));
201 }
202 }
203
ControlInput(StringPiece src_node)204 NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
205 control_inputs_.emplace_back(src_node);
206 return *this;
207 }
208
Device(StringPiece device_spec)209 NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
210 node_def_.set_device(string(device_spec));
211 return *this;
212 }
213
Finalize(NodeDef * node_def,bool consume)214 Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
215 const std::vector<string>* errors_ptr = &errors_;
216 std::vector<string> errors_storage;
217 if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
218 // Since this is a const method, to add an error, we have to make
219 // a copy of the existing errors.
220 errors_storage = errors_;
221 errors_storage.push_back(
222 strings::StrCat(inputs_specified_, " inputs specified of ",
223 op_def_->input_arg_size(), " inputs in Op"));
224 errors_ptr = &errors_storage;
225 }
226
227 if (!errors_ptr->empty()) {
228 if (errors_ptr->size() == 1) {
229 if (op_def_ == nullptr) {
230 return errors::InvalidArgument((*errors_ptr)[0],
231 " while building NodeDef '",
232 node_def_.name(), "'");
233 }
234 return errors::InvalidArgument(
235 (*errors_ptr)[0], " while building NodeDef '", node_def_.name(),
236 "' using ", SummarizeOpDef(*op_def_));
237 } else {
238 return errors::InvalidArgument(
239 errors_ptr->size(), " errors while building NodeDef '",
240 node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
241 absl::StrJoin(*errors_ptr, "\n"));
242 }
243 } else {
244 NodeDef node_def_backup;
245 if (node_def == nullptr) node_def = &node_def_backup;
246 if (consume) {
247 *node_def = std::move(node_def_);
248 } else {
249 *node_def = node_def_;
250 }
251
252 // Add control inputs after the regular inputs.
253 for (const auto& control_input : control_inputs_) {
254 node_def->add_input(strings::StrCat("^", control_input));
255 }
256
257 // Add default values for unspecified attrs.
258 AddDefaultsToNodeDef(*op_def_, node_def);
259
260 return Status::OK();
261 }
262 }
263
AttrValueAlreadyPresent(StringPiece name,const AttrValue & value)264 bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
265 const AttrValue& value) {
266 if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
267 if (!AreAttrValuesEqual(*found, value)) {
268 errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
269 "' ", SummarizeAttrValue(*found),
270 " vs. ", SummarizeAttrValue(value)));
271 }
272 return true;
273 }
274 return false;
275 }
276
Attr(StringPiece name,const AttrValue & value)277 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
278 if (!AttrValueAlreadyPresent(name, value)) {
279 AddNodeAttr(name, value, &node_def_);
280 }
281 return *this;
282 }
283
Attr(StringPiece name,AttrValue && value)284 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
285 if (!AttrValueAlreadyPresent(name, value)) {
286 AddNodeAttr(name, std::move(value), &node_def_);
287 }
288 return *this;
289 }
290
291 #define ATTR(T) \
292 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
293 AttrValue attr_value; \
294 SetAttrValue(value, &attr_value); \
295 return Attr(name, attr_value); \
296 }
297 ATTR(StringPiece)
298 ATTR(const char*)
299 ATTR(int32)
300 ATTR(int64)
301 ATTR(float)
302 ATTR(double)
303 ATTR(bool)
304 ATTR(DataType)
305 ATTR(const PartialTensorShape&)
306 ATTR(const Tensor&)
307 ATTR(const TensorProto&)
308 ATTR(const NameAttrList&)
309 ATTR(gtl::ArraySlice<StringPiece>)
310 ATTR(gtl::ArraySlice<const char*>)
311 ATTR(gtl::ArraySlice<string>)
312 ATTR(gtl::ArraySlice<tstring>)
313 ATTR(gtl::ArraySlice<int32>)
314 ATTR(gtl::ArraySlice<int64>)
315 ATTR(gtl::ArraySlice<float>)
316 ATTR(gtl::ArraySlice<bool>)
317 ATTR(const std::vector<bool>&)
318 ATTR(gtl::ArraySlice<DataType>)
319 ATTR(gtl::ArraySlice<TensorShape>)
320 ATTR(gtl::ArraySlice<PartialTensorShape>)
321 ATTR(gtl::ArraySlice<TensorShapeProto>)
322 ATTR(gtl::ArraySlice<Tensor>)
323 ATTR(gtl::ArraySlice<NameAttrList>)
324 #undef ATTR
325
326 } // namespace tensorflow
327