1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 #include <unordered_map>
17
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/variant_encode_decode.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/graph/graph_def_builder.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 namespace tensorflow {
27 namespace data {
28 namespace {
29
30 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
31 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
32 // and the wrapper's copy constructor and destructor take care of managing the
33 // reference count.
34 //
35 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
36 // specification. In particular, we cannot currently serialize an arbitrary
37 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
38 // implemented.
39 class DatasetVariantWrapper {
40 public:
DatasetVariantWrapper()41 DatasetVariantWrapper() : dataset_(nullptr) {}
42
43 // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)44 explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
45
DatasetVariantWrapper(const DatasetVariantWrapper & other)46 DatasetVariantWrapper(const DatasetVariantWrapper& other)
47 : dataset_(other.dataset_) {
48 if (dataset_) dataset_->Ref();
49 }
50
~DatasetVariantWrapper()51 ~DatasetVariantWrapper() {
52 if (dataset_) dataset_->Unref();
53 }
54
get() const55 DatasetBase* get() const { return dataset_; }
56
TypeName() const57 string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const58 string DebugString() const {
59 if (dataset_) {
60 return dataset_->DebugString();
61 } else {
62 return "<Uninitialized DatasetVariantWrapper>";
63 }
64 }
Encode(VariantTensorData * data) const65 void Encode(VariantTensorData* data) const {
66 LOG(ERROR) << "The Encode() method is not implemented for "
67 "DatasetVariantWrapper objects.";
68 }
Decode(const VariantTensorData & data)69 bool Decode(const VariantTensorData& data) {
70 LOG(ERROR) << "The Decode() method is not implemented for "
71 "DatasetVariantWrapper objects.";
72 return false;
73 }
74
75 private:
76 DatasetBase* const dataset_; // Owns one reference.
77 };
78
79 const char kWrappedDatasetVariantTypeName[] =
80 "tensorflow::data::WrappedDatasetVariant";
81
82 class WrappedDatasetVariantWrapper {
83 public:
WrappedDatasetVariantWrapper()84 WrappedDatasetVariantWrapper() {}
85
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)86 explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
87 : ds_tensor_(ds_tensor) {}
88
get() const89 Tensor get() const { return ds_tensor_; }
90
TypeName() const91 string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
92
DebugString() const93 string DebugString() const {
94 return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
95 }
96
Encode(VariantTensorData * data) const97 void Encode(VariantTensorData* data) const {
98 *(data->add_tensors()) = ds_tensor_;
99 }
100
Decode(const VariantTensorData & data)101 bool Decode(const VariantTensorData& data) {
102 ds_tensor_ = data.tensors(0);
103 return true;
104 }
105
106 private:
107 Tensor ds_tensor_;
108 };
109
110 class WrapDatasetVariantOp : public OpKernel {
111 public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)112 explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
113
Compute(OpKernelContext * ctx)114 void Compute(OpKernelContext* ctx) override {
115 const Tensor& tensor = ctx->input(0);
116 OP_REQUIRES(ctx,
117 tensor.dtype() == DT_VARIANT &&
118 TensorShapeUtils::IsScalar(tensor.shape()),
119 errors::InvalidArgument(
120 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
121 DatasetBase* unused;
122 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
123 Tensor* output = nullptr;
124 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
125 output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
126 }
127 };
128
129 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
130 WrapDatasetVariantOp);
131 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
132 .HostMemory("input_handle")
133 .HostMemory("output_handle")
134 .Device(DEVICE_GPU),
135 WrapDatasetVariantOp);
136
137 class UnwrapDatasetVariantOp : public OpKernel {
138 public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)139 explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
140
Compute(OpKernelContext * ctx)141 void Compute(OpKernelContext* ctx) override {
142 const Tensor& tensor = ctx->input(0);
143 OP_REQUIRES(ctx,
144 tensor.dtype() == DT_VARIANT &&
145 TensorShapeUtils::IsScalar(tensor.shape()),
146 errors::InvalidArgument(
147 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
148 Variant variant = tensor.scalar<Variant>()();
149 const WrappedDatasetVariantWrapper* wrapper =
150 variant.get<WrappedDatasetVariantWrapper>();
151 OP_REQUIRES(ctx, wrapper != nullptr,
152 errors::InvalidArgument(
153 "Tensor must be a WrappedDataset variant object."));
154 Tensor ds_tensor = wrapper->get();
155 OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
156 }
157 };
158
159 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
160 UnwrapDatasetVariantOp);
161 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
162 .HostMemory("input_handle")
163 .HostMemory("output_handle")
164 .Device(DEVICE_GPU),
165 UnwrapDatasetVariantOp);
166
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)167 static Status WrappedDatasetVariantDeviceCopy(
168 const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
169 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
170 *to = WrappedDatasetVariantWrapper(from);
171 return Status::OK();
172 }
173
174 #define REGISTER_OPTIONAL_COPY(DIRECTION) \
175 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
176 WrappedDatasetVariantWrapper, DIRECTION, \
177 WrappedDatasetVariantDeviceCopy)
178
179 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
180 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
181 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
182
183 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
184 kWrappedDatasetVariantTypeName);
185
186 } // namespace
187
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)188 Status GraphDefBuilderWrapper::AddDataset(
189 const DatasetBase* dataset,
190 const std::vector<std::pair<size_t, Node*>>& inputs,
191 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
192 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
193 Node** output) {
194 const string& type_string = dataset->type_string();
195 std::unique_ptr<const GraphDefBuilder::Options> opts(
196 new GraphDefBuilder::Options(b_->opts()));
197 // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
198 // attributes defined. It will be nice to have a consistent pattern.
199 bool has_output_types_attr = HasAttr(type_string, "output_types");
200 bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
201 if (has_output_shapes_attr) {
202 opts.reset(new GraphDefBuilder::Options(
203 opts->WithAttr("output_shapes", dataset->output_shapes())));
204 }
205 if (has_output_types_attr) {
206 opts.reset(new GraphDefBuilder::Options(
207 opts->WithAttr("output_types", dataset->output_dtypes())));
208 }
209 for (auto attr : attrs) {
210 opts.reset(
211 new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
212 }
213 if (opts->HaveError()) {
214 return errors::Internal("AddDataset: Failed to build Options with error ",
215 opts->StatusToString());
216 }
217 NodeBuilder node_builder(opts->GetNameForOp(type_string), type_string,
218 opts->op_registry());
219 {
220 size_t total_size = inputs.size() + list_inputs.size();
221 auto inputs_iter = inputs.begin();
222 auto list_inputs_iter = list_inputs.begin();
223 for (int i = 0; i < total_size; i++) {
224 if (inputs_iter != inputs.end() && inputs_iter->first == i) {
225 node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
226 inputs_iter++;
227 } else if (list_inputs_iter != list_inputs.end() &&
228 list_inputs_iter->first == i) {
229 std::vector<NodeBuilder::NodeOut> nodeout_inputs;
230 nodeout_inputs.reserve(list_inputs_iter->second.size());
231 for (Node* n : list_inputs_iter->second) {
232 nodeout_inputs.emplace_back(n);
233 }
234 node_builder.Input(nodeout_inputs);
235 list_inputs_iter++;
236 } else {
237 return errors::InvalidArgument("No input found for index ", i);
238 }
239 }
240 }
241 *output = opts->FinalizeBuilder(&node_builder);
242 if (*output == nullptr) {
243 return errors::Internal("AddDataset: Failed to build ", type_string,
244 " op with error ", opts->StatusToString());
245 }
246 return Status::OK();
247 }
248
AddFunction(SerializationContext * ctx,const string & function_name)249 Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
250 const string& function_name) {
251 if (b_->HasFunction(function_name)) {
252 VLOG(1) << "Function with name " << function_name << "already exists in"
253 << " the graph. It will not be added again.";
254 return Status::OK();
255 }
256 if (!ctx->optimization_only()) {
257 TF_RETURN_IF_ERROR(
258 EnsureFunctionIsStateless(ctx->flib_def(), function_name));
259 }
260 const FunctionDef* f_def = ctx->flib_def().Find(function_name);
261 if (f_def == nullptr) {
262 return errors::InvalidArgument("Unable to find FunctionDef for ",
263 function_name, " in the registry.");
264 }
265 FunctionDefLibrary def;
266 *def.add_function() = *f_def;
267 const string gradient_func = ctx->flib_def().FindGradient(function_name);
268 if (!gradient_func.empty()) {
269 GradientDef* g_def = def.add_gradient();
270 g_def->set_function_name(function_name);
271 g_def->set_gradient_func(gradient_func);
272 }
273 TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
274
275 // Recursively add functions in inputs of function_name.
276 for (const NodeDef& node_def : f_def->node_def()) {
277 const OpRegistrationData* op_reg_data = nullptr;
278 TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data));
279 if (op_reg_data->is_function_op) {
280 TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
281 }
282 // Recursively add functions in attrs of this NodeDef.
283 for (const auto& pair : node_def.attr()) {
284 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second));
285 }
286 }
287
288 // Recursively add functions in attrs of function_name.
289 for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
290 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second));
291 }
292 return Status::OK();
293 }
294
AddPlaceholderInternal(const Tensor & val,Node ** output)295 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
296 Node** output) {
297 *output = ops::SourceOp(
298 "Placeholder",
299 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
300 }
301
AddTensorInternal(const Tensor & val,Node ** output)302 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
303 Node** output) {
304 *output = ops::SourceOp(
305 "Const",
306 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
307 }
308
HasAttr(const string & name,const string & attr_name) const309 bool GraphDefBuilderWrapper::HasAttr(const string& name,
310 const string& attr_name) const {
311 const OpDef* op_def = nullptr;
312 Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
313 if (!s.ok() || op_def == nullptr) {
314 return false;
315 }
316 return HasAttr(op_def, attr_name);
317 }
318
GetAllocatedBytes(const std::vector<Tensor> & element)319 int64 GetAllocatedBytes(const std::vector<Tensor>& element) {
320 int64 allocated_bytes = 0;
321 DatasetBase* dataset;
322 for (auto& tensor : element) {
323 if (tensor.dtype() == DT_VARIANT &&
324 GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
325 allocated_bytes += dataset->AllocatedBytes();
326 } else {
327 allocated_bytes += tensor.AllocatedBytes();
328 }
329 }
330 return allocated_bytes;
331 }
332
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)333 Status GetDatasetFromVariantTensor(const Tensor& tensor,
334 DatasetBase** out_dataset) {
335 if (!(tensor.dtype() == DT_VARIANT &&
336 TensorShapeUtils::IsScalar(tensor.shape()))) {
337 return errors::InvalidArgument(
338 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
339 }
340 const Variant& variant = tensor.scalar<Variant>()();
341 const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
342 if (wrapper == nullptr) {
343 return errors::InvalidArgument("Tensor must be a Dataset object.");
344 }
345 *out_dataset = wrapper->get();
346 if (*out_dataset == nullptr) {
347 return errors::Internal("Read uninitialized Dataset variant.");
348 }
349 return Status::OK();
350 }
351
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)352 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
353 if (!(tensor->dtype() == DT_VARIANT &&
354 TensorShapeUtils::IsScalar(tensor->shape()))) {
355 return errors::InvalidArgument(
356 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
357 }
358 tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
359 return Status::OK();
360 }
361
Save(SerializationContext * ctx,IteratorStateWriter * writer) const362 Status DatasetBase::Save(SerializationContext* ctx,
363 IteratorStateWriter* writer) const {
364 string serialized_graph_def;
365 string output_node;
366 GraphDefBuilder b;
367 DatasetGraphDefBuilder db(&b);
368 Node* node = nullptr;
369 TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
370 output_node = node->name();
371 GraphDef graph_def;
372 TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
373 graph_def.SerializeToString(&serialized_graph_def);
374 TF_RETURN_IF_ERROR(
375 writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
376 TF_RETURN_IF_ERROR(
377 writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
378 return Status::OK();
379 }
380
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)381 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
382 SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
383 Status status = dataset->AsGraphDefInternal(ctx, this, output);
384 if (ctx->optimization_only() && errors::IsUnimplemented(status)) {
385 Tensor t(DT_VARIANT, TensorShape({}));
386 // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
387 // increment the refcount of `dataset` here to retain ownership.
388 dataset->Ref();
389 TF_RETURN_IF_ERROR(
390 StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
391 TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
392 DCHECK_NE(ctx->input_list(), nullptr);
393 ctx->input_list()->emplace_back((*output)->name(), std::move(t));
394 LOG(WARNING)
395 << "Input of " << dataset->DebugString()
396 << " will not be optimized because the dataset does not implement the "
397 "AsGraphDefInternal() method needed to apply optimizations.";
398 return Status::OK();
399 }
400 return status;
401 }
402
Compute(OpKernelContext * ctx)403 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
404 DatasetBase* dataset = nullptr;
405 MakeDataset(ctx, &dataset);
406 if (ctx->status().ok()) {
407 Tensor* output = nullptr;
408 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
409 OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
410 }
411 }
412
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)413 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
414 DatasetBase** output) {
415 DatasetBase* input;
416 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
417 MakeDataset(ctx, input, output);
418 }
419
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)420 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
421 DatasetBase** output) {
422 DatasetBase* input;
423 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
424 DatasetBase* another_input;
425 OP_REQUIRES_OK(ctx,
426 GetDatasetFromVariantTensor(ctx->input(1), &another_input));
427 MakeDataset(ctx, input, another_input, output);
428 }
429
430 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
431 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
432 "_DATASET_GRAPH_OUTPUT_NODE";
433
BackgroundWorker(Env * env,const string & name)434 BackgroundWorker::BackgroundWorker(Env* env, const string& name) {
435 thread_.reset(env->StartThread({} /* thread_options */, name,
436 [this]() { WorkerLoop(); }));
437 }
438
~BackgroundWorker()439 BackgroundWorker::~BackgroundWorker() {
440 {
441 mutex_lock l(mu_);
442 cancelled_ = true;
443 }
444 cond_var_.notify_one();
445 // Block until the background thread has terminated.
446 //
447 // NOTE(mrry): We explicitly free and join the thread here because
448 // `WorkerLoop()` uses other members of this object, and so we must join
449 // the thread before destroying them.
450 thread_.reset();
451 }
452
Schedule(std::function<void ()> work_item)453 void BackgroundWorker::Schedule(std::function<void()> work_item) {
454 {
455 mutex_lock l(mu_);
456 work_queue_.push_back(std::move(work_item));
457 }
458 cond_var_.notify_one();
459 }
460
WorkerLoop()461 void BackgroundWorker::WorkerLoop() {
462 while (true) {
463 std::function<void()> work_item = nullptr;
464 {
465 mutex_lock l(mu_);
466 while (!cancelled_ && work_queue_.empty()) {
467 cond_var_.wait(l);
468 }
469 if (cancelled_) {
470 return;
471 }
472 DCHECK(!work_queue_.empty());
473 work_item = std::move(work_queue_.front());
474 work_queue_.pop_front();
475 }
476 DCHECK(work_item != nullptr);
477 work_item();
478 }
479 }
480
481 } // namespace data
482 } // namespace tensorflow
483