1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
18
19 #include <stddef.h>
20
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/graph_constructor.h"
31 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
32 #include "tensorflow/core/framework/allocator.h"
33 #include "tensorflow/core/framework/cancellation.h"
34 #include "tensorflow/core/framework/dataset.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/function_handle_cache.h"
37 #include "tensorflow/core/framework/function_testlib.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/resource_mgr.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.h"
42 #include "tensorflow/core/framework/tensor_testutil.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/kernels/data/name_utils.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/lib/gtl/array_slice.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/io/zlib_compression_options.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/test.h"
53 #include "tensorflow/core/platform/threadpool.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
56
57 namespace tensorflow {
58 namespace data {
59
60 typedef std::vector<
61 std::pair<string, tensorflow::FunctionDefHelper::AttrValueWrapper>>
62 AttributeVector;
63
64 constexpr int kDefaultCPUNum = 2;
65 constexpr int kDefaultThreadNum = 2;
66
67 // Creates a tensor with the specified dtype, shape, and value.
68 template <typename T>
CreateTensor(const TensorShape & input_shape,const gtl::ArraySlice<T> & input_data)69 static Tensor CreateTensor(const TensorShape& input_shape,
70 const gtl::ArraySlice<T>& input_data) {
71 Tensor tensor(DataTypeToEnum<T>::value, input_shape);
72 test::FillValues<T>(&tensor, input_data);
73 return tensor;
74 }
75
76 // Creates a vector of tensors with the specified dtype, shape, and values.
77 template <typename T>
CreateTensors(const TensorShape & shape,const std::vector<gtl::ArraySlice<T>> & values)78 std::vector<Tensor> CreateTensors(
79 const TensorShape& shape, const std::vector<gtl::ArraySlice<T>>& values) {
80 std::vector<Tensor> result;
81 result.reserve(values.size());
82 for (auto& value : values) {
83 result.emplace_back(CreateTensor<T>(shape, value));
84 }
85 return result;
86 }
87
88 enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
89
90 // Returns a string representation for the given compression type.
91 string ToString(CompressionType compression_type);
92
93 // Gets the specified zlib compression options according to the compression
94 // type. Note that `CompressionType::UNCOMPRESSED` is not supported because
95 // `ZlibCompressionOptions` does not have an option.
96 io::ZlibCompressionOptions GetZlibCompressionOptions(
97 CompressionType compression_type);
98
99 // Used to specify parameters when writing data into files with compression.
100 // `input_buffer_size` and `output_buffer_size` specify the input and output
101 // buffer size when ZLIB and GZIP compression is used.
102 struct CompressionParams {
103 CompressionType compression_type = CompressionType::UNCOMPRESSED;
104 int32 input_buffer_size = 0;
105 int32 output_buffer_size = 0;
106 };
107
108 // Writes the input data into the file without compression.
109 Status WriteDataToFile(const string& filename, const char* data);
110
111 // Writes the input data into the file with the specified compression.
112 Status WriteDataToFile(const string& filename, const char* data,
113 const CompressionParams& params);
114
115 // Writes the input data into the TFRecord file with the specified compression.
116 Status WriteDataToTFRecordFile(const string& filename,
117 const std::vector<absl::string_view>& records,
118 const CompressionParams& params);
119
120 // Provides the parameters for running the dataset op.
121 class DatasetParams {
122 public:
123 DatasetParams(DataTypeVector output_dtypes,
124 std::vector<PartialTensorShape> output_shapes,
125 string node_name);
126
~DatasetParams()127 virtual ~DatasetParams() {}
128
129 // Returns the inputs (except the input datasets) as a tensor vector.
130 virtual std::vector<Tensor> GetInputTensors() const = 0;
131
132 // Returns the dataset input names as a string vector.
133 virtual Status GetInputNames(std::vector<string>* input_names) const = 0;
134
135 // Returns the dataset attributes as a vector.
136 virtual Status GetAttributes(AttributeVector* attributes) const = 0;
137
138 // Checks if the tensor is a dataset variant tensor.
139 static bool IsDatasetTensor(const Tensor& tensor);
140
node_name()141 string node_name() const { return node_name_; }
142
output_dtypes()143 DataTypeVector output_dtypes() const { return output_dtypes_; }
144
output_shapes()145 std::vector<PartialTensorShape> output_shapes() const {
146 return output_shapes_;
147 }
148
iterator_prefix()149 string iterator_prefix() const { return iterator_prefix_; }
150
input_dataset_params()151 const std::vector<std::shared_ptr<DatasetParams>>& input_dataset_params()
152 const {
153 return input_dataset_params_;
154 }
155
156 // Returns the functions that will be used when running the dataset op.
func_lib()157 virtual std::vector<FunctionDef> func_lib() const { return {}; }
158
159 // Returns the dataset type for the op represented by these parameters. This
160 // type usually needs to match the constant called `kDatasetType` defined in
161 // the dataset kernel.
162 virtual string dataset_type() const = 0;
163
op_version()164 virtual int op_version() const { return op_version_; }
165
166 protected:
167 std::vector<std::shared_ptr<DatasetParams>> input_dataset_params_;
168 DataTypeVector output_dtypes_;
169 std::vector<PartialTensorShape> output_shapes_;
170 string node_name_;
171 string iterator_prefix_ = "Iterator";
172 int op_version_ = 1;
173 };
174
175 // `RangeDatasetParams` is a common dataset parameter type that are used in
176 // testing.
177 class RangeDatasetParams : public DatasetParams {
178 public:
179 RangeDatasetParams(int64 start, int64 stop, int64 step,
180 DataTypeVector output_dtypes,
181 std::vector<PartialTensorShape> output_shapes,
182 string node_name);
183
184 RangeDatasetParams(int64 start, int64 stop, int64 step);
185
186 RangeDatasetParams(int64 start, int64 stop, int64 step,
187 DataTypeVector output_dtypes);
188
189 std::vector<Tensor> GetInputTensors() const override;
190
191 Status GetInputNames(std::vector<string>* input_names) const override;
192
193 Status GetAttributes(AttributeVector* attr_vector) const override;
194
195 string dataset_type() const override;
196
197 private:
198 int64 start_;
199 int64 stop_;
200 int64 step_;
201 };
202
203 // `BatchDatasetParams` is a common dataset parameter type that are used in
204 // testing.
205 class BatchDatasetParams : public DatasetParams {
206 public:
207 template <typename T>
BatchDatasetParams(T input_dataset_params,int64 batch_size,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)208 BatchDatasetParams(T input_dataset_params, int64 batch_size,
209 bool drop_remainder, bool parallel_copy,
210 DataTypeVector output_dtypes,
211 std::vector<PartialTensorShape> output_shapes,
212 string node_name)
213 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
214 std::move(node_name)),
215 batch_size_(batch_size),
216 drop_remainder_(drop_remainder),
217 parallel_copy_(parallel_copy) {
218 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
219 op_version_ = 2;
220 iterator_prefix_ =
221 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
222 input_dataset_params.iterator_prefix());
223 }
224
225 std::vector<Tensor> GetInputTensors() const override;
226
227 Status GetInputNames(std::vector<string>* input_names) const override;
228
229 Status GetAttributes(AttributeVector* attr_vector) const override;
230
231 string dataset_type() const override;
232
233 private:
234 int64 batch_size_;
235 bool drop_remainder_;
236 bool parallel_copy_;
237 };
238
239 // `MapDatasetParams` is a common dataset parameter type that are used in
240 // testing.
241 class MapDatasetParams : public DatasetParams {
242 public:
243 template <typename T>
MapDatasetParams(T input_dataset_params,std::vector<Tensor> other_arguments,FunctionDefHelper::AttrValueWrapper func,std::vector<FunctionDef> func_lib,DataTypeVector type_arguments,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,bool use_inter_op_parallelism,bool preserve_cardinality,string node_name)244 MapDatasetParams(T input_dataset_params, std::vector<Tensor> other_arguments,
245 FunctionDefHelper::AttrValueWrapper func,
246 std::vector<FunctionDef> func_lib,
247 DataTypeVector type_arguments, DataTypeVector output_dtypes,
248 std::vector<PartialTensorShape> output_shapes,
249 bool use_inter_op_parallelism, bool preserve_cardinality,
250 string node_name)
251 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
252 std::move(node_name)),
253 other_arguments_(std::move(other_arguments)),
254 func_(std::move(func)),
255 func_lib_(std::move(func_lib)),
256 type_arguments_(std::move(type_arguments)),
257 use_inter_op_parallelism_(use_inter_op_parallelism),
258 preserve_cardinality_(preserve_cardinality) {
259 input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
260 iterator_prefix_ =
261 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
262 input_dataset_params.iterator_prefix());
263 }
264
265 std::vector<Tensor> GetInputTensors() const override;
266
267 Status GetInputNames(std::vector<string>* input_names) const override;
268
269 Status GetAttributes(AttributeVector* attr_vector) const override;
270
271 string dataset_type() const override;
272
273 std::vector<FunctionDef> func_lib() const override;
274
275 private:
276 std::vector<Tensor> other_arguments_;
277 FunctionDefHelper::AttrValueWrapper func_;
278 std::vector<FunctionDef> func_lib_;
279 DataTypeVector type_arguments_;
280 bool use_inter_op_parallelism_;
281 bool preserve_cardinality_;
282 };
283
284 // `TensorSliceDatasetParams` is a common dataset parameter type that are used
285 // in testing.
286 class TensorSliceDatasetParams : public DatasetParams {
287 public:
288 TensorSliceDatasetParams(std::vector<Tensor> components, string node_name);
289
290 std::vector<Tensor> GetInputTensors() const override;
291
292 Status GetInputNames(std::vector<string>* input_names) const override;
293
294 Status GetAttributes(AttributeVector* attr_vector) const override;
295
296 string dataset_type() const override;
297
num_slices()298 int64 num_slices() const { return components_[0].dim_size(0); }
299
num_tensors_per_slice()300 size_t num_tensors_per_slice() const { return components_.size(); }
301
302 private:
303 DataTypeVector TensorSliceDtypes(const std::vector<Tensor>& input_components);
304
305 std::vector<PartialTensorShape> TensorSliceShapes(
306 const std::vector<Tensor>& input_components);
307
308 public:
309 std::vector<Tensor> components_;
310 };
311
312 // `TakeDatasetParams` is a common dataset parameter type that are used in
313 // testing.
314 class TakeDatasetParams : public DatasetParams {
315 public:
316 template <typename T>
TakeDatasetParams(T input_dataset_params,int count,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)317 TakeDatasetParams(T input_dataset_params, int count,
318 DataTypeVector output_dtypes,
319 std::vector<PartialTensorShape> output_shapes,
320 string node_name)
321 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
322 std::move(node_name)),
323 count_(count) {
324 input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
325 iterator_prefix_ =
326 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
327 input_dataset_params.iterator_prefix());
328 }
329
330 std::vector<Tensor> GetInputTensors() const override;
331
332 Status GetInputNames(std::vector<string>* input_names) const override;
333
334 Status GetAttributes(AttributeVector* attr_vector) const override;
335
336 string dataset_type() const override;
337
338 private:
339 int64 count_;
340 };
341
342 // `ConcatenateDatasetParams` is a common dataset parameter type that are used
343 // in testing.
344 class ConcatenateDatasetParams : public DatasetParams {
345 public:
346 template <typename T, typename P>
ConcatenateDatasetParams(T input_dataset_params_0,P input_dataset_params_1,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)347 ConcatenateDatasetParams(T input_dataset_params_0, P input_dataset_params_1,
348 DataTypeVector output_dtypes,
349 std::vector<PartialTensorShape> output_shapes,
350 string node_name)
351 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
352 std::move(node_name)) {
353 input_dataset_params_.push_back(
354 absl::make_unique<T>(input_dataset_params_0));
355 input_dataset_params_.push_back(
356 absl::make_unique<T>(input_dataset_params_1));
357 iterator_prefix_ =
358 name_utils::IteratorPrefix(input_dataset_params_0.dataset_type(),
359 input_dataset_params_0.iterator_prefix());
360 }
361
362 std::vector<Tensor> GetInputTensors() const override;
363
364 Status GetInputNames(std::vector<string>* input_names) const override;
365
366 Status GetAttributes(AttributeVector* attr_vector) const override;
367
368 string dataset_type() const override;
369 };
370
371 template <typename T>
372 struct GetNextTestCase {
373 GetNextTestCase(T dataset_params, std::vector<Tensor> expected_outputs,
374 bool compare_order = true)
dataset_paramsGetNextTestCase375 : dataset_params(std::move(dataset_params)),
376 expected_outputs(std::move(expected_outputs)),
377 compare_order(compare_order) {}
378
379 T dataset_params;
380 std::vector<Tensor> expected_outputs;
381 bool compare_order;
382 };
383
384 template <typename T>
385 struct SkipTestCase {
386 SkipTestCase(T dataset_params, int num_to_skip, int expected_num_skipped,
387 bool get_next = false, std::vector<Tensor> expected_outputs = {},
388 bool compare_order = true)
dataset_paramsSkipTestCase389 : dataset_params(std::move(dataset_params)),
390 num_to_skip(num_to_skip),
391 expected_num_skipped(expected_num_skipped),
392 get_next(get_next),
393 expected_outputs(std::move(expected_outputs)),
394 compare_order(compare_order) {}
395
396 T dataset_params;
397 int num_to_skip;
398 int expected_num_skipped;
399 bool get_next;
400 std::vector<Tensor> expected_outputs;
401 bool compare_order;
402 };
403
404 template <typename T>
405 struct DatasetNodeNameTestCase {
406 T dataset_params;
407 string expected_node_name;
408 };
409
410 template <typename T>
411 struct DatasetTypeStringTestCase {
412 T dataset_params;
413 string expected_dataset_type_string;
414 };
415
416 template <typename T>
417 struct DatasetOutputDtypesTestCase {
418 T dataset_params;
419 DataTypeVector expected_output_dtypes;
420 };
421
422 template <typename T>
423 struct DatasetOutputShapesTestCase {
424 T dataset_params;
425 std::vector<PartialTensorShape> expected_output_shapes;
426 };
427
428 template <typename T>
429 struct CardinalityTestCase {
430 T dataset_params;
431 int64 expected_cardinality;
432 };
433
434 template <typename T>
435 struct DatasetSaveTestCase {
436 T dataset_params;
437 };
438
439 template <typename T>
440 struct IteratorOutputDtypesTestCase {
441 T dataset_params;
442 DataTypeVector expected_output_dtypes;
443 };
444
445 template <typename T>
446 struct IteratorOutputShapesTestCase {
447 T dataset_params;
448 std::vector<PartialTensorShape> expected_output_shapes;
449 };
450
451 template <typename T>
452 struct IteratorPrefixTestCase {
453 T dataset_params;
454 string expected_iterator_prefix;
455 };
456
457 template <typename T>
458 struct IteratorSaveAndRestoreTestCase {
459 IteratorSaveAndRestoreTestCase(T dataset_params, std::vector<int> breakpoints,
460 std::vector<Tensor> expected_outputs,
461 bool compare_order = true)
dataset_paramsIteratorSaveAndRestoreTestCase462 : dataset_params(std::move(dataset_params)),
463 breakpoints(std::move(breakpoints)),
464 expected_outputs(std::move(expected_outputs)),
465 compare_order(compare_order) {}
466
467 T dataset_params;
468 std::vector<int> breakpoints;
469 std::vector<Tensor> expected_outputs;
470 bool compare_order;
471 };
472
473 // Class composing a dataset with its dependencies.
474 class TestDataset {
475 public:
476 // TestDataset expects that the caller has Ref'd the wrapped dataset. When
477 // TestDataset is destroyed, it will Unref the dataset.
TestDataset(std::unique_ptr<OpKernel> kernel_,std::unique_ptr<OpKernelContext::Params> ctx_params,std::unique_ptr<OpKernelContext> ctx,std::vector<std::unique_ptr<Tensor>> input_tensors,DatasetBase * dataset)478 TestDataset(std::unique_ptr<OpKernel> kernel_,
479 std::unique_ptr<OpKernelContext::Params> ctx_params,
480 std::unique_ptr<OpKernelContext> ctx,
481 std::vector<std::unique_ptr<Tensor>> input_tensors,
482 DatasetBase* dataset)
483 : kernel_(std::move(kernel_)),
484 ctx_params_(std::move(ctx_params)),
485 ctx_(std::move(ctx)),
486 input_tensors_(std::move(input_tensors)),
487 dataset_(dataset),
488 scoped_unref_(dataset) {}
489
dataset()490 DatasetBase* dataset() const { return dataset_; }
491
op_kernel_context()492 OpKernelContext* op_kernel_context() const { return ctx_.get(); }
493
494 protected:
495 std::unique_ptr<OpKernel> kernel_;
496 std::unique_ptr<OpKernelContext::Params> ctx_params_;
497 std::unique_ptr<OpKernelContext> ctx_;
498 // The input tensors that this dataset depends on. They must outlive the
499 // dataset.
500 std::vector<std::unique_ptr<Tensor>> input_tensors_;
501 DatasetBase* dataset_;
502 core::ScopedUnref scoped_unref_;
503 };
504
505 // Class composing a dataset iterator with its dependencies.
506 class TestIterator {
507 public:
TestIterator(std::unique_ptr<IteratorContext> ctx,std::unique_ptr<IteratorBase> iterator)508 TestIterator(std::unique_ptr<IteratorContext> ctx,
509 std::unique_ptr<IteratorBase> iterator)
510 : iterator_(std::move(iterator)), ctx_(std::move(ctx)) {}
511
iterator()512 IteratorBase* iterator() const { return iterator_.get(); }
513
ctx()514 IteratorContext* ctx() const { return ctx_.get(); }
515
GetNext(std::vector<Tensor> * out_tensors,bool * end_of_sequence)516 Status GetNext(std::vector<Tensor>* out_tensors, bool* end_of_sequence) {
517 return iterator_->GetNext(ctx(), out_tensors, end_of_sequence);
518 }
519
520 protected:
521 std::unique_ptr<IteratorBase> iterator_;
522 std::unique_ptr<IteratorContext> ctx_;
523 };
524
525 // Helpful functions to test Dataset op kernels.
526 class DatasetOpsTestBase : public ::testing::Test {
527 public:
528 DatasetOpsTestBase();
529
530 // Initializes the runtime and creates a dataset and iterator.
531 Status Initialize(const DatasetParams& dataset_params);
532
533 // Initializes the parts of the runtime needed to run dataset ops.
534 Status InitializeRuntime(const DatasetParams& dataset_params);
535
536 // Creates a dataset.
537 Status MakeDataset(const DatasetParams& dataset_params,
538 std::unique_ptr<TestDataset>* dataset);
539
540 // Creates an iterator for the given dataset, using the specified split
541 // provider.
542 Status MakeIterator(const DatasetParams& dataset_params,
543 const TestDataset& dataset,
544 std::unique_ptr<SplitProvider> split_provider,
545 std::unique_ptr<TestIterator>* iterator);
546 // Creates an iterator for the given dataset.
547 Status MakeIterator(const DatasetParams& dataset_params,
548 const TestDataset& dataset,
549 std::unique_ptr<TestIterator>* iterator);
550
551 // Runs the dataset operation according to the predefined dataset params and
552 // produces outputs. Different from `MakeDataset()` which returns a Dataset
553 // object, `RunDatasetOp()` executes the dataset kernel based on the input
554 // DatasetParams and returns the produced outputs as a tensor vector. It can
555 // be used to run some dataset operations that do not have an internal
556 // customized `Dataset` class (e.g. `ReduceDatasetOp`).
557 Status RunDatasetOp(const DatasetParams& dataset_params,
558 std::vector<Tensor>* outputs);
559
560 // The method validates whether the two tensors have the same shape, dtype,
561 // and value.
562 static Status ExpectEqual(const Tensor& a, const Tensor& b);
563
564 // The method validates whether the two tensor vectors have the same tensors.
565 // If `compare_order` is false, the method will only evaluate whether the two
566 // vectors have the same elements regardless of order.
567 static Status ExpectEqual(std::vector<Tensor> produced_tensors,
568 std::vector<Tensor> expected_tensors,
569 bool compare_order);
570
571 // Checks `IteratorBase::GetNext()`.
572 Status CheckIteratorGetNext(const std::vector<Tensor>& expected_outputs,
573 bool compare_order);
574
575 // Checks `IteratorBase::GetNext()`.
576 Status CheckIteratorGetNext(TestIterator* iterator,
577 const std::vector<Tensor>& expected_outputs,
578 bool compare_order);
579
580 // Checks `IteratorBase::GetNext()`.
581 Status CheckIteratorGetNext(IteratorBase* iterator, IteratorContext* ctx,
582 const std::vector<Tensor>& expected_outputs,
583 bool compare_order);
584
585 // Checks `IteratorBase::Skip()`
586 Status CheckIteratorSkip(int num_to_skip, int expected_num_skipped,
587 bool get_next,
588 const std::vector<Tensor>& expected_outputs,
589 bool compare_order);
590
591 // Checks that iterating through the dataset using a split provider produces
592 // the expected outputs.
593 Status CheckSplitProviderFullIteration(
594 const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
595
596 // Checks that iterating through the dataset using a sharded split provider
597 // with the given `num_shards` and `shard_index` produces the expected
598 // outputs.
599 Status CheckSplitProviderShardedIteration(
600 const DatasetParams& params, int64 num_shards, int64 shard_index,
601 const std::vector<Tensor>& expected_outputs);
602
603 // Checks `DatasetBase::node_name()`.
604 Status CheckDatasetNodeName(const string& expected_dataset_node_name);
605
606 // Checks `DatasetBase::type_string()`.
607 Status CheckDatasetTypeString(const string& expected_type_str);
608
609 // Checks `DatasetBase::output_dtypes()`.
610 Status CheckDatasetOutputDtypes(const DataTypeVector& expected_output_dtypes);
611
612 // Checks `DatasetBase::output_shapes()`.
613 Status CheckDatasetOutputShapes(
614 const std::vector<PartialTensorShape>& expected_output_shapes);
615
616 // Checks `DatasetBase::Cardinality()`.
617 Status CheckDatasetCardinality(int expected_cardinality);
618
619 // Checks `IteratorBase::output_dtypes()`.
620 Status CheckIteratorOutputDtypes(
621 const DataTypeVector& expected_output_dtypes);
622
623 // Checks `IteratorBase::output_shapes()`.
624 Status CheckIteratorOutputShapes(
625 const std::vector<PartialTensorShape>& expected_output_shapes);
626
627 // Checks `IteratorBase::prefix()`.
628 Status CheckIteratorPrefix(const string& expected_iterator_prefix);
629
630 Status CheckIteratorSaveAndRestore(
631 DatasetBase* dataset, IteratorContext* iterator_ctx,
632 const std::string& iterator_prefix,
633 const std::vector<Tensor>& expected_outputs,
634 const std::vector<int>& breakpoints, bool compare_order);
635
636 Status CheckIteratorSaveAndRestore(
637 const std::string& iterator_prefix,
638 const std::vector<Tensor>& expected_outputs,
639 const std::vector<int>& breakpoints, bool compare_order);
640
641 protected:
642 // Make destructor protected so that DatasetOpsTestBase objects cannot
643 // be instantiated directly. Only subclasses can be instantiated.
644 ~DatasetOpsTestBase() override;
645
646 // Creates a thread pool for parallel tasks.
647 Status InitThreadPool(int thread_num);
648
649 // Initializes the runtime for computing the dataset operation and registers
650 // the input function definitions. `InitThreadPool()' needs to be called
651 // before this method if we want to run the tasks in parallel.
652 Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
653 int cpu_num);
654
655 // Creates a new op kernel based on the node definition.
656 Status CreateOpKernel(const NodeDef& node_def,
657 std::unique_ptr<OpKernel>* op_kernel);
658
659 // Creates a new op kernel context.
660 Status CreateDatasetContext(
661 OpKernel* const dateset_kernel,
662 gtl::InlinedVector<TensorValue, 4>* const inputs,
663 std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
664 std::unique_ptr<OpKernelContext>* dataset_context);
665
666 // Creates a new dataset.
667 Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
668 DatasetBase** const dataset);
669
670 // Restores the state of the input iterator. It resets the iterator before
671 // restoring it to make sure the input iterator does not hold any
672 // resources or tasks. Otherwise, restoring an existing iterator may cause
673 // the timeout issue or duplicated elements.
674 Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader,
675 const string& output_prefix,
676 const DatasetBase& dataset,
677 std::unique_ptr<IteratorBase>* iterator);
678
679 // Fetches the dataset from the operation context.
680 Status GetDatasetFromContext(OpKernelContext* context, int output_index,
681 DatasetBase** const dataset);
682
683 // Runs an operation producing outputs.
684 Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
685
686 // Executes a function producing outputs.
687 Status RunFunction(const FunctionDef& fdef, test::function::Attrs attrs,
688 const std::vector<Tensor>& args,
689 const GraphConstructorOptions& graph_options,
690 std::vector<Tensor*> rets);
691
692 // Checks that the size of `inputs` matches the requirement of the op kernel.
693 Status CheckOpKernelInput(const OpKernel& kernel,
694 const gtl::InlinedVector<TensorValue, 4>& inputs);
695
696 // Creates a new context for running the dataset operation.
697 Status CreateOpKernelContext(OpKernel* kernel,
698 gtl::InlinedVector<TensorValue, 4>* inputs,
699 std::unique_ptr<OpKernelContext>* context);
700
701 // Creates a new context for running the dataset operation.
702 Status CreateOpKernelContext(OpKernel* kernel,
703 gtl::InlinedVector<TensorValue, 4>* inputs,
704 std::unique_ptr<OpKernelContext::Params>* params,
705 std::unique_ptr<OpKernelContext>* context);
706
707 // Creates a new iterator context for iterating the dataset.
708 Status CreateIteratorContext(
709 OpKernelContext* const op_context,
710 std::unique_ptr<IteratorContext>* iterator_context);
711
712 // Creates a new iterator context for iterating the dataset.
713 // Creates a new serialization context for serializing the dataset and
714 // iterator.
715 Status CreateSerializationContext(
716 std::unique_ptr<SerializationContext>* context);
717
718 private:
719 // Runs the dataset operation according to the predefined dataset params and
720 // the produced outputs will be stored in `dataset_ctx`.
721 Status RunDatasetOp(
722 const DatasetParams& dataset_params,
723 std::unique_ptr<OpKernel>* dataset_kernel,
724 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
725 std::vector<std::unique_ptr<Tensor>>* created_tensors,
726 std::unique_ptr<OpKernelContext>* dataset_ctx);
727
728 Status MakeDataset(
729 const DatasetParams& dataset_params,
730 std::unique_ptr<OpKernel>* dataset_kernel,
731 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
732 std::unique_ptr<OpKernelContext>* dataset_ctx,
733 std::vector<std::unique_ptr<Tensor>>* created_tensors,
734 DatasetBase** dataset);
735
736 // Creates the dataset op kernel.
737 Status MakeDatasetOpKernel(const DatasetParams& dataset_params,
738 std::unique_ptr<OpKernel>* dataset_kernel);
739
740 // Creates a dataset tensor according to the input dataset params.
741 Status MakeDatasetTensor(
742 const DatasetParams& dataset_params,
743 std::vector<std::unique_ptr<Tensor>>* created_tensors,
744 std::unique_ptr<Tensor>* dataset);
745
746 // Adds an empty tensor with the specified dtype and shape to the input
747 // vector.
748 Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
749 DataTypeVector input_types, DataType dtype,
750 const TensorShape& shape);
751
752 protected:
753 std::unique_ptr<Device> device_;
754 DeviceType device_type_;
755 int cpu_num_;
756 int thread_num_;
757 Allocator* allocator_; // Owned by `AllocatorFactoryRegistry`.
758 std::vector<AllocatorAttributes> allocator_attrs_;
759 std::unique_ptr<ScopedStepContainer> step_container_;
760
761 // Device manager is used by function handle cache and needs to outlive it.
762 std::unique_ptr<DeviceMgr> device_mgr_;
763 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
764 FunctionLibraryRuntime* flr_; // Owned by `pflr_`.
765 std::unique_ptr<FunctionHandleCache> function_handle_cache_;
766 std::function<void(std::function<void()>)> runner_;
767 std::unique_ptr<FunctionLibraryDefinition> lib_def_;
768 std::unique_ptr<ResourceMgr> resource_mgr_;
769 std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
770 slice_reader_cache_;
771 std::unique_ptr<thread::ThreadPool> thread_pool_;
772 std::vector<std::unique_ptr<Tensor>> tensors_; // Owns tensors.
773 mutex lock_for_refs_; // Used as the Mutex for inputs added as refs.
774 std::unique_ptr<CancellationManager> cancellation_manager_;
775
776 // Indicates if the below fields have been initialized.
777 bool initialized_ = false;
778 std::unique_ptr<OpKernel> dataset_kernel_;
779 std::unique_ptr<OpKernelContext::Params> params_;
780 std::unique_ptr<OpKernelContext> dataset_ctx_;
781 DatasetBase* dataset_ = nullptr;
782 std::unique_ptr<IteratorContext> iterator_ctx_;
783 std::unique_ptr<IteratorBase> iterator_;
784 };
785
786 #define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
787 test_cases) \
788 class ParameterizedGetNextTest \
789 : public dataset_op_test_class, \
790 public ::testing::WithParamInterface< \
791 GetNextTestCase<dataset_params_class>> {}; \
792 \
793 TEST_P(ParameterizedGetNextTest, GetNext) { \
794 auto test_case = GetParam(); \
795 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
796 TF_ASSERT_OK( \
797 CheckIteratorGetNext(test_case.expected_outputs, \
798 /*compare_order=*/test_case.compare_order)); \
799 } \
800 \
801 INSTANTIATE_TEST_SUITE_P( \
802 dataset_op_test_class, ParameterizedGetNextTest, \
803 ::testing::ValuesIn( \
804 std::vector<GetNextTestCase<dataset_params_class>>(test_cases)));
805
806 #define ITERATOR_SKIP_TEST_P(dataset_op_test_class, dataset_params_class, \
807 test_cases) \
808 class ParameterizedSkipTest : public dataset_op_test_class, \
809 public ::testing::WithParamInterface< \
810 SkipTestCase<dataset_params_class>> {}; \
811 \
812 TEST_P(ParameterizedSkipTest, GetNext) { \
813 auto test_case = GetParam(); \
814 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
815 TF_ASSERT_OK(CheckIteratorSkip( \
816 test_case.num_to_skip, test_case.expected_num_skipped, \
817 test_case.get_next, test_case.expected_outputs, \
818 /*compare_order=*/test_case.compare_order)); \
819 } \
820 \
821 INSTANTIATE_TEST_SUITE_P( \
822 dataset_op_test_class, ParameterizedSkipTest, \
823 ::testing::ValuesIn( \
824 std::vector<SkipTestCase<dataset_params_class>>(test_cases)));
825
826 #define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
827 test_cases) \
828 class ParameterizedDatasetNodeNameTest \
829 : public dataset_op_test_class, \
830 public ::testing::WithParamInterface< \
831 DatasetNodeNameTestCase<dataset_params_class>> {}; \
832 \
833 TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) { \
834 auto test_case = GetParam(); \
835 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
836 TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name)); \
837 } \
838 \
839 INSTANTIATE_TEST_SUITE_P( \
840 dataset_op_test_class, ParameterizedDatasetNodeNameTest, \
841 ::testing::ValuesIn( \
842 std::vector<DatasetNodeNameTestCase<dataset_params_class>>( \
843 test_cases)));
844
845 #define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class, \
846 dataset_params_class, test_cases) \
847 class ParameterizedDatasetTypeStringTest \
848 : public dataset_op_test_class, \
849 public ::testing::WithParamInterface< \
850 DatasetTypeStringTestCase<dataset_params_class>> {}; \
851 \
852 TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) { \
853 auto test_case = GetParam(); \
854 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
855 TF_ASSERT_OK( \
856 CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
857 } \
858 \
859 INSTANTIATE_TEST_SUITE_P( \
860 dataset_op_test_class, ParameterizedDatasetTypeStringTest, \
861 ::testing::ValuesIn( \
862 std::vector<DatasetTypeStringTestCase<dataset_params_class>>( \
863 test_cases)));
864
865 #define DATASET_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \
866 dataset_params_class, test_cases) \
867 \
868 class ParameterizedDatasetOutputDtypesTest \
869 : public dataset_op_test_class, \
870 public ::testing::WithParamInterface< \
871 DatasetOutputDtypesTestCase<dataset_params_class>> {}; \
872 \
873 TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) { \
874 auto test_case = GetParam(); \
875 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
876 TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
877 } \
878 \
879 INSTANTIATE_TEST_SUITE_P( \
880 dataset_op_test_class, ParameterizedDatasetOutputDtypesTest, \
881 ::testing::ValuesIn( \
882 std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>( \
883 test_cases)));
884
885 #define DATASET_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \
886 dataset_params_class, test_cases) \
887 \
888 class ParameterizedDatasetOutputShapesTest \
889 : public dataset_op_test_class, \
890 public ::testing::WithParamInterface< \
891 DatasetOutputShapesTestCase<dataset_params_class>> {}; \
892 \
893 TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { \
894 auto test_case = GetParam(); \
895 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
896 TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
897 } \
898 \
899 INSTANTIATE_TEST_SUITE_P( \
900 dataset_op_test_class, ParameterizedDatasetOutputShapesTest, \
901 ::testing::ValuesIn( \
902 std::vector<DatasetOutputShapesTestCase<dataset_params_class>>( \
903 test_cases)));
904
905 #define DATASET_CARDINALITY_TEST_P(dataset_op_test_class, \
906 dataset_params_class, test_cases) \
907 \
908 class ParameterizedCardinalityTest \
909 : public dataset_op_test_class, \
910 public ::testing::WithParamInterface< \
911 CardinalityTestCase<dataset_params_class>> {}; \
912 \
913 TEST_P(ParameterizedCardinalityTest, Cardinality) { \
914 auto test_case = GetParam(); \
915 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
916 TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
917 } \
918 \
919 INSTANTIATE_TEST_SUITE_P( \
920 dataset_op_test_class, ParameterizedCardinalityTest, \
921 ::testing::ValuesIn( \
922 std::vector<CardinalityTestCase<dataset_params_class>>( \
923 test_cases)));
924
925 #define ITERATOR_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \
926 dataset_params_class, test_cases) \
927 class ParameterizedIteratorOutputDtypesTest \
928 : public dataset_op_test_class, \
929 public ::testing::WithParamInterface< \
930 IteratorOutputDtypesTestCase<dataset_params_class>> {}; \
931 \
932 TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) { \
933 auto test_case = GetParam(); \
934 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
935 TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
936 } \
937 \
938 INSTANTIATE_TEST_SUITE_P( \
939 dataset_op_test_class, ParameterizedIteratorOutputDtypesTest, \
940 ::testing::ValuesIn( \
941 std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>( \
942 test_cases)));
943
944 #define ITERATOR_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \
945 dataset_params_class, test_cases) \
946 class ParameterizedIteratorOutputShapesTest \
947 : public dataset_op_test_class, \
948 public ::testing::WithParamInterface< \
949 IteratorOutputShapesTestCase<dataset_params_class>> {}; \
950 \
951 TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { \
952 auto test_case = GetParam(); \
953 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
954 TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
955 } \
956 \
957 INSTANTIATE_TEST_SUITE_P( \
958 dataset_op_test_class, ParameterizedIteratorOutputShapesTest, \
959 ::testing::ValuesIn( \
960 std::vector<IteratorOutputShapesTestCase<dataset_params_class>>( \
961 test_cases)));
962
963 #define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
964 test_cases) \
965 class ParameterizedIteratorPrefixTest \
966 : public dataset_op_test_class, \
967 public ::testing::WithParamInterface< \
968 IteratorPrefixTestCase<dataset_params_class>> {}; \
969 \
970 TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) { \
971 auto test_case = GetParam(); \
972 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
973 TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix)); \
974 } \
975 \
976 INSTANTIATE_TEST_SUITE_P( \
977 dataset_op_test_class, ParameterizedIteratorPrefixTest, \
978 ::testing::ValuesIn( \
979 std::vector<IteratorPrefixTestCase<dataset_params_class>>( \
980 test_cases)));
981
982 #define ITERATOR_SAVE_AND_RESTORE_TEST_P(dataset_op_test_class, \
983 dataset_params_class, test_cases) \
984 class ParameterizedIteratorSaveAndRestoreTest \
985 : public dataset_op_test_class, \
986 public ::testing::WithParamInterface< \
987 IteratorSaveAndRestoreTestCase<dataset_params_class>> {}; \
988 TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { \
989 auto test_case = GetParam(); \
990 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
991 TF_ASSERT_OK(CheckIteratorSaveAndRestore( \
992 test_case.dataset_params.iterator_prefix(), \
993 test_case.expected_outputs, test_case.breakpoints, \
994 test_case.compare_order)); \
995 } \
996 INSTANTIATE_TEST_SUITE_P( \
997 dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest, \
998 ::testing::ValuesIn( \
999 std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
1000 test_cases)));
1001
1002 } // namespace data
1003 } // namespace tensorflow
1004
1005 #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
1006