• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 #include <unordered_map>
17 
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/variant_encode_decode.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/graph/graph_def_builder.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 namespace tensorflow {
27 namespace data {
28 namespace {
29 
30 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
31 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
32 // and the wrapper's copy constructor and destructor take care of managing the
33 // reference count.
34 //
35 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
36 // specification. In particular, we cannot currently serialize an arbitrary
37 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
38 // implemented.
39 class DatasetVariantWrapper {
40  public:
DatasetVariantWrapper()41   DatasetVariantWrapper() : dataset_(nullptr) {}
42 
43   // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)44   explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
45 
DatasetVariantWrapper(const DatasetVariantWrapper & other)46   DatasetVariantWrapper(const DatasetVariantWrapper& other)
47       : dataset_(other.dataset_) {
48     if (dataset_) dataset_->Ref();
49   }
50 
~DatasetVariantWrapper()51   ~DatasetVariantWrapper() {
52     if (dataset_) dataset_->Unref();
53   }
54 
get() const55   DatasetBase* get() const { return dataset_; }
56 
TypeName() const57   string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const58   string DebugString() const {
59     if (dataset_) {
60       return dataset_->DebugString();
61     } else {
62       return "<Uninitialized DatasetVariantWrapper>";
63     }
64   }
Encode(VariantTensorData * data) const65   void Encode(VariantTensorData* data) const {
66     LOG(ERROR) << "The Encode() method is not implemented for "
67                   "DatasetVariantWrapper objects.";
68   }
Decode(const VariantTensorData & data)69   bool Decode(const VariantTensorData& data) {
70     LOG(ERROR) << "The Decode() method is not implemented for "
71                   "DatasetVariantWrapper objects.";
72     return false;
73   }
74 
75  private:
76   DatasetBase* const dataset_;  // Owns one reference.
77 };
78 
79 const char kWrappedDatasetVariantTypeName[] =
80     "tensorflow::data::WrappedDatasetVariant";
81 
82 class WrappedDatasetVariantWrapper {
83  public:
WrappedDatasetVariantWrapper()84   WrappedDatasetVariantWrapper() {}
85 
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)86   explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
87       : ds_tensor_(ds_tensor) {}
88 
get() const89   Tensor get() const { return ds_tensor_; }
90 
TypeName() const91   string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
92 
DebugString() const93   string DebugString() const {
94     return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
95   }
96 
Encode(VariantTensorData * data) const97   void Encode(VariantTensorData* data) const {
98     *(data->add_tensors()) = ds_tensor_;
99   }
100 
Decode(const VariantTensorData & data)101   bool Decode(const VariantTensorData& data) {
102     ds_tensor_ = data.tensors(0);
103     return true;
104   }
105 
106  private:
107   Tensor ds_tensor_;
108 };
109 
110 class WrapDatasetVariantOp : public OpKernel {
111  public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)112   explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
113 
Compute(OpKernelContext * ctx)114   void Compute(OpKernelContext* ctx) override {
115     const Tensor& tensor = ctx->input(0);
116     OP_REQUIRES(ctx,
117                 tensor.dtype() == DT_VARIANT &&
118                     TensorShapeUtils::IsScalar(tensor.shape()),
119                 errors::InvalidArgument(
120                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
121     DatasetBase* unused;
122     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
123     Tensor* output = nullptr;
124     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
125     output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
126   }
127 };
128 
129 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
130                         WrapDatasetVariantOp);
131 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
132                             .HostMemory("input_handle")
133                             .HostMemory("output_handle")
134                             .Device(DEVICE_GPU),
135                         WrapDatasetVariantOp);
136 
137 class UnwrapDatasetVariantOp : public OpKernel {
138  public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)139   explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
140 
Compute(OpKernelContext * ctx)141   void Compute(OpKernelContext* ctx) override {
142     const Tensor& tensor = ctx->input(0);
143     OP_REQUIRES(ctx,
144                 tensor.dtype() == DT_VARIANT &&
145                     TensorShapeUtils::IsScalar(tensor.shape()),
146                 errors::InvalidArgument(
147                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
148     Variant variant = tensor.scalar<Variant>()();
149     const WrappedDatasetVariantWrapper* wrapper =
150         variant.get<WrappedDatasetVariantWrapper>();
151     OP_REQUIRES(ctx, wrapper != nullptr,
152                 errors::InvalidArgument(
153                     "Tensor must be a WrappedDataset variant object."));
154     Tensor ds_tensor = wrapper->get();
155     OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
156   }
157 };
158 
159 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
160                         UnwrapDatasetVariantOp);
161 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
162                             .HostMemory("input_handle")
163                             .HostMemory("output_handle")
164                             .Device(DEVICE_GPU),
165                         UnwrapDatasetVariantOp);
166 
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)167 static Status WrappedDatasetVariantDeviceCopy(
168     const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
169     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
170   *to = WrappedDatasetVariantWrapper(from);
171   return Status::OK();
172 }
173 
174 #define REGISTER_OPTIONAL_COPY(DIRECTION)               \
175   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
176       WrappedDatasetVariantWrapper, DIRECTION,          \
177       WrappedDatasetVariantDeviceCopy)
178 
179 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
180 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
181 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
182 
183 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
184                                        kWrappedDatasetVariantTypeName);
185 
186 }  // namespace
187 
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)188 Status GraphDefBuilderWrapper::AddDataset(
189     const DatasetBase* dataset,
190     const std::vector<std::pair<size_t, Node*>>& inputs,
191     const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
192     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
193     Node** output) {
194   const string& type_string = dataset->type_string();
195   std::unique_ptr<const GraphDefBuilder::Options> opts(
196       new GraphDefBuilder::Options(b_->opts()));
197   // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
198   // attributes defined. It will be nice to have a consistent pattern.
199   bool has_output_types_attr = HasAttr(type_string, "output_types");
200   bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
201   if (has_output_shapes_attr) {
202     opts.reset(new GraphDefBuilder::Options(
203         opts->WithAttr("output_shapes", dataset->output_shapes())));
204   }
205   if (has_output_types_attr) {
206     opts.reset(new GraphDefBuilder::Options(
207         opts->WithAttr("output_types", dataset->output_dtypes())));
208   }
209   for (auto attr : attrs) {
210     opts.reset(
211         new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
212   }
213   if (opts->HaveError()) {
214     return errors::Internal("AddDataset: Failed to build Options with error ",
215                             opts->StatusToString());
216   }
217   NodeBuilder node_builder(opts->GetNameForOp(type_string), type_string,
218                            opts->op_registry());
219   {
220     size_t total_size = inputs.size() + list_inputs.size();
221     auto inputs_iter = inputs.begin();
222     auto list_inputs_iter = list_inputs.begin();
223     for (int i = 0; i < total_size; i++) {
224       if (inputs_iter != inputs.end() && inputs_iter->first == i) {
225         node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
226         inputs_iter++;
227       } else if (list_inputs_iter != list_inputs.end() &&
228                  list_inputs_iter->first == i) {
229         std::vector<NodeBuilder::NodeOut> nodeout_inputs;
230         nodeout_inputs.reserve(list_inputs_iter->second.size());
231         for (Node* n : list_inputs_iter->second) {
232           nodeout_inputs.emplace_back(n);
233         }
234         node_builder.Input(nodeout_inputs);
235         list_inputs_iter++;
236       } else {
237         return errors::InvalidArgument("No input found for index ", i);
238       }
239     }
240   }
241   *output = opts->FinalizeBuilder(&node_builder);
242   if (*output == nullptr) {
243     return errors::Internal("AddDataset: Failed to build ", type_string,
244                             " op with error ", opts->StatusToString());
245   }
246   return Status::OK();
247 }
248 
AddFunction(SerializationContext * ctx,const string & function_name)249 Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
250                                            const string& function_name) {
251   if (b_->HasFunction(function_name)) {
252     VLOG(1) << "Function with name " << function_name << "already exists in"
253             << " the graph. It will not be added again.";
254     return Status::OK();
255   }
256   if (!ctx->optimization_only()) {
257     TF_RETURN_IF_ERROR(
258         EnsureFunctionIsStateless(ctx->flib_def(), function_name));
259   }
260   const FunctionDef* f_def = ctx->flib_def().Find(function_name);
261   if (f_def == nullptr) {
262     return errors::InvalidArgument("Unable to find FunctionDef for ",
263                                    function_name, " in the registry.");
264   }
265   FunctionDefLibrary def;
266   *def.add_function() = *f_def;
267   const string gradient_func = ctx->flib_def().FindGradient(function_name);
268   if (!gradient_func.empty()) {
269     GradientDef* g_def = def.add_gradient();
270     g_def->set_function_name(function_name);
271     g_def->set_gradient_func(gradient_func);
272   }
273   TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
274 
275   // Recursively add functions in inputs of function_name.
276   for (const NodeDef& node_def : f_def->node_def()) {
277     const OpRegistrationData* op_reg_data = nullptr;
278     TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data));
279     if (op_reg_data->is_function_op) {
280       TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
281     }
282     // Recursively add functions in attrs of this NodeDef.
283     for (const auto& pair : node_def.attr()) {
284       TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second));
285     }
286   }
287 
288   // Recursively add functions in attrs of function_name.
289   for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
290     TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second));
291   }
292   return Status::OK();
293 }
294 
AddPlaceholderInternal(const Tensor & val,Node ** output)295 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
296                                                     Node** output) {
297   *output = ops::SourceOp(
298       "Placeholder",
299       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
300 }
301 
AddTensorInternal(const Tensor & val,Node ** output)302 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
303                                                Node** output) {
304   *output = ops::SourceOp(
305       "Const",
306       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
307 }
308 
HasAttr(const string & name,const string & attr_name) const309 bool GraphDefBuilderWrapper::HasAttr(const string& name,
310                                      const string& attr_name) const {
311   const OpDef* op_def = nullptr;
312   Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
313   if (!s.ok() || op_def == nullptr) {
314     return false;
315   }
316   return HasAttr(op_def, attr_name);
317 }
318 
GetAllocatedBytes(const std::vector<Tensor> & element)319 int64 GetAllocatedBytes(const std::vector<Tensor>& element) {
320   int64 allocated_bytes = 0;
321   DatasetBase* dataset;
322   for (auto& tensor : element) {
323     if (tensor.dtype() == DT_VARIANT &&
324         GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
325       allocated_bytes += dataset->AllocatedBytes();
326     } else {
327       allocated_bytes += tensor.AllocatedBytes();
328     }
329   }
330   return allocated_bytes;
331 }
332 
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)333 Status GetDatasetFromVariantTensor(const Tensor& tensor,
334                                    DatasetBase** out_dataset) {
335   if (!(tensor.dtype() == DT_VARIANT &&
336         TensorShapeUtils::IsScalar(tensor.shape()))) {
337     return errors::InvalidArgument(
338         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
339   }
340   const Variant& variant = tensor.scalar<Variant>()();
341   const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
342   if (wrapper == nullptr) {
343     return errors::InvalidArgument("Tensor must be a Dataset object.");
344   }
345   *out_dataset = wrapper->get();
346   if (*out_dataset == nullptr) {
347     return errors::Internal("Read uninitialized Dataset variant.");
348   }
349   return Status::OK();
350 }
351 
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)352 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
353   if (!(tensor->dtype() == DT_VARIANT &&
354         TensorShapeUtils::IsScalar(tensor->shape()))) {
355     return errors::InvalidArgument(
356         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
357   }
358   tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
359   return Status::OK();
360 }
361 
Save(SerializationContext * ctx,IteratorStateWriter * writer) const362 Status DatasetBase::Save(SerializationContext* ctx,
363                          IteratorStateWriter* writer) const {
364   string serialized_graph_def;
365   string output_node;
366   GraphDefBuilder b;
367   DatasetGraphDefBuilder db(&b);
368   Node* node = nullptr;
369   TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
370   output_node = node->name();
371   GraphDef graph_def;
372   TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
373   graph_def.SerializeToString(&serialized_graph_def);
374   TF_RETURN_IF_ERROR(
375       writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
376   TF_RETURN_IF_ERROR(
377       writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
378   return Status::OK();
379 }
380 
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)381 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
382     SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
383   Status status = dataset->AsGraphDefInternal(ctx, this, output);
384   if (ctx->optimization_only() && errors::IsUnimplemented(status)) {
385     Tensor t(DT_VARIANT, TensorShape({}));
386     // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
387     // increment the refcount of `dataset` here to retain ownership.
388     dataset->Ref();
389     TF_RETURN_IF_ERROR(
390         StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
391     TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
392     DCHECK_NE(ctx->input_list(), nullptr);
393     ctx->input_list()->emplace_back((*output)->name(), std::move(t));
394     LOG(WARNING)
395         << "Input of " << dataset->DebugString()
396         << " will not be optimized because the dataset does not implement the "
397            "AsGraphDefInternal() method needed to apply optimizations.";
398     return Status::OK();
399   }
400   return status;
401 }
402 
Compute(OpKernelContext * ctx)403 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
404   DatasetBase* dataset = nullptr;
405   MakeDataset(ctx, &dataset);
406   if (ctx->status().ok()) {
407     Tensor* output = nullptr;
408     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
409     OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
410   }
411 }
412 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)413 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
414                                        DatasetBase** output) {
415   DatasetBase* input;
416   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
417   MakeDataset(ctx, input, output);
418 }
419 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)420 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
421                                         DatasetBase** output) {
422   DatasetBase* input;
423   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
424   DatasetBase* another_input;
425   OP_REQUIRES_OK(ctx,
426                  GetDatasetFromVariantTensor(ctx->input(1), &another_input));
427   MakeDataset(ctx, input, another_input, output);
428 }
429 
430 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
431 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
432     "_DATASET_GRAPH_OUTPUT_NODE";
433 
BackgroundWorker(Env * env,const string & name)434 BackgroundWorker::BackgroundWorker(Env* env, const string& name) {
435   thread_.reset(env->StartThread({} /* thread_options */, name,
436                                  [this]() { WorkerLoop(); }));
437 }
438 
~BackgroundWorker()439 BackgroundWorker::~BackgroundWorker() {
440   {
441     mutex_lock l(mu_);
442     cancelled_ = true;
443   }
444   cond_var_.notify_one();
445   // Block until the background thread has terminated.
446   //
447   // NOTE(mrry): We explicitly free and join the thread here because
448   // `WorkerLoop()` uses other members of this object, and so we must join
449   // the thread before destroying them.
450   thread_.reset();
451 }
452 
Schedule(std::function<void ()> work_item)453 void BackgroundWorker::Schedule(std::function<void()> work_item) {
454   {
455     mutex_lock l(mu_);
456     work_queue_.push_back(std::move(work_item));
457   }
458   cond_var_.notify_one();
459 }
460 
WorkerLoop()461 void BackgroundWorker::WorkerLoop() {
462   while (true) {
463     std::function<void()> work_item = nullptr;
464     {
465       mutex_lock l(mu_);
466       while (!cancelled_ && work_queue_.empty()) {
467         cond_var_.wait(l);
468       }
469       if (cancelled_) {
470         return;
471       }
472       DCHECK(!work_queue_.empty());
473       work_item = std::move(work_queue_.front());
474       work_queue_.pop_front();
475     }
476     DCHECK(work_item != nullptr);
477     work_item();
478   }
479 }
480 
481 }  // namespace data
482 }  // namespace tensorflow
483