1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/root_dataset.h"
17
18 #include "tensorflow/core/data/dataset_utils.h"
19 #include "tensorflow/core/data/name_utils.h"
20 #include "tensorflow/core/data/rewrite_utils.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/stringprintf.h"
23
24 namespace tensorflow {
25 namespace data {
26 namespace {
27
28 constexpr char kDatasetType[] = "Root";
29 constexpr char kAlgorithm[] = "algorithm";
30 constexpr char kCpuBudget[] = "cpu_budget";
31 constexpr char kRamBudget[] = "ram_budget_bytes";
32 constexpr char kHillClimb[] = "hill_climb";
33 constexpr char kGradientDescent[] = "gradient_descent";
34 constexpr char kIntraOpParallelism[] = "intra_op_parallelism";
35 constexpr char kPrivateThreadpoolSize[] = "threadpool_size";
36
37 // Default share of available RAM that can be used by model's internal buffers.
38 constexpr double kRamBudgetShare = 0.5;
39
40 // If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
value_or_default(int64_t x,int64_t y,int64_t z)41 inline int64 value_or_default(int64_t x, int64_t y, int64_t z) {
42 return x == y ? z : x;
43 }
44
45 } // namespace
46
47 // static
FromOptions(DatasetBase * input,DatasetBase ** output)48 Status RootDataset::FromOptions(DatasetBase* input, DatasetBase** output) {
49 const Options& options = input->options();
50 Params params;
51 if (ShouldConfigureMaxIntraOpParallelism(options)) {
52 params.max_intra_op_parallelism =
53 options.threading_options().max_intra_op_parallelism();
54 }
55 if (ShouldUsePrivateThreadPool(options)) {
56 params.private_threadpool_size =
57 options.threading_options().private_threadpool_size();
58 }
59 params.autotune = ShouldUseAutotuning(options);
60 if (params.autotune) {
61 params.autotune_algorithm = model::AutotuneAlgorithm::HILL_CLIMB;
62 params.autotune_cpu_budget = value_or_default(
63 options.autotune_options().cpu_budget(), 0, port::NumSchedulableCPUs());
64 params.autotune_ram_budget =
65 value_or_default(options.autotune_options().ram_budget(), 0,
66 kRamBudgetShare * port::AvailableRam());
67 }
68 *output = new RootDataset(input, params);
69 return Status::OK();
70 }
71
72 class RootDataset::Iterator : public DatasetIterator<RootDataset> {
73 public:
Iterator(const Params & params)74 explicit Iterator(const Params& params)
75 : DatasetIterator<RootDataset>(params) {
76 if (dataset()->params_.autotune) {
77 model_ = std::make_shared<model::Model>();
78 }
79 if (dataset()->params_.max_intra_op_parallelism >= 0) {
80 max_intra_op_parallelism_ =
81 value_or_default(dataset()->params_.max_intra_op_parallelism, 0,
82 port::MaxParallelism());
83 }
84 if (dataset()->params_.private_threadpool_size >= 0) {
85 threadpool_size_ =
86 value_or_default(dataset()->params_.private_threadpool_size, 0,
87 port::MaxParallelism());
88 thread_pool_ = absl::make_unique<thread::ThreadPool>(
89 Env::Default(), ThreadOptions{}, "data_private_threadpool",
90 threadpool_size_);
91 }
92 cancellation_manager_ = absl::make_unique<CancellationManager>();
93 }
94
~Iterator()95 ~Iterator() override { cancellation_manager_->StartCancel(); }
96
Initialize(IteratorContext * ctx)97 Status Initialize(IteratorContext* ctx) override {
98 return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)),
99 this, prefix(), &input_impl_);
100 }
101
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)102 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
103 bool* end_of_sequence) override {
104 if (dataset()->params_.autotune) {
105 TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
106 }
107 return input_impl_->GetNext(IteratorContext(CreateParams(ctx)), out_tensors,
108 end_of_sequence);
109 }
110
111 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const112 std::shared_ptr<model::Node> CreateNode(
113 IteratorContext* ctx, model::Node::Args args) const override {
114 return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
115 }
116
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)117 Status SaveInternal(SerializationContext* ctx,
118 IteratorStateWriter* writer) override {
119 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
120 return Status::OK();
121 }
122
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)123 Status RestoreInternal(IteratorContext* ctx,
124 IteratorStateReader* reader) override {
125 TF_RETURN_IF_ERROR(
126 RestoreInput(IteratorContext(CreateParams(ctx)), reader, input_impl_));
127 return Status::OK();
128 }
129
GetTraceMeMetadata() const130 TraceMeMetadata GetTraceMeMetadata() const override {
131 return dataset()->traceme_metadata_;
132 }
133
134 private:
CreateParams(IteratorContext * ctx)135 IteratorContext::Params CreateParams(IteratorContext* ctx) {
136 IteratorContext::Params params(ctx);
137 if (dataset()->params_.autotune) {
138 params.model = model_;
139 }
140 if (dataset()->params_.private_threadpool_size >= 0) {
141 params.runner = [pool = thread_pool_.get()](std::function<void()> c) {
142 pool->Schedule(std::move(c));
143 };
144 params.runner_threadpool_size = threadpool_size_;
145 }
146 if (dataset()->params_.max_intra_op_parallelism >= 0) {
147 params.runner =
148 RunnerWithMaxParallelism(params.runner, max_intra_op_parallelism_);
149 }
150 return params;
151 }
152
EnsureModelThreadStarted(IteratorContext * ctx)153 Status EnsureModelThreadStarted(IteratorContext* ctx) {
154 mutex_lock l(mu_);
155 if (!model_thread_) {
156 model_thread_ = ctx->StartThread("tf_data_model", [this]() {
157 Status status =
158 model_->OptimizeLoop(dataset()->params_.autotune_algorithm,
159 dataset()->params_.autotune_cpu_budget,
160 dataset()->params_.autotune_ram_budget,
161 cancellation_manager_.get());
162 if (!status.ok()) {
163 LOG(WARNING) << "Optimization loop failed: " << status.ToString();
164 }
165 });
166 }
167 return Status::OK();
168 }
169
170 std::shared_ptr<model::Model> model_ = nullptr;
171 // Controls cancellation of `model_thread_`. Must be ordered before
172 // `model_thread_` so that `model_thread_` is destroyed first.
173 std::unique_ptr<CancellationManager> cancellation_manager_;
174 mutex mu_;
175 std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
176 int64 max_intra_op_parallelism_;
177 int64 threadpool_size_;
178 std::unique_ptr<thread::ThreadPool> thread_pool_;
179
180 // Must be ordered last as its execution may depend on other members.
181 std::unique_ptr<IteratorBase> input_impl_;
182 };
183
RootDataset(const DatasetBase * input,Params params)184 RootDataset::RootDataset(const DatasetBase* input, Params params)
185 : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
186 name_utils::OpName(kDatasetType)})),
187 input_(input),
188 params_(std::move(params)) {
189 if (params_.autotune) {
190 traceme_metadata_.push_back(std::make_pair(
191 kAlgorithm,
192 params_.autotune_algorithm == model::AutotuneAlgorithm::HILL_CLIMB
193 ? kHillClimb
194 : kGradientDescent));
195 traceme_metadata_.push_back(std::make_pair(
196 kCpuBudget, strings::Printf("%lld", static_cast<long long>(
197 params_.autotune_cpu_budget))));
198 traceme_metadata_.push_back(std::make_pair(
199 kRamBudget, strings::Printf("%lld", static_cast<long long>(
200 params_.autotune_ram_budget))));
201 }
202 if (params_.max_intra_op_parallelism >= 0) {
203 traceme_metadata_.push_back(std::make_pair(
204 kIntraOpParallelism,
205 strings::Printf("%lld", static_cast<long long>(value_or_default(
206 params_.max_intra_op_parallelism, 0,
207 port::MaxParallelism())))));
208 }
209 if (params_.private_threadpool_size >= 0) {
210 traceme_metadata_.push_back(std::make_pair(
211 kPrivateThreadpoolSize,
212 strings::Printf("%lld", static_cast<long long>(value_or_default(
213 params_.private_threadpool_size, 0,
214 port::MaxParallelism())))));
215 }
216 input_->Ref();
217 }
218
~RootDataset()219 RootDataset::~RootDataset() { input_->Unref(); }
220
MakeIteratorInternal(const string & prefix) const221 std::unique_ptr<IteratorBase> RootDataset::MakeIteratorInternal(
222 const string& prefix) const {
223 return absl::make_unique<Iterator>(
224 Iterator::Params{this, name_utils::IteratorPrefix(kDatasetType, prefix)});
225 }
226
output_dtypes() const227 const DataTypeVector& RootDataset::output_dtypes() const {
228 return input_->output_dtypes();
229 }
230
output_shapes() const231 const std::vector<PartialTensorShape>& RootDataset::output_shapes() const {
232 return input_->output_shapes();
233 }
234
DebugString() const235 string RootDataset::DebugString() const {
236 return name_utils::DatasetDebugString(kDatasetType);
237 }
238
Cardinality() const239 int64 RootDataset::Cardinality() const { return input_->Cardinality(); }
240
InputDatasets(std::vector<const DatasetBase * > * inputs) const241 Status RootDataset::InputDatasets(
242 std::vector<const DatasetBase*>* inputs) const {
243 inputs->push_back(input_);
244 return Status::OK();
245 }
246
CheckExternalState() const247 Status RootDataset::CheckExternalState() const {
248 return input_->CheckExternalState();
249 }
250
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const251 Status RootDataset::AsGraphDefInternal(SerializationContext* ctx,
252 DatasetGraphDefBuilder* b,
253 Node** output) const {
254 return errors::Unimplemented("RootDataset does not support serialization.");
255 }
256
257 #if !defined(IS_MOBILE_PLATFORM)
FinalizeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)258 Status FinalizeDataset(OpKernelContext* ctx, DatasetBase* input,
259 DatasetBase** output) {
260 const Options& options = input->options();
261 absl::flat_hash_set<tstring> optimizations_enabled;
262 absl::flat_hash_set<tstring> optimizations_disabled;
263 absl::flat_hash_set<tstring> optimizations_default;
264 GetOptimizations(options, &optimizations_enabled, &optimizations_disabled,
265 &optimizations_default);
266 // Disable `enable_gradient_descent` as it assumes presence of ModelDatasetOp.
267 optimizations_disabled.insert("enable_gradient_descent");
268
269 auto experiments = GetExperiments();
270 LogAndRecordExperiments(experiments);
271 auto optimizations =
272 SelectOptimizations(experiments, optimizations_enabled,
273 optimizations_disabled, optimizations_default);
274 if (optimizations.empty()) {
275 return RootDataset::FromOptions(input, output);
276 }
277
278 auto optimization_configs = CreateGraphRewriteConfigs(options);
279 auto config_factory = [&optimizations, &optimization_configs]() {
280 return CreateRewriterConfig(optimizations, optimization_configs);
281 };
282 Status s = RewriteDataset(ctx, input, std::move(config_factory),
283 /*record_fingerprint=*/true, output);
284 if (errors::IsDeadlineExceeded(s)) {
285 // Ignore DeadlineExceeded as it implies that the attempted rewrite took too
286 // long which should not prevent further computation.
287 LOG(WARNING) << s.ToString();
288 return RootDataset::FromOptions(input, output);
289 }
290 if (!s.ok()) {
291 return s;
292 }
293 input = *output;
294 TF_RETURN_IF_ERROR(RootDataset::FromOptions(input, output));
295 input->Unref();
296 return Status::OK();
297 }
298 #else // !IS_MOBILE_PLATFORM
FinalizeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)299 Status FinalizeDataset(OpKernelContext* ctx, DatasetBase* input,
300 DatasetBase** output) {
301 return RootDataset::FromOptions(input, output);
302 }
303 #endif // !IS_MOBILE_PLATFORM
304
305 } // namespace data
306 } // namespace tensorflow
307