1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
17 #define TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
18
19 #include <stddef.h>
20
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/graph_constructor.h"
31 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
32 #include "tensorflow/core/data/name_utils.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/dataset.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function_handle_cache.h"
38 #include "tensorflow/core/framework/function_testlib.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/resource_mgr.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_shape.h"
43 #include "tensorflow/core/framework/tensor_testutil.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/lib/gtl/array_slice.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/io/zlib_compression_options.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/test.h"
53 #include "tensorflow/core/platform/threadpool.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
56
57 namespace tensorflow {
58 namespace data {
59
60 typedef std::vector<
61 std::pair<string, tensorflow::FunctionDefHelper::AttrValueWrapper>>
62 AttributeVector;
63
64 constexpr int kDefaultCPUNum = 2;
65 constexpr int kDefaultThreadNum = 2;
66
67 // Creates a tensor with the specified dtype, shape, and value.
68 template <typename T>
CreateTensor(const TensorShape & input_shape,gtl::ArraySlice<T> input_data)69 static Tensor CreateTensor(const TensorShape& input_shape,
70 gtl::ArraySlice<T> input_data) {
71 Tensor tensor(DataTypeToEnum<T>::value, input_shape);
72 test::FillValues<T>(&tensor, input_data);
73 return tensor;
74 }
75
76 // Creates a vector of tensors with the specified dtype, shape, and values.
77 template <typename T>
CreateTensors(const TensorShape & shape,const std::vector<gtl::ArraySlice<T>> & values)78 std::vector<Tensor> CreateTensors(
79 const TensorShape& shape, const std::vector<gtl::ArraySlice<T>>& values) {
80 std::vector<Tensor> result;
81 result.reserve(values.size());
82 for (auto& value : values) {
83 result.emplace_back(CreateTensor<T>(shape, value));
84 }
85 return result;
86 }
87
88 enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
89
90 // Returns a string representation for the given compression type.
91 string ToString(CompressionType compression_type);
92
93 // Gets the specified zlib compression options according to the compression
94 // type. Note that `CompressionType::UNCOMPRESSED` is not supported because
95 // `ZlibCompressionOptions` does not have an option.
96 io::ZlibCompressionOptions GetZlibCompressionOptions(
97 CompressionType compression_type);
98
99 // Used to specify parameters when writing data into files with compression.
100 // `input_buffer_size` and `output_buffer_size` specify the input and output
101 // buffer size when ZLIB and GZIP compression is used.
102 struct CompressionParams {
103 CompressionType compression_type = CompressionType::UNCOMPRESSED;
104 int32 input_buffer_size = 0;
105 int32 output_buffer_size = 0;
106 };
107
108 // Writes the input data into the file without compression.
109 Status WriteDataToFile(const string& filename, const char* data);
110
111 // Writes the input data into the file with the specified compression.
112 Status WriteDataToFile(const string& filename, const char* data,
113 const CompressionParams& params);
114
115 // Writes the input data into the TFRecord file with the specified compression.
116 Status WriteDataToTFRecordFile(const string& filename,
117 const std::vector<absl::string_view>& records,
118 const CompressionParams& params);
119
120 // Provides the parameters for running the dataset op.
121 class DatasetParams {
122 public:
123 DatasetParams(DataTypeVector output_dtypes,
124 std::vector<PartialTensorShape> output_shapes,
125 string node_name);
126
~DatasetParams()127 virtual ~DatasetParams() {}
128
129 // Returns the inputs (except the input datasets) as a tensor vector.
130 virtual std::vector<Tensor> GetInputTensors() const = 0;
131
132 // Returns the dataset input names as a string vector.
133 virtual Status GetInputNames(std::vector<string>* input_names) const = 0;
134
135 // Returns the dataset attributes as a vector.
136 virtual Status GetAttributes(AttributeVector* attributes) const = 0;
137
138 // Checks if the tensor is a dataset variant tensor.
139 static bool IsDatasetTensor(const Tensor& tensor);
140
node_name()141 string node_name() const { return node_name_; }
142
output_dtypes()143 DataTypeVector output_dtypes() const { return output_dtypes_; }
144
output_shapes()145 std::vector<PartialTensorShape> output_shapes() const {
146 return output_shapes_;
147 }
148
iterator_prefix()149 string iterator_prefix() const { return iterator_prefix_; }
150
input_dataset_params()151 const std::vector<std::shared_ptr<DatasetParams>>& input_dataset_params()
152 const {
153 return input_dataset_params_;
154 }
155
156 // Returns the functions that will be used when running the dataset op.
func_lib()157 virtual std::vector<FunctionDef> func_lib() const { return {}; }
158
159 // Returns the dataset type for the op represented by these parameters. This
160 // type usually needs to match the constant called `kDatasetType` defined in
161 // the dataset kernel.
162 virtual string dataset_type() const = 0;
163
164 // Returns the dataset op name. By default, it returns the Op::kDatasetType
165 // concatenated with "Dataset". For ops that do not have "Dataset" suffix,
166 // this method can be overriden to return a different name.
op_name()167 virtual string op_name() const {
168 name_utils::OpNameParams params;
169 params.op_version = op_version();
170 return name_utils::OpName(dataset_type(), params);
171 }
172
op_version()173 virtual int op_version() const { return op_version_; }
174
175 protected:
176 std::vector<std::shared_ptr<DatasetParams>> input_dataset_params_;
177 DataTypeVector output_dtypes_;
178 std::vector<PartialTensorShape> output_shapes_;
179 string node_name_;
180 string iterator_prefix_ = "Iterator";
181 int op_version_ = 1;
182 };
183
184 // `RangeDatasetParams` is a common dataset parameter type that are used in
185 // testing.
186 class RangeDatasetParams : public DatasetParams {
187 public:
188 RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
189 DataTypeVector output_dtypes,
190 std::vector<PartialTensorShape> output_shapes,
191 string node_name);
192
193 RangeDatasetParams(int64_t start, int64_t stop, int64_t step);
194
195 RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
196 DataTypeVector output_dtypes);
197
198 std::vector<Tensor> GetInputTensors() const override;
199
200 Status GetInputNames(std::vector<string>* input_names) const override;
201
202 Status GetAttributes(AttributeVector* attr_vector) const override;
203
204 string dataset_type() const override;
205
206 private:
207 int64 start_;
208 int64 stop_;
209 int64 step_;
210 };
211
212 // `BatchDatasetParams` is a common dataset parameter type that are used in
213 // testing.
214 class BatchDatasetParams : public DatasetParams {
215 public:
216 template <typename T>
BatchDatasetParams(T input_dataset_params,int64_t batch_size,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)217 BatchDatasetParams(T input_dataset_params, int64_t batch_size,
218 bool drop_remainder, bool parallel_copy,
219 DataTypeVector output_dtypes,
220 std::vector<PartialTensorShape> output_shapes,
221 string node_name)
222 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
223 std::move(node_name)),
224 batch_size_(batch_size),
225 drop_remainder_(drop_remainder),
226 parallel_copy_(parallel_copy) {
227 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
228 op_version_ = 2;
229 iterator_prefix_ =
230 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
231 input_dataset_params.iterator_prefix());
232 }
233
234 std::vector<Tensor> GetInputTensors() const override;
235
236 Status GetInputNames(std::vector<string>* input_names) const override;
237
238 Status GetAttributes(AttributeVector* attr_vector) const override;
239
240 string dataset_type() const override;
241
242 private:
243 int64 batch_size_;
244 bool drop_remainder_;
245 bool parallel_copy_;
246 };
247
248 // `MapDatasetParams` is a common dataset parameter type that are used in
249 // testing.
250 class MapDatasetParams : public DatasetParams {
251 public:
252 template <typename T>
MapDatasetParams(T input_dataset_params,std::vector<Tensor> other_arguments,FunctionDefHelper::AttrValueWrapper func,std::vector<FunctionDef> func_lib,DataTypeVector type_arguments,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,bool use_inter_op_parallelism,bool preserve_cardinality,string node_name)253 MapDatasetParams(T input_dataset_params, std::vector<Tensor> other_arguments,
254 FunctionDefHelper::AttrValueWrapper func,
255 std::vector<FunctionDef> func_lib,
256 DataTypeVector type_arguments, DataTypeVector output_dtypes,
257 std::vector<PartialTensorShape> output_shapes,
258 bool use_inter_op_parallelism, bool preserve_cardinality,
259 string node_name)
260 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
261 std::move(node_name)),
262 other_arguments_(std::move(other_arguments)),
263 func_(std::move(func)),
264 func_lib_(std::move(func_lib)),
265 type_arguments_(std::move(type_arguments)),
266 use_inter_op_parallelism_(use_inter_op_parallelism),
267 preserve_cardinality_(preserve_cardinality) {
268 input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
269 iterator_prefix_ =
270 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
271 input_dataset_params.iterator_prefix());
272 }
273
274 std::vector<Tensor> GetInputTensors() const override;
275
276 Status GetInputNames(std::vector<string>* input_names) const override;
277
278 Status GetAttributes(AttributeVector* attr_vector) const override;
279
280 string dataset_type() const override;
281
282 std::vector<FunctionDef> func_lib() const override;
283
284 private:
285 std::vector<Tensor> other_arguments_;
286 FunctionDefHelper::AttrValueWrapper func_;
287 std::vector<FunctionDef> func_lib_;
288 DataTypeVector type_arguments_;
289 bool use_inter_op_parallelism_;
290 bool preserve_cardinality_;
291 };
292
293 // `TensorSliceDatasetParams` is a common dataset parameter type that are used
294 // in testing.
295 class TensorSliceDatasetParams : public DatasetParams {
296 public:
297 TensorSliceDatasetParams(std::vector<Tensor> components, string node_name);
298
299 std::vector<Tensor> GetInputTensors() const override;
300
301 Status GetInputNames(std::vector<string>* input_names) const override;
302
303 Status GetAttributes(AttributeVector* attr_vector) const override;
304
305 string dataset_type() const override;
306
num_slices()307 int64 num_slices() const { return components_[0].dim_size(0); }
308
num_tensors_per_slice()309 size_t num_tensors_per_slice() const { return components_.size(); }
310
311 private:
312 DataTypeVector TensorSliceDtypes(const std::vector<Tensor>& input_components);
313
314 std::vector<PartialTensorShape> TensorSliceShapes(
315 const std::vector<Tensor>& input_components);
316
317 public:
318 std::vector<Tensor> components_;
319 };
320
321 // `TakeDatasetParams` is a common dataset parameter type that are used in
322 // testing.
323 class TakeDatasetParams : public DatasetParams {
324 public:
325 template <typename T>
TakeDatasetParams(T input_dataset_params,int count,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)326 TakeDatasetParams(T input_dataset_params, int count,
327 DataTypeVector output_dtypes,
328 std::vector<PartialTensorShape> output_shapes,
329 string node_name)
330 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
331 std::move(node_name)),
332 count_(count) {
333 input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
334 iterator_prefix_ =
335 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
336 input_dataset_params.iterator_prefix());
337 }
338
339 std::vector<Tensor> GetInputTensors() const override;
340
341 Status GetInputNames(std::vector<string>* input_names) const override;
342
343 Status GetAttributes(AttributeVector* attr_vector) const override;
344
345 string dataset_type() const override;
346
347 private:
348 int64 count_;
349 };
350
351 // `ConcatenateDatasetParams` is a common dataset parameter type that are used
352 // in testing.
353 class ConcatenateDatasetParams : public DatasetParams {
354 public:
355 template <typename T, typename P>
ConcatenateDatasetParams(T input_dataset_params_0,P input_dataset_params_1,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)356 ConcatenateDatasetParams(T input_dataset_params_0, P input_dataset_params_1,
357 DataTypeVector output_dtypes,
358 std::vector<PartialTensorShape> output_shapes,
359 string node_name)
360 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
361 std::move(node_name)) {
362 input_dataset_params_.push_back(
363 absl::make_unique<T>(input_dataset_params_0));
364 input_dataset_params_.push_back(
365 absl::make_unique<T>(input_dataset_params_1));
366 iterator_prefix_ =
367 name_utils::IteratorPrefix(input_dataset_params_0.dataset_type(),
368 input_dataset_params_0.iterator_prefix());
369 }
370
371 std::vector<Tensor> GetInputTensors() const override;
372
373 Status GetInputNames(std::vector<string>* input_names) const override;
374
375 Status GetAttributes(AttributeVector* attr_vector) const override;
376
377 string dataset_type() const override;
378 };
379
380 // `OptionsDatasetParams` is a common dataset parameter type that is used in
381 // testing.
382 class OptionsDatasetParams : public DatasetParams {
383 public:
384 template <typename T>
OptionsDatasetParams(T input_dataset_params,const string & serialized_options,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)385 OptionsDatasetParams(T input_dataset_params, const string& serialized_options,
386 DataTypeVector output_dtypes,
387 std::vector<PartialTensorShape> output_shapes,
388 string node_name)
389 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
390 std::move(node_name)),
391 serialized_options_(serialized_options) {
392 input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
393 }
394
395 std::vector<Tensor> GetInputTensors() const override;
396
397 Status GetInputNames(std::vector<string>* input_names) const override;
398
399 Status GetAttributes(AttributeVector* attr_vector) const override;
400
401 string dataset_type() const override;
402
403 private:
404 string serialized_options_;
405 };
406
407 template <typename T>
408 struct GetNextTestCase {
409 GetNextTestCase(T dataset_params, std::vector<Tensor> expected_outputs,
410 bool compare_order = true)
dataset_paramsGetNextTestCase411 : dataset_params(std::move(dataset_params)),
412 expected_outputs(std::move(expected_outputs)),
413 compare_order(compare_order) {}
414
415 T dataset_params;
416 std::vector<Tensor> expected_outputs;
417 bool compare_order;
418 };
419
420 template <typename T>
421 struct SkipTestCase {
422 SkipTestCase(T dataset_params, int num_to_skip, int expected_num_skipped,
423 bool get_next = false, std::vector<Tensor> expected_outputs = {},
424 bool compare_order = true)
dataset_paramsSkipTestCase425 : dataset_params(std::move(dataset_params)),
426 num_to_skip(num_to_skip),
427 expected_num_skipped(expected_num_skipped),
428 get_next(get_next),
429 expected_outputs(std::move(expected_outputs)),
430 compare_order(compare_order) {}
431
432 T dataset_params;
433 int num_to_skip;
434 int expected_num_skipped;
435 bool get_next;
436 std::vector<Tensor> expected_outputs;
437 bool compare_order;
438 };
439
440 template <typename T>
441 struct DatasetNodeNameTestCase {
442 T dataset_params;
443 string expected_node_name;
444 };
445
446 template <typename T>
447 struct DatasetTypeStringTestCase {
448 T dataset_params;
449 string expected_dataset_type_string;
450 };
451
452 template <typename T>
453 struct DatasetOutputDtypesTestCase {
454 T dataset_params;
455 DataTypeVector expected_output_dtypes;
456 };
457
458 template <typename T>
459 struct DatasetOutputShapesTestCase {
460 T dataset_params;
461 std::vector<PartialTensorShape> expected_output_shapes;
462 };
463
464 template <typename T>
465 struct CardinalityTestCase {
466 T dataset_params;
467 int64 expected_cardinality;
468 };
469
470 template <typename T>
471 struct DatasetSaveTestCase {
472 T dataset_params;
473 };
474
475 template <typename T>
476 struct IteratorOutputDtypesTestCase {
477 T dataset_params;
478 DataTypeVector expected_output_dtypes;
479 };
480
481 template <typename T>
482 struct IteratorOutputShapesTestCase {
483 T dataset_params;
484 std::vector<PartialTensorShape> expected_output_shapes;
485 };
486
487 template <typename T>
488 struct IteratorPrefixTestCase {
489 T dataset_params;
490 string expected_iterator_prefix;
491 };
492
493 template <typename T>
494 struct IteratorSaveAndRestoreTestCase {
495 IteratorSaveAndRestoreTestCase(T dataset_params, std::vector<int> breakpoints,
496 std::vector<Tensor> expected_outputs,
497 bool compare_order = true)
dataset_paramsIteratorSaveAndRestoreTestCase498 : dataset_params(std::move(dataset_params)),
499 breakpoints(std::move(breakpoints)),
500 expected_outputs(std::move(expected_outputs)),
501 compare_order(compare_order) {}
502
503 T dataset_params;
504 std::vector<int> breakpoints;
505 std::vector<Tensor> expected_outputs;
506 bool compare_order;
507 };
508
509 // Class composing a dataset with its dependencies.
510 class TestDataset {
511 public:
512 // TestDataset expects that the caller has Ref'd the wrapped dataset. When
513 // TestDataset is destroyed, it will Unref the dataset.
TestDataset(std::unique_ptr<OpKernel> kernel_,std::unique_ptr<OpKernelContext::Params> ctx_params,std::unique_ptr<OpKernelContext> ctx,std::vector<std::unique_ptr<Tensor>> input_tensors,DatasetBase * dataset)514 TestDataset(std::unique_ptr<OpKernel> kernel_,
515 std::unique_ptr<OpKernelContext::Params> ctx_params,
516 std::unique_ptr<OpKernelContext> ctx,
517 std::vector<std::unique_ptr<Tensor>> input_tensors,
518 DatasetBase* dataset)
519 : kernel_(std::move(kernel_)),
520 ctx_params_(std::move(ctx_params)),
521 ctx_(std::move(ctx)),
522 input_tensors_(std::move(input_tensors)),
523 dataset_(dataset),
524 scoped_unref_(dataset) {}
525
dataset()526 DatasetBase* dataset() const { return dataset_; }
527
op_kernel_context()528 OpKernelContext* op_kernel_context() const { return ctx_.get(); }
529
530 protected:
531 std::unique_ptr<OpKernel> kernel_;
532 std::unique_ptr<OpKernelContext::Params> ctx_params_;
533 std::unique_ptr<OpKernelContext> ctx_;
534 // The input tensors that this dataset depends on. They must outlive the
535 // dataset.
536 std::vector<std::unique_ptr<Tensor>> input_tensors_;
537 DatasetBase* dataset_;
538 core::ScopedUnref scoped_unref_;
539 };
540
541 // Class composing a dataset iterator with its dependencies.
542 class TestIterator {
543 public:
TestIterator(std::unique_ptr<IteratorContext> ctx,std::unique_ptr<IteratorBase> iterator)544 TestIterator(std::unique_ptr<IteratorContext> ctx,
545 std::unique_ptr<IteratorBase> iterator)
546 : iterator_(std::move(iterator)), ctx_(std::move(ctx)) {}
547
iterator()548 IteratorBase* iterator() const { return iterator_.get(); }
549
ctx()550 IteratorContext* ctx() const { return ctx_.get(); }
551
GetNext(std::vector<Tensor> * out_tensors,bool * end_of_sequence)552 Status GetNext(std::vector<Tensor>* out_tensors, bool* end_of_sequence) {
553 return iterator_->GetNext(ctx(), out_tensors, end_of_sequence);
554 }
555
556 protected:
557 std::unique_ptr<IteratorBase> iterator_;
558 std::unique_ptr<IteratorContext> ctx_;
559 };
560
561 // Helpful functions to test Dataset op kernels.
562 class DatasetOpsTestBase : public ::testing::Test {
563 public:
564 DatasetOpsTestBase();
565
566 // Initializes the runtime and creates a dataset and iterator.
567 Status Initialize(const DatasetParams& dataset_params);
568
569 // Initializes the parts of the runtime needed to run dataset ops.
570 Status InitializeRuntime(const DatasetParams& dataset_params);
571
572 // Creates a dataset.
573 Status MakeDataset(const DatasetParams& dataset_params,
574 std::unique_ptr<TestDataset>* dataset);
575
576 // Creates an iterator for the given dataset, using the specified split
577 // providers.
578 Status MakeIterator(
579 const DatasetParams& dataset_params, const TestDataset& dataset,
580 std::vector<std::unique_ptr<SplitProvider>> split_providers,
581 std::unique_ptr<TestIterator>* iterator);
582 // Creates an iterator for the given dataset.
583 Status MakeIterator(const DatasetParams& dataset_params,
584 const TestDataset& dataset,
585 std::unique_ptr<TestIterator>* iterator);
586
587 // Runs the dataset operation according to the predefined dataset params and
588 // produces outputs. Different from `MakeDataset()` which returns a Dataset
589 // object, `RunDatasetOp()` executes the dataset kernel based on the input
590 // DatasetParams and returns the produced outputs as a tensor vector. It can
591 // be used to run some dataset operations that do not have an internal
592 // customized `Dataset` class (e.g. `ReduceDatasetOp`).
593 Status RunDatasetOp(const DatasetParams& dataset_params,
594 std::vector<Tensor>* outputs);
595
596 // The method validates whether the two tensors have the same shape, dtype,
597 // and value.
598 static Status ExpectEqual(const Tensor& a, const Tensor& b);
599
600 // The method validates whether the two tensor vectors have the same tensors.
601 // If `compare_order` is false, the method will only evaluate whether the two
602 // vectors have the same elements regardless of order.
603 static Status ExpectEqual(std::vector<Tensor> produced_tensors,
604 std::vector<Tensor> expected_tensors,
605 bool compare_order);
606
607 // Checks `IteratorBase::GetNext()`.
608 Status CheckIteratorGetNext(const std::vector<Tensor>& expected_outputs,
609 bool compare_order);
610
611 // Checks `IteratorBase::GetNext()`.
612 Status CheckIteratorGetNext(TestIterator* iterator,
613 const std::vector<Tensor>& expected_outputs,
614 bool compare_order);
615
616 // Checks `IteratorBase::GetNext()`.
617 Status CheckIteratorGetNext(IteratorBase* iterator, IteratorContext* ctx,
618 const std::vector<Tensor>& expected_outputs,
619 bool compare_order);
620
621 // Checks `IteratorBase::Skip()`
622 Status CheckIteratorSkip(int num_to_skip, int expected_num_skipped,
623 bool get_next,
624 const std::vector<Tensor>& expected_outputs,
625 bool compare_order);
626
627 // Checks that iterating through the dataset using a split provider produces
628 // the expected outputs.
629 Status CheckSplitProviderFullIteration(
630 const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
631
632 // Checks that iterating through the dataset using a sharded split provider
633 // with the given `num_shards` and `shard_index` produces the expected
634 // outputs.
635 Status CheckSplitProviderShardedIteration(
636 const DatasetParams& params, int64_t num_shards, int64_t shard_index,
637 const std::vector<Tensor>& expected_outputs);
638
639 // Checks `DatasetBase::node_name()`.
640 Status CheckDatasetNodeName(const string& expected_dataset_node_name);
641
642 // Checks `DatasetBase::type_string()`.
643 Status CheckDatasetTypeString(const string& expected_type_str);
644
645 // Checks `DatasetBase::output_dtypes()`.
646 Status CheckDatasetOutputDtypes(const DataTypeVector& expected_output_dtypes);
647
648 // Checks `DatasetBase::output_shapes()`.
649 Status CheckDatasetOutputShapes(
650 const std::vector<PartialTensorShape>& expected_output_shapes);
651
652 // Checks `DatasetBase::Cardinality()`.
653 Status CheckDatasetCardinality(int expected_cardinality);
654
655 // Checks `DatasetBase::options()`.
656 Status CheckDatasetOptions(const Options& expected_options);
657
658 // Checks `IteratorBase::output_dtypes()`.
659 Status CheckIteratorOutputDtypes(
660 const DataTypeVector& expected_output_dtypes);
661
662 // Checks `IteratorBase::output_shapes()`.
663 Status CheckIteratorOutputShapes(
664 const std::vector<PartialTensorShape>& expected_output_shapes);
665
666 // Checks `IteratorBase::prefix()`.
667 Status CheckIteratorPrefix(const string& expected_iterator_prefix);
668
669 Status CheckIteratorSaveAndRestore(
670 DatasetBase* dataset, IteratorContext* iterator_ctx,
671 const std::string& iterator_prefix,
672 const std::vector<Tensor>& expected_outputs,
673 const std::vector<int>& breakpoints, bool compare_order);
674
675 Status CheckIteratorSaveAndRestore(
676 const std::string& iterator_prefix,
677 const std::vector<Tensor>& expected_outputs,
678 const std::vector<int>& breakpoints, bool compare_order);
679
680 protected:
681 // Make destructor protected so that DatasetOpsTestBase objects cannot
682 // be instantiated directly. Only subclasses can be instantiated.
683 ~DatasetOpsTestBase() override;
684
685 // Creates a thread pool for parallel tasks.
686 Status InitThreadPool(int thread_num);
687
688 // Initializes the runtime for computing the dataset operation and registers
689 // the input function definitions. `InitThreadPool()' needs to be called
690 // before this method if we want to run the tasks in parallel.
691 Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
692 int cpu_num);
693
694 // Creates a new op kernel based on the node definition.
695 Status CreateOpKernel(const NodeDef& node_def,
696 std::unique_ptr<OpKernel>* op_kernel);
697
698 // Creates a new op kernel context.
699 Status CreateDatasetContext(
700 OpKernel* const dateset_kernel,
701 gtl::InlinedVector<TensorValue, 4>* const inputs,
702 std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
703 std::unique_ptr<OpKernelContext>* dataset_context);
704
705 // Creates a new dataset.
706 Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
707 DatasetBase** const dataset);
708
709 // Restores the state of the input iterator. It resets the iterator before
710 // restoring it to make sure the input iterator does not hold any
711 // resources or tasks. Otherwise, restoring an existing iterator may cause
712 // the timeout issue or duplicated elements.
713 Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader,
714 const string& output_prefix,
715 const DatasetBase& dataset,
716 std::unique_ptr<IteratorBase>* iterator);
717
718 // Fetches the dataset from the operation context.
719 Status GetDatasetFromContext(OpKernelContext* context, int output_index,
720 DatasetBase** const dataset);
721
722 // Runs an operation producing outputs.
723 Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
724
725 // Executes a function producing outputs.
726 Status RunFunction(const FunctionDef& fdef, test::function::Attrs attrs,
727 const std::vector<Tensor>& args,
728 const GraphConstructorOptions& graph_options,
729 std::vector<Tensor*> rets);
730
731 // Checks that the size of `inputs` matches the requirement of the op kernel.
732 Status CheckOpKernelInput(const OpKernel& kernel,
733 const gtl::InlinedVector<TensorValue, 4>& inputs);
734
735 // Creates a new context for running the dataset operation.
736 Status CreateOpKernelContext(OpKernel* kernel,
737 gtl::InlinedVector<TensorValue, 4>* inputs,
738 std::unique_ptr<OpKernelContext>* context);
739
740 // Creates a new context for running the dataset operation.
741 Status CreateOpKernelContext(OpKernel* kernel,
742 gtl::InlinedVector<TensorValue, 4>* inputs,
743 std::unique_ptr<OpKernelContext::Params>* params,
744 std::unique_ptr<OpKernelContext>* context);
745
746 // Creates a new iterator context for iterating the dataset.
747 Status CreateIteratorContext(
748 OpKernelContext* const op_context,
749 std::unique_ptr<IteratorContext>* iterator_context);
750
751 // Creates a new iterator context for iterating the dataset.
752 // Creates a new serialization context for serializing the dataset and
753 // iterator.
754 Status CreateSerializationContext(
755 std::unique_ptr<SerializationContext>* context);
756
757 // Creates the dataset op kernel.
758 Status MakeGetOptionsOpKernel(const DatasetParams& dataset_params,
759 std::unique_ptr<OpKernel>* op_kernel);
760
761 private:
762 // Runs the dataset operation according to the predefined dataset params and
763 // the produced outputs will be stored in `dataset_ctx`.
764 Status RunDatasetOp(
765 const DatasetParams& dataset_params,
766 std::unique_ptr<OpKernel>* dataset_kernel,
767 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
768 std::vector<std::unique_ptr<Tensor>>* created_tensors,
769 std::unique_ptr<OpKernelContext>* dataset_ctx);
770
771 Status MakeDataset(
772 const DatasetParams& dataset_params,
773 std::unique_ptr<OpKernel>* dataset_kernel,
774 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
775 std::unique_ptr<OpKernelContext>* dataset_ctx,
776 std::vector<std::unique_ptr<Tensor>>* created_tensors,
777 DatasetBase** dataset);
778
779 // Creates the dataset op kernel.
780 Status MakeDatasetOpKernel(const DatasetParams& dataset_params,
781 std::unique_ptr<OpKernel>* dataset_kernel);
782
783 // Creates a dataset tensor according to the input dataset params.
784 Status MakeDatasetTensor(
785 const DatasetParams& dataset_params,
786 std::vector<std::unique_ptr<Tensor>>* created_tensors,
787 std::unique_ptr<Tensor>* dataset);
788
789 // Adds an empty tensor with the specified dtype and shape to the input
790 // vector.
791 Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
792 DataTypeVector input_types, DataType dtype,
793 const TensorShape& shape);
794
795 protected:
796 std::unique_ptr<Device> device_;
797 DeviceType device_type_;
798 int cpu_num_;
799 int thread_num_;
800 Allocator* allocator_; // Owned by `AllocatorFactoryRegistry`.
801 std::vector<AllocatorAttributes> allocator_attrs_;
802 std::unique_ptr<ScopedStepContainer> step_container_;
803
804 // Device manager is used by function handle cache and needs to outlive it.
805 std::unique_ptr<DeviceMgr> device_mgr_;
806 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
807 FunctionLibraryRuntime* flr_; // Owned by `pflr_`.
808 std::unique_ptr<FunctionHandleCache> function_handle_cache_;
809 std::function<void(std::function<void()>)> runner_;
810 std::unique_ptr<FunctionLibraryDefinition> lib_def_;
811 std::unique_ptr<ResourceMgr> resource_mgr_;
812 std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
813 slice_reader_cache_;
814 std::unique_ptr<thread::ThreadPool> thread_pool_;
815 std::vector<std::unique_ptr<Tensor>> tensors_; // Owns tensors.
816 mutex lock_for_refs_; // Used as the Mutex for inputs added as refs.
817 std::unique_ptr<CancellationManager> cancellation_manager_;
818
819 // Indicates if the below fields have been initialized.
820 bool initialized_ = false;
821 std::unique_ptr<OpKernel> dataset_kernel_;
822 std::unique_ptr<OpKernelContext::Params> params_;
823 std::unique_ptr<OpKernelContext> dataset_ctx_;
824 DatasetBase* dataset_ = nullptr;
825 std::unique_ptr<IteratorContext> iterator_ctx_;
826 std::unique_ptr<IteratorBase> iterator_;
827 };
828
829 #define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
830 test_cases) \
831 class ParameterizedGetNextTest \
832 : public dataset_op_test_class, \
833 public ::testing::WithParamInterface< \
834 GetNextTestCase<dataset_params_class>> {}; \
835 \
836 TEST_P(ParameterizedGetNextTest, GetNext) { \
837 auto test_case = GetParam(); \
838 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
839 TF_ASSERT_OK( \
840 CheckIteratorGetNext(test_case.expected_outputs, \
841 /*compare_order=*/test_case.compare_order)); \
842 } \
843 \
844 INSTANTIATE_TEST_SUITE_P( \
845 dataset_op_test_class, ParameterizedGetNextTest, \
846 ::testing::ValuesIn( \
847 std::vector<GetNextTestCase<dataset_params_class>>(test_cases)));
848
849 #define ITERATOR_SKIP_TEST_P(dataset_op_test_class, dataset_params_class, \
850 test_cases) \
851 class ParameterizedSkipTest : public dataset_op_test_class, \
852 public ::testing::WithParamInterface< \
853 SkipTestCase<dataset_params_class>> {}; \
854 \
855 TEST_P(ParameterizedSkipTest, Skip) { \
856 auto test_case = GetParam(); \
857 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
858 TF_ASSERT_OK(CheckIteratorSkip( \
859 test_case.num_to_skip, test_case.expected_num_skipped, \
860 test_case.get_next, test_case.expected_outputs, \
861 /*compare_order=*/test_case.compare_order)); \
862 } \
863 \
864 INSTANTIATE_TEST_SUITE_P( \
865 dataset_op_test_class, ParameterizedSkipTest, \
866 ::testing::ValuesIn( \
867 std::vector<SkipTestCase<dataset_params_class>>(test_cases)));
868
869 #define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
870 test_cases) \
871 class ParameterizedDatasetNodeNameTest \
872 : public dataset_op_test_class, \
873 public ::testing::WithParamInterface< \
874 DatasetNodeNameTestCase<dataset_params_class>> {}; \
875 \
876 TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) { \
877 auto test_case = GetParam(); \
878 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
879 TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name)); \
880 } \
881 \
882 INSTANTIATE_TEST_SUITE_P( \
883 dataset_op_test_class, ParameterizedDatasetNodeNameTest, \
884 ::testing::ValuesIn( \
885 std::vector<DatasetNodeNameTestCase<dataset_params_class>>( \
886 test_cases)));
887
888 #define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class, \
889 dataset_params_class, test_cases) \
890 class ParameterizedDatasetTypeStringTest \
891 : public dataset_op_test_class, \
892 public ::testing::WithParamInterface< \
893 DatasetTypeStringTestCase<dataset_params_class>> {}; \
894 \
895 TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) { \
896 auto test_case = GetParam(); \
897 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
898 TF_ASSERT_OK( \
899 CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
900 } \
901 \
902 INSTANTIATE_TEST_SUITE_P( \
903 dataset_op_test_class, ParameterizedDatasetTypeStringTest, \
904 ::testing::ValuesIn( \
905 std::vector<DatasetTypeStringTestCase<dataset_params_class>>( \
906 test_cases)));
907
908 #define DATASET_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \
909 dataset_params_class, test_cases) \
910 \
911 class ParameterizedDatasetOutputDtypesTest \
912 : public dataset_op_test_class, \
913 public ::testing::WithParamInterface< \
914 DatasetOutputDtypesTestCase<dataset_params_class>> {}; \
915 \
916 TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) { \
917 auto test_case = GetParam(); \
918 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
919 TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
920 } \
921 \
922 INSTANTIATE_TEST_SUITE_P( \
923 dataset_op_test_class, ParameterizedDatasetOutputDtypesTest, \
924 ::testing::ValuesIn( \
925 std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>( \
926 test_cases)));
927
928 #define DATASET_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \
929 dataset_params_class, test_cases) \
930 \
931 class ParameterizedDatasetOutputShapesTest \
932 : public dataset_op_test_class, \
933 public ::testing::WithParamInterface< \
934 DatasetOutputShapesTestCase<dataset_params_class>> {}; \
935 \
936 TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { \
937 auto test_case = GetParam(); \
938 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
939 TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
940 } \
941 \
942 INSTANTIATE_TEST_SUITE_P( \
943 dataset_op_test_class, ParameterizedDatasetOutputShapesTest, \
944 ::testing::ValuesIn( \
945 std::vector<DatasetOutputShapesTestCase<dataset_params_class>>( \
946 test_cases)));
947
948 #define DATASET_CARDINALITY_TEST_P(dataset_op_test_class, \
949 dataset_params_class, test_cases) \
950 \
951 class ParameterizedCardinalityTest \
952 : public dataset_op_test_class, \
953 public ::testing::WithParamInterface< \
954 CardinalityTestCase<dataset_params_class>> {}; \
955 \
956 TEST_P(ParameterizedCardinalityTest, Cardinality) { \
957 auto test_case = GetParam(); \
958 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
959 TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
960 } \
961 \
962 INSTANTIATE_TEST_SUITE_P( \
963 dataset_op_test_class, ParameterizedCardinalityTest, \
964 ::testing::ValuesIn( \
965 std::vector<CardinalityTestCase<dataset_params_class>>( \
966 test_cases)));
967
968 #define ITERATOR_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \
969 dataset_params_class, test_cases) \
970 class ParameterizedIteratorOutputDtypesTest \
971 : public dataset_op_test_class, \
972 public ::testing::WithParamInterface< \
973 IteratorOutputDtypesTestCase<dataset_params_class>> {}; \
974 \
975 TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) { \
976 auto test_case = GetParam(); \
977 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
978 TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
979 } \
980 \
981 INSTANTIATE_TEST_SUITE_P( \
982 dataset_op_test_class, ParameterizedIteratorOutputDtypesTest, \
983 ::testing::ValuesIn( \
984 std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>( \
985 test_cases)));
986
987 #define ITERATOR_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \
988 dataset_params_class, test_cases) \
989 class ParameterizedIteratorOutputShapesTest \
990 : public dataset_op_test_class, \
991 public ::testing::WithParamInterface< \
992 IteratorOutputShapesTestCase<dataset_params_class>> {}; \
993 \
994 TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { \
995 auto test_case = GetParam(); \
996 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
997 TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
998 } \
999 \
1000 INSTANTIATE_TEST_SUITE_P( \
1001 dataset_op_test_class, ParameterizedIteratorOutputShapesTest, \
1002 ::testing::ValuesIn( \
1003 std::vector<IteratorOutputShapesTestCase<dataset_params_class>>( \
1004 test_cases)));
1005
1006 #define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
1007 test_cases) \
1008 class ParameterizedIteratorPrefixTest \
1009 : public dataset_op_test_class, \
1010 public ::testing::WithParamInterface< \
1011 IteratorPrefixTestCase<dataset_params_class>> {}; \
1012 \
1013 TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) { \
1014 auto test_case = GetParam(); \
1015 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
1016 TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix)); \
1017 } \
1018 \
1019 INSTANTIATE_TEST_SUITE_P( \
1020 dataset_op_test_class, ParameterizedIteratorPrefixTest, \
1021 ::testing::ValuesIn( \
1022 std::vector<IteratorPrefixTestCase<dataset_params_class>>( \
1023 test_cases)));
1024
1025 #define ITERATOR_SAVE_AND_RESTORE_TEST_P(dataset_op_test_class, \
1026 dataset_params_class, test_cases) \
1027 class ParameterizedIteratorSaveAndRestoreTest \
1028 : public dataset_op_test_class, \
1029 public ::testing::WithParamInterface< \
1030 IteratorSaveAndRestoreTestCase<dataset_params_class>> {}; \
1031 TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { \
1032 auto test_case = GetParam(); \
1033 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
1034 TF_ASSERT_OK(CheckIteratorSaveAndRestore( \
1035 test_case.dataset_params.iterator_prefix(), \
1036 test_case.expected_outputs, test_case.breakpoints, \
1037 test_case.compare_order)); \
1038 } \
1039 INSTANTIATE_TEST_SUITE_P( \
1040 dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest, \
1041 ::testing::ValuesIn( \
1042 std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
1043 test_cases)));
1044
1045 } // namespace data
1046 } // namespace tensorflow
1047
1048 #endif // TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
1049