• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/root_dataset.h"
17 
18 #include "tensorflow/core/data/dataset_utils.h"
19 #include "tensorflow/core/data/name_utils.h"
20 #include "tensorflow/core/data/rewrite_utils.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/stringprintf.h"
23 
24 namespace tensorflow {
25 namespace data {
26 namespace {
27 
28 constexpr char kDatasetType[] = "Root";
29 constexpr char kAlgorithm[] = "algorithm";
30 constexpr char kCpuBudget[] = "cpu_budget";
31 constexpr char kRamBudget[] = "ram_budget_bytes";
32 constexpr char kHillClimb[] = "hill_climb";
33 constexpr char kGradientDescent[] = "gradient_descent";
34 constexpr char kIntraOpParallelism[] = "intra_op_parallelism";
35 constexpr char kPrivateThreadpoolSize[] = "threadpool_size";
36 
37 // Default share of available RAM that can be used by model's internal buffers.
38 constexpr double kRamBudgetShare = 0.5;
39 
40 // If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
value_or_default(int64_t x,int64_t y,int64_t z)41 inline int64 value_or_default(int64_t x, int64_t y, int64_t z) {
42   return x == y ? z : x;
43 }
44 
45 }  // namespace
46 
47 // static
FromOptions(DatasetBase * input,DatasetBase ** output)48 Status RootDataset::FromOptions(DatasetBase* input, DatasetBase** output) {
49   const Options& options = input->options();
50   Params params;
51   if (ShouldConfigureMaxIntraOpParallelism(options)) {
52     params.max_intra_op_parallelism =
53         options.threading_options().max_intra_op_parallelism();
54   }
55   if (ShouldUsePrivateThreadPool(options)) {
56     params.private_threadpool_size =
57         options.threading_options().private_threadpool_size();
58   }
59   params.autotune = ShouldUseAutotuning(options);
60   if (params.autotune) {
61     params.autotune_algorithm = model::AutotuneAlgorithm::HILL_CLIMB;
62     params.autotune_cpu_budget = value_or_default(
63         options.autotune_options().cpu_budget(), 0, port::NumSchedulableCPUs());
64     params.autotune_ram_budget =
65         value_or_default(options.autotune_options().ram_budget(), 0,
66                          kRamBudgetShare * port::AvailableRam());
67   }
68   *output = new RootDataset(input, params);
69   return Status::OK();
70 }
71 
72 class RootDataset::Iterator : public DatasetIterator<RootDataset> {
73  public:
Iterator(const Params & params)74   explicit Iterator(const Params& params)
75       : DatasetIterator<RootDataset>(params) {
76     if (dataset()->params_.autotune) {
77       model_ = std::make_shared<model::Model>();
78     }
79     if (dataset()->params_.max_intra_op_parallelism >= 0) {
80       max_intra_op_parallelism_ =
81           value_or_default(dataset()->params_.max_intra_op_parallelism, 0,
82                            port::MaxParallelism());
83     }
84     if (dataset()->params_.private_threadpool_size >= 0) {
85       threadpool_size_ =
86           value_or_default(dataset()->params_.private_threadpool_size, 0,
87                            port::MaxParallelism());
88       thread_pool_ = absl::make_unique<thread::ThreadPool>(
89           Env::Default(), ThreadOptions{}, "data_private_threadpool",
90           threadpool_size_);
91     }
92     cancellation_manager_ = absl::make_unique<CancellationManager>();
93   }
94 
~Iterator()95   ~Iterator() override { cancellation_manager_->StartCancel(); }
96 
Initialize(IteratorContext * ctx)97   Status Initialize(IteratorContext* ctx) override {
98     return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)),
99                                            this, prefix(), &input_impl_);
100   }
101 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)102   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
103                          bool* end_of_sequence) override {
104     if (dataset()->params_.autotune) {
105       TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
106     }
107     return input_impl_->GetNext(IteratorContext(CreateParams(ctx)), out_tensors,
108                                 end_of_sequence);
109   }
110 
111  protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const112   std::shared_ptr<model::Node> CreateNode(
113       IteratorContext* ctx, model::Node::Args args) const override {
114     return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
115   }
116 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)117   Status SaveInternal(SerializationContext* ctx,
118                       IteratorStateWriter* writer) override {
119     TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
120     return Status::OK();
121   }
122 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)123   Status RestoreInternal(IteratorContext* ctx,
124                          IteratorStateReader* reader) override {
125     TF_RETURN_IF_ERROR(
126         RestoreInput(IteratorContext(CreateParams(ctx)), reader, input_impl_));
127     return Status::OK();
128   }
129 
GetTraceMeMetadata() const130   TraceMeMetadata GetTraceMeMetadata() const override {
131     return dataset()->traceme_metadata_;
132   }
133 
134  private:
CreateParams(IteratorContext * ctx)135   IteratorContext::Params CreateParams(IteratorContext* ctx) {
136     IteratorContext::Params params(ctx);
137     if (dataset()->params_.autotune) {
138       params.model = model_;
139     }
140     if (dataset()->params_.private_threadpool_size >= 0) {
141       params.runner = [pool = thread_pool_.get()](std::function<void()> c) {
142         pool->Schedule(std::move(c));
143       };
144       params.runner_threadpool_size = threadpool_size_;
145     }
146     if (dataset()->params_.max_intra_op_parallelism >= 0) {
147       params.runner =
148           RunnerWithMaxParallelism(params.runner, max_intra_op_parallelism_);
149     }
150     return params;
151   }
152 
EnsureModelThreadStarted(IteratorContext * ctx)153   Status EnsureModelThreadStarted(IteratorContext* ctx) {
154     mutex_lock l(mu_);
155     if (!model_thread_) {
156       model_thread_ = ctx->StartThread("tf_data_model", [this]() {
157         Status status =
158             model_->OptimizeLoop(dataset()->params_.autotune_algorithm,
159                                  dataset()->params_.autotune_cpu_budget,
160                                  dataset()->params_.autotune_ram_budget,
161                                  cancellation_manager_.get());
162         if (!status.ok()) {
163           LOG(WARNING) << "Optimization loop failed: " << status.ToString();
164         }
165       });
166     }
167     return Status::OK();
168   }
169 
170   std::shared_ptr<model::Model> model_ = nullptr;
171   // Controls cancellation of `model_thread_`. Must be ordered before
172   // `model_thread_` so that `model_thread_` is destroyed first.
173   std::unique_ptr<CancellationManager> cancellation_manager_;
174   mutex mu_;
175   std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
176   int64 max_intra_op_parallelism_;
177   int64 threadpool_size_;
178   std::unique_ptr<thread::ThreadPool> thread_pool_;
179 
180   // Must be ordered last as its execution may depend on other members.
181   std::unique_ptr<IteratorBase> input_impl_;
182 };
183 
RootDataset(const DatasetBase * input,Params params)184 RootDataset::RootDataset(const DatasetBase* input, Params params)
185     : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
186                                   name_utils::OpName(kDatasetType)})),
187       input_(input),
188       params_(std::move(params)) {
189   if (params_.autotune) {
190     traceme_metadata_.push_back(std::make_pair(
191         kAlgorithm,
192         params_.autotune_algorithm == model::AutotuneAlgorithm::HILL_CLIMB
193             ? kHillClimb
194             : kGradientDescent));
195     traceme_metadata_.push_back(std::make_pair(
196         kCpuBudget, strings::Printf("%lld", static_cast<long long>(
197                                                 params_.autotune_cpu_budget))));
198     traceme_metadata_.push_back(std::make_pair(
199         kRamBudget, strings::Printf("%lld", static_cast<long long>(
200                                                 params_.autotune_ram_budget))));
201   }
202   if (params_.max_intra_op_parallelism >= 0) {
203     traceme_metadata_.push_back(std::make_pair(
204         kIntraOpParallelism,
205         strings::Printf("%lld", static_cast<long long>(value_or_default(
206                                     params_.max_intra_op_parallelism, 0,
207                                     port::MaxParallelism())))));
208   }
209   if (params_.private_threadpool_size >= 0) {
210     traceme_metadata_.push_back(std::make_pair(
211         kPrivateThreadpoolSize,
212         strings::Printf("%lld", static_cast<long long>(value_or_default(
213                                     params_.private_threadpool_size, 0,
214                                     port::MaxParallelism())))));
215   }
216   input_->Ref();
217 }
218 
~RootDataset()219 RootDataset::~RootDataset() { input_->Unref(); }
220 
MakeIteratorInternal(const string & prefix) const221 std::unique_ptr<IteratorBase> RootDataset::MakeIteratorInternal(
222     const string& prefix) const {
223   return absl::make_unique<Iterator>(
224       Iterator::Params{this, name_utils::IteratorPrefix(kDatasetType, prefix)});
225 }
226 
output_dtypes() const227 const DataTypeVector& RootDataset::output_dtypes() const {
228   return input_->output_dtypes();
229 }
230 
output_shapes() const231 const std::vector<PartialTensorShape>& RootDataset::output_shapes() const {
232   return input_->output_shapes();
233 }
234 
DebugString() const235 string RootDataset::DebugString() const {
236   return name_utils::DatasetDebugString(kDatasetType);
237 }
238 
Cardinality() const239 int64 RootDataset::Cardinality() const { return input_->Cardinality(); }
240 
InputDatasets(std::vector<const DatasetBase * > * inputs) const241 Status RootDataset::InputDatasets(
242     std::vector<const DatasetBase*>* inputs) const {
243   inputs->push_back(input_);
244   return Status::OK();
245 }
246 
CheckExternalState() const247 Status RootDataset::CheckExternalState() const {
248   return input_->CheckExternalState();
249 }
250 
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const251 Status RootDataset::AsGraphDefInternal(SerializationContext* ctx,
252                                        DatasetGraphDefBuilder* b,
253                                        Node** output) const {
254   return errors::Unimplemented("RootDataset does not support serialization.");
255 }
256 
257 #if !defined(IS_MOBILE_PLATFORM)
FinalizeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)258 Status FinalizeDataset(OpKernelContext* ctx, DatasetBase* input,
259                        DatasetBase** output) {
260   const Options& options = input->options();
261   absl::flat_hash_set<tstring> optimizations_enabled;
262   absl::flat_hash_set<tstring> optimizations_disabled;
263   absl::flat_hash_set<tstring> optimizations_default;
264   GetOptimizations(options, &optimizations_enabled, &optimizations_disabled,
265                    &optimizations_default);
266   // Disable `enable_gradient_descent` as it assumes presence of ModelDatasetOp.
267   optimizations_disabled.insert("enable_gradient_descent");
268 
269   auto experiments = GetExperiments();
270   LogAndRecordExperiments(experiments);
271   auto optimizations =
272       SelectOptimizations(experiments, optimizations_enabled,
273                           optimizations_disabled, optimizations_default);
274   if (optimizations.empty()) {
275     return RootDataset::FromOptions(input, output);
276   }
277 
278   auto optimization_configs = CreateGraphRewriteConfigs(options);
279   auto config_factory = [&optimizations, &optimization_configs]() {
280     return CreateRewriterConfig(optimizations, optimization_configs);
281   };
282   Status s = RewriteDataset(ctx, input, std::move(config_factory),
283                             /*record_fingerprint=*/true, output);
284   if (errors::IsDeadlineExceeded(s)) {
285     // Ignore DeadlineExceeded as it implies that the attempted rewrite took too
286     // long which should not prevent further computation.
287     LOG(WARNING) << s.ToString();
288     return RootDataset::FromOptions(input, output);
289   }
290   if (!s.ok()) {
291     return s;
292   }
293   input = *output;
294   TF_RETURN_IF_ERROR(RootDataset::FromOptions(input, output));
295   input->Unref();
296   return Status::OK();
297 }
298 #else   // !IS_MOBILE_PLATFORM
FinalizeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)299 Status FinalizeDataset(OpKernelContext* ctx, DatasetBase* input,
300                        DatasetBase** output) {
301   return RootDataset::FromOptions(input, output);
302 }
303 #endif  // !IS_MOBILE_PLATFORM
304 
305 }  // namespace data
306 }  // namespace tensorflow
307