1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/model_dataset_op.h"
16
17 #include "tensorflow/core/framework/cancellation.h"
18
19 // On mobile we do not provide model dataset op because not all of its
20 // dependencies are available there. The op is replaced with a no-op.
21 #if !defined(IS_MOBILE_PLATFORM)
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/metrics.h"
25 #include "tensorflow/core/framework/model.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/lib/random/random.h"
29 #include "tensorflow/core/platform/cpu_info.h"
30 #include "tensorflow/core/platform/stringprintf.h"
31 #include "tensorflow/core/util/ptr_util.h"
32
33 namespace tensorflow {
34 namespace data {
35 namespace {
36
37 // Default share of available RAM that can be used by model's internal buffers.
38 constexpr double kRamBudgetShare = 0.5;
39
40 } // namespace
41
42 /* static */ constexpr const char* const ModelDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const ModelDatasetOp::kDatasetOp;
44 /* static */ constexpr const char* const ModelDatasetOp::kAlgorithm;
45 /* static */ constexpr const char* const ModelDatasetOp::kCpuBudget;
46 /* static */ constexpr const char* const ModelDatasetOp::kRamBudget;
47
48 class ModelDatasetOp::Dataset : public DatasetBase {
49 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,model::AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget)50 Dataset(OpKernelContext* ctx, const DatasetBase* input,
51 model::AutotuneAlgorithm algorithm, int64_t cpu_budget,
52 int64_t ram_budget)
53 : Dataset(DatasetContext(ctx), input, algorithm, cpu_budget, ram_budget) {
54 }
55
Dataset(DatasetContext && ctx,const DatasetBase * input,model::AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget)56 Dataset(DatasetContext&& ctx, const DatasetBase* input,
57 model::AutotuneAlgorithm algorithm, int64_t cpu_budget,
58 int64_t ram_budget)
59 : DatasetBase(std::move(ctx)),
60 input_(input),
61 algorithm_(algorithm),
62 cpu_budget_(cpu_budget),
63 ram_budget_(ram_budget),
64 traceme_metadata_(
65 {{"algorithm", algorithm == model::AutotuneAlgorithm::HILL_CLIMB
66 ? "hill climb"
67 : "gradient descent"},
68 {"cpu_budget",
69 strings::Printf("%lld", static_cast<long long>(cpu_budget))},
70 {"ram_budget",
71 strings::Printf("%lldB", static_cast<long long>(ram_budget))}}) {
72 input_->Ref();
73 }
74
~Dataset()75 ~Dataset() override { input_->Unref(); }
76
MakeIteratorInternal(const string & prefix) const77 std::unique_ptr<IteratorBase> MakeIteratorInternal(
78 const string& prefix) const override {
79 return absl::make_unique<Iterator>(
80 Iterator::Params{this, strings::StrCat(prefix, "::Model")});
81 }
82
output_dtypes() const83 const DataTypeVector& output_dtypes() const override {
84 return input_->output_dtypes();
85 }
output_shapes() const86 const std::vector<PartialTensorShape>& output_shapes() const override {
87 return input_->output_shapes();
88 }
89
DebugString() const90 string DebugString() const override { return "ModelDatasetOp::Dataset"; }
91
Cardinality() const92 int64 Cardinality() const override { return input_->Cardinality(); }
93
InputDatasets(std::vector<const DatasetBase * > * inputs) const94 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
95 inputs->push_back(input_);
96 return Status::OK();
97 }
98
CheckExternalState() const99 Status CheckExternalState() const override {
100 return input_->CheckExternalState();
101 }
102
103 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const104 Status AsGraphDefInternal(SerializationContext* ctx,
105 DatasetGraphDefBuilder* b,
106 Node** output) const override {
107 Node* input_graph_node = nullptr;
108 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
109 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
110 AttrValue algorithm_attr;
111 b->BuildAttrValue(static_cast<int64>(algorithm_), &algorithm_attr);
112 AttrValue cpu_budget_attr;
113 b->BuildAttrValue(cpu_budget_, &cpu_budget_attr);
114 AttrValue ram_budget_attr;
115 b->BuildAttrValue(ram_budget_, &ram_budget_attr);
116
117 TF_RETURN_IF_ERROR(
118 b->AddDataset(this, {input_graph_node},
119 {std::make_pair(kAlgorithm, algorithm_attr),
120 std::make_pair(kCpuBudget, cpu_budget_attr),
121 std::make_pair(kRamBudget, ram_budget_attr)},
122 output));
123 return Status::OK();
124 }
125
126 private:
127 class Iterator : public DatasetIterator<Dataset> {
128 public:
Iterator(const Params & params)129 explicit Iterator(const Params& params)
130 : DatasetIterator<Dataset>(params),
131 cpu_budget_(dataset()->cpu_budget_ == 0 ? port::NumSchedulableCPUs()
132 : dataset()->cpu_budget_),
133 ram_budget_(dataset()->ram_budget_ == 0
134 ? kRamBudgetShare * port::AvailableRam()
135 : dataset()->ram_budget_) {
136 cancellation_manager_ = absl::make_unique<CancellationManager>();
137 model_ = std::make_shared<model::Model>();
138 }
139
~Iterator()140 ~Iterator() override { cancellation_manager_->StartCancel(); }
141
Initialize(IteratorContext * ctx)142 Status Initialize(IteratorContext* ctx) override {
143 return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)),
144 this, prefix(), &input_impl_);
145 }
146
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)147 Status GetNextInternal(IteratorContext* ctx,
148 std::vector<Tensor>* out_tensors,
149 bool* end_of_sequence) override {
150 if (!ctx->model()) {
151 mutex_lock l(mu_);
152 TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx));
153 }
154 return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
155 out_tensors, end_of_sequence);
156 }
157
158 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const159 std::shared_ptr<model::Node> CreateNode(
160 IteratorContext* ctx, model::Node::Args args) const override {
161 return model::MakeKnownRatioNode(std::move(args),
162 /*ratio=*/1);
163 }
164
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)165 Status SaveInternal(SerializationContext* ctx,
166 IteratorStateWriter* writer) override {
167 return SaveInput(ctx, writer, input_impl_);
168 }
169
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)170 Status RestoreInternal(IteratorContext* ctx,
171 IteratorStateReader* reader) override {
172 return RestoreInput(IteratorContext(CreateParams(ctx)), reader,
173 input_impl_);
174 }
175
GetTraceMeMetadata() const176 TraceMeMetadata GetTraceMeMetadata() const override {
177 return dataset()->traceme_metadata_;
178 }
179
180 private:
CreateParams(IteratorContext * ctx)181 IteratorContext::Params CreateParams(IteratorContext* ctx) {
182 IteratorContext::Params params(ctx);
183 if (!ctx->model()) {
184 params.model = model_;
185 }
186 return params;
187 }
188
EnsureOptimizationLoopThreadStarted(IteratorContext * ctx)189 Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx)
190 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
191 if (!model_thread_) {
192 model_thread_ = ctx->StartThread("tf_data_model", [this]() {
193 Status status =
194 model_->OptimizeLoop(dataset()->algorithm_, cpu_budget_,
195 ram_budget_, cancellation_manager_.get());
196 if (!status.ok()) {
197 LOG(WARNING) << "Optimization loop failed: " << status.ToString();
198 }
199 });
200 }
201 return Status::OK();
202 }
203
204 mutex mu_;
205 std::shared_ptr<model::Model> model_;
206 // Controls cancellation of `model_thread_`. Must be ordered before
207 // `model_thread_` so that `model_thread_` is destroyed first.
208 std::unique_ptr<CancellationManager> cancellation_manager_;
209 std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
210 std::unique_ptr<IteratorBase> input_impl_;
211 const int64 cpu_budget_;
212 const int64 ram_budget_;
213 };
214
215 const DatasetBase* input_;
216 const model::AutotuneAlgorithm algorithm_;
217 const int64 cpu_budget_;
218 const int64 ram_budget_;
219 const TraceMeMetadata traceme_metadata_;
220 };
221
222 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,model::AutotuneAlgorithm algorithm,bool cpu_budget,bool ram_budget,DatasetBase ** output)223 void ModelDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
224 DatasetBase* input,
225 model::AutotuneAlgorithm algorithm,
226 bool cpu_budget, bool ram_budget,
227 DatasetBase** output) {
228 *output = new ModelDatasetOp::Dataset(
229 DatasetContext(DatasetContext::Params(
230 {ModelDatasetOp::kDatasetType, ModelDatasetOp::kDatasetOp})),
231 input, algorithm, cpu_budget, ram_budget);
232 }
233
ModelDatasetOp(OpKernelConstruction * ctx)234 ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
235 : UnaryDatasetOpKernel(ctx) {
236 if (ctx->HasAttr(kAlgorithm)) {
237 int64_t algorithm;
238 OP_REQUIRES_OK(ctx, ctx->GetAttr(kAlgorithm, &algorithm));
239 algorithm_ = model::AutotuneAlgorithm(algorithm);
240 } else {
241 algorithm_ = model::AutotuneAlgorithm::HILL_CLIMB;
242 }
243 OP_REQUIRES_OK(ctx, ctx->GetAttr(kCpuBudget, &cpu_budget_));
244 OP_REQUIRES(ctx, cpu_budget_ >= 0,
245 errors::InvalidArgument("CPU budget must be positive but is ",
246 cpu_budget_, "."));
247 if (ctx->HasAttr(kRamBudget)) {
248 OP_REQUIRES_OK(ctx, ctx->GetAttr(kRamBudget, &ram_budget_));
249 } else {
250 ram_budget_ = 0;
251 }
252 OP_REQUIRES(ctx, ram_budget_ >= 0,
253 errors::InvalidArgument("RAM budget must be positive but is ",
254 ram_budget_, "."));
255 }
256
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)257 void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
258 DatasetBase** output) {
259 *output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_,
260 ram_budget_);
261 }
262
263 namespace {
264 REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
265 ModelDatasetOp);
266 } // namespace
267 } // namespace data
268 } // namespace tensorflow
269 #else // !IS_MOBILE_PLATFORM
270 namespace tensorflow {
271 namespace data {
272 // static
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,model::AutotuneAlgorithm algorithm,bool cpu_budget,bool ram_budget,DatasetBase ** output)273 void ModelDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
274 DatasetBase* input,
275 model::AutotuneAlgorithm algorithm,
276 bool cpu_budget, bool ram_budget,
277 DatasetBase** output) {
278 input->Ref();
279 *output = input;
280 }
281
ModelDatasetOp(OpKernelConstruction * ctx)282 ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx)
283 : UnaryDatasetOpKernel(ctx) {}
284
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)285 void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
286 DatasetBase** output) {
287 input->Ref();
288 *output = input;
289 }
290
291 namespace {
292 REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
293 ModelDatasetOp);
294 } // namespace
295 } // namespace data
296 } // namespace tensorflow
297 #endif // !IS_MOBILE_PLATFORM
298