• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
17 #define TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
18 
19 #include <stddef.h>
20 
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/graph_constructor.h"
31 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
32 #include "tensorflow/core/data/name_utils.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/dataset.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function_handle_cache.h"
38 #include "tensorflow/core/framework/function_testlib.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/resource_mgr.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_shape.h"
43 #include "tensorflow/core/framework/tensor_testutil.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/lib/gtl/array_slice.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/io/zlib_compression_options.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/test.h"
53 #include "tensorflow/core/platform/threadpool.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
56 
57 namespace tensorflow {
58 namespace data {
59 
60 typedef std::vector<
61     std::pair<string, tensorflow::FunctionDefHelper::AttrValueWrapper>>
62     AttributeVector;
63 
64 constexpr int kDefaultCPUNum = 2;
65 constexpr int kDefaultThreadNum = 2;
66 
67 // Creates a tensor with the specified dtype, shape, and value.
68 template <typename T>
CreateTensor(const TensorShape & input_shape,gtl::ArraySlice<T> input_data)69 static Tensor CreateTensor(const TensorShape& input_shape,
70                            gtl::ArraySlice<T> input_data) {
71   Tensor tensor(DataTypeToEnum<T>::value, input_shape);
72   test::FillValues<T>(&tensor, input_data);
73   return tensor;
74 }
75 
76 // Creates a tensor with the specified dtype and shape, with values 0, 1, 2, ...
77 template <typename T>
CreateTensor(const TensorShape & input_shape)78 static Tensor CreateTensor(const TensorShape& input_shape) {
79   Tensor tensor(DataTypeToEnum<T>::value, input_shape);
80   test::FillIota<T>(&tensor, 0);
81   return tensor;
82 }
83 
84 // Creates a vector of tensors with the specified dtype, shape, and values.
85 template <typename T>
CreateTensors(const TensorShape & shape,const std::vector<gtl::ArraySlice<T>> & values)86 std::vector<Tensor> CreateTensors(
87     const TensorShape& shape, const std::vector<gtl::ArraySlice<T>>& values) {
88   std::vector<Tensor> result;
89   result.reserve(values.size());
90   for (auto& value : values) {
91     result.emplace_back(CreateTensor<T>(shape, value));
92   }
93   return result;
94 }
95 
96 enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
97 
98 // Returns a string representation for the given compression type.
99 string ToString(CompressionType compression_type);
100 
101 // Gets the specified zlib compression options according to the compression
102 // type. Note that `CompressionType::UNCOMPRESSED` is not supported because
103 // `ZlibCompressionOptions` does not have an option.
104 io::ZlibCompressionOptions GetZlibCompressionOptions(
105     CompressionType compression_type);
106 
107 // Used to specify parameters when writing data into files with compression.
108 // `input_buffer_size` and `output_buffer_size` specify the input and output
109 // buffer size when ZLIB and GZIP compression is used.
110 struct CompressionParams {
111   CompressionType compression_type = CompressionType::UNCOMPRESSED;
112   int32 input_buffer_size = 0;
113   int32 output_buffer_size = 0;
114 };
115 
116 // Writes the input data into the file without compression.
117 Status WriteDataToFile(const string& filename, const char* data);
118 
119 // Writes the input data into the file with the specified compression.
120 Status WriteDataToFile(const string& filename, const char* data,
121                        const CompressionParams& params);
122 
123 // Writes the input data into the TFRecord file with the specified compression.
124 Status WriteDataToTFRecordFile(const string& filename,
125                                const std::vector<absl::string_view>& records,
126                                const CompressionParams& params);
127 
128 // Provides the parameters for running the dataset op.
129 class DatasetParams {
130  public:
131   DatasetParams(DataTypeVector output_dtypes,
132                 std::vector<PartialTensorShape> output_shapes,
133                 string node_name);
134 
~DatasetParams()135   virtual ~DatasetParams() {}
136 
137   // Returns the inputs (except the input datasets) as a tensor vector.
138   virtual std::vector<Tensor> GetInputTensors() const = 0;
139 
140   // Returns the dataset input names as a string vector.
141   virtual Status GetInputNames(std::vector<string>* input_names) const = 0;
142 
143   // Returns the dataset attributes as a vector.
144   virtual Status GetAttributes(AttributeVector* attributes) const = 0;
145 
146   // Checks if the tensor is a dataset variant tensor.
147   static bool IsDatasetTensor(const Tensor& tensor);
148 
node_name()149   string node_name() const { return node_name_; }
150 
output_dtypes()151   DataTypeVector output_dtypes() const { return output_dtypes_; }
152 
output_shapes()153   std::vector<PartialTensorShape> output_shapes() const {
154     return output_shapes_;
155   }
156 
iterator_prefix()157   string iterator_prefix() const { return iterator_prefix_; }
158 
input_dataset_params()159   const std::vector<std::shared_ptr<DatasetParams>>& input_dataset_params()
160       const {
161     return input_dataset_params_;
162   }
163 
164   // Returns the functions that will be used when running the dataset op.
func_lib()165   virtual std::vector<FunctionDef> func_lib() const { return {}; }
166 
167   // Returns the dataset type for the op represented by these parameters. This
168   // type usually needs to match the constant called `kDatasetType` defined in
169   // the dataset kernel.
170   virtual string dataset_type() const = 0;
171 
172   // Returns the dataset op name. By default, it returns the Op::kDatasetType
173   // concatenated with "Dataset". For ops that do not have "Dataset" suffix,
174   // this method can be overriden to return a different name.
op_name()175   virtual string op_name() const {
176     name_utils::OpNameParams params;
177     params.op_version = op_version();
178     return name_utils::OpName(dataset_type(), params);
179   }
180 
op_version()181   virtual int op_version() const { return op_version_; }
182 
183  protected:
184   std::vector<std::shared_ptr<DatasetParams>> input_dataset_params_;
185   DataTypeVector output_dtypes_;
186   std::vector<PartialTensorShape> output_shapes_;
187   string node_name_;
188   string iterator_prefix_ = "Iterator";
189   int op_version_ = 1;
190 };
191 
192 // `RangeDatasetParams` is a common dataset parameter type that are used in
193 // testing.
194 class RangeDatasetParams : public DatasetParams {
195  public:
196   RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
197                      DataTypeVector output_dtypes,
198                      std::vector<PartialTensorShape> output_shapes,
199                      string node_name);
200 
201   RangeDatasetParams(int64_t start, int64_t stop, int64_t step);
202 
203   RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
204                      DataTypeVector output_dtypes);
205 
206   std::vector<Tensor> GetInputTensors() const override;
207 
208   Status GetInputNames(std::vector<string>* input_names) const override;
209 
210   Status GetAttributes(AttributeVector* attr_vector) const override;
211 
212   string dataset_type() const override;
213 
214  private:
215   int64_t start_;
216   int64_t stop_;
217   int64_t step_;
218 };
219 
220 // `BatchDatasetParams` is a common dataset parameter type that are used in
221 // testing.
222 class BatchDatasetParams : public DatasetParams {
223  public:
224   template <typename T>
BatchDatasetParams(T input_dataset_params,int64_t batch_size,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)225   BatchDatasetParams(T input_dataset_params, int64_t batch_size,
226                      bool drop_remainder, bool parallel_copy,
227                      DataTypeVector output_dtypes,
228                      std::vector<PartialTensorShape> output_shapes,
229                      string node_name)
230       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
231                       std::move(node_name)),
232         batch_size_(batch_size),
233         drop_remainder_(drop_remainder),
234         parallel_copy_(parallel_copy) {
235     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
236     op_version_ = 2;
237     iterator_prefix_ =
238         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
239                                    input_dataset_params.iterator_prefix());
240   }
241 
242   std::vector<Tensor> GetInputTensors() const override;
243 
244   Status GetInputNames(std::vector<string>* input_names) const override;
245 
246   Status GetAttributes(AttributeVector* attr_vector) const override;
247 
248   string dataset_type() const override;
249 
250  private:
251   int64_t batch_size_;
252   bool drop_remainder_;
253   bool parallel_copy_;
254 };
255 
256 // `MapDatasetParams` is a common dataset parameter type that are used in
257 // testing.
258 class MapDatasetParams : public DatasetParams {
259  public:
260   template <typename T>
MapDatasetParams(T input_dataset_params,std::vector<Tensor> other_arguments,FunctionDefHelper::AttrValueWrapper func,std::vector<FunctionDef> func_lib,DataTypeVector type_arguments,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,bool use_inter_op_parallelism,bool preserve_cardinality,string node_name)261   MapDatasetParams(T input_dataset_params, std::vector<Tensor> other_arguments,
262                    FunctionDefHelper::AttrValueWrapper func,
263                    std::vector<FunctionDef> func_lib,
264                    DataTypeVector type_arguments, DataTypeVector output_dtypes,
265                    std::vector<PartialTensorShape> output_shapes,
266                    bool use_inter_op_parallelism, bool preserve_cardinality,
267                    string node_name)
268       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
269                       std::move(node_name)),
270         other_arguments_(std::move(other_arguments)),
271         func_(std::move(func)),
272         func_lib_(std::move(func_lib)),
273         type_arguments_(std::move(type_arguments)),
274         use_inter_op_parallelism_(use_inter_op_parallelism),
275         preserve_cardinality_(preserve_cardinality) {
276     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
277     iterator_prefix_ =
278         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
279                                    input_dataset_params.iterator_prefix());
280   }
281 
282   std::vector<Tensor> GetInputTensors() const override;
283 
284   Status GetInputNames(std::vector<string>* input_names) const override;
285 
286   Status GetAttributes(AttributeVector* attr_vector) const override;
287 
288   string dataset_type() const override;
289 
290   std::vector<FunctionDef> func_lib() const override;
291 
292  private:
293   std::vector<Tensor> other_arguments_;
294   FunctionDefHelper::AttrValueWrapper func_;
295   std::vector<FunctionDef> func_lib_;
296   DataTypeVector type_arguments_;
297   bool use_inter_op_parallelism_;
298   bool preserve_cardinality_;
299 };
300 
301 // `TensorSliceDatasetParams` is a common dataset parameter type that are used
302 // in testing.
303 class TensorSliceDatasetParams : public DatasetParams {
304  public:
305   TensorSliceDatasetParams(std::vector<Tensor> components, string node_name,
306                            bool is_files = false);
307 
308   std::vector<Tensor> GetInputTensors() const override;
309 
310   Status GetInputNames(std::vector<string>* input_names) const override;
311 
312   Status GetAttributes(AttributeVector* attr_vector) const override;
313 
314   string dataset_type() const override;
315 
num_slices()316   int64_t num_slices() const { return components_[0].dim_size(0); }
317 
num_tensors_per_slice()318   size_t num_tensors_per_slice() const { return components_.size(); }
319 
320  private:
321   DataTypeVector TensorSliceDtypes(const std::vector<Tensor>& input_components);
322 
323   std::vector<PartialTensorShape> TensorSliceShapes(
324       const std::vector<Tensor>& input_components);
325 
326  public:
327   std::vector<Tensor> components_;
328   bool is_files_;
329 };
330 
331 // `TakeDatasetParams` is a common dataset parameter type that are used in
332 // testing.
333 class TakeDatasetParams : public DatasetParams {
334  public:
335   template <typename T>
TakeDatasetParams(T input_dataset_params,int count,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)336   TakeDatasetParams(T input_dataset_params, int count,
337                     DataTypeVector output_dtypes,
338                     std::vector<PartialTensorShape> output_shapes,
339                     string node_name)
340       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
341                       std::move(node_name)),
342         count_(count) {
343     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
344     iterator_prefix_ =
345         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
346                                    input_dataset_params.iterator_prefix());
347   }
348 
349   std::vector<Tensor> GetInputTensors() const override;
350 
351   Status GetInputNames(std::vector<string>* input_names) const override;
352 
353   Status GetAttributes(AttributeVector* attr_vector) const override;
354 
355   string dataset_type() const override;
356 
357  private:
358   int64_t count_;
359 };
360 
361 // `ConcatenateDatasetParams` is a common dataset parameter type that are used
362 // in testing.
363 class ConcatenateDatasetParams : public DatasetParams {
364  public:
365   template <typename T, typename P>
ConcatenateDatasetParams(T input_dataset_params_0,P input_dataset_params_1,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)366   ConcatenateDatasetParams(T input_dataset_params_0, P input_dataset_params_1,
367                            DataTypeVector output_dtypes,
368                            std::vector<PartialTensorShape> output_shapes,
369                            string node_name)
370       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
371                       std::move(node_name)) {
372     input_dataset_params_.push_back(
373         std::make_unique<T>(input_dataset_params_0));
374     input_dataset_params_.push_back(
375         std::make_unique<T>(input_dataset_params_1));
376     iterator_prefix_ =
377         name_utils::IteratorPrefix(input_dataset_params_0.dataset_type(),
378                                    input_dataset_params_0.iterator_prefix());
379   }
380 
381   std::vector<Tensor> GetInputTensors() const override;
382 
383   Status GetInputNames(std::vector<string>* input_names) const override;
384 
385   Status GetAttributes(AttributeVector* attr_vector) const override;
386 
387   string dataset_type() const override;
388 };
389 
390 // `OptionsDatasetParams` is a common dataset parameter type that is used in
391 // testing.
392 class OptionsDatasetParams : public DatasetParams {
393  public:
394   template <typename T>
OptionsDatasetParams(T input_dataset_params,const string & serialized_options,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)395   OptionsDatasetParams(T input_dataset_params, const string& serialized_options,
396                        DataTypeVector output_dtypes,
397                        std::vector<PartialTensorShape> output_shapes,
398                        string node_name)
399       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
400                       std::move(node_name)),
401         serialized_options_(serialized_options) {
402     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
403   }
404 
405   std::vector<Tensor> GetInputTensors() const override;
406 
407   Status GetInputNames(std::vector<string>* input_names) const override;
408 
409   Status GetAttributes(AttributeVector* attr_vector) const override;
410 
411   string dataset_type() const override;
412 
413  private:
414   string serialized_options_;
415 };
416 
417 template <typename T>
418 struct GetNextTestCase {
419   GetNextTestCase(T dataset_params, std::vector<Tensor> expected_outputs,
420                   bool compare_order = true)
dataset_paramsGetNextTestCase421       : dataset_params(std::move(dataset_params)),
422         expected_outputs(std::move(expected_outputs)),
423         compare_order(compare_order) {}
424 
425   T dataset_params;
426   std::vector<Tensor> expected_outputs;
427   bool compare_order;
428 };
429 
430 template <typename T>
431 struct SkipTestCase {
432   SkipTestCase(T dataset_params, int num_to_skip, int expected_num_skipped,
433                bool get_next = false, std::vector<Tensor> expected_outputs = {},
434                bool compare_order = true)
dataset_paramsSkipTestCase435       : dataset_params(std::move(dataset_params)),
436         num_to_skip(num_to_skip),
437         expected_num_skipped(expected_num_skipped),
438         get_next(get_next),
439         expected_outputs(std::move(expected_outputs)),
440         compare_order(compare_order) {}
441 
442   T dataset_params;
443   int num_to_skip;
444   int expected_num_skipped;
445   bool get_next;
446   std::vector<Tensor> expected_outputs;
447   bool compare_order;
448 };
449 
450 template <typename T>
451 struct DatasetNodeNameTestCase {
452   T dataset_params;
453   string expected_node_name;
454 };
455 
456 template <typename T>
457 struct DatasetTypeStringTestCase {
458   T dataset_params;
459   string expected_dataset_type_string;
460 };
461 
462 template <typename T>
463 struct DatasetOutputDtypesTestCase {
464   T dataset_params;
465   DataTypeVector expected_output_dtypes;
466 };
467 
468 template <typename T>
469 struct DatasetOutputShapesTestCase {
470   T dataset_params;
471   std::vector<PartialTensorShape> expected_output_shapes;
472 };
473 
474 template <typename T>
475 struct CardinalityTestCase {
476   T dataset_params;
477   int64_t expected_cardinality;
478 };
479 
480 template <typename T>
481 struct DatasetSaveTestCase {
482   T dataset_params;
483 };
484 
485 template <typename T>
486 struct IteratorOutputDtypesTestCase {
487   T dataset_params;
488   DataTypeVector expected_output_dtypes;
489 };
490 
491 template <typename T>
492 struct IteratorOutputShapesTestCase {
493   T dataset_params;
494   std::vector<PartialTensorShape> expected_output_shapes;
495 };
496 
497 template <typename T>
498 struct IteratorPrefixTestCase {
499   T dataset_params;
500   string expected_iterator_prefix;
501 };
502 
503 template <typename T>
504 struct IteratorSaveAndRestoreTestCase {
505   IteratorSaveAndRestoreTestCase(T dataset_params, std::vector<int> breakpoints,
506                                  std::vector<Tensor> expected_outputs,
507                                  bool compare_order = true)
dataset_paramsIteratorSaveAndRestoreTestCase508       : dataset_params(std::move(dataset_params)),
509         breakpoints(std::move(breakpoints)),
510         expected_outputs(std::move(expected_outputs)),
511         compare_order(compare_order) {}
512 
513   T dataset_params;
514   std::vector<int> breakpoints;
515   std::vector<Tensor> expected_outputs;
516   bool compare_order;
517 };
518 
519 // Class composing a dataset with its dependencies.
520 class TestDataset {
521  public:
522   // TestDataset expects that the caller has Ref'd the wrapped dataset. When
523   // TestDataset is destroyed, it will Unref the dataset.
TestDataset(std::unique_ptr<OpKernel> kernel_,std::unique_ptr<OpKernelContext::Params> ctx_params,std::unique_ptr<OpKernelContext> ctx,std::vector<std::unique_ptr<Tensor>> input_tensors,DatasetBase * dataset)524   TestDataset(std::unique_ptr<OpKernel> kernel_,
525               std::unique_ptr<OpKernelContext::Params> ctx_params,
526               std::unique_ptr<OpKernelContext> ctx,
527               std::vector<std::unique_ptr<Tensor>> input_tensors,
528               DatasetBase* dataset)
529       : kernel_(std::move(kernel_)),
530         ctx_params_(std::move(ctx_params)),
531         ctx_(std::move(ctx)),
532         input_tensors_(std::move(input_tensors)),
533         dataset_(dataset),
534         scoped_unref_(dataset) {}
535 
dataset()536   DatasetBase* dataset() const { return dataset_; }
537 
op_kernel_context()538   OpKernelContext* op_kernel_context() const { return ctx_.get(); }
539 
540  protected:
541   std::unique_ptr<OpKernel> kernel_;
542   std::unique_ptr<OpKernelContext::Params> ctx_params_;
543   std::unique_ptr<OpKernelContext> ctx_;
544   // The input tensors that this dataset depends on. They must outlive the
545   // dataset.
546   std::vector<std::unique_ptr<Tensor>> input_tensors_;
547   DatasetBase* dataset_;
548   core::ScopedUnref scoped_unref_;
549 };
550 
551 // Class composing a dataset iterator with its dependencies.
552 class TestIterator {
553  public:
TestIterator(std::unique_ptr<IteratorContext> ctx,std::unique_ptr<IteratorBase> iterator)554   TestIterator(std::unique_ptr<IteratorContext> ctx,
555                std::unique_ptr<IteratorBase> iterator)
556       : iterator_(std::move(iterator)), ctx_(std::move(ctx)) {}
557 
iterator()558   IteratorBase* iterator() const { return iterator_.get(); }
559 
ctx()560   IteratorContext* ctx() const { return ctx_.get(); }
561 
GetNext(std::vector<Tensor> * out_tensors,bool * end_of_sequence)562   Status GetNext(std::vector<Tensor>* out_tensors, bool* end_of_sequence) {
563     return iterator_->GetNext(ctx(), out_tensors, end_of_sequence);
564   }
565 
566  protected:
567   std::unique_ptr<IteratorBase> iterator_;
568   std::unique_ptr<IteratorContext> ctx_;
569 };
570 
571 // Helpful functions to test Dataset op kernels.
572 class DatasetOpsTestBase : public ::testing::Test {
573  public:
574   DatasetOpsTestBase();
575 
576   // Initializes the runtime and creates a dataset and iterator.
577   Status Initialize(const DatasetParams& dataset_params);
578 
579   // Initializes the parts of the runtime needed to run dataset ops.
580   Status InitializeRuntime(const DatasetParams& dataset_params);
581 
582   // Creates a dataset.
583   Status MakeDataset(const DatasetParams& dataset_params,
584                      std::unique_ptr<TestDataset>* dataset);
585 
586   // Creates an iterator for the given dataset, using the specified split
587   // providers.
588   Status MakeIterator(
589       const DatasetParams& dataset_params, const TestDataset& dataset,
590       std::vector<std::unique_ptr<SplitProvider>> split_providers,
591       std::unique_ptr<TestIterator>* iterator);
592   // Creates an iterator for the given dataset.
593   Status MakeIterator(const DatasetParams& dataset_params,
594                       const TestDataset& dataset,
595                       std::unique_ptr<TestIterator>* iterator);
596 
597   // Runs the dataset operation according to the predefined dataset params and
598   // produces outputs. Different from `MakeDataset()` which returns a Dataset
599   // object, `RunDatasetOp()` executes the dataset kernel based on the input
600   // DatasetParams and returns the produced outputs as a tensor vector. It can
601   // be used to run some dataset operations that do not have an internal
602   // customized `Dataset` class (e.g. `ReduceDatasetOp`).
603   Status RunDatasetOp(const DatasetParams& dataset_params,
604                       std::vector<Tensor>* outputs);
605 
606   // The method validates whether the two tensors have the same shape, dtype,
607   // and value.
608   static Status ExpectEqual(const Tensor& a, const Tensor& b);
609 
610   // The method validates whether the two tensor vectors have the same tensors.
611   // If `compare_order` is false, the method will only evaluate whether the two
612   // vectors have the same elements regardless of order.
613   static Status ExpectEqual(std::vector<Tensor> produced_tensors,
614                             std::vector<Tensor> expected_tensors,
615                             bool compare_order);
616 
617   // Checks `IteratorBase::GetNext()`.
618   Status CheckIteratorGetNext(const std::vector<Tensor>& expected_outputs,
619                               bool compare_order);
620 
621   // Checks `IteratorBase::GetNext()`.
622   Status CheckIteratorGetNext(TestIterator* iterator,
623                               const std::vector<Tensor>& expected_outputs,
624                               bool compare_order);
625 
626   // Checks `IteratorBase::GetNext()`.
627   Status CheckIteratorGetNext(IteratorBase* iterator, IteratorContext* ctx,
628                               const std::vector<Tensor>& expected_outputs,
629                               bool compare_order);
630 
631   // Checks `IteratorBase::Skip()`
632   Status CheckIteratorSkip(int num_to_skip, int expected_num_skipped,
633                            bool get_next,
634                            const std::vector<Tensor>& expected_outputs,
635                            bool compare_order);
636 
637   // Checks that iterating through the dataset using a split provider produces
638   // the expected outputs.
639   Status CheckSplitProviderFullIteration(
640       const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
641 
642   // Checks that iterating through the dataset using a sharded split provider
643   // with the given `num_shards` and `shard_index` produces the expected
644   // outputs.
645   Status CheckSplitProviderShardedIteration(
646       const DatasetParams& params, int64_t num_shards, int64_t shard_index,
647       const std::vector<Tensor>& expected_outputs);
648 
649   // Checks `DatasetBase::node_name()`.
650   Status CheckDatasetNodeName(const string& expected_dataset_node_name);
651 
652   // Checks `DatasetBase::type_string()`.
653   Status CheckDatasetTypeString(const string& expected_type_str);
654 
655   // Checks `DatasetBase::output_dtypes()`.
656   Status CheckDatasetOutputDtypes(const DataTypeVector& expected_output_dtypes);
657 
658   // Checks `DatasetBase::output_shapes()`.
659   Status CheckDatasetOutputShapes(
660       const std::vector<PartialTensorShape>& expected_output_shapes);
661 
662   // Checks `DatasetBase::Cardinality()`.
663   Status CheckDatasetCardinality(int expected_cardinality);
664 
665   // Checks `DatasetBase::options()`.
666   Status CheckDatasetOptions(const Options& expected_options);
667 
668   // Checks `IteratorBase::output_dtypes()`.
669   Status CheckIteratorOutputDtypes(
670       const DataTypeVector& expected_output_dtypes);
671 
672   // Checks `IteratorBase::output_shapes()`.
673   Status CheckIteratorOutputShapes(
674       const std::vector<PartialTensorShape>& expected_output_shapes);
675 
676   // Checks `IteratorBase::prefix()`.
677   Status CheckIteratorPrefix(const string& expected_iterator_prefix);
678 
679   Status CheckIteratorSaveAndRestore(
680       DatasetBase* dataset, IteratorContext* iterator_ctx,
681       const std::string& iterator_prefix,
682       const std::vector<Tensor>& expected_outputs,
683       const std::vector<int>& breakpoints, bool compare_order);
684 
685   Status CheckIteratorSaveAndRestore(
686       const std::string& iterator_prefix,
687       const std::vector<Tensor>& expected_outputs,
688       const std::vector<int>& breakpoints, bool compare_order);
689 
690  protected:
691   // Make destructor protected so that DatasetOpsTestBase objects cannot
692   // be instantiated directly. Only subclasses can be instantiated.
693   ~DatasetOpsTestBase() override;
694 
695   // Creates a thread pool for parallel tasks.
696   Status InitThreadPool(int thread_num);
697 
698   // Initializes the runtime for computing the dataset operation and registers
699   // the input function definitions. `InitThreadPool()' needs to be called
700   // before this method if we want to run the tasks in parallel.
701   Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
702                                     int cpu_num);
703 
704   // Creates a new op kernel based on the node definition.
705   Status CreateOpKernel(const NodeDef& node_def,
706                         std::unique_ptr<OpKernel>* op_kernel);
707 
708   // Creates a new op kernel context.
709   Status CreateDatasetContext(
710       OpKernel* const dateset_kernel,
711       gtl::InlinedVector<TensorValue, 4>* const inputs,
712       std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
713       std::unique_ptr<OpKernelContext>* dataset_context);
714 
715   // Creates a new dataset.
716   Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
717                        DatasetBase** const dataset);
718 
719   // Restores the state of the input iterator. It resets the iterator before
720   // restoring it to make sure the input iterator does not hold any
721   // resources or tasks. Otherwise, restoring an existing iterator may cause
722   // the timeout issue or duplicated elements.
723   Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader,
724                          const string& output_prefix,
725                          const DatasetBase& dataset,
726                          std::unique_ptr<IteratorBase>* iterator);
727 
728   // Fetches the dataset from the operation context.
729   Status GetDatasetFromContext(OpKernelContext* context, int output_index,
730                                DatasetBase** const dataset);
731 
732   // Runs an operation producing outputs.
733   Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
734 
735   // Executes a function producing outputs.
736   Status RunFunction(const FunctionDef& fdef, test::function::Attrs attrs,
737                      const std::vector<Tensor>& args,
738                      const GraphConstructorOptions& graph_options,
739                      std::vector<Tensor*> rets);
740 
741   // Checks that the size of `inputs` matches the requirement of the op kernel.
742   Status CheckOpKernelInput(const OpKernel& kernel,
743                             const gtl::InlinedVector<TensorValue, 4>& inputs);
744 
745   // Creates a new context for running the dataset operation.
746   Status CreateOpKernelContext(OpKernel* kernel,
747                                gtl::InlinedVector<TensorValue, 4>* inputs,
748                                std::unique_ptr<OpKernelContext>* context);
749 
750   // Creates a new context for running the dataset operation.
751   Status CreateOpKernelContext(OpKernel* kernel,
752                                gtl::InlinedVector<TensorValue, 4>* inputs,
753                                std::unique_ptr<OpKernelContext::Params>* params,
754                                std::unique_ptr<OpKernelContext>* context);
755 
756   // Creates a new iterator context for iterating the dataset.
757   Status CreateIteratorContext(
758       OpKernelContext* const op_context,
759       std::unique_ptr<IteratorContext>* iterator_context);
760 
761   // Creates a new iterator context for iterating the dataset.
762   // Creates a new serialization context for serializing the dataset and
763   // iterator.
764   Status CreateSerializationContext(
765       std::unique_ptr<SerializationContext>* context);
766 
767   // Creates the dataset op kernel.
768   Status MakeGetOptionsOpKernel(const DatasetParams& dataset_params,
769                                 std::unique_ptr<OpKernel>* op_kernel);
770 
771  private:
772   // Runs the dataset operation according to the predefined dataset params and
773   // the produced outputs will be stored in `dataset_ctx`.
774   Status RunDatasetOp(
775       const DatasetParams& dataset_params,
776       std::unique_ptr<OpKernel>* dataset_kernel,
777       std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
778       std::vector<std::unique_ptr<Tensor>>* created_tensors,
779       std::unique_ptr<OpKernelContext>* dataset_ctx);
780 
781   Status MakeDataset(
782       const DatasetParams& dataset_params,
783       std::unique_ptr<OpKernel>* dataset_kernel,
784       std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
785       std::unique_ptr<OpKernelContext>* dataset_ctx,
786       std::vector<std::unique_ptr<Tensor>>* created_tensors,
787       DatasetBase** dataset);
788 
789   // Creates the dataset op kernel.
790   Status MakeDatasetOpKernel(const DatasetParams& dataset_params,
791                              std::unique_ptr<OpKernel>* dataset_kernel);
792 
793   // Creates a dataset tensor according to the input dataset params.
794   Status MakeDatasetTensor(
795       const DatasetParams& dataset_params,
796       std::vector<std::unique_ptr<Tensor>>* created_tensors,
797       std::unique_ptr<Tensor>* dataset);
798 
799   // Adds an empty tensor with the specified dtype and shape to the input
800   // vector.
801   Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
802                          DataTypeVector input_types, DataType dtype,
803                          const TensorShape& shape);
804 
805  protected:
806   std::unique_ptr<Device> device_;
807   DeviceType device_type_;
808   int cpu_num_;
809   int thread_num_;
810   Allocator* allocator_;  // Owned by `AllocatorFactoryRegistry`.
811   std::vector<AllocatorAttributes> allocator_attrs_;
812   std::unique_ptr<ScopedStepContainer> step_container_;
813 
814   // Device manager is used by function handle cache and needs to outlive it.
815   std::unique_ptr<DeviceMgr> device_mgr_;
816   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
817   FunctionLibraryRuntime* flr_;  // Owned by `pflr_`.
818   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
819   std::function<void(std::function<void()>)> runner_;
820   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
821   std::unique_ptr<ResourceMgr> resource_mgr_;
822   std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
823       slice_reader_cache_;
824   std::unique_ptr<thread::ThreadPool> thread_pool_;
825   std::vector<std::unique_ptr<Tensor>> tensors_;  // Owns tensors.
826   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs.
827   std::unique_ptr<CancellationManager> cancellation_manager_;
828 
829   // Indicates if the below fields have been initialized.
830   bool initialized_ = false;
831   std::unique_ptr<OpKernel> dataset_kernel_;
832   std::unique_ptr<OpKernelContext::Params> params_;
833   std::unique_ptr<OpKernelContext> dataset_ctx_;
834   DatasetBase* dataset_ = nullptr;
835   std::unique_ptr<IteratorContext> iterator_ctx_;
836   std::unique_ptr<IteratorBase> iterator_;
837 };
838 
839 #define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
840                                  test_cases)                                  \
841   class ParameterizedGetNextTest                                              \
842       : public dataset_op_test_class,                                         \
843         public ::testing::WithParamInterface<                                 \
844             GetNextTestCase<dataset_params_class>> {};                        \
845                                                                               \
846   TEST_P(ParameterizedGetNextTest, GetNext) {                                 \
847     auto test_case = GetParam();                                              \
848     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
849     TF_ASSERT_OK(                                                             \
850         CheckIteratorGetNext(test_case.expected_outputs,                      \
851                              /*compare_order=*/test_case.compare_order));     \
852   }                                                                           \
853                                                                               \
854   INSTANTIATE_TEST_SUITE_P(                                                   \
855       dataset_op_test_class, ParameterizedGetNextTest,                        \
856       ::testing::ValuesIn(                                                    \
857           std::vector<GetNextTestCase<dataset_params_class>>(test_cases)));
858 
859 #define ITERATOR_SKIP_TEST_P(dataset_op_test_class, dataset_params_class,   \
860                              test_cases)                                    \
861   class ParameterizedSkipTest : public dataset_op_test_class,               \
862                                 public ::testing::WithParamInterface<       \
863                                     SkipTestCase<dataset_params_class>> {}; \
864                                                                             \
865   TEST_P(ParameterizedSkipTest, Skip) {                                     \
866     auto test_case = GetParam();                                            \
867     TF_ASSERT_OK(Initialize(test_case.dataset_params));                     \
868     TF_ASSERT_OK(CheckIteratorSkip(                                         \
869         test_case.num_to_skip, test_case.expected_num_skipped,              \
870         test_case.get_next, test_case.expected_outputs,                     \
871         /*compare_order=*/test_case.compare_order));                        \
872   }                                                                         \
873                                                                             \
874   INSTANTIATE_TEST_SUITE_P(                                                 \
875       dataset_op_test_class, ParameterizedSkipTest,                         \
876       ::testing::ValuesIn(                                                  \
877           std::vector<SkipTestCase<dataset_params_class>>(test_cases)));
878 
879 #define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
880                                  test_cases)                                  \
881   class ParameterizedDatasetNodeNameTest                                      \
882       : public dataset_op_test_class,                                         \
883         public ::testing::WithParamInterface<                                 \
884             DatasetNodeNameTestCase<dataset_params_class>> {};                \
885                                                                               \
886   TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) {                 \
887     auto test_case = GetParam();                                              \
888     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
889     TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name));         \
890   }                                                                           \
891                                                                               \
892   INSTANTIATE_TEST_SUITE_P(                                                   \
893       dataset_op_test_class, ParameterizedDatasetNodeNameTest,                \
894       ::testing::ValuesIn(                                                    \
895           std::vector<DatasetNodeNameTestCase<dataset_params_class>>(         \
896               test_cases)));
897 
898 #define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class,                \
899                                    dataset_params_class, test_cases)     \
900   class ParameterizedDatasetTypeStringTest                               \
901       : public dataset_op_test_class,                                    \
902         public ::testing::WithParamInterface<                            \
903             DatasetTypeStringTestCase<dataset_params_class>> {};         \
904                                                                          \
905   TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) {        \
906     auto test_case = GetParam();                                         \
907     TF_ASSERT_OK(Initialize(test_case.dataset_params));                  \
908     TF_ASSERT_OK(                                                        \
909         CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
910   }                                                                      \
911                                                                          \
912   INSTANTIATE_TEST_SUITE_P(                                              \
913       dataset_op_test_class, ParameterizedDatasetTypeStringTest,         \
914       ::testing::ValuesIn(                                               \
915           std::vector<DatasetTypeStringTestCase<dataset_params_class>>(  \
916               test_cases)));
917 
918 #define DATASET_OUTPUT_DTYPES_TEST_P(dataset_op_test_class,                   \
919                                      dataset_params_class, test_cases)        \
920                                                                               \
921   class ParameterizedDatasetOutputDtypesTest                                  \
922       : public dataset_op_test_class,                                         \
923         public ::testing::WithParamInterface<                                 \
924             DatasetOutputDtypesTestCase<dataset_params_class>> {};            \
925                                                                               \
926   TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) {         \
927     auto test_case = GetParam();                                              \
928     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
929     TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
930   }                                                                           \
931                                                                               \
932   INSTANTIATE_TEST_SUITE_P(                                                   \
933       dataset_op_test_class, ParameterizedDatasetOutputDtypesTest,            \
934       ::testing::ValuesIn(                                                    \
935           std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>(     \
936               test_cases)));
937 
938 #define DATASET_OUTPUT_SHAPES_TEST_P(dataset_op_test_class,                   \
939                                      dataset_params_class, test_cases)        \
940                                                                               \
941   class ParameterizedDatasetOutputShapesTest                                  \
942       : public dataset_op_test_class,                                         \
943         public ::testing::WithParamInterface<                                 \
944             DatasetOutputShapesTestCase<dataset_params_class>> {};            \
945                                                                               \
946   TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) {         \
947     auto test_case = GetParam();                                              \
948     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
949     TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
950   }                                                                           \
951                                                                               \
952   INSTANTIATE_TEST_SUITE_P(                                                   \
953       dataset_op_test_class, ParameterizedDatasetOutputShapesTest,            \
954       ::testing::ValuesIn(                                                    \
955           std::vector<DatasetOutputShapesTestCase<dataset_params_class>>(     \
956               test_cases)));
957 
958 #define DATASET_CARDINALITY_TEST_P(dataset_op_test_class,                  \
959                                    dataset_params_class, test_cases)       \
960                                                                            \
961   class ParameterizedCardinalityTest                                       \
962       : public dataset_op_test_class,                                      \
963         public ::testing::WithParamInterface<                              \
964             CardinalityTestCase<dataset_params_class>> {};                 \
965                                                                            \
966   TEST_P(ParameterizedCardinalityTest, Cardinality) {                      \
967     auto test_case = GetParam();                                           \
968     TF_ASSERT_OK(Initialize(test_case.dataset_params));                    \
969     TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
970   }                                                                        \
971                                                                            \
972   INSTANTIATE_TEST_SUITE_P(                                                \
973       dataset_op_test_class, ParameterizedCardinalityTest,                 \
974       ::testing::ValuesIn(                                                 \
975           std::vector<CardinalityTestCase<dataset_params_class>>(          \
976               test_cases)));
977 
978 #define ITERATOR_OUTPUT_DTYPES_TEST_P(dataset_op_test_class,                  \
979                                       dataset_params_class, test_cases)       \
980   class ParameterizedIteratorOutputDtypesTest                                 \
981       : public dataset_op_test_class,                                         \
982         public ::testing::WithParamInterface<                                 \
983             IteratorOutputDtypesTestCase<dataset_params_class>> {};           \
984                                                                               \
985   TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) {       \
986     auto test_case = GetParam();                                              \
987     TF_ASSERT_OK(Initialize(test_case.dataset_params));                       \
988     TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
989   }                                                                           \
990                                                                               \
991   INSTANTIATE_TEST_SUITE_P(                                                   \
992       dataset_op_test_class, ParameterizedIteratorOutputDtypesTest,           \
993       ::testing::ValuesIn(                                                    \
994           std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>(    \
995               test_cases)));
996 
997 #define ITERATOR_OUTPUT_SHAPES_TEST_P(dataset_op_test_class,                   \
998                                       dataset_params_class, test_cases)        \
999   class ParameterizedIteratorOutputShapesTest                                  \
1000       : public dataset_op_test_class,                                          \
1001         public ::testing::WithParamInterface<                                  \
1002             IteratorOutputShapesTestCase<dataset_params_class>> {};            \
1003                                                                                \
1004   TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) {        \
1005     auto test_case = GetParam();                                               \
1006     TF_ASSERT_OK(Initialize(test_case.dataset_params));                        \
1007     TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
1008   }                                                                            \
1009                                                                                \
1010   INSTANTIATE_TEST_SUITE_P(                                                    \
1011       dataset_op_test_class, ParameterizedIteratorOutputShapesTest,            \
1012       ::testing::ValuesIn(                                                     \
1013           std::vector<IteratorOutputShapesTestCase<dataset_params_class>>(     \
1014               test_cases)));
1015 
1016 #define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
1017                                test_cases)                                  \
1018   class ParameterizedIteratorPrefixTest                                     \
1019       : public dataset_op_test_class,                                       \
1020         public ::testing::WithParamInterface<                               \
1021             IteratorPrefixTestCase<dataset_params_class>> {};               \
1022                                                                             \
1023   TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) {                 \
1024     auto test_case = GetParam();                                            \
1025     TF_ASSERT_OK(Initialize(test_case.dataset_params));                     \
1026     TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix));  \
1027   }                                                                         \
1028                                                                             \
1029   INSTANTIATE_TEST_SUITE_P(                                                 \
1030       dataset_op_test_class, ParameterizedIteratorPrefixTest,               \
1031       ::testing::ValuesIn(                                                  \
1032           std::vector<IteratorPrefixTestCase<dataset_params_class>>(        \
1033               test_cases)));
1034 
1035 #define ITERATOR_SAVE_AND_RESTORE_TEST_P(dataset_op_test_class,              \
1036                                          dataset_params_class, test_cases)   \
1037   class ParameterizedIteratorSaveAndRestoreTest                              \
1038       : public dataset_op_test_class,                                        \
1039         public ::testing::WithParamInterface<                                \
1040             IteratorSaveAndRestoreTestCase<dataset_params_class>> {};        \
1041   TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {  \
1042     auto test_case = GetParam();                                             \
1043     TF_ASSERT_OK(Initialize(test_case.dataset_params));                      \
1044     TF_ASSERT_OK(CheckIteratorSaveAndRestore(                                \
1045         test_case.dataset_params.iterator_prefix(),                          \
1046         test_case.expected_outputs, test_case.breakpoints,                   \
1047         test_case.compare_order));                                           \
1048   }                                                                          \
1049   INSTANTIATE_TEST_SUITE_P(                                                  \
1050       dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest,        \
1051       ::testing::ValuesIn(                                                   \
1052           std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
1053               test_cases)));
1054 
1055 }  // namespace data
1056 }  // namespace tensorflow
1057 
1058 #endif  // TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
1059