• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_builder.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 
25 namespace tensorflow {
26 
NodeOut(StringPiece n,int i,DataType dt)27 NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
28     : node(n), index(i), data_type(dt) {}
29 
NodeOut()30 NodeDefBuilder::NodeOut::NodeOut() {
31   // uninitialized, call Reset() before use.
32 }
33 
Reset(StringPiece n,int i,DataType dt)34 void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
35   node = string(n);
36   index = i;
37   data_type = dt;
38 }
39 
NodeDefBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)40 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
41                                const OpRegistryInterface* op_registry,
42                                const NodeDebugInfo* debug) {
43   node_def_.set_name(string(name));
44   const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_);
45   if (status.ok()) {
46     Initialize();
47   } else {
48     errors_.push_back(status.error_message());
49     inputs_specified_ = 0;
50   }
51   if (debug != nullptr) MergeDebugInfo(*debug, &node_def_);
52 }
53 
NodeDefBuilder(StringPiece name,StringPiece op_name,const NodeDebugInfo & debug)54 NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
55                                const NodeDebugInfo& debug)
56     : NodeDefBuilder(name, op_name) {
57   MergeDebugInfo(debug, &node_def_);
58 }
59 
NodeDefBuilder(StringPiece name,const OpDef * op_def)60 NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
61     : op_def_(op_def) {
62   node_def_.set_name(string(name));
63   Initialize();
64 }
65 
Initialize()66 void NodeDefBuilder::Initialize() {
67   inputs_specified_ = 0;
68   node_def_.set_op(op_def_->name());
69 }
70 
NextArgDef()71 const OpDef::ArgDef* NodeDefBuilder::NextArgDef() {
72   if (!NextArgAvailable()) return nullptr;
73   return &op_def_->input_arg(inputs_specified_++);
74 }
75 
NextArgAvailable()76 bool NodeDefBuilder::NextArgAvailable() {
77   if (op_def_ == nullptr) {
78     return false;
79   } else if (inputs_specified_ >= op_def_->input_arg_size()) {
80     errors_.push_back(strings::StrCat("More Input() calls than the ",
81                                       op_def_->input_arg_size(),
82                                       " input_args"));
83     return false;
84   }
85   return true;
86 }
87 
Input(FakeInputFunctor fake_input)88 NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
89   if (NextArgAvailable()) {
90     Status status = fake_input(*op_def_, inputs_specified_, node_def_, this);
91     if (!status.ok()) errors_.push_back(status.error_message());
92   }
93   return *this;
94 }
95 
Input(StringPiece src_node,int src_index,DataType dt)96 NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
97                                       DataType dt) {
98   const OpDef::ArgDef* arg = NextArgDef();
99   if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
100   return *this;
101 }
102 
Input(const NodeOut & src)103 NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
104   Input(src.node, src.index, src.data_type);
105   return *this;
106 }
107 
108 // For inputs that take a list of tensors.
Input(gtl::ArraySlice<NodeOut> src_list)109 NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
110   const OpDef::ArgDef* arg = NextArgDef();
111   if (arg != nullptr) ListInput(arg, src_list);
112   return *this;
113 }
114 
SingleInput(const OpDef::ArgDef * input_arg,StringPiece src_node,int src_index,DataType dt)115 void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
116                                  StringPiece src_node, int src_index,
117                                  DataType dt) {
118   AddInput(src_node, src_index);
119 
120   if (!input_arg->number_attr().empty() ||
121       !input_arg->type_list_attr().empty()) {
122     errors_.push_back(strings::StrCat("Single tensor passed to '",
123                                       input_arg->name(), "', expected list"));
124     return;
125   }
126 
127   if (input_arg->type() != DT_INVALID) {
128     const DataType expected = MaybeAddRef(input_arg, input_arg->type());
129     VerifyInputType(input_arg, expected, dt);
130   } else {
131     VerifyInputRef(input_arg, dt);
132     Attr(input_arg->type_attr(), BaseType(dt));
133   }
134 }
135 
ListInput(const OpDef::ArgDef * input_arg,gtl::ArraySlice<NodeOut> src_list)136 void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
137                                gtl::ArraySlice<NodeOut> src_list) {
138   for (const auto& node_out : src_list) {
139     AddInput(node_out.node, node_out.index);
140   }
141 
142   if (!input_arg->number_attr().empty()) {
143     Attr(input_arg->number_attr(), static_cast<int64>(src_list.size()));
144     if (input_arg->type() != DT_INVALID) {
145       const DataType expected = MaybeAddRef(input_arg, input_arg->type());
146       for (const auto& node_out : src_list) {
147         VerifyInputType(input_arg, expected, node_out.data_type);
148       }
149     } else if (!src_list.empty()) {
150       const DataType base = BaseType(src_list[0].data_type);
151       Attr(input_arg->type_attr(), base);
152       const DataType expected = MaybeAddRef(input_arg, base);
153       for (const auto& node_out : src_list) {
154         VerifyInputType(input_arg, expected, node_out.data_type);
155       }
156     }
157   } else if (!input_arg->type_list_attr().empty()) {
158     DataTypeVector type_vec;
159     type_vec.reserve(src_list.size());
160     for (const auto& node_out : src_list) {
161       const DataType dt = node_out.data_type;
162       VerifyInputRef(input_arg, dt);
163       type_vec.push_back(BaseType(dt));
164     }
165     Attr(input_arg->type_list_attr(), type_vec);
166   } else {
167     errors_.push_back(strings::StrCat("List provided to input '",
168                                       input_arg->name(),
169                                       "' when single Tensor expected"));
170   }
171 }
172 
AddInput(StringPiece src_node,int src_index)173 void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
174   if (src_node.empty()) {
175     errors_.push_back("Empty input node name");
176   } else if (src_node[0] == '^') {
177     errors_.push_back(
178         strings::StrCat("Non-control input starting with ^: ", src_node));
179   } else if (src_index > 0) {
180     node_def_.add_input(strings::StrCat(src_node, ":", src_index));
181   } else {
182     node_def_.add_input(string(src_node));
183   }
184 }
185 
VerifyInputType(const OpDef::ArgDef * input_arg,DataType expected,DataType dt)186 void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg,
187                                      DataType expected, DataType dt) {
188   if (!TypesCompatible(expected, dt)) {
189     errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
190                                       DataTypeString(dt), " expected ",
191                                       DataTypeString(expected)));
192   }
193 }
194 
VerifyInputRef(const OpDef::ArgDef * input_arg,DataType dt)195 void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
196                                     DataType dt) {
197   if (input_arg->is_ref() && !IsRefType(dt)) {
198     errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
199                                       DataTypeString(dt),
200                                       " expected ref type"));
201   }
202 }
203 
ControlInput(StringPiece src_node)204 NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
205   control_inputs_.emplace_back(src_node);
206   return *this;
207 }
208 
Device(StringPiece device_spec)209 NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
210   node_def_.set_device(string(device_spec));
211   return *this;
212 }
213 
Finalize(NodeDef * node_def) const214 Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
215   const std::vector<string>* errors_ptr = &errors_;
216   std::vector<string> errors_storage;
217   if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
218     // Since this is a const method, to add an error, we have to make
219     // a copy of the existing errors.
220     errors_storage = errors_;
221     errors_storage.push_back(
222         strings::StrCat(inputs_specified_, " inputs specified of ",
223                         op_def_->input_arg_size(), " inputs in Op"));
224     errors_ptr = &errors_storage;
225   }
226 
227   if (!errors_ptr->empty()) {
228     if (errors_ptr->size() == 1) {
229       if (op_def_ == nullptr) {
230         return errors::InvalidArgument((*errors_ptr)[0],
231                                        " while building NodeDef '",
232                                        node_def_.name(), "'");
233       }
234       return errors::InvalidArgument(
235           (*errors_ptr)[0], " while building NodeDef '", node_def_.name(),
236           "' using ", SummarizeOpDef(*op_def_));
237     } else {
238       return errors::InvalidArgument(
239           errors_ptr->size(), " errors while building NodeDef '",
240           node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
241           str_util::Join(*errors_ptr, "\n"));
242     }
243   } else {
244     NodeDef node_def_backup;
245     if (node_def == nullptr) node_def = &node_def_backup;
246     *node_def = node_def_;
247 
248     // Add control inputs after the regular inputs.
249     for (const auto& control_input : control_inputs_) {
250       node_def->add_input(strings::StrCat("^", control_input));
251     }
252 
253     // Add default values for unspecified attrs.
254     AddDefaultsToNodeDef(*op_def_, node_def);
255 
256     return Status::OK();
257   }
258 }
259 
Attr(StringPiece name,const AttrValue & value)260 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
261   if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
262     if (!AreAttrValuesEqual(*found, value)) {
263       errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
264                                         "' ", SummarizeAttrValue(*found),
265                                         " vs. ", SummarizeAttrValue(value)));
266     }
267   } else {
268     AddNodeAttr(name, value, &node_def_);
269   }
270   return *this;
271 }
272 
273 #define ATTR(T)                                                     \
274   NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
275     AttrValue attr_value;                                           \
276     SetAttrValue(value, &attr_value);                               \
277     return Attr(name, attr_value);                                  \
278   }
279 ATTR(StringPiece)
280 ATTR(const char*)
281 ATTR(int32)
282 ATTR(int64)
283 ATTR(float)
284 ATTR(double)
285 ATTR(bool)
286 ATTR(DataType)
287 ATTR(const PartialTensorShape&)
288 ATTR(const Tensor&)
289 ATTR(const TensorProto&)
290 ATTR(const NameAttrList&)
291 ATTR(gtl::ArraySlice<StringPiece>)
292 ATTR(gtl::ArraySlice<const char*>)
293 ATTR(gtl::ArraySlice<string>)
294 ATTR(gtl::ArraySlice<int32>)
295 ATTR(gtl::ArraySlice<int64>)
296 ATTR(gtl::ArraySlice<float>)
297 ATTR(gtl::ArraySlice<bool>)
298 ATTR(const std::vector<bool>&)
299 ATTR(gtl::ArraySlice<DataType>)
300 ATTR(gtl::ArraySlice<TensorShape>)
301 ATTR(gtl::ArraySlice<PartialTensorShape>)
302 ATTR(gtl::ArraySlice<TensorShapeProto>)
303 ATTR(gtl::ArraySlice<Tensor>)
304 ATTR(gtl::ArraySlice<NameAttrList>)
305 #undef ATTR
306 
307 }  // namespace tensorflow
308