1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
17 #define TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
18
19 #include <stddef.h>
20
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/graph_constructor.h"
31 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
32 #include "tensorflow/core/data/name_utils.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/dataset.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function_handle_cache.h"
38 #include "tensorflow/core/framework/function_testlib.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/resource_mgr.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_shape.h"
43 #include "tensorflow/core/framework/tensor_testutil.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/lib/gtl/array_slice.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/io/zlib_compression_options.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/platform/status.h"
52 #include "tensorflow/core/platform/test.h"
53 #include "tensorflow/core/platform/threadpool.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
56
57 namespace tensorflow {
58 namespace data {
59
60 typedef std::vector<
61 std::pair<string, tensorflow::FunctionDefHelper::AttrValueWrapper>>
62 AttributeVector;
63
64 constexpr int kDefaultCPUNum = 2;
65 constexpr int kDefaultThreadNum = 2;
66
67 // Creates a tensor with the specified dtype, shape, and value.
68 template <typename T>
CreateTensor(const TensorShape & input_shape,gtl::ArraySlice<T> input_data)69 static Tensor CreateTensor(const TensorShape& input_shape,
70 gtl::ArraySlice<T> input_data) {
71 Tensor tensor(DataTypeToEnum<T>::value, input_shape);
72 test::FillValues<T>(&tensor, input_data);
73 return tensor;
74 }
75
76 // Creates a tensor with the specified dtype and shape, with values 0, 1, 2, ...
77 template <typename T>
CreateTensor(const TensorShape & input_shape)78 static Tensor CreateTensor(const TensorShape& input_shape) {
79 Tensor tensor(DataTypeToEnum<T>::value, input_shape);
80 test::FillIota<T>(&tensor, 0);
81 return tensor;
82 }
83
84 // Creates a vector of tensors with the specified dtype, shape, and values.
85 template <typename T>
CreateTensors(const TensorShape & shape,const std::vector<gtl::ArraySlice<T>> & values)86 std::vector<Tensor> CreateTensors(
87 const TensorShape& shape, const std::vector<gtl::ArraySlice<T>>& values) {
88 std::vector<Tensor> result;
89 result.reserve(values.size());
90 for (auto& value : values) {
91 result.emplace_back(CreateTensor<T>(shape, value));
92 }
93 return result;
94 }
95
96 enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 };
97
98 // Returns a string representation for the given compression type.
99 string ToString(CompressionType compression_type);
100
101 // Gets the specified zlib compression options according to the compression
102 // type. Note that `CompressionType::UNCOMPRESSED` is not supported because
103 // `ZlibCompressionOptions` does not have an option.
104 io::ZlibCompressionOptions GetZlibCompressionOptions(
105 CompressionType compression_type);
106
107 // Used to specify parameters when writing data into files with compression.
108 // `input_buffer_size` and `output_buffer_size` specify the input and output
109 // buffer size when ZLIB and GZIP compression is used.
110 struct CompressionParams {
111 CompressionType compression_type = CompressionType::UNCOMPRESSED;
112 int32 input_buffer_size = 0;
113 int32 output_buffer_size = 0;
114 };
115
116 // Writes the input data into the file without compression.
117 Status WriteDataToFile(const string& filename, const char* data);
118
119 // Writes the input data into the file with the specified compression.
120 Status WriteDataToFile(const string& filename, const char* data,
121 const CompressionParams& params);
122
123 // Writes the input data into the TFRecord file with the specified compression.
124 Status WriteDataToTFRecordFile(const string& filename,
125 const std::vector<absl::string_view>& records,
126 const CompressionParams& params);
127
128 // Provides the parameters for running the dataset op.
129 class DatasetParams {
130 public:
131 DatasetParams(DataTypeVector output_dtypes,
132 std::vector<PartialTensorShape> output_shapes,
133 string node_name);
134
~DatasetParams()135 virtual ~DatasetParams() {}
136
137 // Returns the inputs (except the input datasets) as a tensor vector.
138 virtual std::vector<Tensor> GetInputTensors() const = 0;
139
140 // Returns the dataset input names as a string vector.
141 virtual Status GetInputNames(std::vector<string>* input_names) const = 0;
142
143 // Returns the dataset attributes as a vector.
144 virtual Status GetAttributes(AttributeVector* attributes) const = 0;
145
146 // Checks if the tensor is a dataset variant tensor.
147 static bool IsDatasetTensor(const Tensor& tensor);
148
node_name()149 string node_name() const { return node_name_; }
150
output_dtypes()151 DataTypeVector output_dtypes() const { return output_dtypes_; }
152
output_shapes()153 std::vector<PartialTensorShape> output_shapes() const {
154 return output_shapes_;
155 }
156
iterator_prefix()157 string iterator_prefix() const { return iterator_prefix_; }
158
input_dataset_params()159 const std::vector<std::shared_ptr<DatasetParams>>& input_dataset_params()
160 const {
161 return input_dataset_params_;
162 }
163
164 // Returns the functions that will be used when running the dataset op.
func_lib()165 virtual std::vector<FunctionDef> func_lib() const { return {}; }
166
167 // Returns the dataset type for the op represented by these parameters. This
168 // type usually needs to match the constant called `kDatasetType` defined in
169 // the dataset kernel.
170 virtual string dataset_type() const = 0;
171
172 // Returns the dataset op name. By default, it returns the Op::kDatasetType
173 // concatenated with "Dataset". For ops that do not have "Dataset" suffix,
174 // this method can be overriden to return a different name.
op_name()175 virtual string op_name() const {
176 name_utils::OpNameParams params;
177 params.op_version = op_version();
178 return name_utils::OpName(dataset_type(), params);
179 }
180
op_version()181 virtual int op_version() const { return op_version_; }
182
183 protected:
184 std::vector<std::shared_ptr<DatasetParams>> input_dataset_params_;
185 DataTypeVector output_dtypes_;
186 std::vector<PartialTensorShape> output_shapes_;
187 string node_name_;
188 string iterator_prefix_ = "Iterator";
189 int op_version_ = 1;
190 };
191
192 // `RangeDatasetParams` is a common dataset parameter type that are used in
193 // testing.
194 class RangeDatasetParams : public DatasetParams {
195 public:
196 RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
197 DataTypeVector output_dtypes,
198 std::vector<PartialTensorShape> output_shapes,
199 string node_name);
200
201 RangeDatasetParams(int64_t start, int64_t stop, int64_t step);
202
203 RangeDatasetParams(int64_t start, int64_t stop, int64_t step,
204 DataTypeVector output_dtypes);
205
206 std::vector<Tensor> GetInputTensors() const override;
207
208 Status GetInputNames(std::vector<string>* input_names) const override;
209
210 Status GetAttributes(AttributeVector* attr_vector) const override;
211
212 string dataset_type() const override;
213
214 private:
215 int64_t start_;
216 int64_t stop_;
217 int64_t step_;
218 };
219
220 // `BatchDatasetParams` is a common dataset parameter type that are used in
221 // testing.
222 class BatchDatasetParams : public DatasetParams {
223 public:
224 template <typename T>
BatchDatasetParams(T input_dataset_params,int64_t batch_size,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)225 BatchDatasetParams(T input_dataset_params, int64_t batch_size,
226 bool drop_remainder, bool parallel_copy,
227 DataTypeVector output_dtypes,
228 std::vector<PartialTensorShape> output_shapes,
229 string node_name)
230 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
231 std::move(node_name)),
232 batch_size_(batch_size),
233 drop_remainder_(drop_remainder),
234 parallel_copy_(parallel_copy) {
235 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
236 op_version_ = 2;
237 iterator_prefix_ =
238 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
239 input_dataset_params.iterator_prefix());
240 }
241
242 std::vector<Tensor> GetInputTensors() const override;
243
244 Status GetInputNames(std::vector<string>* input_names) const override;
245
246 Status GetAttributes(AttributeVector* attr_vector) const override;
247
248 string dataset_type() const override;
249
250 private:
251 int64_t batch_size_;
252 bool drop_remainder_;
253 bool parallel_copy_;
254 };
255
256 // `MapDatasetParams` is a common dataset parameter type that are used in
257 // testing.
258 class MapDatasetParams : public DatasetParams {
259 public:
260 template <typename T>
MapDatasetParams(T input_dataset_params,std::vector<Tensor> other_arguments,FunctionDefHelper::AttrValueWrapper func,std::vector<FunctionDef> func_lib,DataTypeVector type_arguments,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,bool use_inter_op_parallelism,bool preserve_cardinality,string node_name)261 MapDatasetParams(T input_dataset_params, std::vector<Tensor> other_arguments,
262 FunctionDefHelper::AttrValueWrapper func,
263 std::vector<FunctionDef> func_lib,
264 DataTypeVector type_arguments, DataTypeVector output_dtypes,
265 std::vector<PartialTensorShape> output_shapes,
266 bool use_inter_op_parallelism, bool preserve_cardinality,
267 string node_name)
268 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
269 std::move(node_name)),
270 other_arguments_(std::move(other_arguments)),
271 func_(std::move(func)),
272 func_lib_(std::move(func_lib)),
273 type_arguments_(std::move(type_arguments)),
274 use_inter_op_parallelism_(use_inter_op_parallelism),
275 preserve_cardinality_(preserve_cardinality) {
276 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
277 iterator_prefix_ =
278 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
279 input_dataset_params.iterator_prefix());
280 }
281
282 std::vector<Tensor> GetInputTensors() const override;
283
284 Status GetInputNames(std::vector<string>* input_names) const override;
285
286 Status GetAttributes(AttributeVector* attr_vector) const override;
287
288 string dataset_type() const override;
289
290 std::vector<FunctionDef> func_lib() const override;
291
292 private:
293 std::vector<Tensor> other_arguments_;
294 FunctionDefHelper::AttrValueWrapper func_;
295 std::vector<FunctionDef> func_lib_;
296 DataTypeVector type_arguments_;
297 bool use_inter_op_parallelism_;
298 bool preserve_cardinality_;
299 };
300
301 // `TensorSliceDatasetParams` is a common dataset parameter type that are used
302 // in testing.
303 class TensorSliceDatasetParams : public DatasetParams {
304 public:
305 TensorSliceDatasetParams(std::vector<Tensor> components, string node_name,
306 bool is_files = false);
307
308 std::vector<Tensor> GetInputTensors() const override;
309
310 Status GetInputNames(std::vector<string>* input_names) const override;
311
312 Status GetAttributes(AttributeVector* attr_vector) const override;
313
314 string dataset_type() const override;
315
num_slices()316 int64_t num_slices() const { return components_[0].dim_size(0); }
317
num_tensors_per_slice()318 size_t num_tensors_per_slice() const { return components_.size(); }
319
320 private:
321 DataTypeVector TensorSliceDtypes(const std::vector<Tensor>& input_components);
322
323 std::vector<PartialTensorShape> TensorSliceShapes(
324 const std::vector<Tensor>& input_components);
325
326 public:
327 std::vector<Tensor> components_;
328 bool is_files_;
329 };
330
331 // `TakeDatasetParams` is a common dataset parameter type that are used in
332 // testing.
333 class TakeDatasetParams : public DatasetParams {
334 public:
335 template <typename T>
TakeDatasetParams(T input_dataset_params,int count,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)336 TakeDatasetParams(T input_dataset_params, int count,
337 DataTypeVector output_dtypes,
338 std::vector<PartialTensorShape> output_shapes,
339 string node_name)
340 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
341 std::move(node_name)),
342 count_(count) {
343 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
344 iterator_prefix_ =
345 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
346 input_dataset_params.iterator_prefix());
347 }
348
349 std::vector<Tensor> GetInputTensors() const override;
350
351 Status GetInputNames(std::vector<string>* input_names) const override;
352
353 Status GetAttributes(AttributeVector* attr_vector) const override;
354
355 string dataset_type() const override;
356
357 private:
358 int64_t count_;
359 };
360
361 // `ConcatenateDatasetParams` is a common dataset parameter type that are used
362 // in testing.
363 class ConcatenateDatasetParams : public DatasetParams {
364 public:
365 template <typename T, typename P>
ConcatenateDatasetParams(T input_dataset_params_0,P input_dataset_params_1,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)366 ConcatenateDatasetParams(T input_dataset_params_0, P input_dataset_params_1,
367 DataTypeVector output_dtypes,
368 std::vector<PartialTensorShape> output_shapes,
369 string node_name)
370 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
371 std::move(node_name)) {
372 input_dataset_params_.push_back(
373 std::make_unique<T>(input_dataset_params_0));
374 input_dataset_params_.push_back(
375 std::make_unique<T>(input_dataset_params_1));
376 iterator_prefix_ =
377 name_utils::IteratorPrefix(input_dataset_params_0.dataset_type(),
378 input_dataset_params_0.iterator_prefix());
379 }
380
381 std::vector<Tensor> GetInputTensors() const override;
382
383 Status GetInputNames(std::vector<string>* input_names) const override;
384
385 Status GetAttributes(AttributeVector* attr_vector) const override;
386
387 string dataset_type() const override;
388 };
389
390 // `OptionsDatasetParams` is a common dataset parameter type that is used in
391 // testing.
392 class OptionsDatasetParams : public DatasetParams {
393 public:
394 template <typename T>
OptionsDatasetParams(T input_dataset_params,const string & serialized_options,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)395 OptionsDatasetParams(T input_dataset_params, const string& serialized_options,
396 DataTypeVector output_dtypes,
397 std::vector<PartialTensorShape> output_shapes,
398 string node_name)
399 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
400 std::move(node_name)),
401 serialized_options_(serialized_options) {
402 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
403 }
404
405 std::vector<Tensor> GetInputTensors() const override;
406
407 Status GetInputNames(std::vector<string>* input_names) const override;
408
409 Status GetAttributes(AttributeVector* attr_vector) const override;
410
411 string dataset_type() const override;
412
413 private:
414 string serialized_options_;
415 };
416
417 template <typename T>
418 struct GetNextTestCase {
419 GetNextTestCase(T dataset_params, std::vector<Tensor> expected_outputs,
420 bool compare_order = true)
dataset_paramsGetNextTestCase421 : dataset_params(std::move(dataset_params)),
422 expected_outputs(std::move(expected_outputs)),
423 compare_order(compare_order) {}
424
425 T dataset_params;
426 std::vector<Tensor> expected_outputs;
427 bool compare_order;
428 };
429
430 template <typename T>
431 struct SkipTestCase {
432 SkipTestCase(T dataset_params, int num_to_skip, int expected_num_skipped,
433 bool get_next = false, std::vector<Tensor> expected_outputs = {},
434 bool compare_order = true)
dataset_paramsSkipTestCase435 : dataset_params(std::move(dataset_params)),
436 num_to_skip(num_to_skip),
437 expected_num_skipped(expected_num_skipped),
438 get_next(get_next),
439 expected_outputs(std::move(expected_outputs)),
440 compare_order(compare_order) {}
441
442 T dataset_params;
443 int num_to_skip;
444 int expected_num_skipped;
445 bool get_next;
446 std::vector<Tensor> expected_outputs;
447 bool compare_order;
448 };
449
450 template <typename T>
451 struct DatasetNodeNameTestCase {
452 T dataset_params;
453 string expected_node_name;
454 };
455
456 template <typename T>
457 struct DatasetTypeStringTestCase {
458 T dataset_params;
459 string expected_dataset_type_string;
460 };
461
462 template <typename T>
463 struct DatasetOutputDtypesTestCase {
464 T dataset_params;
465 DataTypeVector expected_output_dtypes;
466 };
467
468 template <typename T>
469 struct DatasetOutputShapesTestCase {
470 T dataset_params;
471 std::vector<PartialTensorShape> expected_output_shapes;
472 };
473
474 template <typename T>
475 struct CardinalityTestCase {
476 T dataset_params;
477 int64_t expected_cardinality;
478 };
479
480 template <typename T>
481 struct DatasetSaveTestCase {
482 T dataset_params;
483 };
484
485 template <typename T>
486 struct IteratorOutputDtypesTestCase {
487 T dataset_params;
488 DataTypeVector expected_output_dtypes;
489 };
490
491 template <typename T>
492 struct IteratorOutputShapesTestCase {
493 T dataset_params;
494 std::vector<PartialTensorShape> expected_output_shapes;
495 };
496
497 template <typename T>
498 struct IteratorPrefixTestCase {
499 T dataset_params;
500 string expected_iterator_prefix;
501 };
502
503 template <typename T>
504 struct IteratorSaveAndRestoreTestCase {
505 IteratorSaveAndRestoreTestCase(T dataset_params, std::vector<int> breakpoints,
506 std::vector<Tensor> expected_outputs,
507 bool compare_order = true)
dataset_paramsIteratorSaveAndRestoreTestCase508 : dataset_params(std::move(dataset_params)),
509 breakpoints(std::move(breakpoints)),
510 expected_outputs(std::move(expected_outputs)),
511 compare_order(compare_order) {}
512
513 T dataset_params;
514 std::vector<int> breakpoints;
515 std::vector<Tensor> expected_outputs;
516 bool compare_order;
517 };
518
519 // Class composing a dataset with its dependencies.
520 class TestDataset {
521 public:
522 // TestDataset expects that the caller has Ref'd the wrapped dataset. When
523 // TestDataset is destroyed, it will Unref the dataset.
TestDataset(std::unique_ptr<OpKernel> kernel_,std::unique_ptr<OpKernelContext::Params> ctx_params,std::unique_ptr<OpKernelContext> ctx,std::vector<std::unique_ptr<Tensor>> input_tensors,DatasetBase * dataset)524 TestDataset(std::unique_ptr<OpKernel> kernel_,
525 std::unique_ptr<OpKernelContext::Params> ctx_params,
526 std::unique_ptr<OpKernelContext> ctx,
527 std::vector<std::unique_ptr<Tensor>> input_tensors,
528 DatasetBase* dataset)
529 : kernel_(std::move(kernel_)),
530 ctx_params_(std::move(ctx_params)),
531 ctx_(std::move(ctx)),
532 input_tensors_(std::move(input_tensors)),
533 dataset_(dataset),
534 scoped_unref_(dataset) {}
535
dataset()536 DatasetBase* dataset() const { return dataset_; }
537
op_kernel_context()538 OpKernelContext* op_kernel_context() const { return ctx_.get(); }
539
540 protected:
541 std::unique_ptr<OpKernel> kernel_;
542 std::unique_ptr<OpKernelContext::Params> ctx_params_;
543 std::unique_ptr<OpKernelContext> ctx_;
544 // The input tensors that this dataset depends on. They must outlive the
545 // dataset.
546 std::vector<std::unique_ptr<Tensor>> input_tensors_;
547 DatasetBase* dataset_;
548 core::ScopedUnref scoped_unref_;
549 };
550
551 // Class composing a dataset iterator with its dependencies.
552 class TestIterator {
553 public:
TestIterator(std::unique_ptr<IteratorContext> ctx,std::unique_ptr<IteratorBase> iterator)554 TestIterator(std::unique_ptr<IteratorContext> ctx,
555 std::unique_ptr<IteratorBase> iterator)
556 : iterator_(std::move(iterator)), ctx_(std::move(ctx)) {}
557
iterator()558 IteratorBase* iterator() const { return iterator_.get(); }
559
ctx()560 IteratorContext* ctx() const { return ctx_.get(); }
561
GetNext(std::vector<Tensor> * out_tensors,bool * end_of_sequence)562 Status GetNext(std::vector<Tensor>* out_tensors, bool* end_of_sequence) {
563 return iterator_->GetNext(ctx(), out_tensors, end_of_sequence);
564 }
565
566 protected:
567 std::unique_ptr<IteratorBase> iterator_;
568 std::unique_ptr<IteratorContext> ctx_;
569 };
570
571 // Helpful functions to test Dataset op kernels.
572 class DatasetOpsTestBase : public ::testing::Test {
573 public:
574 DatasetOpsTestBase();
575
576 // Initializes the runtime and creates a dataset and iterator.
577 Status Initialize(const DatasetParams& dataset_params);
578
579 // Initializes the parts of the runtime needed to run dataset ops.
580 Status InitializeRuntime(const DatasetParams& dataset_params);
581
582 // Creates a dataset.
583 Status MakeDataset(const DatasetParams& dataset_params,
584 std::unique_ptr<TestDataset>* dataset);
585
586 // Creates an iterator for the given dataset, using the specified split
587 // providers.
588 Status MakeIterator(
589 const DatasetParams& dataset_params, const TestDataset& dataset,
590 std::vector<std::unique_ptr<SplitProvider>> split_providers,
591 std::unique_ptr<TestIterator>* iterator);
592 // Creates an iterator for the given dataset.
593 Status MakeIterator(const DatasetParams& dataset_params,
594 const TestDataset& dataset,
595 std::unique_ptr<TestIterator>* iterator);
596
597 // Runs the dataset operation according to the predefined dataset params and
598 // produces outputs. Different from `MakeDataset()` which returns a Dataset
599 // object, `RunDatasetOp()` executes the dataset kernel based on the input
600 // DatasetParams and returns the produced outputs as a tensor vector. It can
601 // be used to run some dataset operations that do not have an internal
602 // customized `Dataset` class (e.g. `ReduceDatasetOp`).
603 Status RunDatasetOp(const DatasetParams& dataset_params,
604 std::vector<Tensor>* outputs);
605
606 // The method validates whether the two tensors have the same shape, dtype,
607 // and value.
608 static Status ExpectEqual(const Tensor& a, const Tensor& b);
609
610 // The method validates whether the two tensor vectors have the same tensors.
611 // If `compare_order` is false, the method will only evaluate whether the two
612 // vectors have the same elements regardless of order.
613 static Status ExpectEqual(std::vector<Tensor> produced_tensors,
614 std::vector<Tensor> expected_tensors,
615 bool compare_order);
616
617 // Checks `IteratorBase::GetNext()`.
618 Status CheckIteratorGetNext(const std::vector<Tensor>& expected_outputs,
619 bool compare_order);
620
621 // Checks `IteratorBase::GetNext()`.
622 Status CheckIteratorGetNext(TestIterator* iterator,
623 const std::vector<Tensor>& expected_outputs,
624 bool compare_order);
625
626 // Checks `IteratorBase::GetNext()`.
627 Status CheckIteratorGetNext(IteratorBase* iterator, IteratorContext* ctx,
628 const std::vector<Tensor>& expected_outputs,
629 bool compare_order);
630
631 // Checks `IteratorBase::Skip()`
632 Status CheckIteratorSkip(int num_to_skip, int expected_num_skipped,
633 bool get_next,
634 const std::vector<Tensor>& expected_outputs,
635 bool compare_order);
636
637 // Checks that iterating through the dataset using a split provider produces
638 // the expected outputs.
639 Status CheckSplitProviderFullIteration(
640 const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
641
642 // Checks that iterating through the dataset using a sharded split provider
643 // with the given `num_shards` and `shard_index` produces the expected
644 // outputs.
645 Status CheckSplitProviderShardedIteration(
646 const DatasetParams& params, int64_t num_shards, int64_t shard_index,
647 const std::vector<Tensor>& expected_outputs);
648
649 // Checks `DatasetBase::node_name()`.
650 Status CheckDatasetNodeName(const string& expected_dataset_node_name);
651
652 // Checks `DatasetBase::type_string()`.
653 Status CheckDatasetTypeString(const string& expected_type_str);
654
655 // Checks `DatasetBase::output_dtypes()`.
656 Status CheckDatasetOutputDtypes(const DataTypeVector& expected_output_dtypes);
657
658 // Checks `DatasetBase::output_shapes()`.
659 Status CheckDatasetOutputShapes(
660 const std::vector<PartialTensorShape>& expected_output_shapes);
661
662 // Checks `DatasetBase::Cardinality()`.
663 Status CheckDatasetCardinality(int expected_cardinality);
664
665 // Checks `DatasetBase::options()`.
666 Status CheckDatasetOptions(const Options& expected_options);
667
668 // Checks `IteratorBase::output_dtypes()`.
669 Status CheckIteratorOutputDtypes(
670 const DataTypeVector& expected_output_dtypes);
671
672 // Checks `IteratorBase::output_shapes()`.
673 Status CheckIteratorOutputShapes(
674 const std::vector<PartialTensorShape>& expected_output_shapes);
675
676 // Checks `IteratorBase::prefix()`.
677 Status CheckIteratorPrefix(const string& expected_iterator_prefix);
678
679 Status CheckIteratorSaveAndRestore(
680 DatasetBase* dataset, IteratorContext* iterator_ctx,
681 const std::string& iterator_prefix,
682 const std::vector<Tensor>& expected_outputs,
683 const std::vector<int>& breakpoints, bool compare_order);
684
685 Status CheckIteratorSaveAndRestore(
686 const std::string& iterator_prefix,
687 const std::vector<Tensor>& expected_outputs,
688 const std::vector<int>& breakpoints, bool compare_order);
689
690 protected:
691 // Make destructor protected so that DatasetOpsTestBase objects cannot
692 // be instantiated directly. Only subclasses can be instantiated.
693 ~DatasetOpsTestBase() override;
694
695 // Creates a thread pool for parallel tasks.
696 Status InitThreadPool(int thread_num);
697
698 // Initializes the runtime for computing the dataset operation and registers
699 // the input function definitions. `InitThreadPool()' needs to be called
700 // before this method if we want to run the tasks in parallel.
701 Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
702 int cpu_num);
703
704 // Creates a new op kernel based on the node definition.
705 Status CreateOpKernel(const NodeDef& node_def,
706 std::unique_ptr<OpKernel>* op_kernel);
707
708 // Creates a new op kernel context.
709 Status CreateDatasetContext(
710 OpKernel* const dateset_kernel,
711 gtl::InlinedVector<TensorValue, 4>* const inputs,
712 std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
713 std::unique_ptr<OpKernelContext>* dataset_context);
714
715 // Creates a new dataset.
716 Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
717 DatasetBase** const dataset);
718
719 // Restores the state of the input iterator. It resets the iterator before
720 // restoring it to make sure the input iterator does not hold any
721 // resources or tasks. Otherwise, restoring an existing iterator may cause
722 // the timeout issue or duplicated elements.
723 Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader,
724 const string& output_prefix,
725 const DatasetBase& dataset,
726 std::unique_ptr<IteratorBase>* iterator);
727
728 // Fetches the dataset from the operation context.
729 Status GetDatasetFromContext(OpKernelContext* context, int output_index,
730 DatasetBase** const dataset);
731
732 // Runs an operation producing outputs.
733 Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
734
735 // Executes a function producing outputs.
736 Status RunFunction(const FunctionDef& fdef, test::function::Attrs attrs,
737 const std::vector<Tensor>& args,
738 const GraphConstructorOptions& graph_options,
739 std::vector<Tensor*> rets);
740
741 // Checks that the size of `inputs` matches the requirement of the op kernel.
742 Status CheckOpKernelInput(const OpKernel& kernel,
743 const gtl::InlinedVector<TensorValue, 4>& inputs);
744
745 // Creates a new context for running the dataset operation.
746 Status CreateOpKernelContext(OpKernel* kernel,
747 gtl::InlinedVector<TensorValue, 4>* inputs,
748 std::unique_ptr<OpKernelContext>* context);
749
750 // Creates a new context for running the dataset operation.
751 Status CreateOpKernelContext(OpKernel* kernel,
752 gtl::InlinedVector<TensorValue, 4>* inputs,
753 std::unique_ptr<OpKernelContext::Params>* params,
754 std::unique_ptr<OpKernelContext>* context);
755
756 // Creates a new iterator context for iterating the dataset.
757 Status CreateIteratorContext(
758 OpKernelContext* const op_context,
759 std::unique_ptr<IteratorContext>* iterator_context);
760
761 // Creates a new iterator context for iterating the dataset.
762 // Creates a new serialization context for serializing the dataset and
763 // iterator.
764 Status CreateSerializationContext(
765 std::unique_ptr<SerializationContext>* context);
766
767 // Creates the dataset op kernel.
768 Status MakeGetOptionsOpKernel(const DatasetParams& dataset_params,
769 std::unique_ptr<OpKernel>* op_kernel);
770
771 private:
772 // Runs the dataset operation according to the predefined dataset params and
773 // the produced outputs will be stored in `dataset_ctx`.
774 Status RunDatasetOp(
775 const DatasetParams& dataset_params,
776 std::unique_ptr<OpKernel>* dataset_kernel,
777 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
778 std::vector<std::unique_ptr<Tensor>>* created_tensors,
779 std::unique_ptr<OpKernelContext>* dataset_ctx);
780
781 Status MakeDataset(
782 const DatasetParams& dataset_params,
783 std::unique_ptr<OpKernel>* dataset_kernel,
784 std::unique_ptr<OpKernelContext::Params>* dataset_ctx_params,
785 std::unique_ptr<OpKernelContext>* dataset_ctx,
786 std::vector<std::unique_ptr<Tensor>>* created_tensors,
787 DatasetBase** dataset);
788
789 // Creates the dataset op kernel.
790 Status MakeDatasetOpKernel(const DatasetParams& dataset_params,
791 std::unique_ptr<OpKernel>* dataset_kernel);
792
793 // Creates a dataset tensor according to the input dataset params.
794 Status MakeDatasetTensor(
795 const DatasetParams& dataset_params,
796 std::vector<std::unique_ptr<Tensor>>* created_tensors,
797 std::unique_ptr<Tensor>* dataset);
798
799 // Adds an empty tensor with the specified dtype and shape to the input
800 // vector.
801 Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
802 DataTypeVector input_types, DataType dtype,
803 const TensorShape& shape);
804
805 protected:
806 std::unique_ptr<Device> device_;
807 DeviceType device_type_;
808 int cpu_num_;
809 int thread_num_;
810 Allocator* allocator_; // Owned by `AllocatorFactoryRegistry`.
811 std::vector<AllocatorAttributes> allocator_attrs_;
812 std::unique_ptr<ScopedStepContainer> step_container_;
813
814 // Device manager is used by function handle cache and needs to outlive it.
815 std::unique_ptr<DeviceMgr> device_mgr_;
816 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
817 FunctionLibraryRuntime* flr_; // Owned by `pflr_`.
818 std::unique_ptr<FunctionHandleCache> function_handle_cache_;
819 std::function<void(std::function<void()>)> runner_;
820 std::unique_ptr<FunctionLibraryDefinition> lib_def_;
821 std::unique_ptr<ResourceMgr> resource_mgr_;
822 std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
823 slice_reader_cache_;
824 std::unique_ptr<thread::ThreadPool> thread_pool_;
825 std::vector<std::unique_ptr<Tensor>> tensors_; // Owns tensors.
826 mutex lock_for_refs_; // Used as the Mutex for inputs added as refs.
827 std::unique_ptr<CancellationManager> cancellation_manager_;
828
829 // Indicates if the below fields have been initialized.
830 bool initialized_ = false;
831 std::unique_ptr<OpKernel> dataset_kernel_;
832 std::unique_ptr<OpKernelContext::Params> params_;
833 std::unique_ptr<OpKernelContext> dataset_ctx_;
834 DatasetBase* dataset_ = nullptr;
835 std::unique_ptr<IteratorContext> iterator_ctx_;
836 std::unique_ptr<IteratorBase> iterator_;
837 };
838
839 #define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \
840 test_cases) \
841 class ParameterizedGetNextTest \
842 : public dataset_op_test_class, \
843 public ::testing::WithParamInterface< \
844 GetNextTestCase<dataset_params_class>> {}; \
845 \
846 TEST_P(ParameterizedGetNextTest, GetNext) { \
847 auto test_case = GetParam(); \
848 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
849 TF_ASSERT_OK( \
850 CheckIteratorGetNext(test_case.expected_outputs, \
851 /*compare_order=*/test_case.compare_order)); \
852 } \
853 \
854 INSTANTIATE_TEST_SUITE_P( \
855 dataset_op_test_class, ParameterizedGetNextTest, \
856 ::testing::ValuesIn( \
857 std::vector<GetNextTestCase<dataset_params_class>>(test_cases)));
858
859 #define ITERATOR_SKIP_TEST_P(dataset_op_test_class, dataset_params_class, \
860 test_cases) \
861 class ParameterizedSkipTest : public dataset_op_test_class, \
862 public ::testing::WithParamInterface< \
863 SkipTestCase<dataset_params_class>> {}; \
864 \
865 TEST_P(ParameterizedSkipTest, Skip) { \
866 auto test_case = GetParam(); \
867 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
868 TF_ASSERT_OK(CheckIteratorSkip( \
869 test_case.num_to_skip, test_case.expected_num_skipped, \
870 test_case.get_next, test_case.expected_outputs, \
871 /*compare_order=*/test_case.compare_order)); \
872 } \
873 \
874 INSTANTIATE_TEST_SUITE_P( \
875 dataset_op_test_class, ParameterizedSkipTest, \
876 ::testing::ValuesIn( \
877 std::vector<SkipTestCase<dataset_params_class>>(test_cases)));
878
879 #define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \
880 test_cases) \
881 class ParameterizedDatasetNodeNameTest \
882 : public dataset_op_test_class, \
883 public ::testing::WithParamInterface< \
884 DatasetNodeNameTestCase<dataset_params_class>> {}; \
885 \
886 TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) { \
887 auto test_case = GetParam(); \
888 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
889 TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name)); \
890 } \
891 \
892 INSTANTIATE_TEST_SUITE_P( \
893 dataset_op_test_class, ParameterizedDatasetNodeNameTest, \
894 ::testing::ValuesIn( \
895 std::vector<DatasetNodeNameTestCase<dataset_params_class>>( \
896 test_cases)));
897
898 #define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class, \
899 dataset_params_class, test_cases) \
900 class ParameterizedDatasetTypeStringTest \
901 : public dataset_op_test_class, \
902 public ::testing::WithParamInterface< \
903 DatasetTypeStringTestCase<dataset_params_class>> {}; \
904 \
905 TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) { \
906 auto test_case = GetParam(); \
907 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
908 TF_ASSERT_OK( \
909 CheckDatasetTypeString(test_case.expected_dataset_type_string)); \
910 } \
911 \
912 INSTANTIATE_TEST_SUITE_P( \
913 dataset_op_test_class, ParameterizedDatasetTypeStringTest, \
914 ::testing::ValuesIn( \
915 std::vector<DatasetTypeStringTestCase<dataset_params_class>>( \
916 test_cases)));
917
918 #define DATASET_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \
919 dataset_params_class, test_cases) \
920 \
921 class ParameterizedDatasetOutputDtypesTest \
922 : public dataset_op_test_class, \
923 public ::testing::WithParamInterface< \
924 DatasetOutputDtypesTestCase<dataset_params_class>> {}; \
925 \
926 TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) { \
927 auto test_case = GetParam(); \
928 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
929 TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
930 } \
931 \
932 INSTANTIATE_TEST_SUITE_P( \
933 dataset_op_test_class, ParameterizedDatasetOutputDtypesTest, \
934 ::testing::ValuesIn( \
935 std::vector<DatasetOutputDtypesTestCase<dataset_params_class>>( \
936 test_cases)));
937
938 #define DATASET_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \
939 dataset_params_class, test_cases) \
940 \
941 class ParameterizedDatasetOutputShapesTest \
942 : public dataset_op_test_class, \
943 public ::testing::WithParamInterface< \
944 DatasetOutputShapesTestCase<dataset_params_class>> {}; \
945 \
946 TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { \
947 auto test_case = GetParam(); \
948 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
949 TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \
950 } \
951 \
952 INSTANTIATE_TEST_SUITE_P( \
953 dataset_op_test_class, ParameterizedDatasetOutputShapesTest, \
954 ::testing::ValuesIn( \
955 std::vector<DatasetOutputShapesTestCase<dataset_params_class>>( \
956 test_cases)));
957
958 #define DATASET_CARDINALITY_TEST_P(dataset_op_test_class, \
959 dataset_params_class, test_cases) \
960 \
961 class ParameterizedCardinalityTest \
962 : public dataset_op_test_class, \
963 public ::testing::WithParamInterface< \
964 CardinalityTestCase<dataset_params_class>> {}; \
965 \
966 TEST_P(ParameterizedCardinalityTest, Cardinality) { \
967 auto test_case = GetParam(); \
968 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
969 TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \
970 } \
971 \
972 INSTANTIATE_TEST_SUITE_P( \
973 dataset_op_test_class, ParameterizedCardinalityTest, \
974 ::testing::ValuesIn( \
975 std::vector<CardinalityTestCase<dataset_params_class>>( \
976 test_cases)));
977
978 #define ITERATOR_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \
979 dataset_params_class, test_cases) \
980 class ParameterizedIteratorOutputDtypesTest \
981 : public dataset_op_test_class, \
982 public ::testing::WithParamInterface< \
983 IteratorOutputDtypesTestCase<dataset_params_class>> {}; \
984 \
985 TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) { \
986 auto test_case = GetParam(); \
987 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
988 TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \
989 } \
990 \
991 INSTANTIATE_TEST_SUITE_P( \
992 dataset_op_test_class, ParameterizedIteratorOutputDtypesTest, \
993 ::testing::ValuesIn( \
994 std::vector<IteratorOutputDtypesTestCase<dataset_params_class>>( \
995 test_cases)));
996
997 #define ITERATOR_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \
998 dataset_params_class, test_cases) \
999 class ParameterizedIteratorOutputShapesTest \
1000 : public dataset_op_test_class, \
1001 public ::testing::WithParamInterface< \
1002 IteratorOutputShapesTestCase<dataset_params_class>> {}; \
1003 \
1004 TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { \
1005 auto test_case = GetParam(); \
1006 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
1007 TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \
1008 } \
1009 \
1010 INSTANTIATE_TEST_SUITE_P( \
1011 dataset_op_test_class, ParameterizedIteratorOutputShapesTest, \
1012 ::testing::ValuesIn( \
1013 std::vector<IteratorOutputShapesTestCase<dataset_params_class>>( \
1014 test_cases)));
1015
1016 #define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \
1017 test_cases) \
1018 class ParameterizedIteratorPrefixTest \
1019 : public dataset_op_test_class, \
1020 public ::testing::WithParamInterface< \
1021 IteratorPrefixTestCase<dataset_params_class>> {}; \
1022 \
1023 TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) { \
1024 auto test_case = GetParam(); \
1025 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
1026 TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix)); \
1027 } \
1028 \
1029 INSTANTIATE_TEST_SUITE_P( \
1030 dataset_op_test_class, ParameterizedIteratorPrefixTest, \
1031 ::testing::ValuesIn( \
1032 std::vector<IteratorPrefixTestCase<dataset_params_class>>( \
1033 test_cases)));
1034
1035 #define ITERATOR_SAVE_AND_RESTORE_TEST_P(dataset_op_test_class, \
1036 dataset_params_class, test_cases) \
1037 class ParameterizedIteratorSaveAndRestoreTest \
1038 : public dataset_op_test_class, \
1039 public ::testing::WithParamInterface< \
1040 IteratorSaveAndRestoreTestCase<dataset_params_class>> {}; \
1041 TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { \
1042 auto test_case = GetParam(); \
1043 TF_ASSERT_OK(Initialize(test_case.dataset_params)); \
1044 TF_ASSERT_OK(CheckIteratorSaveAndRestore( \
1045 test_case.dataset_params.iterator_prefix(), \
1046 test_case.expected_outputs, test_case.breakpoints, \
1047 test_case.compare_order)); \
1048 } \
1049 INSTANTIATE_TEST_SUITE_P( \
1050 dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest, \
1051 ::testing::ValuesIn( \
1052 std::vector<IteratorSaveAndRestoreTestCase<dataset_params_class>>( \
1053 test_cases)));
1054
1055 } // namespace data
1056 } // namespace tensorflow
1057
1058 #endif // TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_
1059