• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/optimize_dataset_op.h"
13 
14 #include "tensorflow/core/data/dataset_test_base.h"
15 #include "tensorflow/core/kernels/data/range_dataset_op.h"
16 #include "tensorflow/core/kernels/data/take_dataset_op.h"
17 
18 namespace tensorflow {
19 namespace data {
20 namespace {
21 
22 constexpr char kNodeName[] = "optimize_dataset";
23 constexpr char kNoopElimination[] = "noop_elimination";
24 
25 class OptimizeDatasetParams : public DatasetParams {
26  public:
27   template <typename T>
OptimizeDatasetParams(T input_dataset_params,string optimizations,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,std::vector<tstring> optimization_configs,string node_name)28   OptimizeDatasetParams(T input_dataset_params, string optimizations,
29                         DataTypeVector output_dtypes,
30                         std::vector<PartialTensorShape> output_shapes,
31                         std::vector<tstring> optimization_configs,
32                         string node_name)
33       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
34                       std::move(node_name)),
35         optimizations_(std::move(optimizations)),
36         optimization_configs_(std::move(optimization_configs)) {
37     input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
38     iterator_prefix_ =
39         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
40                                    input_dataset_params.iterator_prefix());
41   }
42 
GetInputTensors() const43   std::vector<Tensor> GetInputTensors() const override {
44     return {CreateTensor<tstring>(TensorShape({1}), {optimizations_})};
45   }
46 
GetInputNames(std::vector<string> * input_names) const47   Status GetInputNames(std::vector<string>* input_names) const override {
48     *input_names = {OptimizeDatasetOp::kInputDataset,
49                     OptimizeDatasetOp::kOptimizations};
50     return Status::OK();
51   }
52 
GetAttributes(AttributeVector * attr_vector) const53   Status GetAttributes(AttributeVector* attr_vector) const override {
54     *attr_vector = {
55         {OptimizeDatasetOp::kOutputShapes, output_shapes_},
56         {OptimizeDatasetOp::kOutputTypes, output_dtypes_},
57         {OptimizeDatasetOp::kOptimizationConfigs, optimization_configs_}};
58     return Status::OK();
59   }
60 
dataset_type() const61   string dataset_type() const override {
62     return OptimizeDatasetOp::kDatasetType;
63   }
64 
65  private:
66   string optimizations_;
67   std::vector<tstring> optimization_configs_;
68 };
69 
70 class OptimizeDatasetOpTest : public DatasetOpsTestBase {};
71 
TEST_F(OptimizeDatasetOpTest,NoopElimination)72 TEST_F(OptimizeDatasetOpTest, NoopElimination) {
73   auto take_dataset_parmas =
74       TakeDatasetParams(RangeDatasetParams(-3, 3, 1),
75                         /*count=*/-3,
76                         /*output_dtypes=*/{DT_INT64},
77                         /*output_shapes=*/{PartialTensorShape({})},
78                         /*node_name=*/"take_dataset");
79   auto optimize_dataset_params =
80       OptimizeDatasetParams(std::move(take_dataset_parmas),
81                             /*optimizations=*/{kNoopElimination},
82                             /*output_dtypes=*/{DT_INT64},
83                             /*output_shapes=*/{PartialTensorShape({})},
84                             /*optimization_configs=*/{},
85                             /*node_name=*/kNodeName);
86   std::vector<Tensor> expected_outputs =
87       CreateTensors<int64>(TensorShape({}), {{-3}, {-2}, {-1}, {0}, {1}, {2}});
88 
89   TF_ASSERT_OK(Initialize(optimize_dataset_params));
90   TF_EXPECT_OK(CheckIteratorGetNext(expected_outputs, /*compare_order=*/true));
91 }
92 
93 }  // namespace
94 }  // namespace data
95 }  // namespace tensorflow
96