• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/standalone.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 
25 namespace tensorflow {
26 namespace data {
27 namespace standalone {
28 namespace {
29 
30 constexpr const char* const kRangeGraphProto = R"proto(
31   node {
32     name: "Const/_0"
33     op: "Const"
34     attr {
35       key: "dtype"
36       value { type: DT_INT64 }
37     }
38     attr {
39       key: "value"
40       value {
41         tensor {
42           dtype: DT_INT64
43           tensor_shape {}
44           int64_val: 0
45         }
46       }
47     }
48   }
49   node {
50     name: "Const/_1"
51     op: "Const"
52     attr {
53       key: "dtype"
54       value { type: DT_INT64 }
55     }
56     attr {
57       key: "value"
58       value {
59         tensor {
60           dtype: DT_INT64
61           tensor_shape {}
62           int64_val: 10
63         }
64       }
65     }
66   }
67   node {
68     name: "Const/_2"
69     op: "Const"
70     attr {
71       key: "dtype"
72       value { type: DT_INT64 }
73     }
74     attr {
75       key: "value"
76       value {
77         tensor {
78           dtype: DT_INT64
79           tensor_shape {}
80           int64_val: 1
81         }
82       }
83     }
84   }
85   node {
86     name: "RangeDataset/_3"
87     op: "RangeDataset"
88     input: "Const/_0"
89     input: "Const/_1"
90     input: "Const/_2"
91     attr {
92       key: "output_shapes"
93       value { list { shape {} } }
94     }
95     attr {
96       key: "output_types"
97       value { list { type: DT_INT64 } }
98     }
99   }
100   node {
101     name: "dataset"
102     op: "_Retval"
103     input: "RangeDataset/_3"
104     attr {
105       key: "T"
106       value { type: DT_VARIANT }
107     }
108     attr {
109       key: "index"
110       value { i: 0 }
111     }
112   }
113   library {}
114   versions { producer: 96 }
115 )proto";
116 
117 // range(10).map(lambda x: x*x)
118 constexpr const char* const kMapGraphProto = R"proto(
119   node {
120     name: "Const/_0"
121     op: "Const"
122     attr {
123       key: "dtype"
124       value { type: DT_INT64 }
125     }
126     attr {
127       key: "value"
128       value {
129         tensor {
130           dtype: DT_INT64
131           tensor_shape {}
132           int64_val: 0
133         }
134       }
135     }
136   }
137   node {
138     name: "Const/_1"
139     op: "Const"
140     attr {
141       key: "dtype"
142       value { type: DT_INT64 }
143     }
144     attr {
145       key: "value"
146       value {
147         tensor {
148           dtype: DT_INT64
149           tensor_shape {}
150           int64_val: 10
151         }
152       }
153     }
154   }
155   node {
156     name: "Const/_2"
157     op: "Const"
158     attr {
159       key: "dtype"
160       value { type: DT_INT64 }
161     }
162     attr {
163       key: "value"
164       value {
165         tensor {
166           dtype: DT_INT64
167           tensor_shape {}
168           int64_val: 1
169         }
170       }
171     }
172   }
173   node {
174     name: "RangeDataset/_3"
175     op: "RangeDataset"
176     input: "Const/_0"
177     input: "Const/_1"
178     input: "Const/_2"
179     attr {
180       key: "output_shapes"
181       value { list { shape {} } }
182     }
183     attr {
184       key: "output_types"
185       value { list { type: DT_INT64 } }
186     }
187   }
188   node {
189     name: "MapDataset/_4"
190     op: "MapDataset"
191     input: "RangeDataset/_3"
192     attr {
193       key: "Targuments"
194       value { list {} }
195     }
196     attr {
197       key: "f"
198       value { func { name: "__inference_Dataset_map_<lambda>_67" } }
199     }
200     attr {
201       key: "output_shapes"
202       value { list { shape {} } }
203     }
204     attr {
205       key: "output_types"
206       value { list { type: DT_INT64 } }
207     }
208     attr {
209       key: "preserve_cardinality"
210       value { b: false }
211     }
212     attr {
213       key: "use_inter_op_parallelism"
214       value { b: true }
215     }
216   }
217   node {
218     name: "dataset"
219     op: "_Retval"
220     input: "MapDataset/_4"
221     attr {
222       key: "T"
223       value { type: DT_VARIANT }
224     }
225     attr {
226       key: "index"
227       value { i: 0 }
228     }
229   }
230   library {
231     function {
232       signature {
233         name: "__inference_Dataset_map_<lambda>_67"
234         input_arg { name: "args_0" type: DT_INT64 }
235         output_arg { name: "identity" type: DT_INT64 }
236       }
237       node_def {
238         name: "mul"
239         op: "Mul"
240         input: "args_0"
241         input: "args_0"
242         attr {
243           key: "T"
244           value { type: DT_INT64 }
245         }
246       }
247       node_def {
248         name: "Identity"
249         op: "Identity"
250         input: "mul:z:0"
251         attr {
252           key: "T"
253           value { type: DT_INT64 }
254         }
255       }
256       ret { key: "identity" value: "Identity:output:0" }
257       arg_attr {
258         key: 0
259         value {
260           attr {
261             key: "_user_specified_name"
262             value { s: "args_0" }
263           }
264         }
265       }
266     }
267   }
268   versions { producer: 96 min_consumer: 12 }
269 )proto";
270 
TEST(Scalar,Standalone)271 TEST(Scalar, Standalone) {
272   struct TestCase {
273     string graph_string;
274     std::vector<int64_t> expected_outputs;
275   };
276   auto test_cases = {
277       TestCase{kRangeGraphProto, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}},
278       TestCase{kMapGraphProto, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81}},
279   };
280   for (auto test_case : test_cases) {
281     GraphDef graph_def;
282     protobuf::TextFormat::ParseFromString(test_case.graph_string, &graph_def);
283     std::unique_ptr<Dataset> dataset;
284     auto s = Dataset::FromGraph({}, graph_def, &dataset);
285     TF_EXPECT_OK(s);
286     std::unique_ptr<Iterator> iterator;
287     s = dataset->MakeIterator(&iterator);
288     TF_EXPECT_OK(s);
289     bool end_of_input = false;
290     for (int num_outputs = 0; !end_of_input; ++num_outputs) {
291       std::vector<tensorflow::Tensor> outputs;
292       s = iterator->GetNext(&outputs, &end_of_input);
293       TF_EXPECT_OK(s);
294       if (!end_of_input) {
295         EXPECT_EQ(outputs[0].scalar<int64_t>()(),
296                   test_case.expected_outputs[num_outputs]);
297       } else {
298         EXPECT_EQ(test_case.expected_outputs.size(), num_outputs);
299       }
300     }
301   }
302 }
303 
304 }  // namespace
305 }  // namespace standalone
306 }  // namespace data
307 }  // namespace tensorflow
308