• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/model_dataset_op.h"
16 
17 #include "tensorflow/core/framework/cancellation.h"
18 
19 // On mobile we do not provide model dataset op because not all of its
20 // dependencies are available there. The op is replaced with a no-op.
21 #if !defined(IS_MOBILE_PLATFORM)
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/metrics.h"
25 #include "tensorflow/core/framework/model.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/lib/random/random.h"
29 #include "tensorflow/core/platform/cpu_info.h"
30 #include "tensorflow/core/platform/stringprintf.h"
31 #include "tensorflow/core/util/ptr_util.h"
32 
33 namespace tensorflow {
34 namespace data {
35 namespace {
36 
37 // Default share of available RAM that can be used by model's internal buffers.
38 constexpr double kRamBudgetShare = 0.5;
39 
40 }  // namespace
41 
42 /* static */ constexpr const char* const ModelDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const ModelDatasetOp::kDatasetOp;
44 /* static */ constexpr const char* const ModelDatasetOp::kAlgorithm;
45 /* static */ constexpr const char* const ModelDatasetOp::kCpuBudget;
46 /* static */ constexpr const char* const ModelDatasetOp::kRamBudget;
47 
48 class ModelDatasetOp::Dataset : public DatasetBase {
49  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,model::AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget)50   Dataset(OpKernelContext* ctx, const DatasetBase* input,
51           model::AutotuneAlgorithm algorithm, int64_t cpu_budget,
52           int64_t ram_budget)
53       : Dataset(DatasetContext(ctx), input, algorithm, cpu_budget, ram_budget) {
54   }
55 
Dataset(DatasetContext && ctx,const DatasetBase * input,model::AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget)56   Dataset(DatasetContext&& ctx, const DatasetBase* input,
57           model::AutotuneAlgorithm algorithm, int64_t cpu_budget,
58           int64_t ram_budget)
59       : DatasetBase(std::move(ctx)),
60         input_(input),
61         algorithm_(algorithm),
62         cpu_budget_(cpu_budget),
63         ram_budget_(ram_budget),
64         traceme_metadata_(
65             {{"algorithm", algorithm == model::AutotuneAlgorithm::HILL_CLIMB
66                                ? "hill climb"
67                                : "gradient descent"},
68              {"cpu_budget",
69               strings::Printf("%lld", static_cast<long long>(cpu_budget))},
70              {"ram_budget",
71               strings::Printf("%lldB", static_cast<long long>(ram_budget))}}) {
72     input_->Ref();
73   }
74 
~Dataset()75   ~Dataset() override { input_->Unref(); }
76 
MakeIteratorInternal(const string & prefix) const77   std::unique_ptr<IteratorBase> MakeIteratorInternal(
78       const string& prefix) const override {
79     return absl::make_unique<Iterator>(
80         Iterator::Params{this, strings::StrCat(prefix, "::Model")});
81   }
82 
output_dtypes() const83   const DataTypeVector& output_dtypes() const override {
84     return input_->output_dtypes();
85   }
output_shapes() const86   const std::vector<PartialTensorShape>& output_shapes() const override {
87     return input_->output_shapes();
88   }
89 
DebugString() const90   string DebugString() const override { return "ModelDatasetOp::Dataset"; }
91 
Cardinality() const92   int64 Cardinality() const override { return input_->Cardinality(); }
93 
InputDatasets(std::vector<const DatasetBase * > * inputs) const94   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
95     inputs->push_back(input_);
96     return Status::OK();
97   }
98 
CheckExternalState() const99   Status CheckExternalState() const override {
100     return input_->CheckExternalState();
101   }
102 
103  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const104   Status AsGraphDefInternal(SerializationContext* ctx,
105                             DatasetGraphDefBuilder* b,
106                             Node** output) const override {
107     Node* input_graph_node = nullptr;
108     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
109     TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
110     AttrValue algorithm_attr;
111     b->BuildAttrValue(static_cast<int64>(algorithm_), &algorithm_attr);
112     AttrValue cpu_budget_attr;
113     b->BuildAttrValue(cpu_budget_, &cpu_budget_attr);
114     AttrValue ram_budget_attr;
115     b->BuildAttrValue(ram_budget_, &ram_budget_attr);
116 
117     TF_RETURN_IF_ERROR(
118         b->AddDataset(this, {input_graph_node},
119                       {std::make_pair(kAlgorithm, algorithm_attr),
120                        std::make_pair(kCpuBudget, cpu_budget_attr),
121                        std::make_pair(kRamBudget, ram_budget_attr)},
122                       output));
123     return Status::OK();
124   }
125 
126  private:
127   class Iterator : public DatasetIterator<Dataset> {
128    public:
Iterator(const Params & params)129     explicit Iterator(const Params& params)
130         : DatasetIterator<Dataset>(params),
131           cpu_budget_(dataset()->cpu_budget_ == 0 ? port::NumSchedulableCPUs()
132                                                   : dataset()->cpu_budget_),
133           ram_budget_(dataset()->ram_budget_ == 0
134                           ? kRamBudgetShare * port::AvailableRam()
135                           : dataset()->ram_budget_) {
136       cancellation_manager_ = absl::make_unique<CancellationManager>();
137       model_ = std::make_shared<model::Model>();
138     }
139 
~Iterator()140     ~Iterator() override { cancellation_manager_->StartCancel(); }
141 
Initialize(IteratorContext * ctx)142     Status Initialize(IteratorContext* ctx) override {
143       return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)),
144                                              this, prefix(), &input_impl_);
145     }
146 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)147     Status GetNextInternal(IteratorContext* ctx,
148                            std::vector<Tensor>* out_tensors,
149                            bool* end_of_sequence) override {
150       if (!ctx->model()) {
151         mutex_lock l(mu_);
152         TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx));
153       }
154       return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
155                                   out_tensors, end_of_sequence);
156     }
157 
158    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const159     std::shared_ptr<model::Node> CreateNode(
160         IteratorContext* ctx, model::Node::Args args) const override {
161       return model::MakeKnownRatioNode(std::move(args),
162                                        /*ratio=*/1);
163     }
164 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)165     Status SaveInternal(SerializationContext* ctx,
166                         IteratorStateWriter* writer) override {
167       return SaveInput(ctx, writer, input_impl_);
168     }
169 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)170     Status RestoreInternal(IteratorContext* ctx,
171                            IteratorStateReader* reader) override {
172       return RestoreInput(IteratorContext(CreateParams(ctx)), reader,
173                           input_impl_);
174     }
175 
GetTraceMeMetadata() const176     TraceMeMetadata GetTraceMeMetadata() const override {
177       return dataset()->traceme_metadata_;
178     }
179 
180    private:
CreateParams(IteratorContext * ctx)181     IteratorContext::Params CreateParams(IteratorContext* ctx) {
182       IteratorContext::Params params(ctx);
183       if (!ctx->model()) {
184         params.model = model_;
185       }
186       return params;
187     }
188 
EnsureOptimizationLoopThreadStarted(IteratorContext * ctx)189     Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx)
190         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
191       if (!model_thread_) {
192         model_thread_ = ctx->StartThread("tf_data_model", [this]() {
193           Status status =
194               model_->OptimizeLoop(dataset()->algorithm_, cpu_budget_,
195                                    ram_budget_, cancellation_manager_.get());
196           if (!status.ok()) {
197             LOG(WARNING) << "Optimization loop failed: " << status.ToString();
198           }
199         });
200       }
201       return Status::OK();
202     }
203 
204     mutex mu_;
205     std::shared_ptr<model::Model> model_;
206     // Controls cancellation of `model_thread_`. Must be ordered before
207     // `model_thread_` so that `model_thread_` is destroyed first.
208     std::unique_ptr<CancellationManager> cancellation_manager_;
209     std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
210     std::unique_ptr<IteratorBase> input_impl_;
211     const int64 cpu_budget_;
212     const int64 ram_budget_;
213   };
214 
215   const DatasetBase* input_;
216   const model::AutotuneAlgorithm algorithm_;
217   const int64 cpu_budget_;
218   const int64 ram_budget_;
219   const TraceMeMetadata traceme_metadata_;
220 };
221 
222 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,model::AutotuneAlgorithm algorithm,bool cpu_budget,bool ram_budget,DatasetBase ** output)223 void ModelDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
224                                             DatasetBase* input,
225                                             model::AutotuneAlgorithm algorithm,
226                                             bool cpu_budget, bool ram_budget,
227                                             DatasetBase** output) {
228   *output = new ModelDatasetOp::Dataset(
229       DatasetContext(DatasetContext::Params(
230           {ModelDatasetOp::kDatasetType, ModelDatasetOp::kDatasetOp})),
231       input, algorithm, cpu_budget, ram_budget);
232 }
233 
ModelDatasetOp(OpKernelConstruction * ctx)234 ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
235     : UnaryDatasetOpKernel(ctx) {
236   if (ctx->HasAttr(kAlgorithm)) {
237     int64_t algorithm;
238     OP_REQUIRES_OK(ctx, ctx->GetAttr(kAlgorithm, &algorithm));
239     algorithm_ = model::AutotuneAlgorithm(algorithm);
240   } else {
241     algorithm_ = model::AutotuneAlgorithm::HILL_CLIMB;
242   }
243   OP_REQUIRES_OK(ctx, ctx->GetAttr(kCpuBudget, &cpu_budget_));
244   OP_REQUIRES(ctx, cpu_budget_ >= 0,
245               errors::InvalidArgument("CPU budget must be positive but is ",
246                                       cpu_budget_, "."));
247   if (ctx->HasAttr(kRamBudget)) {
248     OP_REQUIRES_OK(ctx, ctx->GetAttr(kRamBudget, &ram_budget_));
249   } else {
250     ram_budget_ = 0;
251   }
252   OP_REQUIRES(ctx, ram_budget_ >= 0,
253               errors::InvalidArgument("RAM budget must be positive but is ",
254                                       ram_budget_, "."));
255 }
256 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)257 void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
258                                  DatasetBase** output) {
259   *output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_,
260                                         ram_budget_);
261 }
262 
263 namespace {
264 REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
265                         ModelDatasetOp);
266 }  // namespace
267 }  // namespace data
268 }  // namespace tensorflow
269 #else   // !IS_MOBILE_PLATFORM
270 namespace tensorflow {
271 namespace data {
272 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,model::AutotuneAlgorithm algorithm,bool cpu_budget,bool ram_budget,DatasetBase ** output)273 void ModelDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
274                                             DatasetBase* input,
275                                             model::AutotuneAlgorithm algorithm,
276                                             bool cpu_budget, bool ram_budget,
277                                             DatasetBase** output) {
278   input->Ref();
279   *output = input;
280 }
281 
ModelDatasetOp(OpKernelConstruction * ctx)282 ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
283     : UnaryDatasetOpKernel(ctx) {}
284 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)285 void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
286                                  DatasetBase** output) {
287   input->Ref();
288   *output = input;
289 }
290 
291 namespace {
292 REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
293                         ModelDatasetOp);
294 }  // namespace
295 }  // namespace data
296 }  // namespace tensorflow
297 #endif  // !IS_MOBILE_PLATFORM
298