• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/optimize_dataset_op.h"
16 
17 // On mobile we do not provide optimize dataset op because not all of its
18 // dependencies are available there. The op is replaced with a no-op.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include <map>
21 
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/rewrite_utils.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/platform/host_info.h"
28 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
29 
30 namespace tensorflow {
31 namespace data {
32 
33 /* static */ constexpr const char* const OptimizeDatasetOp::kDatasetType;
34 /* static */ constexpr const char* const OptimizeDatasetOp::kInputDataset;
35 /* static */ constexpr const char* const OptimizeDatasetOp::kOptimizations;
36 /* static */ constexpr const char* const
37     OptimizeDatasetOp::kOptimizationsEnabled;
38 /* static */ constexpr const char* const
39     OptimizeDatasetOp::kOptimizationsDisabled;
40 /* static */ constexpr const char* const
41     OptimizeDatasetOp::kOptimizationsDefault;
42 /* static */ constexpr const char* const OptimizeDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const OptimizeDatasetOp::kOutputShapes;
44 /* static */ constexpr const char* const
45     OptimizeDatasetOp::kOptimizationConfigs;
46 /* static */ constexpr const char* const OptimizeDatasetOp::kOptimizeDatasetV1;
47 /* static */ constexpr const char* const OptimizeDatasetOp::kOptimizeDatasetV2;
48 
49 namespace {
50 
51 // Applies given optimizations and optimizatin_config in dataset graph rewrite
52 // to return the OptimizeDataset.
MakeDatasetHelper(OpKernelContext * ctx,absl::flat_hash_set<tstring> & optimizations,const absl::flat_hash_set<tstring> & optimization_configs,DatasetBase * input,DatasetBase ** output)53 void MakeDatasetHelper(OpKernelContext* ctx,
54                        absl::flat_hash_set<tstring>& optimizations,
55                        const absl::flat_hash_set<tstring>& optimization_configs,
56                        DatasetBase* input, DatasetBase** output) {
57   // The vector stores the graduated experiment names which will be turned on
58   // for all input pipelines.
59   // clang-format off
60   std::vector<string> graduated_experiments = {
61     "disable_intra_op_parallelism",
62     "use_private_thread_pool"
63   };
64   // clang-format on
65 
66   // Add the graduated experiments to the optimization list and log them.
67   for (auto& experiment : graduated_experiments) {
68     if (!optimizations.contains(experiment)) {
69       optimizations.insert(experiment);
70     }
71     VLOG(1) << "The graduated experiment \"" << experiment << "\" is applied.";
72   }
73 
74   // If there are no optimizations to be applied, directly return the input.
75   if (optimizations.empty()) {
76     *output = input;
77     input->Ref();
78     return;
79   }
80 
81   auto config_factory = [&optimizations, &optimization_configs]() {
82     return CreateRewriterConfig(optimizations, optimization_configs);
83   };
84   Status s = RewriteDataset(ctx, input, std::move(config_factory),
85                             /*record_fingerprint=*/true, output);
86   if (errors::IsDeadlineExceeded(s)) {
87     // Ignore DeadlineExceeded as it implies that the attempted rewrite took too
88     // long which should not prevent further computation.
89     LOG(WARNING) << s.ToString();
90 
91     *output = input;
92     input->Ref();
93     return;
94   }
95   OP_REQUIRES_OK(ctx, s);
96 }
97 
98 }  // namespace
99 
100 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_disabled,const absl::flat_hash_set<tstring> & optimizations_default,const absl::flat_hash_set<tstring> & optimization_configs,DatasetBase ** output)101 void OptimizeDatasetOp::MakeDatasetFromOptions(
102     OpKernelContext* ctx, DatasetBase* input,
103     const absl::flat_hash_set<tstring>& optimizations_enabled,
104     const absl::flat_hash_set<tstring>& optimizations_disabled,
105     const absl::flat_hash_set<tstring>& optimizations_default,
106     const absl::flat_hash_set<tstring>& optimization_configs,
107     DatasetBase** output) {
108   auto experiments = GetExperiments();
109   LogAndRecordExperiments(experiments);
110   auto optimizations =
111       SelectOptimizations(experiments, optimizations_enabled,
112                           optimizations_disabled, optimizations_default);
113   MakeDatasetHelper(ctx, optimizations, optimization_configs, input, output);
114 }
115 
OptimizeDatasetOp(OpKernelConstruction * ctx)116 OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
117     : UnaryDatasetOpKernel(ctx) {
118   auto& op_name = ctx->def().op();
119   if (op_name == kOptimizeDatasetV1) {
120     op_version_ = 1;
121   } else if (op_name == kOptimizeDatasetV2) {
122     op_version_ = 2;
123   }
124   std::vector<tstring> optimization_configs;
125   OP_REQUIRES_OK(ctx,
126                  ctx->GetAttr(kOptimizationConfigs, &optimization_configs));
127   optimization_configs_.insert(optimization_configs.begin(),
128                                optimization_configs.end());
129 }
130 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)131 void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
132                                     DatasetBase** output) {
133   absl::flat_hash_set<tstring> optimizations;
134   if (op_version_ == 1) {
135     std::vector<tstring> optimizations_enabled;
136     OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizations,
137                                                      &optimizations_enabled));
138     optimizations.insert(optimizations_enabled.begin(),
139                          optimizations_enabled.end());
140   } else if (op_version_ == 2) {
141     std::vector<tstring> optimizations_enabled, optimizations_disabled,
142         optimizations_default;
143     OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizationsEnabled,
144                                                      &optimizations_enabled));
145     OP_REQUIRES_OK(ctx,
146                    ParseVectorArgument<tstring>(ctx, kOptimizationsDisabled,
147                                                 &optimizations_disabled));
148     OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizationsDefault,
149                                                      &optimizations_default));
150     auto experiments = GetExperiments();
151     LogAndRecordExperiments(experiments);
152     optimizations = SelectOptimizations(
153         experiments,
154         {optimizations_enabled.begin(), optimizations_enabled.end()},
155         {optimizations_disabled.begin(), optimizations_disabled.end()},
156         {optimizations_default.begin(), optimizations_default.end()});
157   }
158   MakeDatasetHelper(
159       ctx, optimizations,
160       {optimization_configs_.begin(), optimization_configs_.end()}, input,
161       output);
162 }
163 
164 namespace {
165 REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
166                         OptimizeDatasetOp);
167 REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU),
168                         OptimizeDatasetOp);
169 }  // namespace
170 }  // namespace data
171 }  // namespace tensorflow
172 #else   // !IS_MOBILE_PLATFORM
173 namespace tensorflow {
174 namespace data {
175 
176 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_disabled,const absl::flat_hash_set<tstring> & optimizations_default,const absl::flat_hash_set<tstring> & optimization_configs,DatasetBase ** output)177 void OptimizeDatasetOp::MakeDatasetFromOptions(
178     OpKernelContext* ctx, DatasetBase* input,
179     const absl::flat_hash_set<tstring>& optimizations_enabled,
180     const absl::flat_hash_set<tstring>& optimizations_disabled,
181     const absl::flat_hash_set<tstring>& optimizations_default,
182     const absl::flat_hash_set<tstring>& optimization_configs,
183     DatasetBase** output) {
184   input->Ref();
185   *output = input;
186 }
187 
OptimizeDatasetOp(OpKernelConstruction * ctx)188 OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
189     : UnaryDatasetOpKernel(ctx) {}
190 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)191 void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
192                                     DatasetBase** output) {
193   input->Ref();
194   *output = input;
195 }
196 
197 namespace {
198 REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
199                         OptimizeDatasetOp);
200 REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU),
201                         OptimizeDatasetOp);
202 }  // namespace
203 }  // namespace data
204 }  // namespace tensorflow
205 #endif  // !IS_MOBILE_PLATFORM
206