• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_builder.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 
25 namespace tensorflow {
26 
NodeOut(StringPiece n,int i,DataType dt)27 NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
28     : node(n), index(i), data_type(dt) {}
29 
NodeOut()30 NodeDefBuilder::NodeOut::NodeOut() {
31   // uninitialized, call Reset() before use.
32 }
33 
Reset(StringPiece n,int i,DataType dt)34 void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
35   node = string(n);
36   index = i;
37   data_type = dt;
38 }
39 
NodeDefBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)40 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
41                                const OpRegistryInterface* op_registry,
42                                const NodeDebugInfo* debug) {
43   node_def_.set_name(string(name));
44   const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_);
45   if (status.ok()) {
46     Initialize();
47   } else {
48     errors_.push_back(status.error_message());
49     inputs_specified_ = 0;
50   }
51   if (debug != nullptr) MergeDebugInfo(*debug, &node_def_);
52 }
53 
NodeDefBuilder(StringPiece name,StringPiece op_name,const NodeDebugInfo & debug)54 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
55                                const NodeDebugInfo& debug)
56     : NodeDefBuilder(name, op_name) {
57   MergeDebugInfo(debug, &node_def_);
58 }
59 
NodeDefBuilder(StringPiece name,const OpDef * op_def)60 NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
61     : op_def_(op_def) {
62   node_def_.set_name(string(name));
63   Initialize();
64 }
65 
Initialize()66 void NodeDefBuilder::Initialize() {
67   inputs_specified_ = 0;
68   node_def_.set_op(op_def_->name());
69 }
70 
NextArgDef()71 const OpDef::ArgDef* NodeDefBuilder::NextArgDef() {
72   if (!NextArgAvailable()) return nullptr;
73   return &op_def_->input_arg(inputs_specified_++);
74 }
75 
NextArgAvailable()76 bool NodeDefBuilder::NextArgAvailable() {
77   if (op_def_ == nullptr) {
78     return false;
79   } else if (inputs_specified_ >= op_def_->input_arg_size()) {
80     errors_.push_back(strings::StrCat("More Input() calls than the ",
81                                       op_def_->input_arg_size(),
82                                       " input_args"));
83     return false;
84   }
85   return true;
86 }
87 
Input(FakeInputFunctor fake_input)88 NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
89   if (NextArgAvailable()) {
90     Status status = fake_input(*op_def_, inputs_specified_, node_def_, this);
91     if (!status.ok()) errors_.push_back(status.error_message());
92   }
93   return *this;
94 }
95 
Input(StringPiece src_node,int src_index,DataType dt)96 NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
97                                       DataType dt) {
98   const OpDef::ArgDef* arg = NextArgDef();
99   if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
100   return *this;
101 }
102 
Input(const NodeOut & src)103 NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
104   Input(src.node, src.index, src.data_type);
105   return *this;
106 }
107 
108 // For inputs that take a list of tensors.
Input(gtl::ArraySlice<NodeOut> src_list)109 NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
110   const OpDef::ArgDef* arg = NextArgDef();
111   if (arg != nullptr) ListInput(arg, src_list);
112   return *this;
113 }
114 
SingleInput(const OpDef::ArgDef * input_arg,StringPiece src_node,int src_index,DataType dt)115 void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
116                                  StringPiece src_node, int src_index,
117                                  DataType dt) {
118   AddInput(src_node, src_index);
119 
120   if (!input_arg->number_attr().empty() ||
121       !input_arg->type_list_attr().empty()) {
122     errors_.push_back(strings::StrCat("Single tensor passed to '",
123                                       input_arg->name(), "', expected list"));
124     return;
125   }
126 
127   if (input_arg->type() != DT_INVALID) {
128     const DataType expected = MaybeAddRef(input_arg, input_arg->type());
129     VerifyInputType(input_arg, expected, dt);
130   } else {
131     VerifyInputRef(input_arg, dt);
132     Attr(input_arg->type_attr(), BaseType(dt));
133   }
134 }
135 
ListInput(const OpDef::ArgDef * input_arg,gtl::ArraySlice<NodeOut> src_list)136 void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
137                                gtl::ArraySlice<NodeOut> src_list) {
138   for (const auto& node_out : src_list) {
139     AddInput(node_out.node, node_out.index);
140   }
141 
142   if (!input_arg->number_attr().empty()) {
143     Attr(input_arg->number_attr(), static_cast<int64>(src_list.size()));
144     if (input_arg->type() != DT_INVALID) {
145       const DataType expected = MaybeAddRef(input_arg, input_arg->type());
146       for (const auto& node_out : src_list) {
147         VerifyInputType(input_arg, expected, node_out.data_type);
148       }
149     } else if (!src_list.empty()) {
150       const DataType base = BaseType(src_list[0].data_type);
151       Attr(input_arg->type_attr(), base);
152       const DataType expected = MaybeAddRef(input_arg, base);
153       for (const auto& node_out : src_list) {
154         VerifyInputType(input_arg, expected, node_out.data_type);
155       }
156     }
157   } else if (!input_arg->type_list_attr().empty()) {
158     DataTypeVector type_vec;
159     type_vec.reserve(src_list.size());
160     for (const auto& node_out : src_list) {
161       const DataType dt = node_out.data_type;
162       VerifyInputRef(input_arg, dt);
163       type_vec.push_back(BaseType(dt));
164     }
165     Attr(input_arg->type_list_attr(), type_vec);
166   } else {
167     errors_.push_back(strings::StrCat("List provided to input '",
168                                       input_arg->name(),
169                                       "' when single Tensor expected"));
170   }
171 }
172 
AddInput(StringPiece src_node,int src_index)173 void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
174   if (src_node.empty()) {
175     errors_.push_back("Empty input node name");
176   } else if (src_node[0] == '^') {
177     errors_.push_back(
178         strings::StrCat("Non-control input starting with ^: ", src_node));
179   } else if (src_index > 0) {
180     node_def_.add_input(strings::StrCat(src_node, ":", src_index));
181   } else {
182     node_def_.add_input(string(src_node));
183   }
184 }
185 
VerifyInputType(const OpDef::ArgDef * input_arg,DataType expected,DataType dt)186 void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg,
187                                      DataType expected, DataType dt) {
188   if (!TypesCompatible(expected, dt)) {
189     errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
190                                       DataTypeString(dt), " expected ",
191                                       DataTypeString(expected)));
192   }
193 }
194 
VerifyInputRef(const OpDef::ArgDef * input_arg,DataType dt)195 void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
196                                     DataType dt) {
197   if (input_arg->is_ref() && !IsRefType(dt)) {
198     errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
199                                       DataTypeString(dt),
200                                       " expected ref type"));
201   }
202 }
203 
ControlInput(StringPiece src_node)204 NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
205   control_inputs_.emplace_back(src_node);
206   return *this;
207 }
208 
Device(StringPiece device_spec)209 NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
210   node_def_.set_device(string(device_spec));
211   return *this;
212 }
213 
Finalize(NodeDef * node_def,bool consume)214 Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
215   const std::vector<string>* errors_ptr = &errors_;
216   std::vector<string> errors_storage;
217   if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
218     // Since this is a const method, to add an error, we have to make
219     // a copy of the existing errors.
220     errors_storage = errors_;
221     errors_storage.push_back(
222         strings::StrCat(inputs_specified_, " inputs specified of ",
223                         op_def_->input_arg_size(), " inputs in Op"));
224     errors_ptr = &errors_storage;
225   }
226 
227   if (!errors_ptr->empty()) {
228     if (errors_ptr->size() == 1) {
229       if (op_def_ == nullptr) {
230         return errors::InvalidArgument((*errors_ptr)[0],
231                                        " while building NodeDef '",
232                                        node_def_.name(), "'");
233       }
234       return errors::InvalidArgument(
235           (*errors_ptr)[0], " while building NodeDef '", node_def_.name(),
236           "' using ", SummarizeOpDef(*op_def_));
237     } else {
238       return errors::InvalidArgument(
239           errors_ptr->size(), " errors while building NodeDef '",
240           node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
241           absl::StrJoin(*errors_ptr, "\n"));
242     }
243   } else {
244     NodeDef node_def_backup;
245     if (node_def == nullptr) node_def = &node_def_backup;
246     if (consume) {
247       *node_def = std::move(node_def_);
248     } else {
249       *node_def = node_def_;
250     }
251 
252     // Add control inputs after the regular inputs.
253     for (const auto& control_input : control_inputs_) {
254       node_def->add_input(strings::StrCat("^", control_input));
255     }
256 
257     // Add default values for unspecified attrs.
258     AddDefaultsToNodeDef(*op_def_, node_def);
259 
260     return Status::OK();
261   }
262 }
263 
AttrValueAlreadyPresent(StringPiece name,const AttrValue & value)264 bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
265                                              const AttrValue& value) {
266   if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
267     if (!AreAttrValuesEqual(*found, value)) {
268       errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
269                                         "' ", SummarizeAttrValue(*found),
270                                         " vs. ", SummarizeAttrValue(value)));
271     }
272     return true;
273   }
274   return false;
275 }
276 
Attr(StringPiece name,const AttrValue & value)277 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
278   if (!AttrValueAlreadyPresent(name, value)) {
279     AddNodeAttr(name, value, &node_def_);
280   }
281   return *this;
282 }
283 
Attr(StringPiece name,AttrValue && value)284 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
285   if (!AttrValueAlreadyPresent(name, value)) {
286     AddNodeAttr(name, std::move(value), &node_def_);
287   }
288   return *this;
289 }
290 
291 #define ATTR(T)                                                     \
292   NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
293     AttrValue attr_value;                                           \
294     SetAttrValue(value, &attr_value);                               \
295     return Attr(name, attr_value);                                  \
296   }
297 ATTR(StringPiece)
298 ATTR(const char*)
299 ATTR(int32)
300 ATTR(int64)
301 ATTR(float)
302 ATTR(double)
303 ATTR(bool)
304 ATTR(DataType)
305 ATTR(const PartialTensorShape&)
306 ATTR(const Tensor&)
307 ATTR(const TensorProto&)
308 ATTR(const NameAttrList&)
309 ATTR(gtl::ArraySlice<StringPiece>)
310 ATTR(gtl::ArraySlice<const char*>)
311 ATTR(gtl::ArraySlice<string>)
312 ATTR(gtl::ArraySlice<tstring>)
313 ATTR(gtl::ArraySlice<int32>)
314 ATTR(gtl::ArraySlice<int64>)
315 ATTR(gtl::ArraySlice<float>)
316 ATTR(gtl::ArraySlice<bool>)
317 ATTR(const std::vector<bool>&)
318 ATTR(gtl::ArraySlice<DataType>)
319 ATTR(gtl::ArraySlice<TensorShape>)
320 ATTR(gtl::ArraySlice<PartialTensorShape>)
321 ATTR(gtl::ArraySlice<TensorShapeProto>)
322 ATTR(gtl::ArraySlice<Tensor>)
323 ATTR(gtl::ArraySlice<NameAttrList>)
324 #undef ATTR
325 
326 }  // namespace tensorflow
327