• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 
17 #include <unordered_map>
18 
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/variant_encode_decode.h"
23 #include "tensorflow/core/framework/variant_op_registry.h"
24 #include "tensorflow/core/graph/graph_def_builder.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/resource.h"
30 #include "tensorflow/core/platform/strcat.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32 
33 // On Windows, disable some macros that would break compile
34 #if defined(PLATFORM_WINDOWS)
35 #undef GetMessage
36 #endif
37 
38 namespace tensorflow {
39 namespace data {
40 namespace {
41 
get_dataset_op_registry_lock()42 static mutex* get_dataset_op_registry_lock() {
43   static mutex dataset_op_registry_lock(LINKER_INITIALIZED);
44   return &dataset_op_registry_lock;
45 }
46 
get_dataset_op_registry()47 static std::unordered_set<string>* get_dataset_op_registry() {
48   static std::unordered_set<string>* names = new std::unordered_set<string>;
49   return names;
50 }
51 
52 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
53 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
54 // and the wrapper's copy constructor and destructor take care of managing the
55 // reference count.
56 //
57 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
58 // specification. In particular, we cannot currently serialize an arbitrary
59 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
60 // implemented.
61 class DatasetVariantWrapper {
62  public:
DatasetVariantWrapper()63   DatasetVariantWrapper() : dataset_(nullptr) {}
64 
65   // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)66   explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
67 
DatasetVariantWrapper(const DatasetVariantWrapper & other)68   DatasetVariantWrapper(const DatasetVariantWrapper& other)
69       : dataset_(other.dataset_) {
70     if (dataset_) dataset_->Ref();
71   }
72 
operator =(DatasetVariantWrapper && other)73   DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
74     if (&other == this) return *this;
75     std::swap(dataset_, other.dataset_);
76     return *this;
77   }
78 
79   DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
80 
~DatasetVariantWrapper()81   ~DatasetVariantWrapper() {
82     if (dataset_) dataset_->Unref();
83   }
84 
get() const85   DatasetBase* get() const { return dataset_; }
86 
TypeName() const87   string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const88   string DebugString() const {
89     if (dataset_) {
90       return dataset_->DebugString();
91     } else {
92       return "<Uninitialized DatasetVariantWrapper>";
93     }
94   }
Encode(VariantTensorData * data) const95   void Encode(VariantTensorData* data) const {
96     LOG(ERROR) << "The Encode() method is not implemented for "
97                   "DatasetVariantWrapper objects.";
98   }
Decode(const VariantTensorData & data)99   bool Decode(const VariantTensorData& data) {
100     LOG(ERROR) << "The Decode() method is not implemented for "
101                   "DatasetVariantWrapper objects.";
102     return false;
103   }
104 
105  private:
106   DatasetBase* dataset_;  // Owns one reference.
107 };
108 
109 const char kWrappedDatasetVariantTypeName[] =
110     "tensorflow::data::WrappedDatasetVariant";
111 
112 class WrappedDatasetVariantWrapper {
113  public:
WrappedDatasetVariantWrapper()114   WrappedDatasetVariantWrapper() {}
115 
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)116   explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
117       : ds_tensor_(ds_tensor) {}
118 
get() const119   Tensor get() const { return ds_tensor_; }
120 
TypeName() const121   string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
122 
DebugString() const123   string DebugString() const {
124     return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
125   }
126 
Encode(VariantTensorData * data) const127   void Encode(VariantTensorData* data) const {
128     *(data->add_tensors()) = ds_tensor_;
129   }
130 
Decode(const VariantTensorData & data)131   bool Decode(const VariantTensorData& data) {
132     ds_tensor_ = data.tensors(0);
133     return true;
134   }
135 
136  private:
137   Tensor ds_tensor_;
138 };
139 
140 class WrapDatasetVariantOp : public OpKernel {
141  public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)142   explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
143 
Compute(OpKernelContext * ctx)144   void Compute(OpKernelContext* ctx) override {
145     const Tensor& tensor = ctx->input(0);
146     OP_REQUIRES(ctx,
147                 tensor.dtype() == DT_VARIANT &&
148                     TensorShapeUtils::IsScalar(tensor.shape()),
149                 errors::InvalidArgument(
150                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
151     DatasetBase* unused;
152     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
153     Tensor* output = nullptr;
154     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
155     output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
156   }
157 };
158 
159 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
160                         WrapDatasetVariantOp);
161 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
162                             .HostMemory("input_handle")
163                             .HostMemory("output_handle")
164                             .Device(DEVICE_GPU),
165                         WrapDatasetVariantOp);
166 
167 class UnwrapDatasetVariantOp : public OpKernel {
168  public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)169   explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
170 
Compute(OpKernelContext * ctx)171   void Compute(OpKernelContext* ctx) override {
172     const Tensor& tensor = ctx->input(0);
173     OP_REQUIRES(ctx,
174                 tensor.dtype() == DT_VARIANT &&
175                     TensorShapeUtils::IsScalar(tensor.shape()),
176                 errors::InvalidArgument(
177                     "Dataset tensor must be a scalar of dtype DT_VARIANT."));
178     Variant variant = tensor.scalar<Variant>()();
179     const WrappedDatasetVariantWrapper* wrapper =
180         variant.get<WrappedDatasetVariantWrapper>();
181     OP_REQUIRES(ctx, wrapper != nullptr,
182                 errors::InvalidArgument(
183                     "Tensor must be a WrappedDataset variant object."));
184     Tensor ds_tensor = wrapper->get();
185     OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
186   }
187 };
188 
189 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
190                         UnwrapDatasetVariantOp);
191 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
192                             .HostMemory("input_handle")
193                             .HostMemory("output_handle")
194                             .Device(DEVICE_GPU),
195                         UnwrapDatasetVariantOp);
196 
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)197 static Status WrappedDatasetVariantDeviceCopy(
198     const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
199     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
200   *to = WrappedDatasetVariantWrapper(from);
201   return Status::OK();
202 }
203 
204 #define REGISTER_OPTIONAL_COPY(DIRECTION)               \
205   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
206       WrappedDatasetVariantWrapper, DIRECTION,          \
207       WrappedDatasetVariantDeviceCopy)
208 
209 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
210 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
211 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
212 
213 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
214                                        kWrappedDatasetVariantTypeName);
215 
216 }  // namespace
217 
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)218 Status GraphDefBuilderWrapper::AddDataset(const DatasetBase* dataset,
219                                           const std::vector<Node*>& inputs,
220                                           Node** output) {
221   return AddDataset(dataset, inputs, {}, output);
222 }
223 
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)224 Status GraphDefBuilderWrapper::AddDataset(
225     const DatasetBase* dataset, const std::vector<Node*>& inputs,
226     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
227     Node** output) {
228   std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
229   for (size_t i = 0; i < inputs.size(); i++) {
230     enumerated_inputs[i] = std::make_pair(i, inputs[i]);
231   }
232   return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
233 }
234 
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)235 Status GraphDefBuilderWrapper::AddDataset(
236     const DatasetBase* dataset,
237     const std::vector<std::pair<size_t, Node*>>& inputs,
238     const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
239     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
240     Node** output) {
241   return AddDataset(dataset, inputs, list_inputs, attrs,
242                     /*use_dataset_name=*/false, output);
243 }
244 
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,bool use_dataset_name,Node ** output)245 Status GraphDefBuilderWrapper::AddDataset(
246     const DatasetBase* dataset,
247     const std::vector<std::pair<size_t, Node*>>& inputs,
248     const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
249     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
250     bool use_dataset_name, Node** output) {
251   auto& type_string = dataset->type_string();
252   auto opts = absl::make_unique<GraphDefBuilder::Options>(b_->opts());
253   // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
254   // attributes defined. It will be nice to have a consistent pattern.
255   bool has_output_types_attr = HasAttr(type_string, "output_types");
256   bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
257   if (has_output_shapes_attr) {
258     opts = absl::make_unique<GraphDefBuilder::Options>(
259         opts->WithAttr("output_shapes", dataset->output_shapes()));
260   }
261   if (has_output_types_attr) {
262     opts = absl::make_unique<GraphDefBuilder::Options>(
263         opts->WithAttr("output_types", dataset->output_dtypes()));
264   }
265   for (const auto& attr : attrs) {
266     opts = absl::make_unique<GraphDefBuilder::Options>(
267         opts->WithAttr(attr.first, attr.second));
268   }
269   if (opts->HaveError()) {
270     return errors::Internal("AddDataset: Failed to build Options with error ",
271                             opts->StatusToString());
272   }
273   NodeBuilder node_builder(
274       use_dataset_name ? dataset->node_name() : opts->GetNameForOp(type_string),
275       type_string, opts->op_registry());
276   {
277     size_t total_size = inputs.size() + list_inputs.size();
278     auto inputs_iter = inputs.begin();
279     auto list_inputs_iter = list_inputs.begin();
280     for (int i = 0; i < total_size; i++) {
281       if (inputs_iter != inputs.end() && inputs_iter->first == i) {
282         node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
283         inputs_iter++;
284       } else if (list_inputs_iter != list_inputs.end() &&
285                  list_inputs_iter->first == i) {
286         std::vector<NodeBuilder::NodeOut> nodeout_inputs;
287         nodeout_inputs.reserve(list_inputs_iter->second.size());
288         for (Node* n : list_inputs_iter->second) {
289           nodeout_inputs.emplace_back(n);
290         }
291         node_builder.Input(nodeout_inputs);
292         list_inputs_iter++;
293       } else {
294         return errors::InvalidArgument("No input found for index ", i);
295       }
296     }
297   }
298   *output = opts->FinalizeBuilder(&node_builder);
299   if (*output == nullptr) {
300     return errors::Internal("AddDataset: Failed to build ", type_string,
301                             " op with error ", opts->StatusToString());
302   }
303   return Status::OK();
304 }
305 
AddFunction(SerializationContext * ctx,const string & function_name,const FunctionLibraryDefinition & lib_def)306 Status GraphDefBuilderWrapper::AddFunction(
307     SerializationContext* ctx, const string& function_name,
308     const FunctionLibraryDefinition& lib_def) {
309   if (b_->HasFunction(function_name)) {
310     VLOG(1) << "Function with name " << function_name << "already exists in"
311             << " the graph. It will not be added again.";
312     return Status::OK();
313   }
314   const FunctionDef* f_def = lib_def.Find(function_name);
315   if (f_def == nullptr) {
316     return errors::InvalidArgument("Unable to find FunctionDef for ",
317                                    function_name, " in the registry.");
318   }
319   FunctionDefLibrary def;
320   *def.add_function() = *f_def;
321   const string gradient_func = lib_def.FindGradient(function_name);
322   if (!gradient_func.empty()) {
323     GradientDef* g_def = def.add_gradient();
324     g_def->set_function_name(function_name);
325     g_def->set_gradient_func(gradient_func);
326   }
327   TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
328 
329   // Recursively add functions in inputs of function_name.
330   for (const NodeDef& node_def : f_def->node_def()) {
331     const OpRegistrationData* op_reg_data = nullptr;
332     TF_RETURN_IF_ERROR(lib_def.LookUp(node_def.op(), &op_reg_data));
333     if (op_reg_data->is_function_op) {
334       TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name(), lib_def));
335     }
336     // Recursively add functions in attrs of this NodeDef.
337     for (const auto& pair : node_def.attr()) {
338       TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second, lib_def));
339     }
340   }
341 
342   // Recursively add functions in attrs of function_name.
343   for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
344     TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def));
345   }
346   return Status::OK();
347 }
348 
AddPlaceholderInternal(const Tensor & val,Node ** output)349 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
350                                                     Node** output) {
351   *output = ops::SourceOp(
352       "Placeholder",
353       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
354 }
355 
AddTensorInternal(const Tensor & val,Node ** output)356 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
357                                                Node** output) {
358   *output = ops::SourceOp(
359       "Const",
360       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
361 }
362 
HasAttr(const string & name,const string & attr_name) const363 bool GraphDefBuilderWrapper::HasAttr(const string& name,
364                                      const string& attr_name) const {
365   const OpDef* op_def = nullptr;
366   Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
367   if (!s.ok() || op_def == nullptr) {
368     return false;
369   }
370   return HasAttr(op_def, attr_name);
371 }
372 
InitializeBase(IteratorContext * ctx,const IteratorBase * parent)373 Status IteratorBase::InitializeBase(IteratorContext* ctx,
374                                     const IteratorBase* parent) {
375   parent_ = parent;
376   id_ =
377       Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
378   if (parent_) {
379     parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()),
380                                         reinterpret_cast<uint64>(parent_));
381   }
382   if (const auto& model = ctx->model()) {
383     auto factory = [ctx, this](model::Node::Args args) {
384       return CreateNode(ctx, std::move(args));
385     };
386     model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
387     cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
388   }
389   return Status::OK();
390 }
391 
GetAllocatedBytes(const std::vector<Tensor> & element)392 int64 GetAllocatedBytes(const std::vector<Tensor>& element) {
393   int64_t allocated_bytes = 0;
394   DatasetBase* dataset;
395   for (auto& tensor : element) {
396     if (tensor.dtype() == DT_VARIANT &&
397         GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
398       allocated_bytes += dataset->AllocatedBytes();
399     } else {
400       allocated_bytes += tensor.AllocatedBytes();
401     }
402   }
403   return allocated_bytes;
404 }
405 
GetTotalBytes(const std::vector<Tensor> & element)406 int64 GetTotalBytes(const std::vector<Tensor>& element) {
407   int64_t total_bytes = 0;
408   DatasetBase* dataset;
409   for (auto& tensor : element) {
410     if (tensor.dtype() == DT_VARIANT &&
411         GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
412       total_bytes += dataset->TotalBytes();
413     } else {
414       total_bytes += tensor.TotalBytes();
415     }
416   }
417   return total_bytes;
418 }
419 
FullName(const std::string & prefix,const std::string & name)420 std::string FullName(const std::string& prefix, const std::string& name) {
421   if (str_util::StrContains(name, kColon)) {
422     LOG(ERROR) << name << " should not contain " << kColon;
423   }
424 
425   return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
426 }
427 
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)428 Status GetDatasetFromVariantTensor(const Tensor& tensor,
429                                    DatasetBase** out_dataset) {
430   if (!(tensor.dtype() == DT_VARIANT &&
431         TensorShapeUtils::IsScalar(tensor.shape()))) {
432     return errors::InvalidArgument(
433         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
434   }
435   const Variant& variant = tensor.scalar<Variant>()();
436   const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
437   if (wrapper == nullptr) {
438     return errors::InvalidArgument("Tensor must be a Dataset object.");
439   }
440   *out_dataset = wrapper->get();
441   if (*out_dataset == nullptr) {
442     return errors::Internal("Read uninitialized Dataset variant.");
443   }
444   return Status::OK();
445 }
446 
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)447 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
448   if (!(tensor->dtype() == DT_VARIANT &&
449         TensorShapeUtils::IsScalar(tensor->shape()))) {
450     return errors::InvalidArgument(
451         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
452   }
453   tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
454   return Status::OK();
455 }
456 
457 namespace internal {
458 
459 #define WARN_PROTO_FIELD_CONFLICT(reflection, field, field_type, src, dst)     \
460   {                                                                            \
461     auto source_value = reflection->Get##field_type(src, field);               \
462     auto destination_value = reflection->Get##field_type(*dst, field);         \
463     if (source_value != destination_value) {                                   \
464       LOG(WARNING) << "Changing the value of option field " << field->name()   \
465                    << " from " << destination_value << " to " << source_value; \
466     }                                                                          \
467   }
468 
469 #define WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst) \
470   {                                                                 \
471     auto source_value = reflection->GetEnum(src, field);            \
472     auto destination_value = reflection->GetEnum(*dst, field);      \
473     if (source_value != destination_value) {                        \
474       LOG(WARNING) << "Changing the value of option enum field "    \
475                    << field->name() << " from "                     \
476                    << destination_value->full_name() << " to "      \
477                    << source_value->full_name();                    \
478     }                                                               \
479   }
480 
WarnProtoConflicts(const protobuf::Message & src,protobuf::Message * dst)481 void WarnProtoConflicts(const protobuf::Message& src, protobuf::Message* dst) {
482   std::vector<const protobuf::FieldDescriptor*> set_src;
483   std::vector<const protobuf::FieldDescriptor*> set_dst;
484   const protobuf::Reflection* reflection = src.GetReflection();
485   reflection->ListFields(src, &set_src);
486   reflection->ListFields(*dst, &set_dst);
487   std::sort(set_src.begin(), set_src.end());
488   std::sort(set_dst.begin(), set_dst.end());
489 
490   std::vector<const protobuf::FieldDescriptor*> in_both;
491   std::set_intersection(set_src.begin(), set_src.end(), set_dst.begin(),
492                         set_dst.end(), std::back_inserter(in_both));
493 
494   for (auto field : in_both) {
495     if (field->type() == protobuf::FieldDescriptor::TYPE_MESSAGE) {
496       WarnProtoConflicts(reflection->GetMessage(src, field),
497                          reflection->MutableMessage(dst, field));
498     } else {
499       switch (field->cpp_type()) {
500         case protobuf::FieldDescriptor::CPPTYPE_INT32:
501           WARN_PROTO_FIELD_CONFLICT(reflection, field, Int32, src, dst);
502           break;
503         case protobuf::FieldDescriptor::CPPTYPE_INT64:
504           WARN_PROTO_FIELD_CONFLICT(reflection, field, Int64, src, dst);
505           break;
506         case protobuf::FieldDescriptor::CPPTYPE_UINT32:
507           WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt32, src, dst);
508           break;
509         case protobuf::FieldDescriptor::CPPTYPE_UINT64:
510           WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt64, src, dst);
511           break;
512         case protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
513           WARN_PROTO_FIELD_CONFLICT(reflection, field, Double, src, dst);
514           break;
515         case protobuf::FieldDescriptor::CPPTYPE_FLOAT:
516           WARN_PROTO_FIELD_CONFLICT(reflection, field, Float, src, dst);
517           break;
518         case protobuf::FieldDescriptor::CPPTYPE_BOOL:
519           WARN_PROTO_FIELD_CONFLICT(reflection, field, Bool, src, dst);
520           break;
521         case protobuf::FieldDescriptor::CPPTYPE_ENUM:
522           WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst);
523           break;
524         default: {
525           LOG(ERROR) << "Unrecognized proto type for field "
526                      << field->full_name();
527         }
528       }
529     }
530   }
531 }
532 
533 #undef WARN_PROTO_ENUM_FIELD_CONFLICT
534 #undef WARN_PROTO_FIELD_CONFLICT
535 
MergeOptions(const protobuf::Message & source,protobuf::Message * destination)536 void MergeOptions(const protobuf::Message& source,
537                   protobuf::Message* destination) {
538   WarnProtoConflicts(source, destination);
539   destination->MergeFrom(source);
540 }
541 
MergeOptions(const protobuf::MessageLite & source,protobuf::MessageLite * destination)542 void MergeOptions(const protobuf::MessageLite& source,
543                   protobuf::MessageLite* destination) {
544   destination->CheckTypeAndMergeFrom(source);
545 }
546 
547 }  // namespace internal
548 
Initialize()549 void DatasetBase::Initialize() {
550   Status s = ComputeNumSources();
551   if (!s.ok()) {
552     LOG(ERROR) << s;
553   }
554   s = MergeOptionsFromInputs();
555   if (!s.ok()) {
556     LOG(ERROR) << s;
557   }
558 }
559 
ComputeNumSources()560 Status DatasetBase::ComputeNumSources() {
561   std::vector<const DatasetBase*> inputs;
562   Status s = InputDatasets(&inputs);
563   if (errors::IsUnimplemented(s)) {
564     return errors::Unimplemented(
565         "Cannot compute input sources for dataset of type ", type_string(),
566         ", because the dataset does not implement `InputDatasets`.");
567   }
568   if (num_sources_ >= 0) {
569     // Already computed.
570     return Status::OK();
571   }
572   num_sources_ = 0;
573   if (inputs.empty()) {
574     num_sources_ = 1;
575     return Status::OK();
576   }
577   for (const auto& input : inputs) {
578     if (input->num_sources() < 0) {
579       return errors::FailedPrecondition(
580           "Cannot compute input sources for dataset of type ", type_string(),
581           ", because sources could not be computed for input dataset of type ",
582           input->type_string());
583     }
584     num_sources_ += input->num_sources();
585   }
586   return Status::OK();
587 }
588 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors)589 Status DatasetBase::Get(OpKernelContext* ctx, int64 index,
590                         std::vector<Tensor>* out_tensors) {
591   return errors::Unimplemented(
592       "Random access is not implemented for this dataset.");
593 }
594 
MergeOptionsFromInputs()595 Status DatasetBase::MergeOptionsFromInputs() {
596   std::vector<const DatasetBase*> inputs;
597   Status s = InputDatasets(&inputs);
598   if (errors::IsUnimplemented(s)) {
599     return errors::Unimplemented(
600         "Cannot merge options for dataset of type ", type_string(),
601         ", because the dataset does not implement `InputDatasets`.");
602   }
603   if (inputs.empty()) {
604     return Status::OK();
605   }
606   // Merge options from inputs sequentially before merging options from dataset.
607   // Since the last options merged takes precedence, the options that may be set
608   // for the current dataset through OptionsDataset takes precedence over those
609   // set on the input datasets.
610   Options merged_options = inputs[0]->options_;
611   for (int i = 1; i < inputs.size(); ++i) {
612     internal::MergeOptions(inputs[i]->options_, &merged_options);
613   }
614   internal::MergeOptions(options_, &merged_options);
615   options_ = merged_options;
616   return Status::OK();
617 }
618 
MakeIterator(IteratorContext * ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator) const619 Status DatasetBase::MakeIterator(
620     IteratorContext* ctx, const IteratorBase* parent,
621     const string& output_prefix,
622     std::unique_ptr<IteratorBase>* iterator) const {
623   if (type_string() == "OptionsDataset" || type_string() == "FinalizeDataset") {
624     std::vector<const DatasetBase*> inputs;
625     Status s = InputDatasets(&inputs);
626     return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator);
627   }
628   profiler::TraceMe traceme(
629       [&] {
630         return profiler::TraceMeEncode(
631             strings::StrCat("MakeIterator::", type_string()), {});
632       },
633       profiler::TraceMeLevel::kInfo);
634   *iterator = MakeIteratorInternal(output_prefix);
635   Status s = (*iterator)->InitializeBase(ctx, parent);
636   if (s.ok()) {
637     s.Update((*iterator)->Initialize(ctx));
638   }
639   if (!s.ok()) {
640     // Reset the iterator to avoid returning an uninitialized iterator.
641     iterator->reset();
642   }
643   return s;
644 }
645 
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const646 Status DatasetBase::MakeSplitProviders(
647     std::vector<std::unique_ptr<SplitProvider>>* split_providers) const {
648   std::vector<const DatasetBase*> inputs;
649   Status s = InputDatasets(&inputs);
650   if (errors::IsUnimplemented(s)) {
651     return errors::Unimplemented(
652         "Cannot create split providers for dataset of type ", type_string(),
653         ", because the dataset implements neither `InputDatasets` nor "
654         "`MakeSplitProvider`.");
655   }
656   if (inputs.size() != 1) {
657     return errors::Unimplemented(
658         "Cannot create split providers for dataset of type ", type_string(),
659         ", because the dataset is not unary (instead having arity ",
660         inputs.size(),
661         "), and no custom implementation of `MakeSplitProvider` is defined.");
662   }
663   return inputs[0]->MakeSplitProviders(split_providers);
664 }
665 
InputDatasets(std::vector<const DatasetBase * > * inputs) const666 Status DatasetBase::InputDatasets(
667     std::vector<const DatasetBase*>* inputs) const {
668   return errors::Unimplemented("InputDatasets not implemented for ",
669                                type_string());
670 }
671 
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)672 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
673     SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
674   Status status = dataset->AsGraphDefInternal(ctx, this, output);
675   if (errors::IsUnimplemented(status) && !ctx->fail_if_unimplemented()) {
676     Tensor t(DT_VARIANT, TensorShape({}));
677     // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
678     // increment the refcount of `dataset` here to retain ownership.
679     dataset->Ref();
680     TF_RETURN_IF_ERROR(
681         StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
682     TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
683     DCHECK_NE(ctx->input_list(), nullptr);
684     ctx->input_list()->emplace_back((*output)->name(), std::move(t));
685     LOG_EVERY_N_SEC(WARNING, 30)
686         << "Input of " << dataset->DebugString()
687         << " will not be optimized because the dataset does not implement the "
688            "AsGraphDefInternal() method needed to apply optimizations.";
689     return Status::OK();
690   }
691   return status;
692 }
693 
AddDatasetOrTensor(SerializationContext * ctx,const Tensor & t,Node ** output)694 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor(
695     SerializationContext* ctx, const Tensor& t, Node** output) {
696   if (t.dtype() == DT_VARIANT) {
697     // If the input tensor is a variant, it may represent a multi-dimensional
698     // array of datasets. We attempt to decode each dataset so that we can use
699     // their custom serialization logic and combine the result of their
700     // individual serializations using the `Pack` operation.
701     //
702     // If this fails, we fallback to using its Variant::Encode() based
703     // serialization.
704     Status s = AddDatasetOrTensorHelper(ctx, t, output);
705     if (s.ok()) {
706       return s;
707     }
708   }
709   if (t.dtype() == DT_RESOURCE && ctx->serialize_data_tensors()) {
710     Status s = AddResourceHelper(ctx, t, output);
711     if (!errors::IsUnimplemented(s)) {
712       // Fall through to AddTensor if AsGraphDef is not implemented for this
713       // resource.
714       return s;
715     }
716   }
717   return AddTensor(t, output);
718 }
719 
AddDatasetOrTensorHelper(SerializationContext * ctx,const Tensor & t,Node ** output)720 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper(
721     SerializationContext* ctx, const Tensor& t, Node** output) {
722   if (t.dims() == 0) {
723     DatasetBase* dataset;
724     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(t, &dataset));
725     return AddInputDataset(ctx, dataset, output);
726   }
727   std::vector<NodeBuilder::NodeOut> nodes;
728   for (int i = 0; i < t.dim_size(0); ++i) {
729     Node* node;
730     TF_RETURN_IF_ERROR(AddDatasetOrTensorHelper(ctx, t.SubSlice(i), &node));
731     nodes.emplace_back(node);
732   }
733   auto op_name = "Pack";
734   auto opts = builder()->opts();
735   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
736                            opts.op_registry());
737   node_builder.Input(std::move(nodes));
738   *output = opts.FinalizeBuilder(&node_builder);
739   return Status::OK();
740 }
741 
AddResourceHelper(SerializationContext * ctx,const Tensor & t,Node ** output)742 Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper(
743     SerializationContext* ctx, const Tensor& t, Node** output) {
744   const ResourceHandle& handle = t.flat<ResourceHandle>()(0);
745   if (ctx->device_name() != handle.device()) {
746     return errors::InvalidArgument("Trying to access resource ", handle.name(),
747                                    " located in device ", handle.device(),
748                                    " from device ", ctx->device_name());
749   }
750   ResourceBase* resource;
751   TF_RETURN_IF_ERROR(ctx->resource_mgr()->Lookup(handle, &resource));
752   core::ScopedUnref unref(resource);
753   return resource->AsGraphDef(builder(), output);
754 }
755 
DatasetBaseIterator(const BaseParams & params)756 DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params)
757     : params_(params) {
758   params_.dataset->Ref();
759   VLOG(2) << prefix() << " constructor";
760   strings::StrAppend(&traceme_metadata_, "shapes=");
761   auto& shapes = output_shapes();
762   for (int i = 0; i < shapes.size(); ++i) {
763     if (i > 0) {
764       strings::StrAppend(&traceme_metadata_, " ");
765     }
766     strings::StrAppend(&traceme_metadata_, shapes.at(i).DebugString());
767   }
768   strings::StrAppend(&traceme_metadata_, ",types=");
769   auto& types = output_dtypes();
770   for (int i = 0; i < types.size(); ++i) {
771     if (i > 0) {
772       strings::StrAppend(&traceme_metadata_, " ");
773     }
774     strings::StrAppend(&traceme_metadata_, DataTypeString(types.at(i)));
775   }
776 }
777 
~DatasetBaseIterator()778 DatasetBaseIterator::~DatasetBaseIterator() {
779   VLOG(2) << prefix() << " destructor";
780   params_.dataset->Unref();
781 }
782 
BuildTraceMeName()783 string DatasetBaseIterator::BuildTraceMeName() {
784   string result =
785       strings::StrCat(params_.prefix, "#", traceme_metadata_, ",id=", id_);
786   if (parent_) {
787     strings::StrAppend(&result, ",parent_id=", parent_id_);
788   }
789   TraceMeMetadata metadata = GetTraceMeMetadata();
790   for (const auto& pair : metadata) {
791     strings::StrAppend(&result, ",", pair.first, "=", pair.second);
792   }
793   strings::StrAppend(&result, "#");
794   return result;
795 }
796 
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)797 Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
798                                     std::vector<Tensor>* out_tensors,
799                                     bool* end_of_sequence) {
800   profiler::TraceMe activity([&] { return BuildTraceMeName(); },
801                              profiler::TraceMeLevel::kInfo);
802   DVLOG(3) << prefix() << " GetNext enter";
803   auto model = ctx->model();
804   if (model && model->collect_resource_usage() && node_) {
805     int64_t now_nanos = EnvTime::NowNanos();
806     auto output = node_->output();
807     if (output) {
808       output->record_stop(now_nanos);
809     }
810     node_->record_start(now_nanos);
811   }
812   out_tensors->clear();
813   Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
814   if (TF_PREDICT_TRUE(s.ok())) {
815     if (TF_PREDICT_TRUE(!*end_of_sequence)) {
816       DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size());
817       RecordElement(ctx, out_tensors);
818     } else {
819       out_tensors->clear();
820     }
821   }
822   if (model && model->collect_resource_usage() && node_) {
823     int64_t now_nanos = EnvTime::NowNanos();
824     node_->record_stop(now_nanos);
825     auto output = node_->output();
826     if (output) {
827       output->record_start(now_nanos);
828     }
829   }
830   if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
831     s = errors::Internal("Iterator \"", params_.prefix,
832                          "\" returned `OutOfRange`. This indicates an "
833                          "implementation error as `OutOfRange` errors are not "
834                          "expected to be returned here. Original message: ",
835                          s.error_message());
836     LOG(ERROR) << s;
837   }
838   DVLOG(3) << prefix() << " GetNext exit";
839   return s;
840 }
841 
Skip(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)842 Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip,
843                                  bool* end_of_sequence, int* num_skipped) {
844   profiler::TraceMe activity([&] { return BuildTraceMeName(); },
845                              profiler::TraceMeLevel::kInfo);
846   DVLOG(3) << prefix() << " Skip enter";
847   auto model = ctx->model();
848   if (model && model->collect_resource_usage() && node_) {
849     int64_t now_nanos = EnvTime::NowNanos();
850     auto output = node_->output();
851     if (output) {
852       output->record_stop(now_nanos);
853     }
854     node_->record_start(now_nanos);
855   }
856   Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped);
857   if (model && model->collect_resource_usage() && node_) {
858     int64_t now_nanos = EnvTime::NowNanos();
859     node_->record_stop(now_nanos);
860     auto output = node_->output();
861     if (output) {
862       output->record_start(now_nanos);
863     }
864   }
865   if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
866     s = errors::Internal("Iterator \"", params_.prefix,
867                          "\" returned `OutOfRange`. This indicates an "
868                          "implementation error as `OutOfRange` errors are not "
869                          "expected to be returned here. Original message: ",
870                          s.error_message());
871     LOG(ERROR) << s;
872   }
873   DVLOG(3) << prefix() << " Skip exit";
874   return s;
875 }
876 
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)877 Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip,
878                                          bool* end_of_sequence,
879                                          int* num_skipped) {
880   *num_skipped = 0;
881   for (int i = 0; i < num_to_skip; ++i) {
882     std::vector<Tensor> out_tensors;
883     TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence));
884     if (*end_of_sequence) {
885       return Status::OK();
886     }
887     // RecordElement is used to count the number of element computed and
888     // help calculate the CPU time spent on a given iterator to do the
889     // autotuning.
890     // Here we only call RecordElement in the default implementation of
891     // SkipInternal (which trivially calls GetNextInternal) and assume
892     // that the overriden SkipInternal in the derived class will have
893     // negligible cost compare to its GetNextInternal.
894     RecordElement(ctx, &out_tensors);
895     (*num_skipped)++;
896   }
897   return Status::OK();
898 }
899 
Compute(OpKernelContext * ctx)900 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
901   DatasetBase* dataset = nullptr;
902   MakeDataset(ctx, &dataset);
903   if (ctx->status().ok()) {
904     Tensor* output = nullptr;
905     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
906     OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
907     dataset->Initialize();
908   }
909 }
910 
TraceString(const OpKernelContext & ctx,bool verbose) const911 string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
912                                     bool verbose) const {
913   return profiler::TraceMeOp(name_view(), type_string_view());
914 }
915 
916 // static
IsDatasetOp(const OpDef * op_def)917 bool DatasetOpKernel::IsDatasetOp(const OpDef* op_def) {
918   return (op_def->output_arg_size() == 1 &&
919           op_def->output_arg(0).type() == DT_VARIANT &&
920           (absl::EndsWith(op_def->name(), "Dataset") ||
921            absl::EndsWith(op_def->name(), "DatasetV2")));
922 }
923 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)924 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
925                                        DatasetBase** output) {
926   DatasetBase* input;
927   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
928   MakeDataset(ctx, input, output);
929 }
930 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)931 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
932                                         DatasetBase** output) {
933   DatasetBase* input;
934   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
935   DatasetBase* another_input;
936   OP_REQUIRES_OK(ctx,
937                  GetDatasetFromVariantTensor(ctx->input(1), &another_input));
938   MakeDataset(ctx, input, another_input, output);
939 }
940 
941 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
942 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
943     "_DATASET_GRAPH_OUTPUT_NODE";
944 
BackgroundWorker(Env * env,const char * name)945 BackgroundWorker::BackgroundWorker(Env* env, const char* name)
946     : env_(env), name_(name) {}
947 
~BackgroundWorker()948 BackgroundWorker::~BackgroundWorker() {
949   {
950     mutex_lock l(mu_);
951     cancelled_ = true;
952   }
953   cond_var_.notify_one();
954   // Block until the background thread has terminated.
955   //
956   // NOTE(mrry): We explicitly free and join the thread here because
957   // `WorkerLoop()` uses other members of this object, and so we must join
958   // the thread before destroying them.
959   thread_.reset();
960 }
961 
Schedule(std::function<void ()> work_item)962 void BackgroundWorker::Schedule(std::function<void()> work_item) {
963   {
964     mutex_lock l(mu_);
965     if (!thread_) {
966       thread_ = absl::WrapUnique(env_->StartThread(
967           {} /* thread_options */, name_, [this]() { WorkerLoop(); }));
968     }
969     work_queue_.push_back(std::move(work_item));
970   }
971   cond_var_.notify_one();
972 }
973 
WorkerLoop()974 void BackgroundWorker::WorkerLoop() {
975   tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
976   while (true) {
977     std::function<void()> work_item = nullptr;
978     {
979       mutex_lock l(mu_);
980       while (!cancelled_ && work_queue_.empty()) {
981         cond_var_.wait(l);
982       }
983       if (cancelled_) {
984         return;
985       }
986       DCHECK(!work_queue_.empty());
987       work_item = std::move(work_queue_.front());
988       work_queue_.pop_front();
989     }
990     DCHECK(work_item != nullptr);
991     work_item();
992   }
993 }
994 
995 namespace {
996 class RunnerImpl : public Runner {
997  public:
Run(const std::function<void ()> & f)998   void Run(const std::function<void()>& f) override {
999     tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
1000     f();
1001 
1002     // NOTE: We invoke a virtual function to prevent `f` being tail-called, and
1003     // thus ensure that this function remains on the stack until after `f`
1004     // returns.
1005     PreventTailCall();
1006   }
1007 
1008  private:
PreventTailCall()1009   virtual void PreventTailCall() {}
1010 };
1011 }  // namespace
1012 
1013 /* static */
get()1014 Runner* Runner::get() {
1015   static Runner* singleton = new RunnerImpl;
1016   return singleton;
1017 }
1018 
1019 }  // namespace data
1020 }  // namespace tensorflow
1021