1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16
17 #include <unordered_map>
18
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/variant_encode_decode.h"
23 #include "tensorflow/core/framework/variant_op_registry.h"
24 #include "tensorflow/core/graph/graph_def_builder.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/resource.h"
30 #include "tensorflow/core/platform/strcat.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32
33 // On Windows, disable some macros that would break compile
34 #if defined(PLATFORM_WINDOWS)
35 #undef GetMessage
36 #endif
37
38 namespace tensorflow {
39 namespace data {
40 namespace {
41
get_dataset_op_registry_lock()42 static mutex* get_dataset_op_registry_lock() {
43 static mutex dataset_op_registry_lock(LINKER_INITIALIZED);
44 return &dataset_op_registry_lock;
45 }
46
get_dataset_op_registry()47 static std::unordered_set<string>* get_dataset_op_registry() {
48 static std::unordered_set<string>* names = new std::unordered_set<string>;
49 return names;
50 }
51
52 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
53 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
54 // and the wrapper's copy constructor and destructor take care of managing the
55 // reference count.
56 //
57 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
58 // specification. In particular, we cannot currently serialize an arbitrary
59 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
60 // implemented.
61 class DatasetVariantWrapper {
62 public:
DatasetVariantWrapper()63 DatasetVariantWrapper() : dataset_(nullptr) {}
64
65 // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)66 explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
67
DatasetVariantWrapper(const DatasetVariantWrapper & other)68 DatasetVariantWrapper(const DatasetVariantWrapper& other)
69 : dataset_(other.dataset_) {
70 if (dataset_) dataset_->Ref();
71 }
72
operator =(DatasetVariantWrapper && other)73 DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
74 if (&other == this) return *this;
75 std::swap(dataset_, other.dataset_);
76 return *this;
77 }
78
79 DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
80
~DatasetVariantWrapper()81 ~DatasetVariantWrapper() {
82 if (dataset_) dataset_->Unref();
83 }
84
get() const85 DatasetBase* get() const { return dataset_; }
86
TypeName() const87 string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const88 string DebugString() const {
89 if (dataset_) {
90 return dataset_->DebugString();
91 } else {
92 return "<Uninitialized DatasetVariantWrapper>";
93 }
94 }
Encode(VariantTensorData * data) const95 void Encode(VariantTensorData* data) const {
96 LOG(ERROR) << "The Encode() method is not implemented for "
97 "DatasetVariantWrapper objects.";
98 }
Decode(const VariantTensorData & data)99 bool Decode(const VariantTensorData& data) {
100 LOG(ERROR) << "The Decode() method is not implemented for "
101 "DatasetVariantWrapper objects.";
102 return false;
103 }
104
105 private:
106 DatasetBase* dataset_; // Owns one reference.
107 };
108
109 const char kWrappedDatasetVariantTypeName[] =
110 "tensorflow::data::WrappedDatasetVariant";
111
112 class WrappedDatasetVariantWrapper {
113 public:
WrappedDatasetVariantWrapper()114 WrappedDatasetVariantWrapper() {}
115
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)116 explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
117 : ds_tensor_(ds_tensor) {}
118
get() const119 Tensor get() const { return ds_tensor_; }
120
TypeName() const121 string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
122
DebugString() const123 string DebugString() const {
124 return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
125 }
126
Encode(VariantTensorData * data) const127 void Encode(VariantTensorData* data) const {
128 *(data->add_tensors()) = ds_tensor_;
129 }
130
Decode(const VariantTensorData & data)131 bool Decode(const VariantTensorData& data) {
132 ds_tensor_ = data.tensors(0);
133 return true;
134 }
135
136 private:
137 Tensor ds_tensor_;
138 };
139
140 class WrapDatasetVariantOp : public OpKernel {
141 public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)142 explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
143
Compute(OpKernelContext * ctx)144 void Compute(OpKernelContext* ctx) override {
145 const Tensor& tensor = ctx->input(0);
146 OP_REQUIRES(ctx,
147 tensor.dtype() == DT_VARIANT &&
148 TensorShapeUtils::IsScalar(tensor.shape()),
149 errors::InvalidArgument(
150 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
151 DatasetBase* unused;
152 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
153 Tensor* output = nullptr;
154 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
155 output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
156 }
157 };
158
159 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
160 WrapDatasetVariantOp);
161 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
162 .HostMemory("input_handle")
163 .HostMemory("output_handle")
164 .Device(DEVICE_GPU),
165 WrapDatasetVariantOp);
166
167 class UnwrapDatasetVariantOp : public OpKernel {
168 public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)169 explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
170
Compute(OpKernelContext * ctx)171 void Compute(OpKernelContext* ctx) override {
172 const Tensor& tensor = ctx->input(0);
173 OP_REQUIRES(ctx,
174 tensor.dtype() == DT_VARIANT &&
175 TensorShapeUtils::IsScalar(tensor.shape()),
176 errors::InvalidArgument(
177 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
178 Variant variant = tensor.scalar<Variant>()();
179 const WrappedDatasetVariantWrapper* wrapper =
180 variant.get<WrappedDatasetVariantWrapper>();
181 OP_REQUIRES(ctx, wrapper != nullptr,
182 errors::InvalidArgument(
183 "Tensor must be a WrappedDataset variant object."));
184 Tensor ds_tensor = wrapper->get();
185 OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
186 }
187 };
188
189 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
190 UnwrapDatasetVariantOp);
191 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
192 .HostMemory("input_handle")
193 .HostMemory("output_handle")
194 .Device(DEVICE_GPU),
195 UnwrapDatasetVariantOp);
196
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)197 static Status WrappedDatasetVariantDeviceCopy(
198 const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
199 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
200 *to = WrappedDatasetVariantWrapper(from);
201 return Status::OK();
202 }
203
204 #define REGISTER_OPTIONAL_COPY(DIRECTION) \
205 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
206 WrappedDatasetVariantWrapper, DIRECTION, \
207 WrappedDatasetVariantDeviceCopy)
208
209 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
210 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
211 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
212
213 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
214 kWrappedDatasetVariantTypeName);
215
216 } // namespace
217
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)218 Status GraphDefBuilderWrapper::AddDataset(const DatasetBase* dataset,
219 const std::vector<Node*>& inputs,
220 Node** output) {
221 return AddDataset(dataset, inputs, {}, output);
222 }
223
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)224 Status GraphDefBuilderWrapper::AddDataset(
225 const DatasetBase* dataset, const std::vector<Node*>& inputs,
226 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
227 Node** output) {
228 std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
229 for (size_t i = 0; i < inputs.size(); i++) {
230 enumerated_inputs[i] = std::make_pair(i, inputs[i]);
231 }
232 return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
233 }
234
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)235 Status GraphDefBuilderWrapper::AddDataset(
236 const DatasetBase* dataset,
237 const std::vector<std::pair<size_t, Node*>>& inputs,
238 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
239 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
240 Node** output) {
241 return AddDataset(dataset, inputs, list_inputs, attrs,
242 /*use_dataset_name=*/false, output);
243 }
244
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,bool use_dataset_name,Node ** output)245 Status GraphDefBuilderWrapper::AddDataset(
246 const DatasetBase* dataset,
247 const std::vector<std::pair<size_t, Node*>>& inputs,
248 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
249 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
250 bool use_dataset_name, Node** output) {
251 auto& type_string = dataset->type_string();
252 auto opts = absl::make_unique<GraphDefBuilder::Options>(b_->opts());
253 // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
254 // attributes defined. It will be nice to have a consistent pattern.
255 bool has_output_types_attr = HasAttr(type_string, "output_types");
256 bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
257 if (has_output_shapes_attr) {
258 opts = absl::make_unique<GraphDefBuilder::Options>(
259 opts->WithAttr("output_shapes", dataset->output_shapes()));
260 }
261 if (has_output_types_attr) {
262 opts = absl::make_unique<GraphDefBuilder::Options>(
263 opts->WithAttr("output_types", dataset->output_dtypes()));
264 }
265 for (const auto& attr : attrs) {
266 opts = absl::make_unique<GraphDefBuilder::Options>(
267 opts->WithAttr(attr.first, attr.second));
268 }
269 if (opts->HaveError()) {
270 return errors::Internal("AddDataset: Failed to build Options with error ",
271 opts->StatusToString());
272 }
273 NodeBuilder node_builder(
274 use_dataset_name ? dataset->node_name() : opts->GetNameForOp(type_string),
275 type_string, opts->op_registry());
276 {
277 size_t total_size = inputs.size() + list_inputs.size();
278 auto inputs_iter = inputs.begin();
279 auto list_inputs_iter = list_inputs.begin();
280 for (int i = 0; i < total_size; i++) {
281 if (inputs_iter != inputs.end() && inputs_iter->first == i) {
282 node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
283 inputs_iter++;
284 } else if (list_inputs_iter != list_inputs.end() &&
285 list_inputs_iter->first == i) {
286 std::vector<NodeBuilder::NodeOut> nodeout_inputs;
287 nodeout_inputs.reserve(list_inputs_iter->second.size());
288 for (Node* n : list_inputs_iter->second) {
289 nodeout_inputs.emplace_back(n);
290 }
291 node_builder.Input(nodeout_inputs);
292 list_inputs_iter++;
293 } else {
294 return errors::InvalidArgument("No input found for index ", i);
295 }
296 }
297 }
298 *output = opts->FinalizeBuilder(&node_builder);
299 if (*output == nullptr) {
300 return errors::Internal("AddDataset: Failed to build ", type_string,
301 " op with error ", opts->StatusToString());
302 }
303 return Status::OK();
304 }
305
AddFunction(SerializationContext * ctx,const string & function_name,const FunctionLibraryDefinition & lib_def)306 Status GraphDefBuilderWrapper::AddFunction(
307 SerializationContext* ctx, const string& function_name,
308 const FunctionLibraryDefinition& lib_def) {
309 if (b_->HasFunction(function_name)) {
310 VLOG(1) << "Function with name " << function_name << "already exists in"
311 << " the graph. It will not be added again.";
312 return Status::OK();
313 }
314 const FunctionDef* f_def = lib_def.Find(function_name);
315 if (f_def == nullptr) {
316 return errors::InvalidArgument("Unable to find FunctionDef for ",
317 function_name, " in the registry.");
318 }
319 FunctionDefLibrary def;
320 *def.add_function() = *f_def;
321 const string gradient_func = lib_def.FindGradient(function_name);
322 if (!gradient_func.empty()) {
323 GradientDef* g_def = def.add_gradient();
324 g_def->set_function_name(function_name);
325 g_def->set_gradient_func(gradient_func);
326 }
327 TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
328
329 // Recursively add functions in inputs of function_name.
330 for (const NodeDef& node_def : f_def->node_def()) {
331 const OpRegistrationData* op_reg_data = nullptr;
332 TF_RETURN_IF_ERROR(lib_def.LookUp(node_def.op(), &op_reg_data));
333 if (op_reg_data->is_function_op) {
334 TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name(), lib_def));
335 }
336 // Recursively add functions in attrs of this NodeDef.
337 for (const auto& pair : node_def.attr()) {
338 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second, lib_def));
339 }
340 }
341
342 // Recursively add functions in attrs of function_name.
343 for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
344 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def));
345 }
346 return Status::OK();
347 }
348
AddPlaceholderInternal(const Tensor & val,Node ** output)349 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
350 Node** output) {
351 *output = ops::SourceOp(
352 "Placeholder",
353 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
354 }
355
AddTensorInternal(const Tensor & val,Node ** output)356 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
357 Node** output) {
358 *output = ops::SourceOp(
359 "Const",
360 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
361 }
362
HasAttr(const string & name,const string & attr_name) const363 bool GraphDefBuilderWrapper::HasAttr(const string& name,
364 const string& attr_name) const {
365 const OpDef* op_def = nullptr;
366 Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
367 if (!s.ok() || op_def == nullptr) {
368 return false;
369 }
370 return HasAttr(op_def, attr_name);
371 }
372
InitializeBase(IteratorContext * ctx,const IteratorBase * parent)373 Status IteratorBase::InitializeBase(IteratorContext* ctx,
374 const IteratorBase* parent) {
375 parent_ = parent;
376 id_ =
377 Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
378 if (parent_) {
379 parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()),
380 reinterpret_cast<uint64>(parent_));
381 }
382 if (const auto& model = ctx->model()) {
383 auto factory = [ctx, this](model::Node::Args args) {
384 return CreateNode(ctx, std::move(args));
385 };
386 model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
387 cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
388 }
389 return Status::OK();
390 }
391
GetAllocatedBytes(const std::vector<Tensor> & element)392 int64 GetAllocatedBytes(const std::vector<Tensor>& element) {
393 int64_t allocated_bytes = 0;
394 DatasetBase* dataset;
395 for (auto& tensor : element) {
396 if (tensor.dtype() == DT_VARIANT &&
397 GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
398 allocated_bytes += dataset->AllocatedBytes();
399 } else {
400 allocated_bytes += tensor.AllocatedBytes();
401 }
402 }
403 return allocated_bytes;
404 }
405
GetTotalBytes(const std::vector<Tensor> & element)406 int64 GetTotalBytes(const std::vector<Tensor>& element) {
407 int64_t total_bytes = 0;
408 DatasetBase* dataset;
409 for (auto& tensor : element) {
410 if (tensor.dtype() == DT_VARIANT &&
411 GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
412 total_bytes += dataset->TotalBytes();
413 } else {
414 total_bytes += tensor.TotalBytes();
415 }
416 }
417 return total_bytes;
418 }
419
FullName(const std::string & prefix,const std::string & name)420 std::string FullName(const std::string& prefix, const std::string& name) {
421 if (str_util::StrContains(name, kColon)) {
422 LOG(ERROR) << name << " should not contain " << kColon;
423 }
424
425 return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
426 }
427
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)428 Status GetDatasetFromVariantTensor(const Tensor& tensor,
429 DatasetBase** out_dataset) {
430 if (!(tensor.dtype() == DT_VARIANT &&
431 TensorShapeUtils::IsScalar(tensor.shape()))) {
432 return errors::InvalidArgument(
433 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
434 }
435 const Variant& variant = tensor.scalar<Variant>()();
436 const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
437 if (wrapper == nullptr) {
438 return errors::InvalidArgument("Tensor must be a Dataset object.");
439 }
440 *out_dataset = wrapper->get();
441 if (*out_dataset == nullptr) {
442 return errors::Internal("Read uninitialized Dataset variant.");
443 }
444 return Status::OK();
445 }
446
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)447 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
448 if (!(tensor->dtype() == DT_VARIANT &&
449 TensorShapeUtils::IsScalar(tensor->shape()))) {
450 return errors::InvalidArgument(
451 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
452 }
453 tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
454 return Status::OK();
455 }
456
457 namespace internal {
458
459 #define WARN_PROTO_FIELD_CONFLICT(reflection, field, field_type, src, dst) \
460 { \
461 auto source_value = reflection->Get##field_type(src, field); \
462 auto destination_value = reflection->Get##field_type(*dst, field); \
463 if (source_value != destination_value) { \
464 LOG(WARNING) << "Changing the value of option field " << field->name() \
465 << " from " << destination_value << " to " << source_value; \
466 } \
467 }
468
469 #define WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst) \
470 { \
471 auto source_value = reflection->GetEnum(src, field); \
472 auto destination_value = reflection->GetEnum(*dst, field); \
473 if (source_value != destination_value) { \
474 LOG(WARNING) << "Changing the value of option enum field " \
475 << field->name() << " from " \
476 << destination_value->full_name() << " to " \
477 << source_value->full_name(); \
478 } \
479 }
480
WarnProtoConflicts(const protobuf::Message & src,protobuf::Message * dst)481 void WarnProtoConflicts(const protobuf::Message& src, protobuf::Message* dst) {
482 std::vector<const protobuf::FieldDescriptor*> set_src;
483 std::vector<const protobuf::FieldDescriptor*> set_dst;
484 const protobuf::Reflection* reflection = src.GetReflection();
485 reflection->ListFields(src, &set_src);
486 reflection->ListFields(*dst, &set_dst);
487 std::sort(set_src.begin(), set_src.end());
488 std::sort(set_dst.begin(), set_dst.end());
489
490 std::vector<const protobuf::FieldDescriptor*> in_both;
491 std::set_intersection(set_src.begin(), set_src.end(), set_dst.begin(),
492 set_dst.end(), std::back_inserter(in_both));
493
494 for (auto field : in_both) {
495 if (field->type() == protobuf::FieldDescriptor::TYPE_MESSAGE) {
496 WarnProtoConflicts(reflection->GetMessage(src, field),
497 reflection->MutableMessage(dst, field));
498 } else {
499 switch (field->cpp_type()) {
500 case protobuf::FieldDescriptor::CPPTYPE_INT32:
501 WARN_PROTO_FIELD_CONFLICT(reflection, field, Int32, src, dst);
502 break;
503 case protobuf::FieldDescriptor::CPPTYPE_INT64:
504 WARN_PROTO_FIELD_CONFLICT(reflection, field, Int64, src, dst);
505 break;
506 case protobuf::FieldDescriptor::CPPTYPE_UINT32:
507 WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt32, src, dst);
508 break;
509 case protobuf::FieldDescriptor::CPPTYPE_UINT64:
510 WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt64, src, dst);
511 break;
512 case protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
513 WARN_PROTO_FIELD_CONFLICT(reflection, field, Double, src, dst);
514 break;
515 case protobuf::FieldDescriptor::CPPTYPE_FLOAT:
516 WARN_PROTO_FIELD_CONFLICT(reflection, field, Float, src, dst);
517 break;
518 case protobuf::FieldDescriptor::CPPTYPE_BOOL:
519 WARN_PROTO_FIELD_CONFLICT(reflection, field, Bool, src, dst);
520 break;
521 case protobuf::FieldDescriptor::CPPTYPE_ENUM:
522 WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst);
523 break;
524 default: {
525 LOG(ERROR) << "Unrecognized proto type for field "
526 << field->full_name();
527 }
528 }
529 }
530 }
531 }
532
533 #undef WARN_PROTO_ENUM_FIELD_CONFLICT
534 #undef WARN_PROTO_FIELD_CONFLICT
535
MergeOptions(const protobuf::Message & source,protobuf::Message * destination)536 void MergeOptions(const protobuf::Message& source,
537 protobuf::Message* destination) {
538 WarnProtoConflicts(source, destination);
539 destination->MergeFrom(source);
540 }
541
MergeOptions(const protobuf::MessageLite & source,protobuf::MessageLite * destination)542 void MergeOptions(const protobuf::MessageLite& source,
543 protobuf::MessageLite* destination) {
544 destination->CheckTypeAndMergeFrom(source);
545 }
546
547 } // namespace internal
548
Initialize()549 void DatasetBase::Initialize() {
550 Status s = ComputeNumSources();
551 if (!s.ok()) {
552 LOG(ERROR) << s;
553 }
554 s = MergeOptionsFromInputs();
555 if (!s.ok()) {
556 LOG(ERROR) << s;
557 }
558 }
559
ComputeNumSources()560 Status DatasetBase::ComputeNumSources() {
561 std::vector<const DatasetBase*> inputs;
562 Status s = InputDatasets(&inputs);
563 if (errors::IsUnimplemented(s)) {
564 return errors::Unimplemented(
565 "Cannot compute input sources for dataset of type ", type_string(),
566 ", because the dataset does not implement `InputDatasets`.");
567 }
568 if (num_sources_ >= 0) {
569 // Already computed.
570 return Status::OK();
571 }
572 num_sources_ = 0;
573 if (inputs.empty()) {
574 num_sources_ = 1;
575 return Status::OK();
576 }
577 for (const auto& input : inputs) {
578 if (input->num_sources() < 0) {
579 return errors::FailedPrecondition(
580 "Cannot compute input sources for dataset of type ", type_string(),
581 ", because sources could not be computed for input dataset of type ",
582 input->type_string());
583 }
584 num_sources_ += input->num_sources();
585 }
586 return Status::OK();
587 }
588
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors)589 Status DatasetBase::Get(OpKernelContext* ctx, int64 index,
590 std::vector<Tensor>* out_tensors) {
591 return errors::Unimplemented(
592 "Random access is not implemented for this dataset.");
593 }
594
MergeOptionsFromInputs()595 Status DatasetBase::MergeOptionsFromInputs() {
596 std::vector<const DatasetBase*> inputs;
597 Status s = InputDatasets(&inputs);
598 if (errors::IsUnimplemented(s)) {
599 return errors::Unimplemented(
600 "Cannot merge options for dataset of type ", type_string(),
601 ", because the dataset does not implement `InputDatasets`.");
602 }
603 if (inputs.empty()) {
604 return Status::OK();
605 }
606 // Merge options from inputs sequentially before merging options from dataset.
607 // Since the last options merged takes precedence, the options that may be set
608 // for the current dataset through OptionsDataset takes precedence over those
609 // set on the input datasets.
610 Options merged_options = inputs[0]->options_;
611 for (int i = 1; i < inputs.size(); ++i) {
612 internal::MergeOptions(inputs[i]->options_, &merged_options);
613 }
614 internal::MergeOptions(options_, &merged_options);
615 options_ = merged_options;
616 return Status::OK();
617 }
618
MakeIterator(IteratorContext * ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator) const619 Status DatasetBase::MakeIterator(
620 IteratorContext* ctx, const IteratorBase* parent,
621 const string& output_prefix,
622 std::unique_ptr<IteratorBase>* iterator) const {
623 if (type_string() == "OptionsDataset" || type_string() == "FinalizeDataset") {
624 std::vector<const DatasetBase*> inputs;
625 Status s = InputDatasets(&inputs);
626 return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator);
627 }
628 profiler::TraceMe traceme(
629 [&] {
630 return profiler::TraceMeEncode(
631 strings::StrCat("MakeIterator::", type_string()), {});
632 },
633 profiler::TraceMeLevel::kInfo);
634 *iterator = MakeIteratorInternal(output_prefix);
635 Status s = (*iterator)->InitializeBase(ctx, parent);
636 if (s.ok()) {
637 s.Update((*iterator)->Initialize(ctx));
638 }
639 if (!s.ok()) {
640 // Reset the iterator to avoid returning an uninitialized iterator.
641 iterator->reset();
642 }
643 return s;
644 }
645
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const646 Status DatasetBase::MakeSplitProviders(
647 std::vector<std::unique_ptr<SplitProvider>>* split_providers) const {
648 std::vector<const DatasetBase*> inputs;
649 Status s = InputDatasets(&inputs);
650 if (errors::IsUnimplemented(s)) {
651 return errors::Unimplemented(
652 "Cannot create split providers for dataset of type ", type_string(),
653 ", because the dataset implements neither `InputDatasets` nor "
654 "`MakeSplitProvider`.");
655 }
656 if (inputs.size() != 1) {
657 return errors::Unimplemented(
658 "Cannot create split providers for dataset of type ", type_string(),
659 ", because the dataset is not unary (instead having arity ",
660 inputs.size(),
661 "), and no custom implementation of `MakeSplitProvider` is defined.");
662 }
663 return inputs[0]->MakeSplitProviders(split_providers);
664 }
665
InputDatasets(std::vector<const DatasetBase * > * inputs) const666 Status DatasetBase::InputDatasets(
667 std::vector<const DatasetBase*>* inputs) const {
668 return errors::Unimplemented("InputDatasets not implemented for ",
669 type_string());
670 }
671
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)672 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
673 SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
674 Status status = dataset->AsGraphDefInternal(ctx, this, output);
675 if (errors::IsUnimplemented(status) && !ctx->fail_if_unimplemented()) {
676 Tensor t(DT_VARIANT, TensorShape({}));
677 // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
678 // increment the refcount of `dataset` here to retain ownership.
679 dataset->Ref();
680 TF_RETURN_IF_ERROR(
681 StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
682 TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
683 DCHECK_NE(ctx->input_list(), nullptr);
684 ctx->input_list()->emplace_back((*output)->name(), std::move(t));
685 LOG_EVERY_N_SEC(WARNING, 30)
686 << "Input of " << dataset->DebugString()
687 << " will not be optimized because the dataset does not implement the "
688 "AsGraphDefInternal() method needed to apply optimizations.";
689 return Status::OK();
690 }
691 return status;
692 }
693
AddDatasetOrTensor(SerializationContext * ctx,const Tensor & t,Node ** output)694 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor(
695 SerializationContext* ctx, const Tensor& t, Node** output) {
696 if (t.dtype() == DT_VARIANT) {
697 // If the input tensor is a variant, it may represent a multi-dimensional
698 // array of datasets. We attempt to decode each dataset so that we can use
699 // their custom serialization logic and combine the result of their
700 // individual serializations using the `Pack` operation.
701 //
702 // If this fails, we fallback to using its Variant::Encode() based
703 // serialization.
704 Status s = AddDatasetOrTensorHelper(ctx, t, output);
705 if (s.ok()) {
706 return s;
707 }
708 }
709 if (t.dtype() == DT_RESOURCE && ctx->serialize_data_tensors()) {
710 Status s = AddResourceHelper(ctx, t, output);
711 if (!errors::IsUnimplemented(s)) {
712 // Fall through to AddTensor if AsGraphDef is not implemented for this
713 // resource.
714 return s;
715 }
716 }
717 return AddTensor(t, output);
718 }
719
AddDatasetOrTensorHelper(SerializationContext * ctx,const Tensor & t,Node ** output)720 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper(
721 SerializationContext* ctx, const Tensor& t, Node** output) {
722 if (t.dims() == 0) {
723 DatasetBase* dataset;
724 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(t, &dataset));
725 return AddInputDataset(ctx, dataset, output);
726 }
727 std::vector<NodeBuilder::NodeOut> nodes;
728 for (int i = 0; i < t.dim_size(0); ++i) {
729 Node* node;
730 TF_RETURN_IF_ERROR(AddDatasetOrTensorHelper(ctx, t.SubSlice(i), &node));
731 nodes.emplace_back(node);
732 }
733 auto op_name = "Pack";
734 auto opts = builder()->opts();
735 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
736 opts.op_registry());
737 node_builder.Input(std::move(nodes));
738 *output = opts.FinalizeBuilder(&node_builder);
739 return Status::OK();
740 }
741
AddResourceHelper(SerializationContext * ctx,const Tensor & t,Node ** output)742 Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper(
743 SerializationContext* ctx, const Tensor& t, Node** output) {
744 const ResourceHandle& handle = t.flat<ResourceHandle>()(0);
745 if (ctx->device_name() != handle.device()) {
746 return errors::InvalidArgument("Trying to access resource ", handle.name(),
747 " located in device ", handle.device(),
748 " from device ", ctx->device_name());
749 }
750 ResourceBase* resource;
751 TF_RETURN_IF_ERROR(ctx->resource_mgr()->Lookup(handle, &resource));
752 core::ScopedUnref unref(resource);
753 return resource->AsGraphDef(builder(), output);
754 }
755
DatasetBaseIterator(const BaseParams & params)756 DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params)
757 : params_(params) {
758 params_.dataset->Ref();
759 VLOG(2) << prefix() << " constructor";
760 strings::StrAppend(&traceme_metadata_, "shapes=");
761 auto& shapes = output_shapes();
762 for (int i = 0; i < shapes.size(); ++i) {
763 if (i > 0) {
764 strings::StrAppend(&traceme_metadata_, " ");
765 }
766 strings::StrAppend(&traceme_metadata_, shapes.at(i).DebugString());
767 }
768 strings::StrAppend(&traceme_metadata_, ",types=");
769 auto& types = output_dtypes();
770 for (int i = 0; i < types.size(); ++i) {
771 if (i > 0) {
772 strings::StrAppend(&traceme_metadata_, " ");
773 }
774 strings::StrAppend(&traceme_metadata_, DataTypeString(types.at(i)));
775 }
776 }
777
~DatasetBaseIterator()778 DatasetBaseIterator::~DatasetBaseIterator() {
779 VLOG(2) << prefix() << " destructor";
780 params_.dataset->Unref();
781 }
782
BuildTraceMeName()783 string DatasetBaseIterator::BuildTraceMeName() {
784 string result =
785 strings::StrCat(params_.prefix, "#", traceme_metadata_, ",id=", id_);
786 if (parent_) {
787 strings::StrAppend(&result, ",parent_id=", parent_id_);
788 }
789 TraceMeMetadata metadata = GetTraceMeMetadata();
790 for (const auto& pair : metadata) {
791 strings::StrAppend(&result, ",", pair.first, "=", pair.second);
792 }
793 strings::StrAppend(&result, "#");
794 return result;
795 }
796
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)797 Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
798 std::vector<Tensor>* out_tensors,
799 bool* end_of_sequence) {
800 profiler::TraceMe activity([&] { return BuildTraceMeName(); },
801 profiler::TraceMeLevel::kInfo);
802 DVLOG(3) << prefix() << " GetNext enter";
803 auto model = ctx->model();
804 if (model && model->collect_resource_usage() && node_) {
805 int64_t now_nanos = EnvTime::NowNanos();
806 auto output = node_->output();
807 if (output) {
808 output->record_stop(now_nanos);
809 }
810 node_->record_start(now_nanos);
811 }
812 out_tensors->clear();
813 Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
814 if (TF_PREDICT_TRUE(s.ok())) {
815 if (TF_PREDICT_TRUE(!*end_of_sequence)) {
816 DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size());
817 RecordElement(ctx, out_tensors);
818 } else {
819 out_tensors->clear();
820 }
821 }
822 if (model && model->collect_resource_usage() && node_) {
823 int64_t now_nanos = EnvTime::NowNanos();
824 node_->record_stop(now_nanos);
825 auto output = node_->output();
826 if (output) {
827 output->record_start(now_nanos);
828 }
829 }
830 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
831 s = errors::Internal("Iterator \"", params_.prefix,
832 "\" returned `OutOfRange`. This indicates an "
833 "implementation error as `OutOfRange` errors are not "
834 "expected to be returned here. Original message: ",
835 s.error_message());
836 LOG(ERROR) << s;
837 }
838 DVLOG(3) << prefix() << " GetNext exit";
839 return s;
840 }
841
Skip(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)842 Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip,
843 bool* end_of_sequence, int* num_skipped) {
844 profiler::TraceMe activity([&] { return BuildTraceMeName(); },
845 profiler::TraceMeLevel::kInfo);
846 DVLOG(3) << prefix() << " Skip enter";
847 auto model = ctx->model();
848 if (model && model->collect_resource_usage() && node_) {
849 int64_t now_nanos = EnvTime::NowNanos();
850 auto output = node_->output();
851 if (output) {
852 output->record_stop(now_nanos);
853 }
854 node_->record_start(now_nanos);
855 }
856 Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped);
857 if (model && model->collect_resource_usage() && node_) {
858 int64_t now_nanos = EnvTime::NowNanos();
859 node_->record_stop(now_nanos);
860 auto output = node_->output();
861 if (output) {
862 output->record_start(now_nanos);
863 }
864 }
865 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
866 s = errors::Internal("Iterator \"", params_.prefix,
867 "\" returned `OutOfRange`. This indicates an "
868 "implementation error as `OutOfRange` errors are not "
869 "expected to be returned here. Original message: ",
870 s.error_message());
871 LOG(ERROR) << s;
872 }
873 DVLOG(3) << prefix() << " Skip exit";
874 return s;
875 }
876
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)877 Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip,
878 bool* end_of_sequence,
879 int* num_skipped) {
880 *num_skipped = 0;
881 for (int i = 0; i < num_to_skip; ++i) {
882 std::vector<Tensor> out_tensors;
883 TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence));
884 if (*end_of_sequence) {
885 return Status::OK();
886 }
887 // RecordElement is used to count the number of element computed and
888 // help calculate the CPU time spent on a given iterator to do the
889 // autotuning.
890 // Here we only call RecordElement in the default implementation of
891 // SkipInternal (which trivially calls GetNextInternal) and assume
892 // that the overriden SkipInternal in the derived class will have
893 // negligible cost compare to its GetNextInternal.
894 RecordElement(ctx, &out_tensors);
895 (*num_skipped)++;
896 }
897 return Status::OK();
898 }
899
Compute(OpKernelContext * ctx)900 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
901 DatasetBase* dataset = nullptr;
902 MakeDataset(ctx, &dataset);
903 if (ctx->status().ok()) {
904 Tensor* output = nullptr;
905 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
906 OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
907 dataset->Initialize();
908 }
909 }
910
TraceString(const OpKernelContext & ctx,bool verbose) const911 string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
912 bool verbose) const {
913 return profiler::TraceMeOp(name_view(), type_string_view());
914 }
915
916 // static
IsDatasetOp(const OpDef * op_def)917 bool DatasetOpKernel::IsDatasetOp(const OpDef* op_def) {
918 return (op_def->output_arg_size() == 1 &&
919 op_def->output_arg(0).type() == DT_VARIANT &&
920 (absl::EndsWith(op_def->name(), "Dataset") ||
921 absl::EndsWith(op_def->name(), "DatasetV2")));
922 }
923
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)924 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
925 DatasetBase** output) {
926 DatasetBase* input;
927 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
928 MakeDataset(ctx, input, output);
929 }
930
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)931 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
932 DatasetBase** output) {
933 DatasetBase* input;
934 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
935 DatasetBase* another_input;
936 OP_REQUIRES_OK(ctx,
937 GetDatasetFromVariantTensor(ctx->input(1), &another_input));
938 MakeDataset(ctx, input, another_input, output);
939 }
940
941 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
942 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
943 "_DATASET_GRAPH_OUTPUT_NODE";
944
BackgroundWorker(Env * env,const char * name)945 BackgroundWorker::BackgroundWorker(Env* env, const char* name)
946 : env_(env), name_(name) {}
947
~BackgroundWorker()948 BackgroundWorker::~BackgroundWorker() {
949 {
950 mutex_lock l(mu_);
951 cancelled_ = true;
952 }
953 cond_var_.notify_one();
954 // Block until the background thread has terminated.
955 //
956 // NOTE(mrry): We explicitly free and join the thread here because
957 // `WorkerLoop()` uses other members of this object, and so we must join
958 // the thread before destroying them.
959 thread_.reset();
960 }
961
Schedule(std::function<void ()> work_item)962 void BackgroundWorker::Schedule(std::function<void()> work_item) {
963 {
964 mutex_lock l(mu_);
965 if (!thread_) {
966 thread_ = absl::WrapUnique(env_->StartThread(
967 {} /* thread_options */, name_, [this]() { WorkerLoop(); }));
968 }
969 work_queue_.push_back(std::move(work_item));
970 }
971 cond_var_.notify_one();
972 }
973
WorkerLoop()974 void BackgroundWorker::WorkerLoop() {
975 tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
976 while (true) {
977 std::function<void()> work_item = nullptr;
978 {
979 mutex_lock l(mu_);
980 while (!cancelled_ && work_queue_.empty()) {
981 cond_var_.wait(l);
982 }
983 if (cancelled_) {
984 return;
985 }
986 DCHECK(!work_queue_.empty());
987 work_item = std::move(work_queue_.front());
988 work_queue_.pop_front();
989 }
990 DCHECK(work_item != nullptr);
991 work_item();
992 }
993 }
994
995 namespace {
996 class RunnerImpl : public Runner {
997 public:
Run(const std::function<void ()> & f)998 void Run(const std::function<void()>& f) override {
999 tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
1000 f();
1001
1002 // NOTE: We invoke a virtual function to prevent `f` being tail-called, and
1003 // thus ensure that this function remains on the stack until after `f`
1004 // returns.
1005 PreventTailCall();
1006 }
1007
1008 private:
PreventTailCall()1009 virtual void PreventTailCall() {}
1010 };
1011 } // namespace
1012
1013 /* static */
get()1014 Runner* Runner::get() {
1015 static Runner* singleton = new RunnerImpl;
1016 return singleton;
1017 }
1018
1019 } // namespace data
1020 } // namespace tensorflow
1021