1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/optimize_dataset_op.h"
13
14 #include "tensorflow/core/data/dataset_test_base.h"
15 #include "tensorflow/core/kernels/data/range_dataset_op.h"
16 #include "tensorflow/core/kernels/data/take_dataset_op.h"
17
18 namespace tensorflow {
19 namespace data {
20 namespace {
21
22 constexpr char kNodeName[] = "optimize_dataset";
23 constexpr char kNoopElimination[] = "noop_elimination";
24
25 class OptimizeDatasetParams : public DatasetParams {
26 public:
27 template <typename T>
OptimizeDatasetParams(T input_dataset_params,string optimizations,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,std::vector<tstring> optimization_configs,string node_name)28 OptimizeDatasetParams(T input_dataset_params, string optimizations,
29 DataTypeVector output_dtypes,
30 std::vector<PartialTensorShape> output_shapes,
31 std::vector<tstring> optimization_configs,
32 string node_name)
33 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
34 std::move(node_name)),
35 optimizations_(std::move(optimizations)),
36 optimization_configs_(std::move(optimization_configs)) {
37 input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
38 iterator_prefix_ =
39 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
40 input_dataset_params.iterator_prefix());
41 }
42
GetInputTensors() const43 std::vector<Tensor> GetInputTensors() const override {
44 return {CreateTensor<tstring>(TensorShape({1}), {optimizations_})};
45 }
46
GetInputNames(std::vector<string> * input_names) const47 Status GetInputNames(std::vector<string>* input_names) const override {
48 *input_names = {OptimizeDatasetOp::kInputDataset,
49 OptimizeDatasetOp::kOptimizations};
50 return Status::OK();
51 }
52
GetAttributes(AttributeVector * attr_vector) const53 Status GetAttributes(AttributeVector* attr_vector) const override {
54 *attr_vector = {
55 {OptimizeDatasetOp::kOutputShapes, output_shapes_},
56 {OptimizeDatasetOp::kOutputTypes, output_dtypes_},
57 {OptimizeDatasetOp::kOptimizationConfigs, optimization_configs_}};
58 return Status::OK();
59 }
60
dataset_type() const61 string dataset_type() const override {
62 return OptimizeDatasetOp::kDatasetType;
63 }
64
65 private:
66 string optimizations_;
67 std::vector<tstring> optimization_configs_;
68 };
69
70 class OptimizeDatasetOpTest : public DatasetOpsTestBase {};
71
TEST_F(OptimizeDatasetOpTest,NoopElimination)72 TEST_F(OptimizeDatasetOpTest, NoopElimination) {
73 auto take_dataset_parmas =
74 TakeDatasetParams(RangeDatasetParams(-3, 3, 1),
75 /*count=*/-3,
76 /*output_dtypes=*/{DT_INT64},
77 /*output_shapes=*/{PartialTensorShape({})},
78 /*node_name=*/"take_dataset");
79 auto optimize_dataset_params =
80 OptimizeDatasetParams(std::move(take_dataset_parmas),
81 /*optimizations=*/{kNoopElimination},
82 /*output_dtypes=*/{DT_INT64},
83 /*output_shapes=*/{PartialTensorShape({})},
84 /*optimization_configs=*/{},
85 /*node_name=*/kNodeName);
86 std::vector<Tensor> expected_outputs =
87 CreateTensors<int64>(TensorShape({}), {{-3}, {-2}, {-1}, {0}, {1}, {2}});
88
89 TF_ASSERT_OK(Initialize(optimize_dataset_params));
90 TF_EXPECT_OK(CheckIteratorGetNext(expected_outputs, /*compare_order=*/true));
91 }
92
93 } // namespace
94 } // namespace data
95 } // namespace tensorflow
96