1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/optimize_dataset_op.h"
16
17 // On mobile we do not provide optimize dataset op because not all of its
18 // dependencies are available there. The op is replaced with a no-op.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include <map>
21
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/rewrite_utils.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/platform/host_info.h"
28 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
29
30 namespace tensorflow {
31 namespace data {
32
33 /* static */ constexpr const char* const OptimizeDatasetOp::kDatasetType;
34 /* static */ constexpr const char* const OptimizeDatasetOp::kInputDataset;
35 /* static */ constexpr const char* const OptimizeDatasetOp::kOptimizations;
36 /* static */ constexpr const char* const
37 OptimizeDatasetOp::kOptimizationsEnabled;
38 /* static */ constexpr const char* const
39 OptimizeDatasetOp::kOptimizationsDisabled;
40 /* static */ constexpr const char* const
41 OptimizeDatasetOp::kOptimizationsDefault;
42 /* static */ constexpr const char* const OptimizeDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const OptimizeDatasetOp::kOutputShapes;
44 /* static */ constexpr const char* const
45 OptimizeDatasetOp::kOptimizationConfigs;
46 /* static */ constexpr const char* const OptimizeDatasetOp::kOptimizeDatasetV1;
47 /* static */ constexpr const char* const OptimizeDatasetOp::kOptimizeDatasetV2;
48
49 namespace {
50
51 // Applies given optimizations and optimizatin_config in dataset graph rewrite
52 // to return the OptimizeDataset.
MakeDatasetHelper(OpKernelContext * ctx,absl::flat_hash_set<tstring> & optimizations,const absl::flat_hash_set<tstring> & optimization_configs,DatasetBase * input,DatasetBase ** output)53 void MakeDatasetHelper(OpKernelContext* ctx,
54 absl::flat_hash_set<tstring>& optimizations,
55 const absl::flat_hash_set<tstring>& optimization_configs,
56 DatasetBase* input, DatasetBase** output) {
57 // The vector stores the graduated experiment names which will be turned on
58 // for all input pipelines.
59 // clang-format off
60 std::vector<string> graduated_experiments = {
61 "disable_intra_op_parallelism",
62 "use_private_thread_pool"
63 };
64 // clang-format on
65
66 // Add the graduated experiments to the optimization list and log them.
67 for (auto& experiment : graduated_experiments) {
68 if (!optimizations.contains(experiment)) {
69 optimizations.insert(experiment);
70 }
71 VLOG(1) << "The graduated experiment \"" << experiment << "\" is applied.";
72 }
73
74 // If there are no optimizations to be applied, directly return the input.
75 if (optimizations.empty()) {
76 *output = input;
77 input->Ref();
78 return;
79 }
80
81 auto config_factory = [&optimizations, &optimization_configs]() {
82 return CreateRewriterConfig(optimizations, optimization_configs);
83 };
84 Status s = RewriteDataset(ctx, input, std::move(config_factory),
85 /*record_fingerprint=*/true, output);
86 if (errors::IsDeadlineExceeded(s)) {
87 // Ignore DeadlineExceeded as it implies that the attempted rewrite took too
88 // long which should not prevent further computation.
89 LOG(WARNING) << s.ToString();
90
91 *output = input;
92 input->Ref();
93 return;
94 }
95 OP_REQUIRES_OK(ctx, s);
96 }
97
98 } // namespace
99
100 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_disabled,const absl::flat_hash_set<tstring> & optimizations_default,const absl::flat_hash_set<tstring> & optimization_configs,DatasetBase ** output)101 void OptimizeDatasetOp::MakeDatasetFromOptions(
102 OpKernelContext* ctx, DatasetBase* input,
103 const absl::flat_hash_set<tstring>& optimizations_enabled,
104 const absl::flat_hash_set<tstring>& optimizations_disabled,
105 const absl::flat_hash_set<tstring>& optimizations_default,
106 const absl::flat_hash_set<tstring>& optimization_configs,
107 DatasetBase** output) {
108 auto experiments = GetExperiments();
109 LogAndRecordExperiments(experiments);
110 auto optimizations =
111 SelectOptimizations(experiments, optimizations_enabled,
112 optimizations_disabled, optimizations_default);
113 MakeDatasetHelper(ctx, optimizations, optimization_configs, input, output);
114 }
115
OptimizeDatasetOp(OpKernelConstruction * ctx)116 OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
117 : UnaryDatasetOpKernel(ctx) {
118 auto& op_name = ctx->def().op();
119 if (op_name == kOptimizeDatasetV1) {
120 op_version_ = 1;
121 } else if (op_name == kOptimizeDatasetV2) {
122 op_version_ = 2;
123 }
124 std::vector<tstring> optimization_configs;
125 OP_REQUIRES_OK(ctx,
126 ctx->GetAttr(kOptimizationConfigs, &optimization_configs));
127 optimization_configs_.insert(optimization_configs.begin(),
128 optimization_configs.end());
129 }
130
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)131 void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
132 DatasetBase** output) {
133 absl::flat_hash_set<tstring> optimizations;
134 if (op_version_ == 1) {
135 std::vector<tstring> optimizations_enabled;
136 OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizations,
137 &optimizations_enabled));
138 optimizations.insert(optimizations_enabled.begin(),
139 optimizations_enabled.end());
140 } else if (op_version_ == 2) {
141 std::vector<tstring> optimizations_enabled, optimizations_disabled,
142 optimizations_default;
143 OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizationsEnabled,
144 &optimizations_enabled));
145 OP_REQUIRES_OK(ctx,
146 ParseVectorArgument<tstring>(ctx, kOptimizationsDisabled,
147 &optimizations_disabled));
148 OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizationsDefault,
149 &optimizations_default));
150 auto experiments = GetExperiments();
151 LogAndRecordExperiments(experiments);
152 optimizations = SelectOptimizations(
153 experiments,
154 {optimizations_enabled.begin(), optimizations_enabled.end()},
155 {optimizations_disabled.begin(), optimizations_disabled.end()},
156 {optimizations_default.begin(), optimizations_default.end()});
157 }
158 MakeDatasetHelper(
159 ctx, optimizations,
160 {optimization_configs_.begin(), optimization_configs_.end()}, input,
161 output);
162 }
163
164 namespace {
165 REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
166 OptimizeDatasetOp);
167 REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU),
168 OptimizeDatasetOp);
169 } // namespace
170 } // namespace data
171 } // namespace tensorflow
172 #else // !IS_MOBILE_PLATFORM
173 namespace tensorflow {
174 namespace data {
175
176 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_disabled,const absl::flat_hash_set<tstring> & optimizations_default,const absl::flat_hash_set<tstring> & optimization_configs,DatasetBase ** output)177 void OptimizeDatasetOp::MakeDatasetFromOptions(
178 OpKernelContext* ctx, DatasetBase* input,
179 const absl::flat_hash_set<tstring>& optimizations_enabled,
180 const absl::flat_hash_set<tstring>& optimizations_disabled,
181 const absl::flat_hash_set<tstring>& optimizations_default,
182 const absl::flat_hash_set<tstring>& optimization_configs,
183 DatasetBase** output) {
184 input->Ref();
185 *output = input;
186 }
187
OptimizeDatasetOp(OpKernelConstruction * ctx)188 OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
189 : UnaryDatasetOpKernel(ctx) {}
190
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)191 void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
192 DatasetBase** output) {
193 input->Ref();
194 *output = input;
195 }
196
197 namespace {
198 REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
199 OptimizeDatasetOp);
200 REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU),
201 OptimizeDatasetOp);
202 } // namespace
203 } // namespace data
204 } // namespace tensorflow
205 #endif // !IS_MOBILE_PLATFORM
206