• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
18 
19 #include <stddef.h>
20 
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/graph_constructor.h"
31 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
32 #include "tensorflow/core/framework/allocator.h"
33 #include "tensorflow/core/framework/cancellation.h"
34 #include "tensorflow/core/framework/dataset.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/function_handle_cache.h"
37 #include "tensorflow/core/framework/function_testlib.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/resource_mgr.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.h"
42 #include "tensorflow/core/framework/tensor_testutil.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/kernels/data/name_utils.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/lib/gtl/array_slice.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/io/zlib_compression_options.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/test.h"
53 #include "tensorflow/core/platform/threadpool.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
56 
57 namespace tensorflow {
58 namespace data {
59 
60 typedef std::vector<
61     std::pair<string, tensorflow::FunctionDefHelper::AttrValueWrapper>>
62     AttributeVector;
63 
64 constexpr int kDefaultCPUNum = 2;
65 constexpr int kDefaultThreadNum = 2;
66 
67 // Creates a tensor with the specified dtype, shape, and value.
68 template <typename T>
CreateTensor(const TensorShape & input_shape,const gtl::ArraySlice<T> & input_data)69 static Tensor CreateTensor(const TensorShape& input_shape,
70                            const gtl::ArraySlice<T>& input_data) {
71   Tensor tensor(DataTypeToEnum<T>::value, input_shape);
72   test::FillValues<T>(&tensor, input_data);
73   return tensor;
74 }
75 
76 // Creates a vector of tensors with the specified dtype, shape, and values.
77 template <typename T>
CreateTensors(const TensorShape & shape,const std::vector<gtl::ArraySlice<T>> & values)78 std::vector<Tensor> CreateTensors(
79     const TensorShape& shape, const std::vector<gtl::ArraySlice<T>>& values) {
80   std::vector<Tensor> result;
81   result.reserve(values.size());
82   for (auto& value : values) {
83     result.emplace_back(CreateTensor<T>(shape, value));
84   }
85   return result;
86 }
87 
88 enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
89 
90 // Returns a string representation for the given compression type.
91 string ToString(CompressionType compression_type);
92 
93 // Gets the specified zlib compression options according to the compression
94 // type. Note that `CompressionType::UNCOMPRESSED` is not supported because
95 // `ZlibCompressionOptions` does not have an option.
96 io::ZlibCompressionOptions GetZlibCompressionOptions(
97     CompressionType compression_type);
98 
99 // Used to specify parameters when writing data into files with compression.
100 // `input_buffer_size` and `output_buffer_size` specify the input and output
101 // buffer size when ZLIB and GZIP compression is used.
102 struct CompressionParams {
103   CompressionType compression_type = CompressionType::UNCOMPRESSED;
104   int32 input_buffer_size = 0;
105   int32 output_buffer_size = 0;
106 };
107 
108 // Writes the input data into the file without compression.
109 Status WriteDataToFile(const string& filename, const char* data);
110 
111 // Writes the input data into the file with the specified compression.
112 Status WriteDataToFile(const string& filename, const char* data,
113                        const CompressionParams& params);
114 
115 // Writes the input data into the TFRecord file with the specified compression.
116 Status WriteDataToTFRecordFile(const string& filename,
117                                const std::vector<absl::string_view>& records,
118                                const CompressionParams& params);
119 
120 // Provides the parameters for running the dataset op.
121 class DatasetParams {
122  public:
123   DatasetParams(DataTypeVector output_dtypes,
124                 std::vector<PartialTensorShape> output_shapes,
125                 string node_name);
126 
~DatasetParams()127   virtual ~DatasetParams() {}
128 
129   // Returns the inputs (except the input datasets) as a tensor vector.
130   virtual std::vector<Tensor> GetInputTensors() const = 0;
131 
132   // Returns the dataset input names as a string vector.
133   virtual Status GetInputNames(std::vector<string>* input_names) const = 0;
134 
135   // Returns the dataset attributes as a vector.
136   virtual Status GetAttributes(AttributeVector* attributes) const = 0;
137 
138   // Checks if the tensor is a dataset variant tensor.
139   static bool IsDatasetTensor(const Tensor& tensor);
140 
node_name()141   string node_name() const { return node_name_; }
142 
output_dtypes()143   DataTypeVector output_dtypes() const { return output_dtypes_; }
144 
output_shapes()145   std::vector<PartialTensorShape> output_shapes() const {
146     return output_shapes_;
147   }
148 
iterator_prefix()149   string iterator_prefix() const { return iterator_prefix_; }
150 
input_dataset_params()151   const std::vector<std::shared_ptr<DatasetParams>>& input_dataset_params()
152       const {
153     return input_dataset_params_;
154   }
155 
156   // Returns the functions that will be used when running the dataset op.
func_lib()157   virtual std::vector<FunctionDef> func_lib() const { return {}; }
158 
159   // Returns the dataset type for the op represented by these parameters. This
160   // type usually needs to match the constant called `kDatasetType` defined in
161   // the dataset kernel.
162   virtual string dataset_type() const = 0;
163 
op_version()164   virtual int op_version() const { return op_version_; }
165 
166  protected:
167   std::vector<std::shared_ptr<DatasetParams>> input_dataset_params_;
168   DataTypeVector output_dtypes_;
169   std::vector<PartialTensorShape> output_shapes_;
170   string node_name_;
171   string iterator_prefix_ = "Iterator";
172   int op_version_ = 1;
173 };
174 
175 // `RangeDatasetParams` is a common dataset parameter type that are used in
176 // testing.
177 class RangeDatasetParams : public DatasetParams {
178  public:
179   RangeDatasetParams(int64 start, int64 stop, int64 step,
180                      DataTypeVector output_dtypes,
181                      std::vector<PartialTensorShape> output_shapes,
182                      string node_name);
183 
184   RangeDatasetParams(int64 start, int64 stop, int64 step);
185 
186   RangeDatasetParams(int64 start, int64 stop, int64 step,
187                      DataTypeVector output_dtypes);
188 
189   std::vector<Tensor> GetInputTensors() const override;
190 
191   Status GetInputNames(std::vector<string>* input_names) const override;
192 
193   Status GetAttributes(AttributeVector* attr_vector) const override;
194 
195   string dataset_type() const override;
196 
197  private:
198   int64 start_;
199   int64 stop_;
200   int64 step_;
201 };
202 
203 // `BatchDatasetParams` is a common dataset parameter type that are used in
204 // testing.
205 class BatchDatasetParams : public DatasetParams {
206  public:
207   template <typename T>
BatchDatasetParams(T input_dataset_params,int64 batch_size,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)208   BatchDatasetParams(T input_dataset_params, int64 batch_size,
209                      bool drop_remainder, bool parallel_copy,
210                      DataTypeVector output_dtypes,
211                      std::vector<PartialTensorShape> output_shapes,
212                      string node_name)
213       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
214                       std::move(node_name)),
215         batch_size_(batch_size),
216         drop_remainder_(drop_remainder),
217         parallel_copy_(parallel_copy) {
218     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
219     op_version_ = 2;
220     iterator_prefix_ =
221         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
222                                    input_dataset_params.iterator_prefix());
223   }
224 
225   std::vector<Tensor> GetInputTensors() const override;
226 
227   Status GetInputNames(std::vector<string>* input_names) const override;
228 
229   Status GetAttributes(AttributeVector* attr_vector) const override;
230 
231   string dataset_type() const override;
232 
233  private:
234   int64 batch_size_;
235   bool drop_remainder_;
236   bool parallel_copy_;
237 };
238 
239 // `MapDatasetParams` is a common dataset parameter type that are used in
240 // testing.
241 class MapDatasetParams : public DatasetParams {
242  public:
243   template <typename T>
MapDatasetParams(T input_dataset_params,std::vector<Tensor> other_arguments,FunctionDefHelper::AttrValueWrapper func,std::vector<FunctionDef> func_lib,DataTypeVector type_arguments,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,bool use_inter_op_parallelism,bool preserve_cardinality,string node_name)244   MapDatasetParams(T input_dataset_params, std::vector<Tensor> other_arguments,
245                    FunctionDefHelper::AttrValueWrapper func,
246                    std::vector<FunctionDef> func_lib,
247                    DataTypeVector type_arguments, DataTypeVector output_dtypes,
248                    std::vector<PartialTensorShape> output_shapes,
249                    bool use_inter_op_parallelism, bool preserve_cardinality,
250                    string node_name)
251       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
252                       std::move(node_name)),
253         other_arguments_(std::move(other_arguments)),
254         func_(std::move(func)),
255         func_lib_(std::move(func_lib)),
256         type_arguments_(std::move(type_arguments)),
257         use_inter_op_parallelism_(use_inter_op_parallelism),
258         preserve_cardinality_(preserve_cardinality) {
259     input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
260     iterator_prefix_ =
261         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
262                                    input_dataset_params.iterator_prefix());
263   }
264 
265   std::vector<Tensor> GetInputTensors() const override;
266 
267   Status GetInputNames(std::vector<string>* input_names) const override;
268 
269   Status GetAttributes(AttributeVector* attr_vector) const override;
270 
271   string dataset_type() const override;
272 
273   std::vector<FunctionDef> func_lib() const override;
274 
275  private:
276   std::vector<Tensor> other_arguments_;
277   FunctionDefHelper::AttrValueWrapper func_;
278   std::vector<FunctionDef> func_lib_;
279   DataTypeVector type_arguments_;
280   bool use_inter_op_parallelism_;
281   bool preserve_cardinality_;
282 };
283 
284 // `TensorSliceDatasetParams` is a common dataset parameter type that are used
285 // in testing.
286 class TensorSliceDatasetParams : public DatasetParams {
287  public:
288   TensorSliceDatasetParams(std::vector<Tensor> components, string node_name);
289 
290   std::vector<Tensor> GetInputTensors() const override;
291 
292   Status GetInputNames(std::vector<string>* input_names) const override;
293 
294   Status GetAttributes(AttributeVector* attr_vector) const override;
295 
296   string dataset_type() const override;
297 
num_slices()298   int64 num_slices() const { return components_[0].dim_size(0); }
299 
num_tensors_per_slice()300   size_t num_tensors_per_slice() const { return components_.size(); }
301 
302  private:
303   DataTypeVector TensorSliceDtypes(const std::vector<Tensor>& input_components);
304 
305   std::vector<PartialTensorShape> TensorSliceShapes(
306       const std::vector<Tensor>& input_components);
307 
308  public:
309   std::vector<Tensor> components_;
310 };
311 
312 // `TakeDatasetParams` is a common dataset parameter type that are used in
313 // testing.
314 class TakeDatasetParams : public DatasetParams {
315  public:
316   template <typename T>
TakeDatasetParams(T input_dataset_params,int count,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)317   TakeDatasetParams(T input_dataset_params, int count,
318                     DataTypeVector output_dtypes,
319                     std::vector<PartialTensorShape> output_shapes,
320                     string node_name)
321       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
322                       std::move(node_name)),
323         count_(count) {
324     input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
325     iterator_prefix_ =
326         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
327                                    input_dataset_params.iterator_prefix());
328   }
329 
330   std::vector<Tensor> GetInputTensors() const override;
331 
332   Status GetInputNames(std::vector<string>* input_names) const override;
333 
334   Status GetAttributes(AttributeVector* attr_vector) const override;
335 
336   string dataset_type() const override;
337 
338  private:
339   int64 count_;
340 };
341 
342 // `ConcatenateDatasetParams` is a common dataset parameter type that are used
343 // in testing.
344 class ConcatenateDatasetParams : public DatasetParams {
345  public:
346   template <typename T, typename P>
ConcatenateDatasetParams(T input_dataset_params_0,P input_dataset_params_1,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)347   ConcatenateDatasetParams(T input_dataset_params_0, P input_dataset_params_1,
348                            DataTypeVector output_dtypes,
349                            std::vector<PartialTensorShape> output_shapes,
350                            string node_name)
351       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
352                       std::move(node_name)) {
353     input_dataset_params_.push_back(
354         absl::make_unique<T>(input_dataset_params_0));
355     input_dataset_params_.push_back(
356         absl::make_unique<T>(input_dataset_params_1));
357     iterator_prefix_ =
358         name_utils::IteratorPrefix(input_dataset_params_0.dataset_type(),
359                                    input_dataset_params_0.iterator_prefix());
360   }
361 
362   std::vector<Tensor> GetInputTensors() const override;
363 
364   Status GetInputNames(std::vector<string>* input_names) const override;
365 
366   Status GetAttributes(AttributeVector* attr_vector) const override;
367 
368   string dataset_type() const override;
369 };
370 
371 template <typename T>
372 struct GetNextTestCase {
373   GetNextTestCase(T dataset_params, std::vector<Tensor> expected_outputs,
374                   bool compare_order = true)
dataset_paramsGetNextTestCase375       : dataset_params(std::move(dataset_params)),
376         expected_outputs(std::move(expected_outputs)),
377         compare_order(compare_order) {}
378 
379   T dataset_params;
380   std::vector<Tensor> expected_outputs;
381   bool compare_order;
382 };
383 
384 template <typename T>
385 struct SkipTestCase {
386   SkipTestCase(T dataset_params, int num_to_skip, int expected_num_skipped,
387                bool get_next = false, std::vector<Tensor> expected_outputs = {},
388                bool compare_order = true)
dataset_paramsSkipTestCase389       : dataset_params(std::move(dataset_params)),
390         num_to_skip(num_to_skip),
391         expected_num_skipped(expected_num_skipped),
392         get_next(get_next),
393         expected_outputs(std::move(expected_outputs)),
394         compare_order(compare_order) {}
395 
396   T dataset_params;
397   int num_to_skip;
398   int expected_num_skipped;
399   bool get_next;
400   std::vector<Tensor> expected_outputs;
401   bool compare_order;
402 };
403 
404 template <typename T>
405 struct DatasetNodeNameTestCase {
406   T dataset_params;
407   string expected_node_name;
408 };
409 
410 template <typename T>
411 struct DatasetTypeStringTestCase {
412   T dataset_params;
413   string expected_dataset_type_string;
414 };
415 
416 template <typename T>
417 struct DatasetOutputDtypesTestCase {
418   T dataset_params;
419   DataTypeVector expected_output_dtypes;
420 };
421 
422 template <typename T>
423 struct DatasetOutputShapesTestCase {
424   T dataset_params;
425   std::vector<PartialTensorShape> expected_output_shapes;
426 };
427 
428 template <typename T>
429 struct CardinalityTestCase {
430   T dataset_params;
431   int64 expected_cardinality;
432 };
433 
434 template <typename T>
435 struct DatasetSaveTestCase {
436   T dataset_params;
437 };
438 
439 template <typename T>
440 struct IteratorOutputDtypesTestCase {
441   T dataset_params;
442   DataTypeVector expected_output_dtypes;
443 };
444 
445 template <typename T>
446 struct IteratorOutputShapesTestCase {
447   T dataset_params;
448   std::vector<PartialTensorShape> expected_output_shapes;
449 };
450 
451 template <typename T>
452 struct IteratorPrefixTestCase {
453   T dataset_params;
454   string expected_iterator_prefix;
455 };
456 
457 template <typename T>
458 struct IteratorSaveAndRestoreTestCase {
459   IteratorSaveAndRestoreTestCase(T dataset_params, std::vector<int> breakpoints,
460                                  std::vector<Tensor> expected_outputs,
461                                  bool compare_order = true)
dataset_paramsIteratorSaveAndRestoreTestCase462       : dataset_params(std::move(dataset_params)),
463         breakpoints(std::move(breakpoints)),
464         expected_outputs(std::move(expected_outputs)),
465         compare_order(compare_order) {}
466 
467   T dataset_params;
468   std::vector<int> breakpoints;
469   std::vector<Tensor> expected_outputs;
470   bool compare_order;
471 };
472 
473 // Class composing a dataset with its dependencies.
474 class TestDataset {
475  public:
476   // TestDataset expects that the caller has Ref'd the wrapped dataset. When
477   // TestDataset is destroyed, it will Unref the dataset.
TestDataset(std::unique_ptr<OpKernel> kernel_,std::unique_ptr<OpKernelContext::Params> ctx_params,std::unique_ptr<OpKernelContext> ctx,std::vector<std::unique_ptr<Tensor>> input_tensors,DatasetBase * dataset)478   TestDataset(std::unique_ptr<OpKernel> kernel_,
479               std::unique_ptr<OpKernelContext::Params> ctx_params,
480               std::unique_ptr<OpKernelContext> ctx,
481               std::vector<std::unique_ptr<Tensor>> input_tensors,
482               DatasetBase* dataset)
483       : kernel_(std::move(kernel_)),
484         ctx_params_(std::move(ctx_params)),
485         ctx_(std::move(ctx)),
486         input_tensors_(std::move(input_tensors)),
487         dataset_(dataset),
488         scoped_unref_(dataset) {}
489 
dataset()490   DatasetBase* dataset() const { return dataset_; }
491 
op_kernel_context()492   OpKernelContext* op_kernel_context() const { return ctx_.get(); }
493 
494  protected:
495   std::unique_ptr<OpKernel> kernel_;
496   std::unique_ptr<OpKernelContext::Params> ctx_params_;
497   std::unique_ptr<OpKernelContext> ctx_;
498   // The input tensors that this dataset depends on. They must outlive the
499   // dataset.
500   std::vector<std::unique_ptr<Tensor>> input_tensors_;
501   DatasetBase* dataset_;
502   core::ScopedUnref scoped_unref_;
503 };
504 
505 // Class composing a dataset iterator with its dependencies.
506 class TestIterator {
507  public:
TestIterator(std::unique_ptr<IteratorContext> ctx,std::unique_ptr<IteratorBase> iterator)508   TestIterator(std::unique_ptr<IteratorContext> ctx,
509                std::unique_ptr<IteratorBase> iterator)
510       : iterator_(std::move(iterator)), ctx_(std::move(ctx)) {}
511 
iterator()512   IteratorBase* iterator() const { return iterator_.get(); }
513 
ctx()514   IteratorContext* ctx() const { return ctx_.get(); }
515 
GetNext(std::vector<Tensor> * out_tensors,bool * end_of_sequence)516   Status GetNext(std::vector<Tensor>* out_tensors, bool* end_of_sequence) {
517     return iterator_->GetNext(ctx(), out_tensors, end_of_sequence);
518   }
519 
520  protected:
521   std::unique_ptr<IteratorBase> iterator_;
522   std::unique_ptr<IteratorContext> ctx_;
523 };
524 
525 // Helpful functions to test Dataset op kernels.
526 class DatasetOpsTestBase : public ::testing::Test {
527  public:
528   DatasetOpsTestBase();
529 
530   // Initializes the runtime and creates a dataset and iterator.
531   Status Initialize(const DatasetParams& dataset_params);
532 
533   // Initializes the parts of the runtime needed to run dataset ops.
534   Status InitializeRuntime(const DatasetParams& dataset_params);
535 
536   // Creates a dataset.
537   Status MakeDataset(const DatasetParams& dataset_params,
538                      std::unique_ptr<TestDataset>* dataset);
539 
540   // Creates an iterator for the given dataset, using the specified split
541   // provider.
542   Status MakeIterator(const DatasetParams& dataset_params,
543                       const TestDataset& dataset,
544                       std::unique_ptr<SplitProvider> split_provider,
545                       std::unique_ptr<TestIterator>* iterator);
546   // Creates an iterator for the given dataset.
547   Status MakeIterator(const DatasetParams& dataset_params,
548                       const TestDataset& dataset,
549                       std::unique_ptr<TestIterator>* iterator);
550 
551   // Runs the dataset operation according to the predefined dataset params and
552   // produces outputs. Different from `MakeDataset()` which returns a Dataset
553   // object, `RunDatasetOp()` executes the dataset kernel based on the input
554   // DatasetParams and returns the produced outputs as a tensor vector. It can
555   // be used to run some dataset operations that do not have an internal
556   // customized `Dataset` class (e.g. `ReduceDatasetOp`).
557   Status RunDatasetOp(const DatasetParams& dataset_params,
558                       std::vector<Tensor>* outputs);
559 
560   // The method validates whether the two tensors have the same shape, dtype,
561   // and value.
562   static Status ExpectEqual(const Tensor& a, const Tensor& b);
563 
564   // The method validates whether the two tensor vectors have the same tensors.
565   // If `compare_order` is false, the method will only evaluate whether the two
566   // vectors have the same elements regardless of order.
567   static Status ExpectEqual(std::vector<Tensor> produced_tensors,
568                             std::vector<Tensor> expected_tensors,
569                             bool compare_order);
570 
571   // Checks `IteratorBase::GetNext()`.
572   Status CheckIteratorGetNext(const std::vector<Tensor>& expected_outputs,
573                               bool compare_order);
574 
575   // Checks `IteratorBase::GetNext()`.
576   Status CheckIteratorGetNext(TestIterator* iterator,
577                               const std::vector<Tensor>& expected_outputs,
578                               bool compare_order);
579 
580   // Checks `IteratorBase::GetNext()`.
581   Status CheckIteratorGetNext(IteratorBase* iterator, IteratorContext* ctx,
582                               const std::vector<Tensor>& expected_outputs,
583                               bool compare_order);
584 
585   // Checks `IteratorBase::Skip()`
586   Status CheckIteratorSkip(int num_to_skip, int expected_num_skipped,
587                            bool get_next,
588                            const std::vector<Tensor>& expected_outputs,
589                            bool compare_order);
590 
591   // Checks that iterating through the dataset using a split provider produces
592   // the expected outputs.
593   Status CheckSplitProviderFullIteration(
594       const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
595 
596   // Checks that iterating through the dataset using a sharded split provider
597   // with the given `num_shards` and `shard_index` produces the expected
598   // outputs.
599   Status CheckSplitProviderShardedIteration(
600       const DatasetParams& params, int64 num_shards, int64 shard_index,
601       const std::vector<Tensor>& expected_outputs);
602 
603   // Checks `DatasetBase::node_name()`.
604   Status CheckDatasetNodeName(const string& expected_dataset_node_name);
605 
606   // Checks `DatasetBase::type_string()`.
607   Status CheckDatasetTypeString(const string& expected_type_str);
608 
609   // Checks `DatasetBase::output_dtypes()`.
610   Status CheckDatasetOutputDtypes(const DataTypeVector& expected_output_dtypes);
611 
612   // Checks `DatasetBase::output_shapes()`.
613   Status CheckDatasetOutputShapes(
614       const std::vector<PartialTensorShape>& expected_output_shapes);
615 
616   // Checks `DatasetBase::Cardinality()`.
617   Status CheckDatasetCardinality(int expected_cardinality);
618 
619   // Checks `IteratorBase::output_dtypes()`.
620   Status CheckIteratorOutputDtypes(
621       const DataTypeVector& expected_output_dtypes);
622 
623   // Checks `IteratorBase::output_shapes()`.
624   Status CheckIteratorOutputShapes(
625       const std::vector<PartialTensorShape>& expected_output_shapes);
626 
627   // Checks `IteratorBase::prefix()`.
628   Status CheckIteratorPrefix(const string& expected_iterator_prefix);
629 
630   Status CheckIteratorSaveAndRestore(
631       DatasetBase* dataset, IteratorContext* iterator_ctx,
632       const std::string& iterator_prefix,
633       const std::vector<Tensor>& expected_outputs,
634       const std::vector<int>& breakpoints, bool compare_order);
635 
636   Status CheckIteratorSaveAndRestore(
637       const std::string& iterator_prefix,
638       const std::vector<Tensor>& expected_outputs,
639       const std::vector<int>& breakpoints, bool compare_order);
640 
641  protected:
642   // Make destructor protected so that DatasetOpsTestBase objects cannot
643   // be instantiated directly. Only subclasses can be instantiated.
644   ~DatasetOpsTestBase() override;
645 
646   // Creates a thread pool for parallel tasks.
647   Status InitThreadPool(int thread_num);
648 
649   // Initializes the runtime for computing the dataset operation and registers
650   // the input function definitions. `InitThreadPool()' needs to be called
651   // before this method if we want to run the tasks in parallel.
652   Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
653                                     int cpu_num);
654 
655   // Creates a new op kernel based on the node definition.
656   Status CreateOpKernel(const NodeDef& node_def,
657                         std::unique_ptr<OpKernel>* op_kernel);
658 
659   // Creates a new op kernel context.
660   Status CreateDatasetContext(
661       OpKernel* const dateset_kernel,
662       gtl::InlinedVector<TensorValue, 4>* const inputs,
663       std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
664       std::unique_ptr<OpKernelContext>* dataset_context);
665 
666   // Creates a new dataset.
667   Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
668                        DatasetBase** const dataset);
669 
670   // Restores the state of the input iterator. It resets the iterator before
671   // restoring it to make sure the input iterator does not hold any
672   // resources or tasks. Otherwise, restoring an existing iterator may cause
673   // the timeout issue or duplicated elements.
674   Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader,
675                          const string& output_prefix,
676                          const DatasetBase& dataset,
677                          std::unique_ptr<IteratorBase>* iterator);
678 
679   // Fetches the dataset from the operation context.
680   Status GetDatasetFromContext(OpKernelContext* context, int output_index,
681                                DatasetBase** const dataset);
682 
683   // Runs an operation producing outputs.
684   Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
685 
686   // Executes a function producing outputs.
687   Status RunFunction(const FunctionDef& fdef, test::function::Attrs attrs,
688                      const std::vector<Tensor>& args,
689                      const GraphConstructorOptions& graph_options,
690                      std::vector<Tensor*> rets);
691 
692   // Checks that the size of `inputs` matches the requirement of the op kernel.
693   Status CheckOpKernelInput(const OpKernel& kernel,
694                             const gtl::InlinedVector<TensorValue, 4>& inputs);
695 
696   // Creates a new context for running the dataset operation.
697   Status CreateOpKernelContext(OpKernel* kernel,
698                                gtl::InlinedVector<TensorValue, 4>* inputs,
699                                std::unique_ptr<OpKernelContext>* context);
700 
701   // Creates a new context for running the dataset operation.
702   Status CreateOpKernelContext(OpKernel* kernel,
703                                gtl::InlinedVector<TensorValue, 4>* inputs,
704                                std::unique_ptr<OpKernelContext::Params>* params,
705                                std::unique_ptr<OpKernelContext>* context);
706 
707   // Creates a new iterator context for iterating the dataset.
708   Status CreateIteratorContext(
709       OpKernelContext* const op_context,
710       std::unique_ptr<IteratorContext>* iterator_context);
711 
712   // Creates a new iterator context for iterating the dataset.
713   // Creates a new serialization context for serializing the dataset and
714   // iterator.
715   Status CreateSerializationContext(
716       std::unique_ptr<SerializationContext>* context);
717 
718  private:
719   // Runs the dataset operation according to the predefined dataset params and
720   // the produced outputs will be stored in `dataset_ctx`.
721   Status RunDatasetOp(
722       const DatasetParams& dataset_params,
723       std::unique_ptr<OpKernel>* dataset_kernel,
724       std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
725       std::vector<std::unique_ptr<Tensor>>* created_tensors,
726       std::unique_ptr<OpKernelContext>* dataset_ctx);
727 
728   Status MakeDataset(
729       const DatasetParams& dataset_params,
730       std::unique_ptr<OpKernel>* dataset_kernel,
731       std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
732       std::unique_ptr<OpKernelContext>* dataset_ctx,
733       std::vector<std::unique_ptr<Tensor>>* created_tensors,
734       DatasetBase** dataset);
735 
736   // Creates the dataset op kernel.
737   Status MakeDatasetOpKernel(const DatasetParams& dataset_params,
738                              std::unique_ptr<OpKernel>* dataset_kernel);
739 
740   // Creates a dataset tensor according to the input dataset params.
741   Status MakeDatasetTensor(
742       const DatasetParams& dataset_params,
743       std::vector<std::unique_ptr<Tensor>>* created_tensors,
744       std::unique_ptr<Tensor>* dataset);
745 
746   // Adds an empty tensor with the specified dtype and shape to the input
747   // vector.
748   Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
749                          DataTypeVector input_types, DataType dtype,
750                          const TensorShape& shape);
751 
752  protected:
753   std::unique_ptr<Device> device_;
754   DeviceType device_type_;
755   int cpu_num_;
756   int thread_num_;
757   Allocator* allocator_;  // Owned by `AllocatorFactoryRegistry`.
758   std::vector<AllocatorAttributes> allocator_attrs_;
759   std::unique_ptr<ScopedStepContainer> step_container_;
760 
761   // Device manager is used by function handle cache and needs to outlive it.
762   std::unique_ptr<DeviceMgr> device_mgr_;
763   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
764   FunctionLibraryRuntime* flr_;  // Owned by `pflr_`.
765   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
766   std::function<void(std::function<void()>)> runner_;
767   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
768   std::unique_ptr<ResourceMgr> resource_mgr_;
769   std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
770       slice_reader_cache_;
771   std::unique_ptr<thread::ThreadPool> thread_pool_;
772   std::vector<std::unique_ptr<Tensor>> tensors_;  // Owns tensors.
773   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs.
774   std::unique_ptr<CancellationManager> cancellation_manager_;
775 
776   // Indicates if the below fields have been initialized.
777   bool initialized_ = false;
778   std::unique_ptr<OpKernel> dataset_kernel_;
779   std::unique_ptr<OpKernelContext::Params> params_;
780   std::unique_ptr<OpKernelContext> dataset_ctx_;
781   DatasetBase* dataset_ = nullptr;
782   std::unique_ptr<IteratorContext> iterator_ctx_;
783   std::unique_ptr<IteratorBase> iterator_;
784 };
785 
786 #define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
787                                  test_cases)                                  \
788   class ParameterizedGetNextTest                                              \
789       : public dataset_op_test_class,                                         \
790         public ::testing::WithParamInterface<                                 \
791             GetNextTestCase<dataset_params_class>> {};                        \
792                                                                               \
793   TEST_P(ParameterizedGetNextTest, GetNext) {                                 \
794     auto test_case = GetParam();                                              \
795     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
796     TF_ASSERT_OK(                                                             \
797         CheckIteratorGetNext(test_case.expected_outputs,                      \
798                              /*compare_order=*/test_case.compare_order));     \
799   }                                                                           \
800                                                                               \
801   INSTANTIATE_TEST_SUITE_P(                                                   \
802       dataset_op_test_class, ParameterizedGetNextTest,                        \
803       ::testing::ValuesIn(                                                    \
804           std::vector<GetNextTestCase<dataset_params_class>>(test_cases)));
805 
806 #define ITERATOR_SKIP_TEST_P(dataset_op_test_class, dataset_params_class,   \
807                              test_cases)                                    \
808   class ParameterizedSkipTest : public dataset_op_test_class,               \
809                                 public ::testing::WithParamInterface<       \
810                                     SkipTestCase<dataset_params_class>> {}; \
811                                                                             \
812   TEST_P(ParameterizedSkipTest, GetNext) {                                  \
813     auto test_case = GetParam();                                            \
814     TF_ASSERT_OK(Initialize(test_case.dataset_params));                     \
815     TF_ASSERT_OK(CheckIteratorSkip(                                         \
816         test_case.num_to_skip, test_case.expected_num_skipped,              \
817         test_case.get_next, test_case.expected_outputs,                     \
818         /*compare_order=*/test_case.compare_order));                        \
819   }                                                                         \
820                                                                             \
821   INSTANTIATE_TEST_SUITE_P(                                                 \
822       dataset_op_test_class, ParameterizedSkipTest,                         \
823       ::testing::ValuesIn(                                                  \
824           std::vector<SkipTestCase<dataset_params_class>>(test_cases)));
825 
826 #define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
827                                  test_cases)                                  \
828   class ParameterizedDatasetNodeNameTest                                      \
829       : public dataset_op_test_class,                                         \
830         public ::testing::WithParamInterface<                                 \
831             DatasetNodeNameTestCase<dataset_params_class>> {};                \
832                                                                               \
833   TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) {                 \
834     auto test_case = GetParam();                                              \
835     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
836     TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name));         \
837   }                                                                           \
838                                                                               \
839   INSTANTIATE_TEST_SUITE_P(                                                   \
840       dataset_op_test_class, ParameterizedDatasetNodeNameTest,                \
841       ::testing::ValuesIn(                                                    \
842           std::vector<DatasetNodeNameTestCase<dataset_params_class>>(         \
843               test_cases)));
844 
845 #define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class,                \
846                                    dataset_params_class, test_cases)     \
847   class ParameterizedDatasetTypeStringTest                               \
848       : public dataset_op_test_class,                                    \
849         public ::testing::WithParamInterface<                            \
850             DatasetTypeStringTestCase<dataset_params_class>> {};         \
851                                                                          \
852   TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) {        \
853     auto test_case = GetParam();                                         \
854     TF_ASSERT_OK(Initialize(test_case.dataset_params));                  \
855     TF_ASSERT_OK(                                                        \
856         CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
857   }                                                                      \
858                                                                          \
859   INSTANTIATE_TEST_SUITE_P(                                              \
860       dataset_op_test_class, ParameterizedDatasetTypeStringTest,         \
861       ::testing::ValuesIn(                                               \
862           std::vector<DatasetTypeStringTestCase<dataset_params_class>>(  \
863               test_cases)));
864 
865 #define DATASET_OUTPUT_DTYPES_TEST_P(dataset_op_test_class,                   \
866                                      dataset_params_class, test_cases)        \
867                                                                               \
868   class ParameterizedDatasetOutputDtypesTest                                  \
869       : public dataset_op_test_class,                                         \
870         public ::testing::WithParamInterface<                                 \
871             DatasetOutputDtypesTestCase<dataset_params_class>> {};            \
872                                                                               \
873   TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) {         \
874     auto test_case = GetParam();                                              \
875     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
876     TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
877   }                                                                           \
878                                                                               \
879   INSTANTIATE_TEST_SUITE_P(                                                   \
880       dataset_op_test_class, ParameterizedDatasetOutputDtypesTest,            \
881       ::testing::ValuesIn(                                                    \
882           std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>(     \
883               test_cases)));
884 
885 #define DATASET_OUTPUT_SHAPES_TEST_P(dataset_op_test_class,                   \
886                                      dataset_params_class, test_cases)        \
887                                                                               \
888   class ParameterizedDatasetOutputShapesTest                                  \
889       : public dataset_op_test_class,                                         \
890         public ::testing::WithParamInterface<                                 \
891             DatasetOutputShapesTestCase<dataset_params_class>> {};            \
892                                                                               \
893   TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) {         \
894     auto test_case = GetParam();                                              \
895     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
896     TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
897   }                                                                           \
898                                                                               \
899   INSTANTIATE_TEST_SUITE_P(                                                   \
900       dataset_op_test_class, ParameterizedDatasetOutputShapesTest,            \
901       ::testing::ValuesIn(                                                    \
902           std::vector<DatasetOutputShapesTestCase<dataset_params_class>>(     \
903               test_cases)));
904 
905 #define DATASET_CARDINALITY_TEST_P(dataset_op_test_class,                  \
906                                    dataset_params_class, test_cases)       \
907                                                                            \
908   class ParameterizedCardinalityTest                                       \
909       : public dataset_op_test_class,                                      \
910         public ::testing::WithParamInterface<                              \
911             CardinalityTestCase<dataset_params_class>> {};                 \
912                                                                            \
913   TEST_P(ParameterizedCardinalityTest, Cardinality) {                      \
914     auto test_case = GetParam();                                           \
915     TF_ASSERT_OK(Initialize(test_case.dataset_params));                    \
916     TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
917   }                                                                        \
918                                                                            \
919   INSTANTIATE_TEST_SUITE_P(                                                \
920       dataset_op_test_class, ParameterizedCardinalityTest,                 \
921       ::testing::ValuesIn(                                                 \
922           std::vector<CardinalityTestCase<dataset_params_class>>(          \
923               test_cases)));
924 
925 #define ITERATOR_OUTPUT_DTYPES_TEST_P(dataset_op_test_class,                  \
926                                       dataset_params_class, test_cases)       \
927   class ParameterizedIteratorOutputDtypesTest                                 \
928       : public dataset_op_test_class,                                         \
929         public ::testing::WithParamInterface<                                 \
930             IteratorOutputDtypesTestCase<dataset_params_class>> {};           \
931                                                                               \
932   TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) {       \
933     auto test_case = GetParam();                                              \
934     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
935     TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
936   }                                                                           \
937                                                                               \
938   INSTANTIATE_TEST_SUITE_P(                                                   \
939       dataset_op_test_class, ParameterizedIteratorOutputDtypesTest,           \
940       ::testing::ValuesIn(                                                    \
941           std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>(    \
942               test_cases)));
943 
944 #define ITERATOR_OUTPUT_SHAPES_TEST_P(dataset_op_test_class,                   \
945                                       dataset_params_class, test_cases)        \
946   class ParameterizedIteratorOutputShapesTest                                  \
947       : public dataset_op_test_class,                                          \
948         public ::testing::WithParamInterface<                                  \
949             IteratorOutputShapesTestCase<dataset_params_class>> {};            \
950                                                                                \
951   TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) {        \
952     auto test_case = GetParam();                                               \
953     TF_ASSERT_OK(Initialize(test_case.dataset_params));                        \
954     TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
955   }                                                                            \
956                                                                                \
957   INSTANTIATE_TEST_SUITE_P(                                                    \
958       dataset_op_test_class, ParameterizedIteratorOutputShapesTest,            \
959       ::testing::ValuesIn(                                                     \
960           std::vector<IteratorOutputShapesTestCase<dataset_params_class>>(     \
961               test_cases)));
962 
963 #define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
964                                test_cases)                                  \
965   class ParameterizedIteratorPrefixTest                                     \
966       : public dataset_op_test_class,                                       \
967         public ::testing::WithParamInterface<                               \
968             IteratorPrefixTestCase<dataset_params_class>> {};               \
969                                                                             \
970   TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) {                 \
971     auto test_case = GetParam();                                            \
972     TF_ASSERT_OK(Initialize(test_case.dataset_params));                     \
973     TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix));  \
974   }                                                                         \
975                                                                             \
976   INSTANTIATE_TEST_SUITE_P(                                                 \
977       dataset_op_test_class, ParameterizedIteratorPrefixTest,               \
978       ::testing::ValuesIn(                                                  \
979           std::vector<IteratorPrefixTestCase<dataset_params_class>>(        \
980               test_cases)));
981 
982 #define ITERATOR_SAVE_AND_RESTORE_TEST_P(dataset_op_test_class,              \
983                                          dataset_params_class, test_cases)   \
984   class ParameterizedIteratorSaveAndRestoreTest                              \
985       : public dataset_op_test_class,                                        \
986         public ::testing::WithParamInterface<                                \
987             IteratorSaveAndRestoreTestCase<dataset_params_class>> {};        \
988   TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {  \
989     auto test_case = GetParam();                                             \
990     TF_ASSERT_OK(Initialize(test_case.dataset_params));                      \
991     TF_ASSERT_OK(CheckIteratorSaveAndRestore(                                \
992         test_case.dataset_params.iterator_prefix(),                          \
993         test_case.expected_outputs, test_case.breakpoints,                   \
994         test_case.compare_order));                                           \
995   }                                                                          \
996   INSTANTIATE_TEST_SUITE_P(                                                  \
997       dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest,        \
998       ::testing::ValuesIn(                                                   \
999           std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
1000               test_cases)));
1001 
1002 }  // namespace data
1003 }  // namespace tensorflow
1004 
1005 #endif  // TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
1006