• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 
17 #include <unordered_map>
18 
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/variant_encode_decode.h"
22 #include "tensorflow/core/framework/variant_op_registry.h"
23 #include "tensorflow/core/graph/graph_def_builder.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/resource.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 
31 namespace tensorflow {
32 namespace data {
33 namespace {
34 
get_dataset_op_registry_lock()35 static mutex* get_dataset_op_registry_lock() {
36   static mutex dataset_op_registry_lock(LINKER_INITIALIZED);
37   return &dataset_op_registry_lock;
38 }
39 
get_dataset_op_registry()40 static std::unordered_set<string>* get_dataset_op_registry() {
41   static std::unordered_set<string>* names = new std::unordered_set<string>;
42   return names;
43 }
44 
45 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
46 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
47 // and the wrapper's copy constructor and destructor take care of managing the
48 // reference count.
49 //
50 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
51 // specification. In particular, we cannot currently serialize an arbitrary
52 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
53 // implemented.
54 class DatasetVariantWrapper {
55  public:
DatasetVariantWrapper()56   DatasetVariantWrapper() : dataset_(nullptr) {}
57 
58   // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)59   explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
60 
DatasetVariantWrapper(const DatasetVariantWrapper & other)61   DatasetVariantWrapper(const DatasetVariantWrapper& other)
62       : dataset_(other.dataset_) {
63     if (dataset_) dataset_->Ref();
64   }
65 
operator =(DatasetVariantWrapper && other)66   DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
67     if (&other == this) return *this;
68     std::swap(dataset_, other.dataset_);
69     return *this;
70   }
71 
72   DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
73 
~DatasetVariantWrapper()74   ~DatasetVariantWrapper() {
75     if (dataset_) dataset_->Unref();
76   }
77 
get() const78   DatasetBase* get() const { return dataset_; }
79 
TypeName() const80   string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const81   string DebugString() const {
82     if (dataset_) {
83       return dataset_->DebugString();
84     } else {
85       return "<Uninitialized DatasetVariantWrapper>";
86     }
87   }
Encode(VariantTensorData * data) const88   void Encode(VariantTensorData* data) const {
89     LOG(ERROR) << "The Encode() method is not implemented for "
90                   "DatasetVariantWrapper objects.";
91   }
Decode(const VariantTensorData & data)92   bool Decode(const VariantTensorData& data) {
93     LOG(ERROR) << "The Decode() method is not implemented for "
94                   "DatasetVariantWrapper objects.";
95     return false;
96   }
97 
98  private:
99   DatasetBase* dataset_;  // Owns one reference.
100 };
101 
102 const char kWrappedDatasetVariantTypeName[] =
103     "tensorflow::data::WrappedDatasetVariant";
104 
105 class WrappedDatasetVariantWrapper {
106  public:
WrappedDatasetVariantWrapper()107   WrappedDatasetVariantWrapper() {}
108 
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)109   explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
110       : ds_tensor_(ds_tensor) {}
111 
get() const112   Tensor get() const { return ds_tensor_; }
113 
TypeName() const114   string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
115 
DebugString() const116   string DebugString() const {
117     return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
118   }
119 
Encode(VariantTensorData * data) const120   void Encode(VariantTensorData* data) const {
121     *(data->add_tensors()) = ds_tensor_;
122   }
123 
Decode(const VariantTensorData & data)124   bool Decode(const VariantTensorData& data) {
125     ds_tensor_ = data.tensors(0);
126     return true;
127   }
128 
129  private:
130   Tensor ds_tensor_;
131 };
132 
133 class WrapDatasetVariantOp : public OpKernel {
134  public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)135   explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
136 
Compute(OpKernelContext * ctx)137   void Compute(OpKernelContext* ctx) override {
138     const Tensor& tensor = ctx->input(0);
139     OP_REQUIRES(ctx,
140                 tensor.dtype() == DT_VARIANT &&
141                     TensorShapeUtils::IsScalar(tensor.shape()),
142                 errors::InvalidArgument(
143                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
144     DatasetBase* unused;
145     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
146     Tensor* output = nullptr;
147     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
148     output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
149   }
150 };
151 
152 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
153                         WrapDatasetVariantOp);
154 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
155                             .HostMemory("input_handle")
156                             .HostMemory("output_handle")
157                             .Device(DEVICE_GPU),
158                         WrapDatasetVariantOp);
159 
160 class UnwrapDatasetVariantOp : public OpKernel {
161  public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)162   explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
163 
Compute(OpKernelContext * ctx)164   void Compute(OpKernelContext* ctx) override {
165     const Tensor& tensor = ctx->input(0);
166     OP_REQUIRES(ctx,
167                 tensor.dtype() == DT_VARIANT &&
168                     TensorShapeUtils::IsScalar(tensor.shape()),
169                 errors::InvalidArgument(
170                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
171     Variant variant = tensor.scalar<Variant>()();
172     const WrappedDatasetVariantWrapper* wrapper =
173         variant.get<WrappedDatasetVariantWrapper>();
174     OP_REQUIRES(ctx, wrapper != nullptr,
175                 errors::InvalidArgument(
176                     "Tensor must be a WrappedDataset variant object."));
177     Tensor ds_tensor = wrapper->get();
178     OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
179   }
180 };
181 
182 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
183                         UnwrapDatasetVariantOp);
184 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
185                             .HostMemory("input_handle")
186                             .HostMemory("output_handle")
187                             .Device(DEVICE_GPU),
188                         UnwrapDatasetVariantOp);
189 
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)190 static Status WrappedDatasetVariantDeviceCopy(
191     const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
192     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
193   *to = WrappedDatasetVariantWrapper(from);
194   return Status::OK();
195 }
196 
197 #define REGISTER_OPTIONAL_COPY(DIRECTION)               \
198   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
199       WrappedDatasetVariantWrapper, DIRECTION,          \
200       WrappedDatasetVariantDeviceCopy)
201 
202 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
203 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
204 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
205 
206 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
207                                        kWrappedDatasetVariantTypeName);
208 
209 }  // namespace
210 
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)211 Status GraphDefBuilderWrapper::AddDataset(
212     const DatasetBase* dataset,
213     const std::vector<std::pair<size_t, Node*>>& inputs,
214     const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
215     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
216     Node** output) {
217   const string& type_string = dataset->type_string();
218   std::unique_ptr<const GraphDefBuilder::Options> opts(
219       new GraphDefBuilder::Options(b_->opts()));
220   // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
221   // attributes defined. It will be nice to have a consistent pattern.
222   bool has_output_types_attr = HasAttr(type_string, "output_types");
223   bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
224   if (has_output_shapes_attr) {
225     opts.reset(new GraphDefBuilder::Options(
226         opts->WithAttr("output_shapes", dataset->output_shapes())));
227   }
228   if (has_output_types_attr) {
229     opts.reset(new GraphDefBuilder::Options(
230         opts->WithAttr("output_types", dataset->output_dtypes())));
231   }
232   for (const auto& attr : attrs) {
233     opts.reset(
234         new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
235   }
236   if (opts->HaveError()) {
237     return errors::Internal("AddDataset: Failed to build Options with error ",
238                             opts->StatusToString());
239   }
240   NodeBuilder node_builder(opts->GetNameForOp(type_string), type_string,
241                            opts->op_registry());
242   {
243     size_t total_size = inputs.size() + list_inputs.size();
244     auto inputs_iter = inputs.begin();
245     auto list_inputs_iter = list_inputs.begin();
246     for (int i = 0; i < total_size; i++) {
247       if (inputs_iter != inputs.end() && inputs_iter->first == i) {
248         node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
249         inputs_iter++;
250       } else if (list_inputs_iter != list_inputs.end() &&
251                  list_inputs_iter->first == i) {
252         std::vector<NodeBuilder::NodeOut> nodeout_inputs;
253         nodeout_inputs.reserve(list_inputs_iter->second.size());
254         for (Node* n : list_inputs_iter->second) {
255           nodeout_inputs.emplace_back(n);
256         }
257         node_builder.Input(nodeout_inputs);
258         list_inputs_iter++;
259       } else {
260         return errors::InvalidArgument("No input found for index ", i);
261       }
262     }
263   }
264   *output = opts->FinalizeBuilder(&node_builder);
265   if (*output == nullptr) {
266     return errors::Internal("AddDataset: Failed to build ", type_string,
267                             " op with error ", opts->StatusToString());
268   }
269   return Status::OK();
270 }
271 
AddFunction(SerializationContext * ctx,const string & function_name,const FunctionLibraryDefinition & lib_def)272 Status GraphDefBuilderWrapper::AddFunction(
273     SerializationContext* ctx, const string& function_name,
274     const FunctionLibraryDefinition& lib_def) {
275   if (b_->HasFunction(function_name)) {
276     VLOG(1) << "Function with name " << function_name << "already exists in"
277             << " the graph. It will not be added again.";
278     return Status::OK();
279   }
280   const FunctionDef* f_def = lib_def.Find(function_name);
281   if (f_def == nullptr) {
282     return errors::InvalidArgument("Unable to find FunctionDef for ",
283                                    function_name, " in the registry.");
284   }
285   FunctionDefLibrary def;
286   *def.add_function() = *f_def;
287   const string gradient_func = lib_def.FindGradient(function_name);
288   if (!gradient_func.empty()) {
289     GradientDef* g_def = def.add_gradient();
290     g_def->set_function_name(function_name);
291     g_def->set_gradient_func(gradient_func);
292   }
293   TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
294 
295   // Recursively add functions in inputs of function_name.
296   for (const NodeDef& node_def : f_def->node_def()) {
297     const OpRegistrationData* op_reg_data = nullptr;
298     TF_RETURN_IF_ERROR(lib_def.LookUp(node_def.op(), &op_reg_data));
299     if (op_reg_data->is_function_op) {
300       TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name(), lib_def));
301     }
302     // Recursively add functions in attrs of this NodeDef.
303     for (const auto& pair : node_def.attr()) {
304       TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second, lib_def));
305     }
306   }
307 
308   // Recursively add functions in attrs of function_name.
309   for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
310     TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def));
311   }
312   return Status::OK();
313 }
314 
AddPlaceholderInternal(const Tensor & val,Node ** output)315 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
316                                                     Node** output) {
317   *output = ops::SourceOp(
318       "Placeholder",
319       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
320 }
321 
AddTensorInternal(const Tensor & val,Node ** output)322 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
323                                                Node** output) {
324   *output = ops::SourceOp(
325       "Const",
326       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
327 }
328 
HasAttr(const string & name,const string & attr_name) const329 bool GraphDefBuilderWrapper::HasAttr(const string& name,
330                                      const string& attr_name) const {
331   const OpDef* op_def = nullptr;
332   Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
333   if (!s.ok() || op_def == nullptr) {
334     return false;
335   }
336   return HasAttr(op_def, attr_name);
337 }
338 
InitializeBase(IteratorContext * ctx,const IteratorBase * parent)339 Status IteratorBase::InitializeBase(IteratorContext* ctx,
340                                     const IteratorBase* parent) {
341   parent_ = parent;
342   id_ =
343       Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
344   if (parent_) {
345     parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()),
346                                         reinterpret_cast<uint64>(parent_));
347   }
348   if (const auto& model = ctx->model()) {
349     auto factory = [ctx, this](model::Node::Args args) {
350       return CreateNode(ctx, std::move(args));
351     };
352     model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
353     cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
354   }
355   return Status::OK();
356 }
357 
GetAllocatedBytes(const std::vector<Tensor> & element)358 int64 GetAllocatedBytes(const std::vector<Tensor>& element) {
359   int64 allocated_bytes = 0;
360   DatasetBase* dataset;
361   for (auto& tensor : element) {
362     if (tensor.dtype() == DT_VARIANT &&
363         GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
364       allocated_bytes += dataset->AllocatedBytes();
365     } else {
366       allocated_bytes += tensor.AllocatedBytes();
367     }
368   }
369   return allocated_bytes;
370 }
371 
GetTotalBytes(const std::vector<Tensor> & element)372 int64 GetTotalBytes(const std::vector<Tensor>& element) {
373   int64 total_bytes = 0;
374   DatasetBase* dataset;
375   for (auto& tensor : element) {
376     if (tensor.dtype() == DT_VARIANT &&
377         GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
378       total_bytes += dataset->TotalBytes();
379     } else {
380       total_bytes += tensor.TotalBytes();
381     }
382   }
383   return total_bytes;
384 }
385 
FullName(const std::string & prefix,const std::string & name)386 std::string FullName(const std::string& prefix, const std::string& name) {
387   if (str_util::StrContains(name, kColon)) {
388     LOG(ERROR) << name << " should not contain " << kColon;
389   }
390 
391   return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
392 }
393 
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)394 Status GetDatasetFromVariantTensor(const Tensor& tensor,
395                                    DatasetBase** out_dataset) {
396   if (!(tensor.dtype() == DT_VARIANT &&
397         TensorShapeUtils::IsScalar(tensor.shape()))) {
398     return errors::InvalidArgument(
399         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
400   }
401   const Variant& variant = tensor.scalar<Variant>()();
402   const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
403   if (wrapper == nullptr) {
404     return errors::InvalidArgument("Tensor must be a Dataset object.");
405   }
406   *out_dataset = wrapper->get();
407   if (*out_dataset == nullptr) {
408     return errors::Internal("Read uninitialized Dataset variant.");
409   }
410   return Status::OK();
411 }
412 
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)413 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
414   if (!(tensor->dtype() == DT_VARIANT &&
415         TensorShapeUtils::IsScalar(tensor->shape()))) {
416     return errors::InvalidArgument(
417         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
418   }
419   tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
420   return Status::OK();
421 }
422 
MakeIterator(IteratorContext * ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator) const423 Status DatasetBase::MakeIterator(
424     IteratorContext* ctx, const IteratorBase* parent,
425     const string& output_prefix,
426     std::unique_ptr<IteratorBase>* iterator) const {
427   *iterator = MakeIteratorInternal(output_prefix);
428   Status s = (*iterator)->InitializeBase(ctx, parent);
429   if (s.ok()) {
430     s.Update((*iterator)->Initialize(ctx));
431   }
432   if (!s.ok()) {
433     // Reset the iterator to avoid returning an uninitialized iterator.
434     iterator->reset();
435   }
436   return s;
437 }
438 
MakeSplitProvider(std::unique_ptr<SplitProvider> * split_provider) const439 Status DatasetBase::MakeSplitProvider(
440     std::unique_ptr<SplitProvider>* split_provider) const {
441   std::vector<const DatasetBase*> inputs;
442   Status s = InputDatasets(&inputs);
443   if (errors::IsUnimplemented(s)) {
444     return errors::Unimplemented(
445         "Cannot create a split provider for dataset of type ", type_string(),
446         ", because the dataset implements neither `InputDatasets` nor "
447         "`MakeSplitProvider`.");
448   }
449   if (inputs.size() != 1) {
450     return errors::Unimplemented(
451         "Cannot create a split provider for dataset of type ", type_string(),
452         ", because the dataset is not unary (having arity ", inputs.size(),
453         "), and no custom implementation of `MakeSplitProvider` is defined.");
454   }
455   return inputs[0]->MakeSplitProvider(split_provider);
456 }
457 
InputDatasets(std::vector<const DatasetBase * > * inputs) const458 Status DatasetBase::InputDatasets(
459     std::vector<const DatasetBase*>* inputs) const {
460   return errors::Unimplemented("InputDatasets not implemented for ",
461                                type_string());
462 }
463 
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)464 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
465     SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
466   Status status = dataset->AsGraphDefInternal(ctx, this, output);
467   if (errors::IsUnimplemented(status) && !ctx->fail_if_unimplemented()) {
468     Tensor t(DT_VARIANT, TensorShape({}));
469     // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
470     // increment the refcount of `dataset` here to retain ownership.
471     dataset->Ref();
472     TF_RETURN_IF_ERROR(
473         StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
474     TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
475     DCHECK_NE(ctx->input_list(), nullptr);
476     ctx->input_list()->emplace_back((*output)->name(), std::move(t));
477     LOG_EVERY_N_SEC(WARNING, 30)
478         << "Input of " << dataset->DebugString()
479         << " will not be optimized because the dataset does not implement the "
480            "AsGraphDefInternal() method needed to apply optimizations.";
481     return Status::OK();
482   }
483   return status;
484 }
485 
AddDatasetOrTensor(SerializationContext * ctx,const Tensor & t,Node ** output)486 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor(
487     SerializationContext* ctx, const Tensor& t, Node** output) {
488   if (t.dtype() == DT_VARIANT) {
489     // If the input tensor is a variant, it may represent a multi-dimensional
490     // array of datasets. We attempt to decode each dataset so that we can use
491     // their custom serialization logic and combine the result of their
492     // individual serializations using the `Pack` operation.
493     //
494     // If this fails, we fallback to using its Variant::Encode() based
495     // serialization.
496     Status s = AddDatasetOrTensorHelper(ctx, t, output);
497     if (s.ok()) {
498       return s;
499     }
500   }
501   return AddTensor(t, output);
502 }
503 
AddDatasetOrTensorHelper(SerializationContext * ctx,const Tensor & t,Node ** output)504 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper(
505     SerializationContext* ctx, const Tensor& t, Node** output) {
506   if (t.dims() == 0) {
507     DatasetBase* dataset;
508     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(t, &dataset));
509     return AddInputDataset(ctx, dataset, output);
510   }
511   std::vector<NodeBuilder::NodeOut> nodes;
512   for (int i = 0; i < t.dim_size(0); ++i) {
513     Node* node;
514     TF_RETURN_IF_ERROR(AddDatasetOrTensorHelper(ctx, t.SubSlice(i), &node));
515     nodes.emplace_back(node);
516   }
517   auto op_name = "Pack";
518   auto opts = builder()->opts();
519   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
520                            opts.op_registry());
521   node_builder.Input(std::move(nodes));
522   *output = opts.FinalizeBuilder(&node_builder);
523   return Status::OK();
524 }
525 
DatasetBaseIterator(const BaseParams & params)526 DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params)
527     : params_(params) {
528   params_.dataset->Ref();
529   VLOG(2) << prefix() << " constructor";
530 }
531 
~DatasetBaseIterator()532 DatasetBaseIterator::~DatasetBaseIterator() {
533   VLOG(2) << prefix() << " destructor";
534   params_.dataset->Unref();
535 }
536 
BuildTraceMeName()537 string DatasetBaseIterator::BuildTraceMeName() {
538   string result = strings::StrCat(params_.prefix, "#id=", id_);
539   if (parent_) {
540     strings::StrAppend(&result, ",parent_id=", parent_id_);
541   }
542 
543   TraceMeMetadata metadata = GetTraceMeMetadata();
544   for (const auto& pair : metadata) {
545     strings::StrAppend(&result, ",", pair.first, "=", pair.second);
546   }
547   strings::StrAppend(&result, "#");
548   return result;
549 }
550 
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)551 Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
552                                     std::vector<Tensor>* out_tensors,
553                                     bool* end_of_sequence) {
554   profiler::TraceMe activity([&] { return BuildTraceMeName(); },
555                              profiler::TraceMeLevel::kInfo);
556   DVLOG(3) << prefix() << " GetNext enter";
557   auto model = ctx->model();
558   if (model && model->collect_resource_usage() && node_) {
559     int64 now_nanos = EnvTime::NowNanos();
560     auto output = node_->output();
561     if (output) {
562       output->record_stop(now_nanos);
563     }
564     node_->record_start(now_nanos);
565   }
566   Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
567   if (TF_PREDICT_TRUE(s.ok() && !*end_of_sequence)) {
568     DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size());
569     RecordElement(ctx, out_tensors);
570   }
571   if (model && model->collect_resource_usage() && node_) {
572     int64 now_nanos = EnvTime::NowNanos();
573     node_->record_stop(now_nanos);
574     auto output = node_->output();
575     if (output) {
576       output->record_start(now_nanos);
577     }
578   }
579   if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
580     s = errors::Internal("Iterator \"", params_.prefix,
581                          "\" returned `OutOfRange`. This indicates an "
582                          "implementation error as `OutOfRange` errors are not "
583                          "expected to be returned here. Original message: ",
584                          s.error_message());
585     LOG(ERROR) << s;
586   }
587   DVLOG(3) << prefix() << " GetNext exit";
588   return s;
589 }
590 
Skip(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)591 Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip,
592                                  bool* end_of_sequence, int* num_skipped) {
593   profiler::TraceMe activity([&] { return BuildTraceMeName(); },
594                              profiler::TraceMeLevel::kInfo);
595   DVLOG(3) << prefix() << " Skip enter";
596   auto model = ctx->model();
597   if (model && model->collect_resource_usage() && node_) {
598     int64 now_nanos = EnvTime::NowNanos();
599     auto output = node_->output();
600     if (output) {
601       output->record_stop(now_nanos);
602     }
603     node_->record_start(now_nanos);
604   }
605   Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped);
606   if (model && model->collect_resource_usage() && node_) {
607     int64 now_nanos = EnvTime::NowNanos();
608     node_->record_stop(now_nanos);
609     auto output = node_->output();
610     if (output) {
611       output->record_start(now_nanos);
612     }
613   }
614   if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
615     s = errors::Internal("Iterator \"", params_.prefix,
616                          "\" returned `OutOfRange`. This indicates an "
617                          "implementation error as `OutOfRange` errors are not "
618                          "expected to be returned here. Original message: ",
619                          s.error_message());
620     LOG(ERROR) << s;
621   }
622   DVLOG(3) << prefix() << " Skip exit";
623   return s;
624 }
625 
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)626 Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip,
627                                          bool* end_of_sequence,
628                                          int* num_skipped) {
629   *num_skipped = 0;
630   for (int i = 0; i < num_to_skip; ++i) {
631     std::vector<Tensor> out_tensors;
632     TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence));
633     if (*end_of_sequence) {
634       return Status::OK();
635     }
636     // RecordElement is used to count the number of element computed and
637     // help calculate the CPU time spent on a given iterator to do the
638     // autotuning.
639     // Here we only call RecordElement in the default implementation of
640     // SkipInternal (which trivially calls GetNextInternal) and assume
641     // that the overriden SkipInternal in the derived class will have
642     // negligible cost compare to its GetNextInternal.
643     RecordElement(ctx, &out_tensors);
644     (*num_skipped)++;
645   }
646   return Status::OK();
647 }
648 
Compute(OpKernelContext * ctx)649 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
650   DatasetBase* dataset = nullptr;
651   MakeDataset(ctx, &dataset);
652   if (ctx->status().ok()) {
653     Tensor* output = nullptr;
654     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
655     OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
656   }
657 }
658 
TraceString(const OpKernelContext & ctx,bool verbose) const659 string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
660                                     bool verbose) const {
661   return profiler::TraceMeOp(name_view(), type_string_view());
662 }
663 
664 // static
IsDatasetOp(const OpDef * op_def)665 bool DatasetOpKernel::IsDatasetOp(const OpDef* op_def) {
666   if (DatasetOpRegistry::IsRegistered(op_def->name())) {
667     return true;
668   }
669 
670   return (op_def->output_arg_size() == 1 &&
671           op_def->output_arg(0).type() == DT_VARIANT &&
672           (absl::EndsWith(op_def->name(), "Dataset") ||
673            absl::EndsWith(op_def->name(), "DatasetV2")));
674 }
675 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)676 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
677                                        DatasetBase** output) {
678   DatasetBase* input;
679   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
680   MakeDataset(ctx, input, output);
681 }
682 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)683 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
684                                         DatasetBase** output) {
685   DatasetBase* input;
686   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
687   DatasetBase* another_input;
688   OP_REQUIRES_OK(ctx,
689                  GetDatasetFromVariantTensor(ctx->input(1), &another_input));
690   MakeDataset(ctx, input, another_input, output);
691 }
692 
693 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
694 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
695     "_DATASET_GRAPH_OUTPUT_NODE";
696 
BackgroundWorker(Env * env,const char * name)697 BackgroundWorker::BackgroundWorker(Env* env, const char* name)
698     : env_(env), name_(name) {}
699 
~BackgroundWorker()700 BackgroundWorker::~BackgroundWorker() {
701   {
702     mutex_lock l(mu_);
703     cancelled_ = true;
704   }
705   cond_var_.notify_one();
706   // Block until the background thread has terminated.
707   //
708   // NOTE(mrry): We explicitly free and join the thread here because
709   // `WorkerLoop()` uses other members of this object, and so we must join
710   // the thread before destroying them.
711   thread_.reset();
712 }
713 
Schedule(std::function<void ()> work_item)714 void BackgroundWorker::Schedule(std::function<void()> work_item) {
715   {
716     mutex_lock l(mu_);
717     if (!thread_) {
718       thread_ = absl::WrapUnique(env_->StartThread(
719           {} /* thread_options */, name_, [this]() { WorkerLoop(); }));
720     }
721     work_queue_.push_back(std::move(work_item));
722   }
723   cond_var_.notify_one();
724 }
725 
WorkerLoop()726 void BackgroundWorker::WorkerLoop() {
727   tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
728   while (true) {
729     std::function<void()> work_item = nullptr;
730     {
731       mutex_lock l(mu_);
732       while (!cancelled_ && work_queue_.empty()) {
733         cond_var_.wait(l);
734       }
735       if (cancelled_) {
736         return;
737       }
738       DCHECK(!work_queue_.empty());
739       work_item = std::move(work_queue_.front());
740       work_queue_.pop_front();
741     }
742     DCHECK(work_item != nullptr);
743     work_item();
744   }
745 }
746 
747 // static
Register(const string & op_name)748 void DatasetOpRegistry::Register(const string& op_name) {
749   mutex_lock l(*get_dataset_op_registry_lock());
750   get_dataset_op_registry()->insert(op_name);
751 }
752 
753 // static
IsRegistered(const string & op_name)754 bool DatasetOpRegistry::IsRegistered(const string& op_name) {
755   mutex_lock l(*get_dataset_op_registry_lock());
756   std::unordered_set<string>* op_names = get_dataset_op_registry();
757   return op_names->find(op_name) != op_names->end();
758 }
759 
760 namespace {
761 class RunnerImpl : public Runner {
762  public:
Run(const std::function<void ()> & f)763   void Run(const std::function<void()>& f) override {
764     tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
765     f();
766 
767     // NOTE: We invoke a virtual function to prevent `f` being tail-called, and
768     // thus ensure that this function remains on the stack until after `f`
769     // returns.
770     PreventTailCall();
771   }
772 
773  private:
PreventTailCall()774   virtual void PreventTailCall() {}
775 };
776 }  // namespace
777 
778 /* static */
get()779 Runner* Runner::get() {
780   static Runner* singleton = new RunnerImpl;
781   return singleton;
782 }
783 
784 }  // namespace data
785 }  // namespace tensorflow
786