• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
16 
17 #include "tensorflow/core/kernels/data/dataset_test_base.h"
18 #include "tensorflow/core/kernels/data/dataset_utils.h"
19 
20 namespace tensorflow {
21 namespace data {
22 namespace {
23 
24 constexpr char kNodeName[] = "tensor_slice_dataset";
25 
26 class TensorSliceDatasetOpTest : public DatasetOpsTestBase {};
27 
PlainTensorSliceDatasetParams()28 TensorSliceDatasetParams PlainTensorSliceDatasetParams() {
29   std::vector<Tensor> components = {
30       CreateTensor<int64>(TensorShape({2}), {1, 2}),
31       CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}),
32       CreateTensor<uint32>(TensorShape({2}), {2, 3}),
33       CreateTensor<uint32>(TensorShape({2, 2}), {2, 3, 4, 5}),
34       CreateTensor<uint64>(TensorShape({2}), {3, 4}),
35       CreateTensor<uint64>(TensorShape({2, 2}), {3, 4, 5, 6}),
36       CreateTensor<double>(TensorShape({2, 1}), {37.0, 38.0}),
37       CreateTensor<tstring>(TensorShape({2, 1}), {"a", "b"})};
38 
39   return {std::move(components), kNodeName};
40 }
41 
NestedTensorSliceDatasetParams()42 TensorSliceDatasetParams NestedTensorSliceDatasetParams() {
43   std::vector<Tensor> components = {
44       CreateTensor<Variant>(
45           TensorShape({2, 1}),
46           {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0}),
47            CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
48       CreateTensor<Variant>(
49           TensorShape({2, 1}),
50           {CreateTensor<tstring>(TensorShape({1, 2}), {"a", "b"}),
51            CreateTensor<tstring>(TensorShape({1, 2}), {"c", "d"})}),
52       CreateTensor<int64>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})};
53 
54   return {std::move(components), kNodeName};
55 }
56 
GetNextTestCases()57 std::vector<GetNextTestCase<TensorSliceDatasetParams>> GetNextTestCases() {
58   return {
59       {/*dataset_params=*/PlainTensorSliceDatasetParams(),
60        /*expected_outputs=*/{CreateTensor<int64>(TensorShape({}), {1}),
61                              CreateTensor<int64>(TensorShape({2}), {1, 2}),
62                              CreateTensor<uint32>(TensorShape({}), {2}),
63                              CreateTensor<uint32>(TensorShape({2}), {2, 3}),
64                              CreateTensor<uint64>(TensorShape({}), {3}),
65                              CreateTensor<uint64>(TensorShape({2}), {3, 4}),
66                              CreateTensor<double>(TensorShape({1}), {37.0}),
67                              CreateTensor<tstring>(TensorShape({1}), {"a"}),
68                              CreateTensor<int64>(TensorShape({}), {2}),
69                              CreateTensor<int64>(TensorShape({2}), {3, 4}),
70                              CreateTensor<uint32>(TensorShape({}), {3}),
71                              CreateTensor<uint32>(TensorShape({2}), {4, 5}),
72                              CreateTensor<uint64>(TensorShape({}), {4}),
73                              CreateTensor<uint64>(TensorShape({2}), {5, 6}),
74                              CreateTensor<double>(TensorShape({1}), {38.0}),
75                              CreateTensor<tstring>(TensorShape({1}), {"b"})}},
76       {/*dataset_params=*/NestedTensorSliceDatasetParams(),
77        /*expected_outputs=*/
78        {CreateTensor<Variant>(
79             TensorShape({1}),
80             {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
81         CreateTensor<Variant>(
82             TensorShape({1}),
83             {CreateTensor<tstring>(TensorShape({1, 2}), {"a", "b"})}),
84         CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
85         CreateTensor<Variant>(
86             TensorShape({1}),
87             {CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
88         CreateTensor<Variant>(
89             TensorShape({1}),
90             {CreateTensor<tstring>(TensorShape({1, 2}), {"c", "d"})}),
91         CreateTensor<int64>(TensorShape({3}), {4, 5, 6})}}};
92 }
93 
94 class ParameterizedGetNextTest
95     : public TensorSliceDatasetOpTest,
96       public ::testing::WithParamInterface<
97           GetNextTestCase<TensorSliceDatasetParams>> {};
98 
TEST_P(ParameterizedGetNextTest,GetNext)99 TEST_P(ParameterizedGetNextTest, GetNext) {
100   auto test_case = GetParam();
101   TF_ASSERT_OK(Initialize(test_case.dataset_params));
102 
103   std::vector<string> input_names;
104   TF_ASSERT_OK(test_case.dataset_params.GetInputNames(&input_names));
105   size_t num_tensors_per_slice = input_names.size();
106   bool end_of_sequence = false;
107   std::vector<Tensor> out_tensors;
108   int cur_slice = 0;
109 
110   while (!end_of_sequence) {
111     TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
112                                     &end_of_sequence));
113     for (int i = 0; i < out_tensors.size(); ++i) {
114       EXPECT_LT(i + num_tensors_per_slice * cur_slice,
115                 test_case.expected_outputs.size());
116       if (out_tensors[i].dtype() == DT_VARIANT) {
117         // Currently `ExpectEqual()` does not support the variant tensor
118         // yet, so we manually cast the variant to numeric/string tensor.
119         const Tensor* output = out_tensors[i].scalar<Variant>()().get<Tensor>();
120         const Tensor* expected_output =
121             test_case.expected_outputs[i + num_tensors_per_slice * cur_slice]
122                 .scalar<Variant>()()
123                 .get<Tensor>();
124         TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
125       } else {
126         TF_EXPECT_OK(ExpectEqual(
127             out_tensors[i],
128             test_case.expected_outputs[i + num_tensors_per_slice * cur_slice]));
129       }
130     }
131     out_tensors.clear();
132     cur_slice++;
133   }
134 }
135 
136 INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest, ParameterizedGetNextTest,
137                          ::testing::ValuesIn(GetNextTestCases()));
138 
TEST_F(TensorSliceDatasetOpTest,DatasetNodeName)139 TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) {
140   auto dataset_params = PlainTensorSliceDatasetParams();
141   TF_ASSERT_OK(Initialize(dataset_params));
142   TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
143 }
144 
TEST_F(TensorSliceDatasetOpTest,DatasetTypeString)145 TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) {
146   auto dataset_params = PlainTensorSliceDatasetParams();
147   TF_ASSERT_OK(Initialize(dataset_params));
148   TF_ASSERT_OK(CheckDatasetTypeString(
149       name_utils::OpName(TensorSliceDatasetOp::kDatasetType)));
150 }
151 
152 std::vector<DatasetOutputDtypesTestCase<TensorSliceDatasetParams>>
DatasetOutputTypesTestCases()153 DatasetOutputTypesTestCases() {
154   return {{PlainTensorSliceDatasetParams(),
155            PlainTensorSliceDatasetParams().output_dtypes()},
156           {NestedTensorSliceDatasetParams(),
157            NestedTensorSliceDatasetParams().output_dtypes()}};
158 }
159 
DATASET_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,DatasetOutputTypesTestCases ())160 DATASET_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest, TensorSliceDatasetParams,
161                              DatasetOutputTypesTestCases())
162 
163 std::vector<DatasetOutputShapesTestCase<TensorSliceDatasetParams>>
164 DatasetOutputShapesTestCases() {
165   return {{PlainTensorSliceDatasetParams(),
166            PlainTensorSliceDatasetParams().output_shapes()},
167           {NestedTensorSliceDatasetParams(),
168            NestedTensorSliceDatasetParams().output_shapes()}};
169 }
170 
DATASET_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,DatasetOutputShapesTestCases ())171 DATASET_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest, TensorSliceDatasetParams,
172                              DatasetOutputShapesTestCases())
173 
174 std::vector<CardinalityTestCase<TensorSliceDatasetParams>>
175 DatasetCardinalityTestCases() {
176   return {{PlainTensorSliceDatasetParams(), /*expected_cardinality=*/2},
177           {NestedTensorSliceDatasetParams(), /*expected_cardinality=*/2}};
178 }
179 
DATASET_CARDINALITY_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,DatasetCardinalityTestCases ())180 DATASET_CARDINALITY_TEST_P(TensorSliceDatasetOpTest, TensorSliceDatasetParams,
181                            DatasetCardinalityTestCases())
182 
183 std::vector<IteratorOutputDtypesTestCase<TensorSliceDatasetParams>>
184 IteratorOutputTypesTestCases() {
185   return {{PlainTensorSliceDatasetParams(),
186            PlainTensorSliceDatasetParams().output_dtypes()},
187           {NestedTensorSliceDatasetParams(),
188            NestedTensorSliceDatasetParams().output_dtypes()}};
189 }
190 
ITERATOR_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,IteratorOutputTypesTestCases ())191 ITERATOR_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest,
192                               TensorSliceDatasetParams,
193                               IteratorOutputTypesTestCases())
194 
195 std::vector<IteratorOutputShapesTestCase<TensorSliceDatasetParams>>
196 IteratorOutputShapesTestCases() {
197   return {{PlainTensorSliceDatasetParams(),
198            PlainTensorSliceDatasetParams().output_shapes()},
199           {NestedTensorSliceDatasetParams(),
200            NestedTensorSliceDatasetParams().output_shapes()}};
201 }
202 
ITERATOR_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,IteratorOutputShapesTestCases ())203 ITERATOR_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest,
204                               TensorSliceDatasetParams,
205                               IteratorOutputShapesTestCases())
206 
207 TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
208   auto dataset_params = PlainTensorSliceDatasetParams();
209   TF_ASSERT_OK(Initialize(dataset_params));
210   TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
211       TensorSliceDatasetOp::kDatasetType, dataset_params.iterator_prefix())));
212 }
213 
214 std::vector<IteratorSaveAndRestoreTestCase<TensorSliceDatasetParams>>
IteratorSaveAndRestoreTestCases()215 IteratorSaveAndRestoreTestCases() {
216   return {
217       {/*dataset_params=*/PlainTensorSliceDatasetParams(),
218        /*breakpoints=*/{0, 1, 2},
219        /*expected_outputs=*/
220        {CreateTensor<int64>(TensorShape({}), {1}),
221         CreateTensor<int64>(TensorShape({2}), {1, 2}),
222         CreateTensor<uint32>(TensorShape({}), {2}),
223         CreateTensor<uint32>(TensorShape({2}), {2, 3}),
224         CreateTensor<uint64>(TensorShape({}), {3}),
225         CreateTensor<uint64>(TensorShape({2}), {3, 4}),
226         CreateTensor<double>(TensorShape({1}), {37.0}),
227         CreateTensor<tstring>(TensorShape({1}), {"a"}),
228         CreateTensor<int64>(TensorShape({}), {2}),
229         CreateTensor<int64>(TensorShape({2}), {3, 4}),
230         CreateTensor<uint32>(TensorShape({}), {3}),
231         CreateTensor<uint32>(TensorShape({2}), {4, 5}),
232         CreateTensor<uint64>(TensorShape({}), {4}),
233         CreateTensor<uint64>(TensorShape({2}), {5, 6}),
234         CreateTensor<double>(TensorShape({1}), {38.0}),
235         CreateTensor<tstring>(TensorShape({1}), {"b"})}},
236       {/*dataset_params=*/NestedTensorSliceDatasetParams(),
237        /*breakpoints=*/{0, 1, 2},
238        /*expected_outputs=*/
239        {CreateTensor<Variant>(
240             TensorShape({1}),
241             {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
242         CreateTensor<Variant>(
243             TensorShape({1}),
244             {CreateTensor<tstring>(TensorShape({1, 2}), {"a", "b"})}),
245         CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
246         CreateTensor<Variant>(
247             TensorShape({1}),
248             {CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
249         CreateTensor<Variant>(
250             TensorShape({1}),
251             {CreateTensor<tstring>(TensorShape({1, 2}), {"c", "d"})}),
252         CreateTensor<int64>(TensorShape({3}), {4, 5, 6})}}};
253 }
254 
255 class ParameterizedIteratorSaveAndRestoreTest
256     : public TensorSliceDatasetOpTest,
257       public ::testing::WithParamInterface<
258           IteratorSaveAndRestoreTestCase<TensorSliceDatasetParams>> {};
259 
TEST_P(ParameterizedIteratorSaveAndRestoreTest,SaveAndRestore)260 TEST_P(ParameterizedIteratorSaveAndRestoreTest, SaveAndRestore) {
261   auto test_case = GetParam();
262   TF_ASSERT_OK(Initialize(test_case.dataset_params));
263 
264   std::unique_ptr<SerializationContext> serialization_context;
265   TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
266 
267   int cur_iteration = 0;
268   bool end_of_sequence = false;
269 
270   auto params =
271       static_cast<TensorSliceDatasetParams&>(test_case.dataset_params);
272   int64 num_slices = params.num_slices();
273   size_t num_tensors_per_slice = params.num_tensors_per_slice();
274   std::vector<Tensor> out_tensors;
275   const std::vector<int>& breakpoints = test_case.breakpoints;
276   for (int breakpoint : breakpoints) {
277     while (cur_iteration < breakpoint) {
278       TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
279                                       &end_of_sequence));
280       cur_iteration++;
281     }
282 
283     if (breakpoint == 0) {
284       EXPECT_FALSE(end_of_sequence);
285     } else if (breakpoint <= num_slices) {
286       for (int i = 0; i < out_tensors.size(); ++i) {
287         if (out_tensors[i].dtype() == DT_VARIANT) {
288           const Tensor* output =
289               out_tensors[i].scalar<Variant>()().get<Tensor>();
290           const Tensor* expected_output =
291               test_case
292                   .expected_outputs[i +
293                                     num_tensors_per_slice * (cur_iteration - 1)]
294                   .scalar<Variant>()()
295                   .get<Tensor>();
296           TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
297         } else {
298           TF_EXPECT_OK(ExpectEqual(
299               out_tensors[i],
300               test_case.expected_outputs[i + num_tensors_per_slice *
301                                                  (cur_iteration - 1)]));
302         }
303       }
304     } else {
305       EXPECT_TRUE(end_of_sequence);
306     }
307 
308     VariantTensorDataWriter writer;
309     TF_ASSERT_OK(iterator_->Save(serialization_context.get(), &writer));
310     std::vector<const VariantTensorData*> data;
311     writer.GetData(&data);
312     VariantTensorDataReader reader(data);
313     TF_EXPECT_OK(RestoreIterator(iterator_ctx_.get(), &reader, "Iterator",
314                                  *dataset_, &iterator_));
315   }
316 }
317 
318 INSTANTIATE_TEST_SUITE_P(
319     TensorSliceDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
320     ::testing::ValuesIn(IteratorSaveAndRestoreTestCases()));
321 
TEST_F(TensorSliceDatasetOpTest,SplitProvider)322 TEST_F(TensorSliceDatasetOpTest, SplitProvider) {
323   auto params = TensorSliceDatasetParams(
324       CreateTensors<int64>(TensorShape({7}), {{6, 2, 3, 8, 7, 0, 10}}),
325       kNodeName);
326   TF_ASSERT_OK(InitializeRuntime(params));
327   TF_EXPECT_OK(CheckSplitProviderFullIteration(
328       params, CreateTensors<int64>(TensorShape({}),
329                                    {{6}, {2}, {3}, {8}, {7}, {0}, {10}})));
330   TF_EXPECT_OK(CheckSplitProviderShardedIteration(
331       params, /*num_shards=*/3, /*shard_index=*/1,
332       CreateTensors<int64>(TensorShape({}), {{2}, {7}})));
333 }
334 
TEST_F(TensorSliceDatasetOpTest,SplitProviderEmpty)335 TEST_F(TensorSliceDatasetOpTest, SplitProviderEmpty) {
336   auto params = TensorSliceDatasetParams(
337       CreateTensors<int64>(TensorShape({0}), {{}}), kNodeName);
338   TF_ASSERT_OK(InitializeRuntime(params));
339   TF_EXPECT_OK(CheckSplitProviderFullIteration(
340       params, CreateTensors<int64>(TensorShape({}), {})));
341   TF_EXPECT_OK(CheckSplitProviderShardedIteration(
342       params, /*num_shards=*/3, /*shard_index=*/1,
343       CreateTensors<int64>(TensorShape({}), {})));
344 }
345 
346 }  // namespace
347 }  // namespace data
348 }  // namespace tensorflow
349