1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/fake_input.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25
26 namespace tensorflow {
27 namespace {
28
29 class FakeInputImpl {
30 public:
31 FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def,
32 NodeDefBuilder* builder);
33 void SetN(int n);
34 void SetDataType(DataType dt);
35 void SetTypeList(DataTypeSlice dts);
36 Status AddInputToBuilder();
37
38 private:
39 static string FakeNodeName(int in_index);
40 Status GetN(int* n) const;
41 Status GetDataType(DataType* dt) const;
42 void NSources(int n, DataType dt) const;
43 void SourceList(DataTypeSlice dts) const;
44
45 const OpDef* const op_def_;
46 const OpDef::ArgDef* const arg_;
47 const string in_node_;
48 const NodeDef* const node_def_;
49 NodeDefBuilder* const builder_;
50
51 bool n_specified_;
52 int n_;
53 bool dt_specified_;
54 DataType dt_;
55 bool dts_specified_;
56 DataTypeSlice dts_;
57 };
58
FakeInputImpl(const OpDef * op_def,int in_index,const NodeDef * node_def,NodeDefBuilder * builder)59 FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index,
60 const NodeDef* node_def, NodeDefBuilder* builder)
61 : op_def_(op_def),
62 arg_(&op_def->input_arg(in_index)),
63 in_node_(FakeNodeName(in_index)),
64 node_def_(node_def),
65 builder_(builder),
66 n_specified_(false),
67 dt_specified_(false),
68 dts_specified_(false) {}
69
SetN(int n)70 void FakeInputImpl::SetN(int n) {
71 n_specified_ = true;
72 n_ = n;
73 }
74
SetDataType(DataType dt)75 void FakeInputImpl::SetDataType(DataType dt) {
76 dt_specified_ = true;
77 dt_ = dt;
78 }
79
SetTypeList(DataTypeSlice dts)80 void FakeInputImpl::SetTypeList(DataTypeSlice dts) {
81 dts_specified_ = true;
82 dts_ = dts;
83 }
84
AddInputToBuilder()85 Status FakeInputImpl::AddInputToBuilder() {
86 if (dts_specified_) {
87 SourceList(dts_);
88
89 } else if (n_specified_ || !arg_->number_attr().empty()) {
90 int n;
91 TF_RETURN_IF_ERROR(GetN(&n));
92
93 DataType dt;
94 if (n > 0) {
95 TF_RETURN_IF_ERROR(GetDataType(&dt));
96 } else {
97 dt = DT_FLOAT;
98 }
99
100 NSources(n, dt);
101 } else {
102 if (!dt_specified_ && !arg_->type_list_attr().empty()) {
103 DataTypeVector dts;
104 Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
105 if (!status.ok()) {
106 return errors::InvalidArgument(
107 "Could not infer list of types for input '", arg_->name(),
108 "': ", status.error_message());
109 }
110 SourceList(dts);
111 return Status::OK();
112 }
113
114 DataType dt;
115 TF_RETURN_IF_ERROR(GetDataType(&dt));
116 builder_->Input(in_node_, 0, dt);
117 }
118 return Status::OK();
119 }
120
121 // static
FakeNodeName(int in_index)122 string FakeInputImpl::FakeNodeName(int in_index) {
123 char c = 'a' + (in_index % 26);
124 return string(&c, 1);
125 }
126
GetN(int * n) const127 Status FakeInputImpl::GetN(int* n) const {
128 if (n_specified_) {
129 *n = n_;
130 } else {
131 Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
132 if (!status.ok()) {
133 return errors::InvalidArgument("Could not infer length of input '",
134 arg_->name(),
135 "': ", status.error_message());
136 }
137 }
138 return Status::OK();
139 }
140
GetDataType(DataType * dt) const141 Status FakeInputImpl::GetDataType(DataType* dt) const {
142 if (dt_specified_) {
143 *dt = dt_;
144 return Status::OK(); // Ignore is_ref field of arg_.
145 } else if (arg_->type() != DT_INVALID) {
146 *dt = arg_->type();
147 } else if (!arg_->type_attr().empty()) {
148 Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
149 if (!status.ok()) {
150 // Check if the type attr has a default
151 const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_);
152 if (attr && attr->has_default_value()) {
153 *dt = attr->default_value().type();
154 } else {
155 return errors::InvalidArgument("Could not infer type for input '",
156 arg_->name(),
157 "': ", status.error_message());
158 }
159 }
160 } else {
161 return errors::InvalidArgument("No type or type_attr field in arg '",
162 arg_->name(), "'");
163 }
164 if (arg_->is_ref()) {
165 *dt = MakeRefType(*dt);
166 }
167 return Status::OK();
168 }
169
NSources(int n,DataType dt) const170 void FakeInputImpl::NSources(int n, DataType dt) const {
171 std::vector<NodeDefBuilder::NodeOut> srcs;
172 srcs.reserve(n);
173 for (int i = 0; i < n; ++i) {
174 srcs.emplace_back(in_node_, i, dt);
175 }
176 builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
177 }
178
SourceList(DataTypeSlice dts) const179 void FakeInputImpl::SourceList(DataTypeSlice dts) const {
180 std::vector<NodeDefBuilder::NodeOut> srcs;
181 srcs.reserve(dts.size());
182 for (size_t i = 0; i < dts.size(); ++i) {
183 srcs.emplace_back(in_node_, i, dts[i]);
184 }
185 builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
186 }
187
188 } // namespace
189
190 // Public interface ------------------------------------------------------------
191
FakeInput()192 FakeInputFunctor FakeInput() {
193 return [](const OpDef& op_def, int in_index, const NodeDef& node_def,
194 NodeDefBuilder* builder) {
195 FakeInputImpl impl(&op_def, in_index, &node_def, builder);
196 return impl.AddInputToBuilder();
197 };
198 }
199
FakeInput(DataType dt)200 FakeInputFunctor FakeInput(DataType dt) {
201 return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
202 NodeDefBuilder* builder) {
203 FakeInputImpl impl(&op_def, in_index, &node_def, builder);
204 impl.SetDataType(dt);
205 return impl.AddInputToBuilder();
206 };
207 }
208
FakeInput(int n)209 FakeInputFunctor FakeInput(int n) {
210 return [n](const OpDef& op_def, int in_index, const NodeDef& node_def,
211 NodeDefBuilder* builder) {
212 FakeInputImpl impl(&op_def, in_index, &node_def, builder);
213 impl.SetN(n);
214 return impl.AddInputToBuilder();
215 };
216 }
217
FakeInput(int n,DataType dt)218 FakeInputFunctor FakeInput(int n, DataType dt) {
219 return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
220 NodeDefBuilder* builder) {
221 FakeInputImpl impl(&op_def, in_index, &node_def, builder);
222 impl.SetN(n);
223 impl.SetDataType(dt);
224 return impl.AddInputToBuilder();
225 };
226 }
227
FakeInput(DataTypeSlice dts)228 FakeInputFunctor FakeInput(DataTypeSlice dts) {
229 // Make a copy to ensure the data will still be around when the lambda is
230 // called.
231 DataTypeVector dtv(dts.begin(), dts.end());
232 return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def,
233 NodeDefBuilder* builder) {
234 FakeInputImpl impl(&op_def, in_index, &node_def, builder);
235 impl.SetTypeList(dtv);
236 return impl.AddInputToBuilder();
237 };
238 }
239
240 } // namespace tensorflow
241