• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 
17 #include <unordered_map>
18 
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/refcount.h"
32 #include "tensorflow/core/platform/resource.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/strcat.h"
35 #include "tensorflow/core/profiler/lib/traceme.h"
36 #include "tensorflow/core/public/version.h"
37 
38 // On Windows, disable some macros that would break compile
39 #if defined(PLATFORM_WINDOWS)
40 #undef GetMessage
41 #endif
42 
43 namespace tensorflow {
44 namespace data {
45 namespace {
46 
get_dataset_op_registry_lock()47 static mutex* get_dataset_op_registry_lock() {
48   static mutex dataset_op_registry_lock(LINKER_INITIALIZED);
49   return &dataset_op_registry_lock;
50 }
51 
get_dataset_op_registry()52 static std::unordered_set<string>* get_dataset_op_registry() {
53   static std::unordered_set<string>* names = new std::unordered_set<string>;
54   return names;
55 }
56 
UniqueNodeName(const std::string & base)57 std::string UniqueNodeName(const std::string& base) {
58   static std::atomic<int64_t> counter(0);
59   return strings::StrCat(base, "/", counter.fetch_add(1));
60 }
61 
62 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
63 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
64 // and the wrapper's copy constructor and destructor take care of managing the
65 // reference count.
66 //
67 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
68 // specification. In particular, we cannot currently serialize an arbitrary
69 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
70 // implemented.
71 class DatasetVariantWrapper {
72  public:
DatasetVariantWrapper()73   DatasetVariantWrapper() : dataset_(nullptr) {}
74 
75   // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)76   explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
77 
DatasetVariantWrapper(const DatasetVariantWrapper & other)78   DatasetVariantWrapper(const DatasetVariantWrapper& other)
79       : dataset_(other.dataset_) {
80     if (dataset_) dataset_->Ref();
81   }
82 
operator =(DatasetVariantWrapper && other)83   DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
84     if (&other == this) return *this;
85     std::swap(dataset_, other.dataset_);
86     return *this;
87   }
88 
89   DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
90 
~DatasetVariantWrapper()91   ~DatasetVariantWrapper() {
92     if (dataset_) dataset_->Unref();
93   }
94 
get() const95   DatasetBase* get() const { return dataset_; }
96 
TypeName() const97   string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const98   string DebugString() const {
99     if (dataset_) {
100       return dataset_->DebugString();
101     } else {
102       return "<Uninitialized DatasetVariantWrapper>";
103     }
104   }
Encode(VariantTensorData * data) const105   void Encode(VariantTensorData* data) const {
106     LOG(ERROR) << "The Encode() method is not implemented for "
107                   "DatasetVariantWrapper objects.";
108   }
Decode(const VariantTensorData & data)109   bool Decode(const VariantTensorData& data) {
110     LOG(ERROR) << "The Decode() method is not implemented for "
111                   "DatasetVariantWrapper objects.";
112     return false;
113   }
114 
115  private:
116   DatasetBase* dataset_;  // Owns one reference.
117 };
118 
119 const char kWrappedDatasetVariantTypeName[] =
120     "tensorflow::data::WrappedDatasetVariant";
121 
122 class WrappedDatasetVariantWrapper {
123  public:
WrappedDatasetVariantWrapper()124   WrappedDatasetVariantWrapper() {}
125 
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)126   explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
127       : ds_tensor_(ds_tensor) {}
128 
get() const129   Tensor get() const { return ds_tensor_; }
130 
TypeName() const131   string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
132 
DebugString() const133   string DebugString() const {
134     return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
135   }
136 
Encode(VariantTensorData * data) const137   void Encode(VariantTensorData* data) const {
138     *(data->add_tensors()) = ds_tensor_;
139   }
140 
Decode(const VariantTensorData & data)141   bool Decode(const VariantTensorData& data) {
142     ds_tensor_ = data.tensors(0);
143     return true;
144   }
145 
146  private:
147   Tensor ds_tensor_;
148 };
149 
150 class WrapDatasetVariantOp : public OpKernel {
151  public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)152   explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
153 
Compute(OpKernelContext * ctx)154   void Compute(OpKernelContext* ctx) override {
155     const Tensor& tensor = ctx->input(0);
156     OP_REQUIRES(ctx,
157                 tensor.dtype() == DT_VARIANT &&
158                     TensorShapeUtils::IsScalar(tensor.shape()),
159                 errors::InvalidArgument(
160                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
161     DatasetBase* unused;
162     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
163     Tensor* output = nullptr;
164     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
165     output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
166   }
167 };
168 
169 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
170                         WrapDatasetVariantOp);
171 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
172                             .HostMemory("input_handle")
173                             .HostMemory("output_handle")
174                             .Device(DEVICE_GPU),
175                         WrapDatasetVariantOp);
176 
177 class UnwrapDatasetVariantOp : public OpKernel {
178  public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)179   explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
180 
Compute(OpKernelContext * ctx)181   void Compute(OpKernelContext* ctx) override {
182     const Tensor& tensor = ctx->input(0);
183     OP_REQUIRES(ctx,
184                 tensor.dtype() == DT_VARIANT &&
185                     TensorShapeUtils::IsScalar(tensor.shape()),
186                 errors::InvalidArgument(
187                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
188     Variant variant = tensor.scalar<Variant>()();
189     const WrappedDatasetVariantWrapper* wrapper =
190         variant.get<WrappedDatasetVariantWrapper>();
191     OP_REQUIRES(ctx, wrapper != nullptr,
192                 errors::InvalidArgument(
193                     "Tensor must be a WrappedDataset variant object."));
194     Tensor ds_tensor = wrapper->get();
195     OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
196   }
197 };
198 
199 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
200                         UnwrapDatasetVariantOp);
201 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
202                             .HostMemory("input_handle")
203                             .HostMemory("output_handle")
204                             .Device(DEVICE_GPU),
205                         UnwrapDatasetVariantOp);
206 
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)207 static Status WrappedDatasetVariantDeviceCopy(
208     const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
209     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
210   *to = WrappedDatasetVariantWrapper(from);
211   return OkStatus();
212 }
213 
214 #define REGISTER_OPTIONAL_COPY(DIRECTION)               \
215   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
216       WrappedDatasetVariantWrapper, DIRECTION,          \
217       WrappedDatasetVariantDeviceCopy)
218 
219 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
220 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
221 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
222 
223 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
224                                        kWrappedDatasetVariantTypeName);
225 
226 }  // namespace
227 
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)228 Status GraphDefBuilderWrapper::AddDataset(const DatasetBase* dataset,
229                                           const std::vector<Node*>& inputs,
230                                           Node** output) {
231   return AddDataset(dataset, inputs, {}, output);
232 }
233 
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)234 Status GraphDefBuilderWrapper::AddDataset(
235     const DatasetBase* dataset, const std::vector<Node*>& inputs,
236     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
237     Node** output) {
238   std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
239   for (size_t i = 0; i < inputs.size(); i++) {
240     enumerated_inputs[i] = std::make_pair(i, inputs[i]);
241   }
242   return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
243 }
244 
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)245 Status GraphDefBuilderWrapper::AddDataset(
246     const DatasetBase* dataset,
247     const std::vector<std::pair<size_t, Node*>>& inputs,
248     const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
249     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
250     Node** output) {
251   return AddDataset(dataset, inputs, list_inputs, attrs,
252                     /*use_dataset_name=*/false, output);
253 }
254 
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,bool use_dataset_name,Node ** output)255 Status GraphDefBuilderWrapper::AddDataset(
256     const DatasetBase* dataset,
257     const std::vector<std::pair<size_t, Node*>>& inputs,
258     const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
259     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
260     bool use_dataset_name, Node** output) {
261   auto& type_string = dataset->type_string();
262   auto opts = absl::make_unique<GraphDefBuilder::Options>(b_->opts());
263   // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
264   // attributes defined. It will be nice to have a consistent pattern.
265   bool has_output_types_attr = HasAttr(type_string, "output_types");
266   bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
267   if (has_output_shapes_attr) {
268     opts = absl::make_unique<GraphDefBuilder::Options>(
269         opts->WithAttr("output_shapes", dataset->output_shapes()));
270   }
271   if (has_output_types_attr) {
272     opts = absl::make_unique<GraphDefBuilder::Options>(
273         opts->WithAttr("output_types", dataset->output_dtypes()));
274   }
275   bool has_metadata_attr = HasAttr(type_string, "metadata");
276   if (has_metadata_attr) {
277     std::string serialized_metadata;
278     dataset->metadata().SerializeToString(&serialized_metadata);
279     opts = absl::make_unique<GraphDefBuilder::Options>(
280         opts->WithAttr("metadata", serialized_metadata));
281   }
282   for (const auto& attr : attrs) {
283     opts = absl::make_unique<GraphDefBuilder::Options>(
284         opts->WithAttr(attr.first, attr.second));
285   }
286   if (opts->HaveError()) {
287     return errors::Internal("AddDataset: Failed to build Options with error ",
288                             opts->StatusToString());
289   }
290   NodeBuilder node_builder(
291       use_dataset_name ? dataset->node_name() : opts->GetNameForOp(type_string),
292       type_string, opts->op_registry());
293   {
294     size_t total_size = inputs.size() + list_inputs.size();
295     auto inputs_iter = inputs.begin();
296     auto list_inputs_iter = list_inputs.begin();
297     for (int i = 0; i < total_size; i++) {
298       if (inputs_iter != inputs.end() && inputs_iter->first == i) {
299         node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
300         inputs_iter++;
301       } else if (list_inputs_iter != list_inputs.end() &&
302                  list_inputs_iter->first == i) {
303         std::vector<NodeBuilder::NodeOut> nodeout_inputs;
304         nodeout_inputs.reserve(list_inputs_iter->second.size());
305         for (Node* n : list_inputs_iter->second) {
306           nodeout_inputs.emplace_back(n);
307         }
308         node_builder.Input(nodeout_inputs);
309         list_inputs_iter++;
310       } else {
311         return errors::InvalidArgument("No input found for index ", i);
312       }
313     }
314   }
315   *output = opts->FinalizeBuilder(&node_builder);
316   if (*output == nullptr) {
317     return errors::Internal("AddDataset: Failed to build ", type_string,
318                             " op with error ", opts->StatusToString());
319   }
320   return OkStatus();
321 }
322 
AddFunction(SerializationContext * ctx,const string & function_name,const FunctionLibraryDefinition & lib_def)323 Status GraphDefBuilderWrapper::AddFunction(
324     SerializationContext* ctx, const string& function_name,
325     const FunctionLibraryDefinition& lib_def) {
326   if (b_->HasFunction(function_name)) {
327     VLOG(1) << "Function with name " << function_name << "already exists in"
328             << " the graph. It will not be added again.";
329     return OkStatus();
330   }
331   const FunctionDef* f_def = lib_def.Find(function_name);
332   if (f_def == nullptr) {
333     return errors::InvalidArgument("Unable to find FunctionDef for ",
334                                    function_name, " in the registry.");
335   }
336   FunctionDefLibrary def;
337   *def.add_function() = *f_def;
338   const string gradient_func = lib_def.FindGradient(function_name);
339   if (!gradient_func.empty()) {
340     GradientDef* g_def = def.add_gradient();
341     g_def->set_function_name(function_name);
342     g_def->set_gradient_func(gradient_func);
343   }
344   TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
345 
346   // Recursively add functions in inputs of function_name.
347   for (const NodeDef& node_def : f_def->node_def()) {
348     const OpRegistrationData* op_reg_data = nullptr;
349     TF_RETURN_IF_ERROR(lib_def.LookUp(node_def.op(), &op_reg_data));
350     if (op_reg_data->is_function_op) {
351       TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name(), lib_def));
352     }
353     // Recursively add functions in attrs of this NodeDef.
354     for (const auto& pair : node_def.attr()) {
355       TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second, lib_def));
356     }
357   }
358 
359   // Recursively add functions in attrs of function_name.
360   for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
361     TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def));
362   }
363   return OkStatus();
364 }
365 
AddPlaceholderInternal(const Tensor & val,Node ** output)366 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
367                                                     Node** output) {
368   *output = ops::SourceOp(
369       "Placeholder",
370       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
371 }
372 
AddTensorInternal(const Tensor & val,Node ** output)373 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
374                                                Node** output) {
375   *output = ops::SourceOp(
376       "Const",
377       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
378 }
379 
HasAttr(const string & name,const string & attr_name) const380 bool GraphDefBuilderWrapper::HasAttr(const string& name,
381                                      const string& attr_name) const {
382   const OpDef* op_def = nullptr;
383   Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
384   if (!s.ok() || op_def == nullptr) {
385     return false;
386   }
387   return HasAttr(op_def, attr_name);
388 }
389 
GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext * ctx)390 int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx) {
391   thread::ThreadPool* thread_pool =
392       ctx->device()->tensorflow_device_thread_pool();
393   if (thread_pool) {
394     return thread_pool->NumThreads();
395   } else {
396     static const int32_t kDefaultRunnerThreadpoolSize = port::MaxParallelism();
397     return kDefaultRunnerThreadpoolSize;
398   }
399 }
400 
InitializeBase(IteratorContext * ctx,const IteratorBase * parent)401 Status IteratorBase::InitializeBase(IteratorContext* ctx,
402                                     const IteratorBase* parent) {
403   parent_ = parent;
404   id_ =
405       Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
406   if (parent_) {
407     parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()),
408                                         reinterpret_cast<uint64>(parent_));
409   }
410   if (const auto& model = ctx->model()) {
411     auto factory = [ctx, this](model::Node::Args args) {
412       return CreateNode(ctx, std::move(args));
413     };
414     model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
415     cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
416   }
417   return OkStatus();
418 }
419 
GetAllocatedBytes(const std::vector<Tensor> & element)420 int64_t GetAllocatedBytes(const std::vector<Tensor>& element) {
421   int64_t allocated_bytes = 0;
422   DatasetBase* dataset;
423   for (auto& tensor : element) {
424     if (tensor.dtype() == DT_VARIANT &&
425         GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
426       allocated_bytes += dataset->AllocatedBytes();
427     } else {
428       allocated_bytes += tensor.AllocatedBytes();
429     }
430   }
431   return allocated_bytes;
432 }
433 
GetTotalBytes(const std::vector<Tensor> & element)434 int64_t GetTotalBytes(const std::vector<Tensor>& element) {
435   int64_t total_bytes = 0;
436   DatasetBase* dataset;
437   for (auto& tensor : element) {
438     if (tensor.dtype() == DT_VARIANT &&
439         GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
440       total_bytes += dataset->TotalBytes();
441     } else {
442       total_bytes += tensor.TotalBytes();
443     }
444   }
445   return total_bytes;
446 }
447 
FullName(const std::string & prefix,const std::string & name)448 std::string FullName(const std::string& prefix, const std::string& name) {
449   if (str_util::StrContains(name, kColon)) {
450     LOG(ERROR) << name << " should not contain " << kColon;
451   }
452 
453   return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
454 }
455 
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)456 Status GetDatasetFromVariantTensor(const Tensor& tensor,
457                                    DatasetBase** out_dataset) {
458   if (!(tensor.dtype() == DT_VARIANT &&
459         TensorShapeUtils::IsScalar(tensor.shape()))) {
460     return errors::InvalidArgument(
461         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
462   }
463   const Variant& variant = tensor.scalar<Variant>()();
464   const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
465   if (wrapper == nullptr) {
466     return errors::InvalidArgument("Tensor must be a Dataset object.");
467   }
468   *out_dataset = wrapper->get();
469   if (*out_dataset == nullptr) {
470     return errors::Internal("Read uninitialized Dataset variant.");
471   }
472   return OkStatus();
473 }
474 
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)475 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
476   if (!(tensor->dtype() == DT_VARIANT &&
477         TensorShapeUtils::IsScalar(tensor->shape()))) {
478     return errors::InvalidArgument(
479         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
480   }
481   tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
482   return OkStatus();
483 }
484 
485 namespace internal {
486 
487 #define WARN_PROTO_FIELD_CONFLICT(reflection, field, field_type, src, dst)     \
488   {                                                                            \
489     auto source_value = reflection->Get##field_type(src, field);               \
490     auto destination_value = reflection->Get##field_type(*dst, field);         \
491     if (source_value != destination_value) {                                   \
492       LOG(WARNING) << "Changing the value of option field " << field->name()   \
493                    << " from " << destination_value << " to " << source_value; \
494     }                                                                          \
495   }
496 
497 #define WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst) \
498   {                                                                 \
499     auto source_value = reflection->GetEnum(src, field);            \
500     auto destination_value = reflection->GetEnum(*dst, field);      \
501     if (source_value != destination_value) {                        \
502       LOG(WARNING) << "Changing the value of option enum field "    \
503                    << field->name() << " from "                     \
504                    << destination_value->full_name() << " to "      \
505                    << source_value->full_name();                    \
506     }                                                               \
507   }
508 
WarnProtoConflicts(const protobuf::Message & src,protobuf::Message * dst)509 void WarnProtoConflicts(const protobuf::Message& src, protobuf::Message* dst) {
510   std::vector<const protobuf::FieldDescriptor*> set_src;
511   std::vector<const protobuf::FieldDescriptor*> set_dst;
512   const protobuf::Reflection* reflection = src.GetReflection();
513   reflection->ListFields(src, &set_src);
514   reflection->ListFields(*dst, &set_dst);
515   std::sort(set_src.begin(), set_src.end());
516   std::sort(set_dst.begin(), set_dst.end());
517 
518   std::vector<const protobuf::FieldDescriptor*> in_both;
519   std::set_intersection(set_src.begin(), set_src.end(), set_dst.begin(),
520                         set_dst.end(), std::back_inserter(in_both));
521 
522   for (auto field : in_both) {
523     if (field->type() == protobuf::FieldDescriptor::TYPE_MESSAGE) {
524       WarnProtoConflicts(reflection->GetMessage(src, field),
525                          reflection->MutableMessage(dst, field));
526     } else {
527       switch (field->cpp_type()) {
528         case protobuf::FieldDescriptor::CPPTYPE_INT32:
529           WARN_PROTO_FIELD_CONFLICT(reflection, field, Int32, src, dst);
530           break;
531         case protobuf::FieldDescriptor::CPPTYPE_INT64:
532           WARN_PROTO_FIELD_CONFLICT(reflection, field, Int64, src, dst);
533           break;
534         case protobuf::FieldDescriptor::CPPTYPE_UINT32:
535           WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt32, src, dst);
536           break;
537         case protobuf::FieldDescriptor::CPPTYPE_UINT64:
538           WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt64, src, dst);
539           break;
540         case protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
541           WARN_PROTO_FIELD_CONFLICT(reflection, field, Double, src, dst);
542           break;
543         case protobuf::FieldDescriptor::CPPTYPE_FLOAT:
544           WARN_PROTO_FIELD_CONFLICT(reflection, field, Float, src, dst);
545           break;
546         case protobuf::FieldDescriptor::CPPTYPE_BOOL:
547           WARN_PROTO_FIELD_CONFLICT(reflection, field, Bool, src, dst);
548           break;
549         case protobuf::FieldDescriptor::CPPTYPE_ENUM:
550           WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst);
551           break;
552         default: {
553           LOG(ERROR) << "Unrecognized proto type for field "
554                      << field->full_name();
555         }
556       }
557     }
558   }
559 }
560 
561 #undef WARN_PROTO_ENUM_FIELD_CONFLICT
562 #undef WARN_PROTO_FIELD_CONFLICT
563 
MergeOptions(const protobuf::Message & source,protobuf::Message * destination)564 void MergeOptions(const protobuf::Message& source,
565                   protobuf::Message* destination) {
566   WarnProtoConflicts(source, destination);
567   destination->MergeFrom(source);
568 }
569 
MergeOptions(const protobuf::MessageLite & source,protobuf::MessageLite * destination)570 void MergeOptions(const protobuf::MessageLite& source,
571                   protobuf::MessageLite* destination) {
572   destination->CheckTypeAndMergeFrom(source);
573 }
574 
575 }  // namespace internal
576 
Initialize(const Metadata & metadata)577 void DatasetBase::Initialize(const Metadata& metadata) {
578   Status s = ComputeNumSources();
579   if (!s.ok()) {
580     LOG(ERROR) << s;
581   }
582   s = MergeOptionsFromInputs();
583   if (!s.ok()) {
584     LOG(ERROR) << s;
585   }
586   metadata_ = metadata;
587   if (metadata_.name() == "") {
588     static std::atomic<int64_t> id_counter(0);
589     *metadata_.mutable_name() =
590         strings::StrCat(type_string(), ":", id_counter.fetch_add(1));
591   }
592 }
593 
ComputeNumSources()594 Status DatasetBase::ComputeNumSources() {
595   std::vector<const DatasetBase*> inputs;
596   Status s = InputDatasets(&inputs);
597   if (errors::IsUnimplemented(s)) {
598     return errors::Unimplemented(
599         "Cannot compute input sources for dataset of type ", type_string(),
600         ", because the dataset does not implement `InputDatasets`.");
601   }
602   if (num_sources_ >= 0) {
603     // Already computed.
604     return OkStatus();
605   }
606   num_sources_ = 0;
607   if (inputs.empty()) {
608     num_sources_ = 1;
609     return OkStatus();
610   }
611   for (const auto& input : inputs) {
612     if (input->num_sources() < 0) {
613       return errors::FailedPrecondition(
614           "Cannot compute input sources for dataset of type ", type_string(),
615           ", because sources could not be computed for input dataset of type ",
616           input->type_string());
617     }
618     num_sources_ += input->num_sources();
619   }
620   return OkStatus();
621 }
622 
CheckRandomAccessCompatible(const int64 index) const623 Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const {
624   CardinalityOptions options;
625   options.set_compute_level(CardinalityOptions::CARDINALITY_COMPUTE_MODERATE);
626   int64 cardinality = Cardinality(options);
627   if (cardinality == kInfiniteCardinality ||
628       cardinality == kUnknownCardinality) {
629     return tensorflow::errors::FailedPrecondition(
630         "Dataset of type ", this->DebugString(), " has ",
631         cardinality == kInfiniteCardinality ? "infinite" : "unknown",
632         " cardinality, which does not support random access.");
633   }
634   if (index < 0 || index >= cardinality) {
635     return errors::OutOfRange("Index out of range [0, ", cardinality,
636                               "):", index);
637   }
638   return OkStatus();
639 }
640 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const641 Status DatasetBase::Get(OpKernelContext* ctx, int64 index,
642                         std::vector<Tensor>* out_tensors) const {
643   return errors::Unimplemented(
644       "Random access is not implemented for this dataset.");
645 }
646 
Finalize(OpKernelContext * ctx,std::function<StatusOr<core::RefCountPtr<DatasetBase>> ()> make_finalized_dataset) const647 StatusOr<DatasetBase*> DatasetBase::Finalize(
648     OpKernelContext* ctx,
649     std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
650         make_finalized_dataset) const {
651   mutex_lock l(mu_);
652   if (!finalized_dataset_) {
653     TF_ASSIGN_OR_RETURN(finalized_dataset_, make_finalized_dataset());
654   }
655   return finalized_dataset_.get();
656 }
657 
MergeOptionsFromInputs()658 Status DatasetBase::MergeOptionsFromInputs() {
659   std::vector<const DatasetBase*> inputs;
660   Status s = InputDatasets(&inputs);
661   if (errors::IsUnimplemented(s)) {
662     return errors::Unimplemented(
663         "Cannot merge options for dataset of type ", type_string(),
664         ", because the dataset does not implement `InputDatasets`.");
665   }
666   if (inputs.empty()) {
667     return OkStatus();
668   }
669   // Merge options from inputs sequentially before merging options from dataset.
670   // Since the last options merged takes precedence, the options that may be set
671   // for the current dataset through OptionsDataset takes precedence over those
672   // set on the input datasets.
673   Options merged_options = inputs[0]->options_;
674   for (int i = 1; i < inputs.size(); ++i) {
675     internal::MergeOptions(inputs[i]->options_, &merged_options);
676   }
677   internal::MergeOptions(options_, &merged_options);
678   options_ = merged_options;
679   return OkStatus();
680 }
681 
MakeIterator(IteratorContext * ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator) const682 Status DatasetBase::MakeIterator(
683     IteratorContext* ctx, const IteratorBase* parent,
684     const string& output_prefix,
685     std::unique_ptr<IteratorBase>* iterator) const {
686   if (type_string() == "OptionsDataset" || type_string() == "FinalizeDataset") {
687     std::vector<const DatasetBase*> inputs;
688     Status s = InputDatasets(&inputs);
689     return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator);
690   }
691   profiler::TraceMe traceme(
692       [&] {
693         return profiler::TraceMeEncode(
694             strings::StrCat("MakeIterator::", type_string()), {});
695       },
696       profiler::TraceMeLevel::kInfo);
697   *iterator = MakeIteratorInternal(output_prefix);
698   Status s = (*iterator)->InitializeBase(ctx, parent);
699   if (s.ok()) {
700     s.Update((*iterator)->Initialize(ctx));
701   }
702   if (!s.ok()) {
703     // Reset the iterator to avoid returning an uninitialized iterator.
704     iterator->reset();
705   }
706   return s;
707 }
708 
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const709 Status DatasetBase::MakeSplitProviders(
710     std::vector<std::unique_ptr<SplitProvider>>* split_providers) const {
711   std::vector<const DatasetBase*> inputs;
712   Status s = InputDatasets(&inputs);
713   if (errors::IsUnimplemented(s)) {
714     return errors::Unimplemented(
715         "Cannot create split providers for dataset of type ", type_string(),
716         ", because the dataset implements neither `InputDatasets` nor "
717         "`MakeSplitProvider`.");
718   }
719   if (inputs.size() != 1) {
720     return errors::Unimplemented(
721         "Cannot create split providers for dataset of type ", type_string(),
722         ", because the dataset is not unary (instead having arity ",
723         inputs.size(),
724         "), and no custom implementation of `MakeSplitProvider` is defined.");
725   }
726   return inputs[0]->MakeSplitProviders(split_providers);
727 }
728 
Cardinality() const729 int64_t DatasetBase::Cardinality() const {
730   mutex_lock l(cardinality_mu_);
731   if (cardinality_ == kUnknownCardinality) {
732     cardinality_ = CardinalityInternal();
733   }
734   return cardinality_;
735 }
736 
Cardinality(CardinalityOptions options) const737 int64_t DatasetBase::Cardinality(CardinalityOptions options) const {
738   mutex_lock l(cardinality_mu_);
739   if (cardinality_ == kUnknownCardinality) {
740     cardinality_ = CardinalityInternal(options);
741   }
742   return cardinality_;
743 }
744 
InputDatasets(std::vector<const DatasetBase * > * inputs) const745 Status DatasetBase::InputDatasets(
746     std::vector<const DatasetBase*>* inputs) const {
747   return errors::Unimplemented("InputDatasets not implemented for ",
748                                type_string());
749 }
750 
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)751 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
752     SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
753   Status status = dataset->AsGraphDefInternal(ctx, this, output);
754   if (ctx->is_graph_rewrite()) {
755     if (status.ok()) {
756       // Record cardinality in an unregistered attributes so that rewrites have
757       // this information.
758       (*output)->AddAttr(kCardinalityAttrForRewrite, dataset->Cardinality());
759     } else if (errors::IsUnimplemented(status)) {
760       Tensor t(DT_VARIANT, TensorShape({}));
761       // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
762       // increment the refcount of `dataset` here to retain ownership.
763       dataset->Ref();
764       TF_RETURN_IF_ERROR(
765           StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
766       TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
767       DCHECK_NE(ctx->input_list(), nullptr);
768       ctx->input_list()->emplace_back((*output)->name(), std::move(t));
769       LOG_EVERY_N_SEC(WARNING, 30)
770           << "Input of " << dataset->DebugString()
771           << " will not be optimized because the dataset does not implement "
772              "the "
773              "AsGraphDefInternal() method needed to apply optimizations.";
774       return OkStatus();
775     }
776   }
777   return status;
778 }
779 
AddDatasetOrTensor(SerializationContext * ctx,const Tensor & t,Node ** output)780 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor(
781     SerializationContext* ctx, const Tensor& t, Node** output) {
782   if (t.dtype() == DT_VARIANT) {
783     // If the input tensor is a variant, it may represent a multi-dimensional
784     // array of datasets. We attempt to decode each dataset so that we can use
785     // their custom serialization logic and combine the result of their
786     // individual serializations using the `Pack` operation.
787     //
788     // If this fails, we fallback to using its Variant::Encode() based
789     // serialization.
790     Status s = AddDatasetOrTensorHelper(ctx, t, output);
791     if (s.ok()) {
792       return s;
793     }
794   }
795   if (t.dtype() == DT_RESOURCE && !ctx->is_graph_rewrite()) {
796     Status s = AddResourceHelper(ctx, t, output);
797     if (!errors::IsUnimplemented(s)) {
798       // Fall through to AddTensor if AsGraphDef is not implemented for this
799       // resource.
800       return s;
801     }
802   }
803   return AddTensor(t, output);
804 }
805 
AddIdentity(SerializationContext * ctx,const std::string & name_prefix,Node ** input,Node ** output)806 Status DatasetBase::DatasetGraphDefBuilder::AddIdentity(
807     SerializationContext* ctx, const std::string& name_prefix, Node** input,
808     Node** output) {
809   *output =
810       ops::UnaryOp("Identity", *input,
811                    builder()->opts().WithName(UniqueNodeName(name_prefix)));
812   return OkStatus();
813 }
814 
AddDatasetOrTensorHelper(SerializationContext * ctx,const Tensor & t,Node ** output)815 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper(
816     SerializationContext* ctx, const Tensor& t, Node** output) {
817   if (t.dims() == 0) {
818     DatasetBase* dataset;
819     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(t, &dataset));
820     return AddInputDataset(ctx, dataset, output);
821   }
822   std::vector<NodeBuilder::NodeOut> nodes;
823   for (int i = 0; i < t.dim_size(0); ++i) {
824     Node* node;
825     TF_RETURN_IF_ERROR(AddDatasetOrTensorHelper(ctx, t.SubSlice(i), &node));
826     nodes.emplace_back(node);
827   }
828   auto op_name = "Pack";
829   auto opts = builder()->opts();
830   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
831                            opts.op_registry());
832   node_builder.Input(std::move(nodes));
833   *output = opts.FinalizeBuilder(&node_builder);
834   return OkStatus();
835 }
836 
AddResourceHelper(SerializationContext * ctx,const Tensor & t,Node ** output)837 Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper(
838     SerializationContext* ctx, const Tensor& t, Node** output) {
839   const ResourceHandle& handle = t.flat<ResourceHandle>()(0);
840   if (ctx->device_name() != handle.device()) {
841     return errors::InvalidArgument("Trying to access resource ", handle.name(),
842                                    " located in device ", handle.device(),
843                                    " from device ", ctx->device_name());
844   }
845   ResourceBase* resource;
846   TF_RETURN_IF_ERROR(ctx->resource_mgr()->Lookup(handle, &resource));
847   core::ScopedUnref unref(resource);
848   return resource->AsGraphDef(builder(), output);
849 }
850 
DatasetBaseIterator(const BaseParams & params)851 DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params)
852     : params_(params) {
853   params_.dataset->Ref();
854   VLOG(2) << prefix() << " constructor";
855   strings::StrAppend(&traceme_metadata_, "name=", dataset()->metadata().name());
856   strings::StrAppend(&traceme_metadata_, ",shapes=");
857   auto& shapes = output_shapes();
858   for (int i = 0; i < shapes.size(); ++i) {
859     if (i > 0) {
860       strings::StrAppend(&traceme_metadata_, " ");
861     }
862     strings::StrAppend(&traceme_metadata_, shapes.at(i).DebugString());
863   }
864   strings::StrAppend(&traceme_metadata_, ",types=");
865   auto& types = output_dtypes();
866   for (int i = 0; i < types.size(); ++i) {
867     if (i > 0) {
868       strings::StrAppend(&traceme_metadata_, " ");
869     }
870     strings::StrAppend(&traceme_metadata_, DataTypeString(types.at(i)));
871   }
872 }
873 
~DatasetBaseIterator()874 DatasetBaseIterator::~DatasetBaseIterator() {
875   VLOG(2) << prefix() << " destructor";
876   params_.dataset->Unref();
877 }
878 
BuildTraceMeName()879 string DatasetBaseIterator::BuildTraceMeName() {
880   string result =
881       strings::StrCat(params_.prefix, "#", traceme_metadata_, ",id=", id_);
882   if (parent_) {
883     strings::StrAppend(&result, ",parent_id=", parent_id_);
884   }
885   TraceMeMetadata metadata = GetTraceMeMetadata();
886   for (const auto& pair : metadata) {
887     strings::StrAppend(&result, ",", pair.first, "=", pair.second);
888   }
889   strings::StrAppend(&result, "#");
890   return result;
891 }
892 
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)893 Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
894                                     std::vector<Tensor>* out_tensors,
895                                     bool* end_of_sequence) {
896   profiler::TraceMe activity([&] { return BuildTraceMeName(); },
897                              profiler::TraceMeLevel::kInfo);
898   DVLOG(3) << prefix() << " GetNext enter";
899   auto model = ctx->model();
900   if (collect_resource_usage(ctx)) {
901     int64_t now_nanos = EnvTime::NowNanos();
902     auto output = node_->output();
903     if (output) {
904       output->record_stop(now_nanos);
905     }
906     node_->record_start(now_nanos);
907   }
908   out_tensors->clear();
909   Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
910   if (TF_PREDICT_TRUE(s.ok())) {
911     if (TF_PREDICT_TRUE(!*end_of_sequence)) {
912       DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size());
913       RecordElement(ctx, out_tensors);
914     } else {
915       out_tensors->clear();
916     }
917   }
918   if (collect_resource_usage(ctx)) {
919     int64_t now_nanos = EnvTime::NowNanos();
920     node_->record_stop(now_nanos);
921     auto output = node_->output();
922     if (output) {
923       output->record_start(now_nanos);
924     }
925   }
926   if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
927     s = errors::Internal("Iterator \"", params_.prefix,
928                          "\" returned `OutOfRange`. This indicates an "
929                          "implementation error as `OutOfRange` errors are not "
930                          "expected to be returned here. Original message: ",
931                          s.error_message());
932     LOG(ERROR) << s;
933   }
934   DVLOG(3) << prefix() << " GetNext exit";
935   return s;
936 }
937 
Skip(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)938 Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip,
939                                  bool* end_of_sequence, int* num_skipped) {
940   profiler::TraceMe activity([&] { return BuildTraceMeName(); },
941                              profiler::TraceMeLevel::kInfo);
942   DVLOG(3) << prefix() << " Skip enter";
943   auto model = ctx->model();
944   if (collect_resource_usage(ctx)) {
945     int64_t now_nanos = EnvTime::NowNanos();
946     auto output = node_->output();
947     if (output) {
948       output->record_stop(now_nanos);
949     }
950     node_->record_start(now_nanos);
951   }
952   Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped);
953   if (collect_resource_usage(ctx)) {
954     int64_t now_nanos = EnvTime::NowNanos();
955     node_->record_stop(now_nanos);
956     auto output = node_->output();
957     if (output) {
958       output->record_start(now_nanos);
959     }
960   }
961   if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
962     s = errors::Internal("Iterator \"", params_.prefix,
963                          "\" returned `OutOfRange`. This indicates an "
964                          "implementation error as `OutOfRange` errors are not "
965                          "expected to be returned here. Original message: ",
966                          s.error_message());
967     LOG(ERROR) << s;
968   }
969   DVLOG(3) << prefix() << " Skip exit";
970   return s;
971 }
972 
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)973 Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip,
974                                          bool* end_of_sequence,
975                                          int* num_skipped) {
976   *num_skipped = 0;
977   for (int i = 0; i < num_to_skip; ++i) {
978     std::vector<Tensor> out_tensors;
979     TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence));
980     if (*end_of_sequence) {
981       return OkStatus();
982     }
983     // RecordElement is used to count the number of element computed and
984     // help calculate the CPU time spent on a given iterator to do the
985     // autotuning.
986     // Here we only call RecordElement in the default implementation of
987     // SkipInternal (which trivially calls GetNextInternal) and assume
988     // that the overridden SkipInternal in the derived class will have
989     // negligible cost compare to its GetNextInternal.
990     RecordElement(ctx, &out_tensors);
991     (*num_skipped)++;
992   }
993   return OkStatus();
994 }
995 
Compute(OpKernelContext * ctx)996 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
997   DatasetBase* dataset = nullptr;
998   MakeDataset(ctx, &dataset);
999   if (ctx->status().ok()) {
1000     Tensor* output = nullptr;
1001     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
1002     OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
1003     dataset->Initialize(metadata_);
1004   }
1005 }
1006 
TraceString(const OpKernelContext & ctx,bool verbose) const1007 string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
1008                                     bool verbose) const {
1009   return profiler::TraceMeOp(name_view(), type_string_view());
1010 }
1011 
1012 // static
IsDatasetOp(const OpDef & op_def)1013 bool DatasetOpKernel::IsDatasetOp(const OpDef& op_def) {
1014   if (op_def.output_arg_size() != 1) return false;
1015   if (op_def.output_arg(0).type() != DT_VARIANT) return false;
1016   absl::string_view op_name = op_def.name();
1017   if (op_name == "DatasetFromGraph") return true;
1018   if (absl::EndsWith(op_name, "Dataset")) return true;
1019   // Check if the suffix matches "DatasetV[0-9]+".
1020   size_t index = op_name.length() - 1;
1021   while (index >= 0 && isdigit(op_name[index])) {
1022     index--;
1023   }
1024   constexpr absl::string_view kDatasetPrefix = "DatasetV";
1025   constexpr absl::string_view::size_type kPrefixLength = kDatasetPrefix.size();
1026   if (index < kPrefixLength - 1 || index == op_name.length() - 1) return false;
1027   return op_name.substr(index - kPrefixLength + 1, kPrefixLength) ==
1028          kDatasetPrefix;
1029 }
1030 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)1031 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
1032                                        DatasetBase** output) {
1033   DatasetBase* input;
1034   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
1035   MakeDataset(ctx, input, output);
1036 }
1037 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)1038 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
1039                                         DatasetBase** output) {
1040   DatasetBase* input;
1041   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
1042   DatasetBase* another_input;
1043   OP_REQUIRES_OK(ctx,
1044                  GetDatasetFromVariantTensor(ctx->input(1), &another_input));
1045   MakeDataset(ctx, input, another_input, output);
1046 }
1047 
1048 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
1049 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
1050     "_DATASET_GRAPH_OUTPUT_NODE";
1051 
BackgroundWorker(Env * env,const char * name)1052 BackgroundWorker::BackgroundWorker(Env* env, const char* name)
1053     : env_(env), name_(name) {}
1054 
~BackgroundWorker()1055 BackgroundWorker::~BackgroundWorker() {
1056   {
1057     mutex_lock l(mu_);
1058     cancelled_ = true;
1059   }
1060   cond_var_.notify_one();
1061   // Block until the background thread has terminated.
1062   //
1063   // NOTE(mrry): We explicitly free and join the thread here because
1064   // `WorkerLoop()` uses other members of this object, and so we must join
1065   // the thread before destroying them.
1066   thread_.reset();
1067 }
1068 
Schedule(std::function<void ()> work_item)1069 void BackgroundWorker::Schedule(std::function<void()> work_item) {
1070   {
1071     mutex_lock l(mu_);
1072     if (!thread_) {
1073       thread_ = absl::WrapUnique(env_->StartThread(
1074           {} /* thread_options */, name_, [this]() { WorkerLoop(); }));
1075     }
1076     work_queue_.push_back(std::move(work_item));
1077   }
1078   cond_var_.notify_one();
1079 }
1080 
WorkerLoop()1081 void BackgroundWorker::WorkerLoop() {
1082   tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
1083   while (true) {
1084     std::function<void()> work_item = nullptr;
1085     {
1086       mutex_lock l(mu_);
1087       while (!cancelled_ && work_queue_.empty()) {
1088         cond_var_.wait(l);
1089       }
1090       if (cancelled_) {
1091         return;
1092       }
1093       DCHECK(!work_queue_.empty());
1094       work_item = std::move(work_queue_.front());
1095       work_queue_.pop_front();
1096     }
1097     DCHECK(work_item != nullptr);
1098     work_item();
1099   }
1100 }
1101 
1102 namespace {
1103 class RunnerImpl : public Runner {
1104  public:
Run(const std::function<void ()> & f)1105   void Run(const std::function<void()>& f) override {
1106     tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
1107     f();
1108 
1109     // NOTE: We invoke a virtual function to prevent `f` being tail-called, and
1110     // thus ensure that this function remains on the stack until after `f`
1111     // returns.
1112     PreventTailCall();
1113   }
1114 
1115  private:
PreventTailCall()1116   virtual void PreventTailCall() {}
1117 };
1118 }  // namespace
1119 
1120 /* static */
get()1121 Runner* Runner::get() {
1122   static Runner* singleton = new RunnerImpl;
1123   return singleton;
1124 }
1125 
1126 }  // namespace data
1127 }  // namespace tensorflow
1128