1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/standalone.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace tensorflow {
26 namespace data {
27 namespace standalone {
28 namespace {
29
30 constexpr const char* const kRangeGraphProto = R"proto(
31 node {
32 name: "Const/_0"
33 op: "Const"
34 attr {
35 key: "dtype"
36 value { type: DT_INT64 }
37 }
38 attr {
39 key: "value"
40 value {
41 tensor {
42 dtype: DT_INT64
43 tensor_shape {}
44 int64_val: 0
45 }
46 }
47 }
48 }
49 node {
50 name: "Const/_1"
51 op: "Const"
52 attr {
53 key: "dtype"
54 value { type: DT_INT64 }
55 }
56 attr {
57 key: "value"
58 value {
59 tensor {
60 dtype: DT_INT64
61 tensor_shape {}
62 int64_val: 10
63 }
64 }
65 }
66 }
67 node {
68 name: "Const/_2"
69 op: "Const"
70 attr {
71 key: "dtype"
72 value { type: DT_INT64 }
73 }
74 attr {
75 key: "value"
76 value {
77 tensor {
78 dtype: DT_INT64
79 tensor_shape {}
80 int64_val: 1
81 }
82 }
83 }
84 }
85 node {
86 name: "RangeDataset/_3"
87 op: "RangeDataset"
88 input: "Const/_0"
89 input: "Const/_1"
90 input: "Const/_2"
91 attr {
92 key: "output_shapes"
93 value { list { shape {} } }
94 }
95 attr {
96 key: "output_types"
97 value { list { type: DT_INT64 } }
98 }
99 }
100 node {
101 name: "dataset"
102 op: "_Retval"
103 input: "RangeDataset/_3"
104 attr {
105 key: "T"
106 value { type: DT_VARIANT }
107 }
108 attr {
109 key: "index"
110 value { i: 0 }
111 }
112 }
113 library {}
114 versions { producer: 96 }
115 )proto";
116
117 // range(10).map(lambda x: x*x)
118 constexpr const char* const kMapGraphProto = R"proto(
119 node {
120 name: "Const/_0"
121 op: "Const"
122 attr {
123 key: "dtype"
124 value { type: DT_INT64 }
125 }
126 attr {
127 key: "value"
128 value {
129 tensor {
130 dtype: DT_INT64
131 tensor_shape {}
132 int64_val: 0
133 }
134 }
135 }
136 }
137 node {
138 name: "Const/_1"
139 op: "Const"
140 attr {
141 key: "dtype"
142 value { type: DT_INT64 }
143 }
144 attr {
145 key: "value"
146 value {
147 tensor {
148 dtype: DT_INT64
149 tensor_shape {}
150 int64_val: 10
151 }
152 }
153 }
154 }
155 node {
156 name: "Const/_2"
157 op: "Const"
158 attr {
159 key: "dtype"
160 value { type: DT_INT64 }
161 }
162 attr {
163 key: "value"
164 value {
165 tensor {
166 dtype: DT_INT64
167 tensor_shape {}
168 int64_val: 1
169 }
170 }
171 }
172 }
173 node {
174 name: "RangeDataset/_3"
175 op: "RangeDataset"
176 input: "Const/_0"
177 input: "Const/_1"
178 input: "Const/_2"
179 attr {
180 key: "output_shapes"
181 value { list { shape {} } }
182 }
183 attr {
184 key: "output_types"
185 value { list { type: DT_INT64 } }
186 }
187 }
188 node {
189 name: "MapDataset/_4"
190 op: "MapDataset"
191 input: "RangeDataset/_3"
192 attr {
193 key: "Targuments"
194 value { list {} }
195 }
196 attr {
197 key: "f"
198 value { func { name: "__inference_Dataset_map_<lambda>_67" } }
199 }
200 attr {
201 key: "output_shapes"
202 value { list { shape {} } }
203 }
204 attr {
205 key: "output_types"
206 value { list { type: DT_INT64 } }
207 }
208 attr {
209 key: "preserve_cardinality"
210 value { b: false }
211 }
212 attr {
213 key: "use_inter_op_parallelism"
214 value { b: true }
215 }
216 }
217 node {
218 name: "dataset"
219 op: "_Retval"
220 input: "MapDataset/_4"
221 attr {
222 key: "T"
223 value { type: DT_VARIANT }
224 }
225 attr {
226 key: "index"
227 value { i: 0 }
228 }
229 }
230 library {
231 function {
232 signature {
233 name: "__inference_Dataset_map_<lambda>_67"
234 input_arg { name: "args_0" type: DT_INT64 }
235 output_arg { name: "identity" type: DT_INT64 }
236 }
237 node_def {
238 name: "mul"
239 op: "Mul"
240 input: "args_0"
241 input: "args_0"
242 attr {
243 key: "T"
244 value { type: DT_INT64 }
245 }
246 }
247 node_def {
248 name: "Identity"
249 op: "Identity"
250 input: "mul:z:0"
251 attr {
252 key: "T"
253 value { type: DT_INT64 }
254 }
255 }
256 ret { key: "identity" value: "Identity:output:0" }
257 arg_attr {
258 key: 0
259 value {
260 attr {
261 key: "_user_specified_name"
262 value { s: "args_0" }
263 }
264 }
265 }
266 }
267 }
268 versions { producer: 96 min_consumer: 12 }
269 )proto";
270
TEST(Scalar,Standalone)271 TEST(Scalar, Standalone) {
272 struct TestCase {
273 string graph_string;
274 std::vector<int64_t> expected_outputs;
275 };
276 auto test_cases = {
277 TestCase{kRangeGraphProto, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}},
278 TestCase{kMapGraphProto, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81}},
279 };
280 for (auto test_case : test_cases) {
281 GraphDef graph_def;
282 protobuf::TextFormat::ParseFromString(test_case.graph_string, &graph_def);
283 std::unique_ptr<Dataset> dataset;
284 auto s = Dataset::FromGraph({}, graph_def, &dataset);
285 TF_EXPECT_OK(s);
286 std::unique_ptr<Iterator> iterator;
287 s = dataset->MakeIterator(&iterator);
288 TF_EXPECT_OK(s);
289 bool end_of_input = false;
290 for (int num_outputs = 0; !end_of_input; ++num_outputs) {
291 std::vector<tensorflow::Tensor> outputs;
292 s = iterator->GetNext(&outputs, &end_of_input);
293 TF_EXPECT_OK(s);
294 if (!end_of_input) {
295 EXPECT_EQ(outputs[0].scalar<int64_t>()(),
296 test_case.expected_outputs[num_outputs]);
297 } else {
298 EXPECT_EQ(test_case.expected_outputs.size(), num_outputs);
299 }
300 }
301 }
302 }
303
304 } // namespace
305 } // namespace standalone
306 } // namespace data
307 } // namespace tensorflow
308