• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
17 #define TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
18 
19 #include <stddef.h>
20 
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/graph_constructor.h"
31 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
32 #include "tensorflow/core/data/name_utils.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/dataset.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function_handle_cache.h"
38 #include "tensorflow/core/framework/function_testlib.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/resource_mgr.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_shape.h"
43 #include "tensorflow/core/framework/tensor_testutil.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/lib/gtl/array_slice.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/io/zlib_compression_options.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/test.h"
53 #include "tensorflow/core/platform/threadpool.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
56 
57 namespace tensorflow {
58 namespace data {
59 
60 typedef std::vector<
61     std::pair<string, tensorflow::FunctionDefHelper::AttrValueWrapper>>
62     AttributeVector;
63 
64 constexpr int kDefaultCPUNum = 2;
65 constexpr int kDefaultThreadNum = 2;
66 
67 // Creates a tensor with the specified dtype, shape, and value.
68 template <typename T>
CreateTensor(const TensorShape & input_shape,gtl::ArraySlice<T> input_data)69 static Tensor CreateTensor(const TensorShape& input_shape,
70                            gtl::ArraySlice<T> input_data) {
71   Tensor tensor(DataTypeToEnum<T>::value, input_shape);
72   test::FillValues<T>(&tensor, input_data);
73   return tensor;
74 }
75 
76 // Creates a vector of tensors with the specified dtype, shape, and values.
77 template <typename T>
CreateTensors(const TensorShape & shape,const std::vector<gtl::ArraySlice<T>> & values)78 std::vector<Tensor> CreateTensors(
79     const TensorShape& shape, const std::vector<gtl::ArraySlice<T>>& values) {
80   std::vector<Tensor> result;
81   result.reserve(values.size());
82   for (auto& value : values) {
83     result.emplace_back(CreateTensor<T>(shape, value));
84   }
85   return result;
86 }
87 
88 enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
89 
90 // Returns a string representation for the given compression type.
91 string ToString(CompressionType compression_type);
92 
93 // Gets the specified zlib compression options according to the compression
94 // type. Note that `CompressionType::UNCOMPRESSED` is not supported because
95 // `ZlibCompressionOptions` does not have an option.
96 io::ZlibCompressionOptions GetZlibCompressionOptions(
97     CompressionType compression_type);
98 
99 // Used to specify parameters when writing data into files with compression.
100 // `input_buffer_size` and `output_buffer_size` specify the input and output
101 // buffer size when ZLIB and GZIP compression is used.
102 struct CompressionParams {
103   CompressionType compression_type = CompressionType::UNCOMPRESSED;
104   int32 input_buffer_size = 0;
105   int32 output_buffer_size = 0;
106 };
107 
108 // Writes the input data into the file without compression.
109 Status WriteDataToFile(const string& filename, const char* data);
110 
111 // Writes the input data into the file with the specified compression.
112 Status WriteDataToFile(const string& filename, const char* data,
113                        const CompressionParams& params);
114 
115 // Writes the input data into the TFRecord file with the specified compression.
116 Status WriteDataToTFRecordFile(const string& filename,
117                                const std::vector<absl::string_view>& records,
118                                const CompressionParams& params);
119 
120 // Provides the parameters for running the dataset op.
121 class DatasetParams {
122  public:
123   DatasetParams(DataTypeVector output_dtypes,
124                 std::vector<PartialTensorShape> output_shapes,
125                 string node_name);
126 
~DatasetParams()127   virtual ~DatasetParams() {}
128 
129   // Returns the inputs (except the input datasets) as a tensor vector.
130   virtual std::vector<Tensor> GetInputTensors() const = 0;
131 
132   // Returns the dataset input names as a string vector.
133   virtual Status GetInputNames(std::vector<string>* input_names) const = 0;
134 
135   // Returns the dataset attributes as a vector.
136   virtual Status GetAttributes(AttributeVector* attributes) const = 0;
137 
138   // Checks if the tensor is a dataset variant tensor.
139   static bool IsDatasetTensor(const Tensor& tensor);
140 
node_name()141   string node_name() const { return node_name_; }
142 
output_dtypes()143   DataTypeVector output_dtypes() const { return output_dtypes_; }
144 
output_shapes()145   std::vector<PartialTensorShape> output_shapes() const {
146     return output_shapes_;
147   }
148 
iterator_prefix()149   string iterator_prefix() const { return iterator_prefix_; }
150 
input_dataset_params()151   const std::vector<std::shared_ptr<DatasetParams>>& input_dataset_params()
152       const {
153     return input_dataset_params_;
154   }
155 
156   // Returns the functions that will be used when running the dataset op.
func_lib()157   virtual std::vector<FunctionDef> func_lib() const { return {}; }
158 
159   // Returns the dataset type for the op represented by these parameters. This
160   // type usually needs to match the constant called `kDatasetType` defined in
161   // the dataset kernel.
162   virtual string dataset_type() const = 0;
163 
164   // Returns the dataset op name. By default, it returns the Op::kDatasetType
165   // concatenated with "Dataset". For ops that do not have "Dataset" suffix,
166   // this method can be overriden to return a different name.
op_name()167   virtual string op_name() const {
168     name_utils::OpNameParams params;
169     params.op_version = op_version();
170     return name_utils::OpName(dataset_type(), params);
171   }
172 
op_version()173   virtual int op_version() const { return op_version_; }
174 
175  protected:
176   std::vector<std::shared_ptr<DatasetParams>> input_dataset_params_;
177   DataTypeVector output_dtypes_;
178   std::vector<PartialTensorShape> output_shapes_;
179   string node_name_;
180   string iterator_prefix_ = "Iterator";
181   int op_version_ = 1;
182 };
183 
184 // `RangeDatasetParams` is a common dataset parameter type that are used in
185 // testing.
186 class RangeDatasetParams : public DatasetParams {
187  public:
188   RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
189                      DataTypeVector output_dtypes,
190                      std::vector<PartialTensorShape> output_shapes,
191                      string node_name);
192 
193   RangeDatasetParams(int64_t start, int64_t stop, int64_t step);
194 
195   RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
196                      DataTypeVector output_dtypes);
197 
198   std::vector<Tensor> GetInputTensors() const override;
199 
200   Status GetInputNames(std::vector<string>* input_names) const override;
201 
202   Status GetAttributes(AttributeVector* attr_vector) const override;
203 
204   string dataset_type() const override;
205 
206  private:
207   int64 start_;
208   int64 stop_;
209   int64 step_;
210 };
211 
212 // `BatchDatasetParams` is a common dataset parameter type that are used in
213 // testing.
214 class BatchDatasetParams : public DatasetParams {
215  public:
216   template <typename T>
BatchDatasetParams(T input_dataset_params,int64_t batch_size,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)217   BatchDatasetParams(T input_dataset_params, int64_t batch_size,
218                      bool drop_remainder, bool parallel_copy,
219                      DataTypeVector output_dtypes,
220                      std::vector<PartialTensorShape> output_shapes,
221                      string node_name)
222       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
223                       std::move(node_name)),
224         batch_size_(batch_size),
225         drop_remainder_(drop_remainder),
226         parallel_copy_(parallel_copy) {
227     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
228     op_version_ = 2;
229     iterator_prefix_ =
230         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
231                                    input_dataset_params.iterator_prefix());
232   }
233 
234   std::vector<Tensor> GetInputTensors() const override;
235 
236   Status GetInputNames(std::vector<string>* input_names) const override;
237 
238   Status GetAttributes(AttributeVector* attr_vector) const override;
239 
240   string dataset_type() const override;
241 
242  private:
243   int64 batch_size_;
244   bool drop_remainder_;
245   bool parallel_copy_;
246 };
247 
248 // `MapDatasetParams` is a common dataset parameter type that are used in
249 // testing.
250 class MapDatasetParams : public DatasetParams {
251  public:
252   template <typename T>
MapDatasetParams(T input_dataset_params,std::vector<Tensor> other_arguments,FunctionDefHelper::AttrValueWrapper func,std::vector<FunctionDef> func_lib,DataTypeVector type_arguments,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,bool use_inter_op_parallelism,bool preserve_cardinality,string node_name)253   MapDatasetParams(T input_dataset_params, std::vector<Tensor> other_arguments,
254                    FunctionDefHelper::AttrValueWrapper func,
255                    std::vector<FunctionDef> func_lib,
256                    DataTypeVector type_arguments, DataTypeVector output_dtypes,
257                    std::vector<PartialTensorShape> output_shapes,
258                    bool use_inter_op_parallelism, bool preserve_cardinality,
259                    string node_name)
260       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
261                       std::move(node_name)),
262         other_arguments_(std::move(other_arguments)),
263         func_(std::move(func)),
264         func_lib_(std::move(func_lib)),
265         type_arguments_(std::move(type_arguments)),
266         use_inter_op_parallelism_(use_inter_op_parallelism),
267         preserve_cardinality_(preserve_cardinality) {
268     input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
269     iterator_prefix_ =
270         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
271                                    input_dataset_params.iterator_prefix());
272   }
273 
274   std::vector<Tensor> GetInputTensors() const override;
275 
276   Status GetInputNames(std::vector<string>* input_names) const override;
277 
278   Status GetAttributes(AttributeVector* attr_vector) const override;
279 
280   string dataset_type() const override;
281 
282   std::vector<FunctionDef> func_lib() const override;
283 
284  private:
285   std::vector<Tensor> other_arguments_;
286   FunctionDefHelper::AttrValueWrapper func_;
287   std::vector<FunctionDef> func_lib_;
288   DataTypeVector type_arguments_;
289   bool use_inter_op_parallelism_;
290   bool preserve_cardinality_;
291 };
292 
293 // `TensorSliceDatasetParams` is a common dataset parameter type that are used
294 // in testing.
295 class TensorSliceDatasetParams : public DatasetParams {
296  public:
297   TensorSliceDatasetParams(std::vector<Tensor> components, string node_name);
298 
299   std::vector<Tensor> GetInputTensors() const override;
300 
301   Status GetInputNames(std::vector<string>* input_names) const override;
302 
303   Status GetAttributes(AttributeVector* attr_vector) const override;
304 
305   string dataset_type() const override;
306 
num_slices()307   int64 num_slices() const { return components_[0].dim_size(0); }
308 
num_tensors_per_slice()309   size_t num_tensors_per_slice() const { return components_.size(); }
310 
311  private:
312   DataTypeVector TensorSliceDtypes(const std::vector<Tensor>& input_components);
313 
314   std::vector<PartialTensorShape> TensorSliceShapes(
315       const std::vector<Tensor>& input_components);
316 
317  public:
318   std::vector<Tensor> components_;
319 };
320 
321 // `TakeDatasetParams` is a common dataset parameter type that are used in
322 // testing.
323 class TakeDatasetParams : public DatasetParams {
324  public:
325   template <typename T>
TakeDatasetParams(T input_dataset_params,int count,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)326   TakeDatasetParams(T input_dataset_params, int count,
327                     DataTypeVector output_dtypes,
328                     std::vector<PartialTensorShape> output_shapes,
329                     string node_name)
330       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
331                       std::move(node_name)),
332         count_(count) {
333     input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
334     iterator_prefix_ =
335         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
336                                    input_dataset_params.iterator_prefix());
337   }
338 
339   std::vector<Tensor> GetInputTensors() const override;
340 
341   Status GetInputNames(std::vector<string>* input_names) const override;
342 
343   Status GetAttributes(AttributeVector* attr_vector) const override;
344 
345   string dataset_type() const override;
346 
347  private:
348   int64 count_;
349 };
350 
351 // `ConcatenateDatasetParams` is a common dataset parameter type that are used
352 // in testing.
353 class ConcatenateDatasetParams : public DatasetParams {
354  public:
355   template <typename T, typename P>
ConcatenateDatasetParams(T input_dataset_params_0,P input_dataset_params_1,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)356   ConcatenateDatasetParams(T input_dataset_params_0, P input_dataset_params_1,
357                            DataTypeVector output_dtypes,
358                            std::vector<PartialTensorShape> output_shapes,
359                            string node_name)
360       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
361                       std::move(node_name)) {
362     input_dataset_params_.push_back(
363         absl::make_unique<T>(input_dataset_params_0));
364     input_dataset_params_.push_back(
365         absl::make_unique<T>(input_dataset_params_1));
366     iterator_prefix_ =
367         name_utils::IteratorPrefix(input_dataset_params_0.dataset_type(),
368                                    input_dataset_params_0.iterator_prefix());
369   }
370 
371   std::vector<Tensor> GetInputTensors() const override;
372 
373   Status GetInputNames(std::vector<string>* input_names) const override;
374 
375   Status GetAttributes(AttributeVector* attr_vector) const override;
376 
377   string dataset_type() const override;
378 };
379 
380 // `OptionsDatasetParams` is a common dataset parameter type that is used in
381 // testing.
382 class OptionsDatasetParams : public DatasetParams {
383  public:
384   template <typename T>
OptionsDatasetParams(T input_dataset_params,const string & serialized_options,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)385   OptionsDatasetParams(T input_dataset_params, const string& serialized_options,
386                        DataTypeVector output_dtypes,
387                        std::vector<PartialTensorShape> output_shapes,
388                        string node_name)
389       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
390                       std::move(node_name)),
391         serialized_options_(serialized_options) {
392     input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
393   }
394 
395   std::vector<Tensor> GetInputTensors() const override;
396 
397   Status GetInputNames(std::vector<string>* input_names) const override;
398 
399   Status GetAttributes(AttributeVector* attr_vector) const override;
400 
401   string dataset_type() const override;
402 
403  private:
404   string serialized_options_;
405 };
406 
407 template <typename T>
408 struct GetNextTestCase {
409   GetNextTestCase(T dataset_params, std::vector<Tensor> expected_outputs,
410                   bool compare_order = true)
dataset_paramsGetNextTestCase411       : dataset_params(std::move(dataset_params)),
412         expected_outputs(std::move(expected_outputs)),
413         compare_order(compare_order) {}
414 
415   T dataset_params;
416   std::vector<Tensor> expected_outputs;
417   bool compare_order;
418 };
419 
420 template <typename T>
421 struct SkipTestCase {
422   SkipTestCase(T dataset_params, int num_to_skip, int expected_num_skipped,
423                bool get_next = false, std::vector<Tensor> expected_outputs = {},
424                bool compare_order = true)
dataset_paramsSkipTestCase425       : dataset_params(std::move(dataset_params)),
426         num_to_skip(num_to_skip),
427         expected_num_skipped(expected_num_skipped),
428         get_next(get_next),
429         expected_outputs(std::move(expected_outputs)),
430         compare_order(compare_order) {}
431 
432   T dataset_params;
433   int num_to_skip;
434   int expected_num_skipped;
435   bool get_next;
436   std::vector<Tensor> expected_outputs;
437   bool compare_order;
438 };
439 
440 template <typename T>
441 struct DatasetNodeNameTestCase {
442   T dataset_params;
443   string expected_node_name;
444 };
445 
446 template <typename T>
447 struct DatasetTypeStringTestCase {
448   T dataset_params;
449   string expected_dataset_type_string;
450 };
451 
452 template <typename T>
453 struct DatasetOutputDtypesTestCase {
454   T dataset_params;
455   DataTypeVector expected_output_dtypes;
456 };
457 
458 template <typename T>
459 struct DatasetOutputShapesTestCase {
460   T dataset_params;
461   std::vector<PartialTensorShape> expected_output_shapes;
462 };
463 
464 template <typename T>
465 struct CardinalityTestCase {
466   T dataset_params;
467   int64 expected_cardinality;
468 };
469 
470 template <typename T>
471 struct DatasetSaveTestCase {
472   T dataset_params;
473 };
474 
475 template <typename T>
476 struct IteratorOutputDtypesTestCase {
477   T dataset_params;
478   DataTypeVector expected_output_dtypes;
479 };
480 
481 template <typename T>
482 struct IteratorOutputShapesTestCase {
483   T dataset_params;
484   std::vector<PartialTensorShape> expected_output_shapes;
485 };
486 
487 template <typename T>
488 struct IteratorPrefixTestCase {
489   T dataset_params;
490   string expected_iterator_prefix;
491 };
492 
493 template <typename T>
494 struct IteratorSaveAndRestoreTestCase {
495   IteratorSaveAndRestoreTestCase(T dataset_params, std::vector<int> breakpoints,
496                                  std::vector<Tensor> expected_outputs,
497                                  bool compare_order = true)
dataset_paramsIteratorSaveAndRestoreTestCase498       : dataset_params(std::move(dataset_params)),
499         breakpoints(std::move(breakpoints)),
500         expected_outputs(std::move(expected_outputs)),
501         compare_order(compare_order) {}
502 
503   T dataset_params;
504   std::vector<int> breakpoints;
505   std::vector<Tensor> expected_outputs;
506   bool compare_order;
507 };
508 
509 // Class composing a dataset with its dependencies.
510 class TestDataset {
511  public:
512   // TestDataset expects that the caller has Ref'd the wrapped dataset. When
513   // TestDataset is destroyed, it will Unref the dataset.
TestDataset(std::unique_ptr<OpKernel> kernel_,std::unique_ptr<OpKernelContext::Params> ctx_params,std::unique_ptr<OpKernelContext> ctx,std::vector<std::unique_ptr<Tensor>> input_tensors,DatasetBase * dataset)514   TestDataset(std::unique_ptr<OpKernel> kernel_,
515               std::unique_ptr<OpKernelContext::Params> ctx_params,
516               std::unique_ptr<OpKernelContext> ctx,
517               std::vector<std::unique_ptr<Tensor>> input_tensors,
518               DatasetBase* dataset)
519       : kernel_(std::move(kernel_)),
520         ctx_params_(std::move(ctx_params)),
521         ctx_(std::move(ctx)),
522         input_tensors_(std::move(input_tensors)),
523         dataset_(dataset),
524         scoped_unref_(dataset) {}
525 
dataset()526   DatasetBase* dataset() const { return dataset_; }
527 
op_kernel_context()528   OpKernelContext* op_kernel_context() const { return ctx_.get(); }
529 
530  protected:
531   std::unique_ptr<OpKernel> kernel_;
532   std::unique_ptr<OpKernelContext::Params> ctx_params_;
533   std::unique_ptr<OpKernelContext> ctx_;
534   // The input tensors that this dataset depends on. They must outlive the
535   // dataset.
536   std::vector<std::unique_ptr<Tensor>> input_tensors_;
537   DatasetBase* dataset_;
538   core::ScopedUnref scoped_unref_;
539 };
540 
541 // Class composing a dataset iterator with its dependencies.
542 class TestIterator {
543  public:
TestIterator(std::unique_ptr<IteratorContext> ctx,std::unique_ptr<IteratorBase> iterator)544   TestIterator(std::unique_ptr<IteratorContext> ctx,
545                std::unique_ptr<IteratorBase> iterator)
546       : iterator_(std::move(iterator)), ctx_(std::move(ctx)) {}
547 
iterator()548   IteratorBase* iterator() const { return iterator_.get(); }
549 
ctx()550   IteratorContext* ctx() const { return ctx_.get(); }
551 
GetNext(std::vector<Tensor> * out_tensors,bool * end_of_sequence)552   Status GetNext(std::vector<Tensor>* out_tensors, bool* end_of_sequence) {
553     return iterator_->GetNext(ctx(), out_tensors, end_of_sequence);
554   }
555 
556  protected:
557   std::unique_ptr<IteratorBase> iterator_;
558   std::unique_ptr<IteratorContext> ctx_;
559 };
560 
561 // Helpful functions to test Dataset op kernels.
562 class DatasetOpsTestBase : public ::testing::Test {
563  public:
564   DatasetOpsTestBase();
565 
566   // Initializes the runtime and creates a dataset and iterator.
567   Status Initialize(const DatasetParams& dataset_params);
568 
569   // Initializes the parts of the runtime needed to run dataset ops.
570   Status InitializeRuntime(const DatasetParams& dataset_params);
571 
572   // Creates a dataset.
573   Status MakeDataset(const DatasetParams& dataset_params,
574                      std::unique_ptr<TestDataset>* dataset);
575 
576   // Creates an iterator for the given dataset, using the specified split
577   // providers.
578   Status MakeIterator(
579       const DatasetParams& dataset_params, const TestDataset& dataset,
580       std::vector<std::unique_ptr<SplitProvider>> split_providers,
581       std::unique_ptr<TestIterator>* iterator);
582   // Creates an iterator for the given dataset.
583   Status MakeIterator(const DatasetParams& dataset_params,
584                       const TestDataset& dataset,
585                       std::unique_ptr<TestIterator>* iterator);
586 
587   // Runs the dataset operation according to the predefined dataset params and
588   // produces outputs. Different from `MakeDataset()` which returns a Dataset
589   // object, `RunDatasetOp()` executes the dataset kernel based on the input
590   // DatasetParams and returns the produced outputs as a tensor vector. It can
591   // be used to run some dataset operations that do not have an internal
592   // customized `Dataset` class (e.g. `ReduceDatasetOp`).
593   Status RunDatasetOp(const DatasetParams& dataset_params,
594                       std::vector<Tensor>* outputs);
595 
596   // The method validates whether the two tensors have the same shape, dtype,
597   // and value.
598   static Status ExpectEqual(const Tensor& a, const Tensor& b);
599 
600   // The method validates whether the two tensor vectors have the same tensors.
601   // If `compare_order` is false, the method will only evaluate whether the two
602   // vectors have the same elements regardless of order.
603   static Status ExpectEqual(std::vector<Tensor> produced_tensors,
604                             std::vector<Tensor> expected_tensors,
605                             bool compare_order);
606 
607   // Checks `IteratorBase::GetNext()`.
608   Status CheckIteratorGetNext(const std::vector<Tensor>& expected_outputs,
609                               bool compare_order);
610 
611   // Checks `IteratorBase::GetNext()`.
612   Status CheckIteratorGetNext(TestIterator* iterator,
613                               const std::vector<Tensor>& expected_outputs,
614                               bool compare_order);
615 
616   // Checks `IteratorBase::GetNext()`.
617   Status CheckIteratorGetNext(IteratorBase* iterator, IteratorContext* ctx,
618                               const std::vector<Tensor>& expected_outputs,
619                               bool compare_order);
620 
621   // Checks `IteratorBase::Skip()`
622   Status CheckIteratorSkip(int num_to_skip, int expected_num_skipped,
623                            bool get_next,
624                            const std::vector<Tensor>& expected_outputs,
625                            bool compare_order);
626 
627   // Checks that iterating through the dataset using a split provider produces
628   // the expected outputs.
629   Status CheckSplitProviderFullIteration(
630       const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
631 
632   // Checks that iterating through the dataset using a sharded split provider
633   // with the given `num_shards` and `shard_index` produces the expected
634   // outputs.
635   Status CheckSplitProviderShardedIteration(
636       const DatasetParams& params, int64_t num_shards, int64_t shard_index,
637       const std::vector<Tensor>& expected_outputs);
638 
639   // Checks `DatasetBase::node_name()`.
640   Status CheckDatasetNodeName(const string& expected_dataset_node_name);
641 
642   // Checks `DatasetBase::type_string()`.
643   Status CheckDatasetTypeString(const string& expected_type_str);
644 
645   // Checks `DatasetBase::output_dtypes()`.
646   Status CheckDatasetOutputDtypes(const DataTypeVector& expected_output_dtypes);
647 
648   // Checks `DatasetBase::output_shapes()`.
649   Status CheckDatasetOutputShapes(
650       const std::vector<PartialTensorShape>& expected_output_shapes);
651 
652   // Checks `DatasetBase::Cardinality()`.
653   Status CheckDatasetCardinality(int expected_cardinality);
654 
655   // Checks `DatasetBase::options()`.
656   Status CheckDatasetOptions(const Options& expected_options);
657 
658   // Checks `IteratorBase::output_dtypes()`.
659   Status CheckIteratorOutputDtypes(
660       const DataTypeVector& expected_output_dtypes);
661 
662   // Checks `IteratorBase::output_shapes()`.
663   Status CheckIteratorOutputShapes(
664       const std::vector<PartialTensorShape>& expected_output_shapes);
665 
666   // Checks `IteratorBase::prefix()`.
667   Status CheckIteratorPrefix(const string& expected_iterator_prefix);
668 
669   Status CheckIteratorSaveAndRestore(
670       DatasetBase* dataset, IteratorContext* iterator_ctx,
671       const std::string& iterator_prefix,
672       const std::vector<Tensor>& expected_outputs,
673       const std::vector<int>& breakpoints, bool compare_order);
674 
675   Status CheckIteratorSaveAndRestore(
676       const std::string& iterator_prefix,
677       const std::vector<Tensor>& expected_outputs,
678       const std::vector<int>& breakpoints, bool compare_order);
679 
680  protected:
681   // Make destructor protected so that DatasetOpsTestBase objects cannot
682   // be instantiated directly. Only subclasses can be instantiated.
683   ~DatasetOpsTestBase() override;
684 
685   // Creates a thread pool for parallel tasks.
686   Status InitThreadPool(int thread_num);
687 
688   // Initializes the runtime for computing the dataset operation and registers
689   // the input function definitions. `InitThreadPool()' needs to be called
690   // before this method if we want to run the tasks in parallel.
691   Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
692                                     int cpu_num);
693 
694   // Creates a new op kernel based on the node definition.
695   Status CreateOpKernel(const NodeDef& node_def,
696                         std::unique_ptr<OpKernel>* op_kernel);
697 
698   // Creates a new op kernel context.
699   Status CreateDatasetContext(
700       OpKernel* const dateset_kernel,
701       gtl::InlinedVector<TensorValue, 4>* const inputs,
702       std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
703       std::unique_ptr<OpKernelContext>* dataset_context);
704 
705   // Creates a new dataset.
706   Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
707                        DatasetBase** const dataset);
708 
709   // Restores the state of the input iterator. It resets the iterator before
710   // restoring it to make sure the input iterator does not hold any
711   // resources or tasks. Otherwise, restoring an existing iterator may cause
712   // the timeout issue or duplicated elements.
713   Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader,
714                          const string& output_prefix,
715                          const DatasetBase& dataset,
716                          std::unique_ptr<IteratorBase>* iterator);
717 
718   // Fetches the dataset from the operation context.
719   Status GetDatasetFromContext(OpKernelContext* context, int output_index,
720                                DatasetBase** const dataset);
721 
722   // Runs an operation producing outputs.
723   Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
724 
725   // Executes a function producing outputs.
726   Status RunFunction(const FunctionDef& fdef, test::function::Attrs attrs,
727                      const std::vector<Tensor>& args,
728                      const GraphConstructorOptions& graph_options,
729                      std::vector<Tensor*> rets);
730 
731   // Checks that the size of `inputs` matches the requirement of the op kernel.
732   Status CheckOpKernelInput(const OpKernel& kernel,
733                             const gtl::InlinedVector<TensorValue, 4>& inputs);
734 
735   // Creates a new context for running the dataset operation.
736   Status CreateOpKernelContext(OpKernel* kernel,
737                                gtl::InlinedVector<TensorValue, 4>* inputs,
738                                std::unique_ptr<OpKernelContext>* context);
739 
740   // Creates a new context for running the dataset operation.
741   Status CreateOpKernelContext(OpKernel* kernel,
742                                gtl::InlinedVector<TensorValue, 4>* inputs,
743                                std::unique_ptr<OpKernelContext::Params>* params,
744                                std::unique_ptr<OpKernelContext>* context);
745 
746   // Creates a new iterator context for iterating the dataset.
747   Status CreateIteratorContext(
748       OpKernelContext* const op_context,
749       std::unique_ptr<IteratorContext>* iterator_context);
750 
751   // Creates a new iterator context for iterating the dataset.
752   // Creates a new serialization context for serializing the dataset and
753   // iterator.
754   Status CreateSerializationContext(
755       std::unique_ptr<SerializationContext>* context);
756 
757   // Creates the dataset op kernel.
758   Status MakeGetOptionsOpKernel(const DatasetParams& dataset_params,
759                                 std::unique_ptr<OpKernel>* op_kernel);
760 
761  private:
762   // Runs the dataset operation according to the predefined dataset params and
763   // the produced outputs will be stored in `dataset_ctx`.
764   Status RunDatasetOp(
765       const DatasetParams& dataset_params,
766       std::unique_ptr<OpKernel>* dataset_kernel,
767       std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
768       std::vector<std::unique_ptr<Tensor>>* created_tensors,
769       std::unique_ptr<OpKernelContext>* dataset_ctx);
770 
771   Status MakeDataset(
772       const DatasetParams& dataset_params,
773       std::unique_ptr<OpKernel>* dataset_kernel,
774       std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
775       std::unique_ptr<OpKernelContext>* dataset_ctx,
776       std::vector<std::unique_ptr<Tensor>>* created_tensors,
777       DatasetBase** dataset);
778 
779   // Creates the dataset op kernel.
780   Status MakeDatasetOpKernel(const DatasetParams& dataset_params,
781                              std::unique_ptr<OpKernel>* dataset_kernel);
782 
783   // Creates a dataset tensor according to the input dataset params.
784   Status MakeDatasetTensor(
785       const DatasetParams& dataset_params,
786       std::vector<std::unique_ptr<Tensor>>* created_tensors,
787       std::unique_ptr<Tensor>* dataset);
788 
789   // Adds an empty tensor with the specified dtype and shape to the input
790   // vector.
791   Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
792                          DataTypeVector input_types, DataType dtype,
793                          const TensorShape& shape);
794 
795  protected:
796   std::unique_ptr<Device> device_;
797   DeviceType device_type_;
798   int cpu_num_;
799   int thread_num_;
800   Allocator* allocator_;  // Owned by `AllocatorFactoryRegistry`.
801   std::vector<AllocatorAttributes> allocator_attrs_;
802   std::unique_ptr<ScopedStepContainer> step_container_;
803 
804   // Device manager is used by function handle cache and needs to outlive it.
805   std::unique_ptr<DeviceMgr> device_mgr_;
806   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
807   FunctionLibraryRuntime* flr_;  // Owned by `pflr_`.
808   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
809   std::function<void(std::function<void()>)> runner_;
810   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
811   std::unique_ptr<ResourceMgr> resource_mgr_;
812   std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
813       slice_reader_cache_;
814   std::unique_ptr<thread::ThreadPool> thread_pool_;
815   std::vector<std::unique_ptr<Tensor>> tensors_;  // Owns tensors.
816   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs.
817   std::unique_ptr<CancellationManager> cancellation_manager_;
818 
819   // Indicates if the below fields have been initialized.
820   bool initialized_ = false;
821   std::unique_ptr<OpKernel> dataset_kernel_;
822   std::unique_ptr<OpKernelContext::Params> params_;
823   std::unique_ptr<OpKernelContext> dataset_ctx_;
824   DatasetBase* dataset_ = nullptr;
825   std::unique_ptr<IteratorContext> iterator_ctx_;
826   std::unique_ptr<IteratorBase> iterator_;
827 };
828 
829 #define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
830                                  test_cases)                                  \
831   class ParameterizedGetNextTest                                              \
832       : public dataset_op_test_class,                                         \
833         public ::testing::WithParamInterface<                                 \
834             GetNextTestCase<dataset_params_class>> {};                        \
835                                                                               \
836   TEST_P(ParameterizedGetNextTest, GetNext) {                                 \
837     auto test_case = GetParam();                                              \
838     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
839     TF_ASSERT_OK(                                                             \
840         CheckIteratorGetNext(test_case.expected_outputs,                      \
841                              /*compare_order=*/test_case.compare_order));     \
842   }                                                                           \
843                                                                               \
844   INSTANTIATE_TEST_SUITE_P(                                                   \
845       dataset_op_test_class, ParameterizedGetNextTest,                        \
846       ::testing::ValuesIn(                                                    \
847           std::vector<GetNextTestCase<dataset_params_class>>(test_cases)));
848 
849 #define ITERATOR_SKIP_TEST_P(dataset_op_test_class, dataset_params_class,   \
850                              test_cases)                                    \
851   class ParameterizedSkipTest : public dataset_op_test_class,               \
852                                 public ::testing::WithParamInterface<       \
853                                     SkipTestCase<dataset_params_class>> {}; \
854                                                                             \
855   TEST_P(ParameterizedSkipTest, Skip) {                                     \
856     auto test_case = GetParam();                                            \
857     TF_ASSERT_OK(Initialize(test_case.dataset_params));                     \
858     TF_ASSERT_OK(CheckIteratorSkip(                                         \
859         test_case.num_to_skip, test_case.expected_num_skipped,              \
860         test_case.get_next, test_case.expected_outputs,                     \
861         /*compare_order=*/test_case.compare_order));                        \
862   }                                                                         \
863                                                                             \
864   INSTANTIATE_TEST_SUITE_P(                                                 \
865       dataset_op_test_class, ParameterizedSkipTest,                         \
866       ::testing::ValuesIn(                                                  \
867           std::vector<SkipTestCase<dataset_params_class>>(test_cases)));
868 
869 #define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
870                                  test_cases)                                  \
871   class ParameterizedDatasetNodeNameTest                                      \
872       : public dataset_op_test_class,                                         \
873         public ::testing::WithParamInterface<                                 \
874             DatasetNodeNameTestCase<dataset_params_class>> {};                \
875                                                                               \
876   TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) {                 \
877     auto test_case = GetParam();                                              \
878     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
879     TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name));         \
880   }                                                                           \
881                                                                               \
882   INSTANTIATE_TEST_SUITE_P(                                                   \
883       dataset_op_test_class, ParameterizedDatasetNodeNameTest,                \
884       ::testing::ValuesIn(                                                    \
885           std::vector<DatasetNodeNameTestCase<dataset_params_class>>(         \
886               test_cases)));
887 
888 #define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class,                \
889                                    dataset_params_class, test_cases)     \
890   class ParameterizedDatasetTypeStringTest                               \
891       : public dataset_op_test_class,                                    \
892         public ::testing::WithParamInterface<                            \
893             DatasetTypeStringTestCase<dataset_params_class>> {};         \
894                                                                          \
895   TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) {        \
896     auto test_case = GetParam();                                         \
897     TF_ASSERT_OK(Initialize(test_case.dataset_params));                  \
898     TF_ASSERT_OK(                                                        \
899         CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
900   }                                                                      \
901                                                                          \
902   INSTANTIATE_TEST_SUITE_P(                                              \
903       dataset_op_test_class, ParameterizedDatasetTypeStringTest,         \
904       ::testing::ValuesIn(                                               \
905           std::vector<DatasetTypeStringTestCase<dataset_params_class>>(  \
906               test_cases)));
907 
908 #define DATASET_OUTPUT_DTYPES_TEST_P(dataset_op_test_class,                   \
909                                      dataset_params_class, test_cases)        \
910                                                                               \
911   class ParameterizedDatasetOutputDtypesTest                                  \
912       : public dataset_op_test_class,                                         \
913         public ::testing::WithParamInterface<                                 \
914             DatasetOutputDtypesTestCase<dataset_params_class>> {};            \
915                                                                               \
916   TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) {         \
917     auto test_case = GetParam();                                              \
918     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
919     TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
920   }                                                                           \
921                                                                               \
922   INSTANTIATE_TEST_SUITE_P(                                                   \
923       dataset_op_test_class, ParameterizedDatasetOutputDtypesTest,            \
924       ::testing::ValuesIn(                                                    \
925           std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>(     \
926               test_cases)));
927 
928 #define DATASET_OUTPUT_SHAPES_TEST_P(dataset_op_test_class,                   \
929                                      dataset_params_class, test_cases)        \
930                                                                               \
931   class ParameterizedDatasetOutputShapesTest                                  \
932       : public dataset_op_test_class,                                         \
933         public ::testing::WithParamInterface<                                 \
934             DatasetOutputShapesTestCase<dataset_params_class>> {};            \
935                                                                               \
936   TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) {         \
937     auto test_case = GetParam();                                              \
938     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
939     TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
940   }                                                                           \
941                                                                               \
942   INSTANTIATE_TEST_SUITE_P(                                                   \
943       dataset_op_test_class, ParameterizedDatasetOutputShapesTest,            \
944       ::testing::ValuesIn(                                                    \
945           std::vector<DatasetOutputShapesTestCase<dataset_params_class>>(     \
946               test_cases)));
947 
948 #define DATASET_CARDINALITY_TEST_P(dataset_op_test_class,                  \
949                                    dataset_params_class, test_cases)       \
950                                                                            \
951   class ParameterizedCardinalityTest                                       \
952       : public dataset_op_test_class,                                      \
953         public ::testing::WithParamInterface<                              \
954             CardinalityTestCase<dataset_params_class>> {};                 \
955                                                                            \
956   TEST_P(ParameterizedCardinalityTest, Cardinality) {                      \
957     auto test_case = GetParam();                                           \
958     TF_ASSERT_OK(Initialize(test_case.dataset_params));                    \
959     TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
960   }                                                                        \
961                                                                            \
962   INSTANTIATE_TEST_SUITE_P(                                                \
963       dataset_op_test_class, ParameterizedCardinalityTest,                 \
964       ::testing::ValuesIn(                                                 \
965           std::vector<CardinalityTestCase<dataset_params_class>>(          \
966               test_cases)));
967 
968 #define ITERATOR_OUTPUT_DTYPES_TEST_P(dataset_op_test_class,                  \
969                                       dataset_params_class, test_cases)       \
970   class ParameterizedIteratorOutputDtypesTest                                 \
971       : public dataset_op_test_class,                                         \
972         public ::testing::WithParamInterface<                                 \
973             IteratorOutputDtypesTestCase<dataset_params_class>> {};           \
974                                                                               \
975   TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) {       \
976     auto test_case = GetParam();                                              \
977     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
978     TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
979   }                                                                           \
980                                                                               \
981   INSTANTIATE_TEST_SUITE_P(                                                   \
982       dataset_op_test_class, ParameterizedIteratorOutputDtypesTest,           \
983       ::testing::ValuesIn(                                                    \
984           std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>(    \
985               test_cases)));
986 
987 #define ITERATOR_OUTPUT_SHAPES_TEST_P(dataset_op_test_class,                   \
988                                       dataset_params_class, test_cases)        \
989   class ParameterizedIteratorOutputShapesTest                                  \
990       : public dataset_op_test_class,                                          \
991         public ::testing::WithParamInterface<                                  \
992             IteratorOutputShapesTestCase<dataset_params_class>> {};            \
993                                                                                \
994   TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) {        \
995     auto test_case = GetParam();                                               \
996     TF_ASSERT_OK(Initialize(test_case.dataset_params));                        \
997     TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
998   }                                                                            \
999                                                                                \
1000   INSTANTIATE_TEST_SUITE_P(                                                    \
1001       dataset_op_test_class, ParameterizedIteratorOutputShapesTest,            \
1002       ::testing::ValuesIn(                                                     \
1003           std::vector<IteratorOutputShapesTestCase<dataset_params_class>>(     \
1004               test_cases)));
1005 
1006 #define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
1007                                test_cases)                                  \
1008   class ParameterizedIteratorPrefixTest                                     \
1009       : public dataset_op_test_class,                                       \
1010         public ::testing::WithParamInterface<                               \
1011             IteratorPrefixTestCase<dataset_params_class>> {};               \
1012                                                                             \
1013   TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) {                 \
1014     auto test_case = GetParam();                                            \
1015     TF_ASSERT_OK(Initialize(test_case.dataset_params));                     \
1016     TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix));  \
1017   }                                                                         \
1018                                                                             \
1019   INSTANTIATE_TEST_SUITE_P(                                                 \
1020       dataset_op_test_class, ParameterizedIteratorPrefixTest,               \
1021       ::testing::ValuesIn(                                                  \
1022           std::vector<IteratorPrefixTestCase<dataset_params_class>>(        \
1023               test_cases)));
1024 
1025 #define ITERATOR_SAVE_AND_RESTORE_TEST_P(dataset_op_test_class,              \
1026                                          dataset_params_class, test_cases)   \
1027   class ParameterizedIteratorSaveAndRestoreTest                              \
1028       : public dataset_op_test_class,                                        \
1029         public ::testing::WithParamInterface<                                \
1030             IteratorSaveAndRestoreTestCase<dataset_params_class>> {};        \
1031   TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {  \
1032     auto test_case = GetParam();                                             \
1033     TF_ASSERT_OK(Initialize(test_case.dataset_params));                      \
1034     TF_ASSERT_OK(CheckIteratorSaveAndRestore(                                \
1035         test_case.dataset_params.iterator_prefix(),                          \
1036         test_case.expected_outputs, test_case.breakpoints,                   \
1037         test_case.compare_order));                                           \
1038   }                                                                          \
1039   INSTANTIATE_TEST_SUITE_P(                                                  \
1040       dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest,        \
1041       ::testing::ValuesIn(                                                   \
1042           std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
1043               test_cases)));
1044 
1045 }  // namespace data
1046 }  // namespace tensorflow
1047 
1048 #endif  // TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
1049