1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16
17 #include <unordered_map>
18
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/refcount.h"
32 #include "tensorflow/core/platform/resource.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/strcat.h"
35 #include "tensorflow/core/profiler/lib/traceme.h"
36 #include "tensorflow/core/public/version.h"
37
38 // On Windows, disable some macros that would break compile
39 #if defined(PLATFORM_WINDOWS)
40 #undef GetMessage
41 #endif
42
43 namespace tensorflow {
44 namespace data {
45 namespace {
46
get_dataset_op_registry_lock()47 static mutex* get_dataset_op_registry_lock() {
48 static mutex dataset_op_registry_lock(LINKER_INITIALIZED);
49 return &dataset_op_registry_lock;
50 }
51
get_dataset_op_registry()52 static std::unordered_set<string>* get_dataset_op_registry() {
53 static std::unordered_set<string>* names = new std::unordered_set<string>;
54 return names;
55 }
56
UniqueNodeName(const std::string & base)57 std::string UniqueNodeName(const std::string& base) {
58 static std::atomic<int64_t> counter(0);
59 return strings::StrCat(base, "/", counter.fetch_add(1));
60 }
61
62 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
63 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
64 // and the wrapper's copy constructor and destructor take care of managing the
65 // reference count.
66 //
67 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
68 // specification. In particular, we cannot currently serialize an arbitrary
69 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
70 // implemented.
71 class DatasetVariantWrapper {
72 public:
DatasetVariantWrapper()73 DatasetVariantWrapper() : dataset_(nullptr) {}
74
75 // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)76 explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
77
DatasetVariantWrapper(const DatasetVariantWrapper & other)78 DatasetVariantWrapper(const DatasetVariantWrapper& other)
79 : dataset_(other.dataset_) {
80 if (dataset_) dataset_->Ref();
81 }
82
operator =(DatasetVariantWrapper && other)83 DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
84 if (&other == this) return *this;
85 std::swap(dataset_, other.dataset_);
86 return *this;
87 }
88
89 DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
90
~DatasetVariantWrapper()91 ~DatasetVariantWrapper() {
92 if (dataset_) dataset_->Unref();
93 }
94
get() const95 DatasetBase* get() const { return dataset_; }
96
TypeName() const97 string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const98 string DebugString() const {
99 if (dataset_) {
100 return dataset_->DebugString();
101 } else {
102 return "<Uninitialized DatasetVariantWrapper>";
103 }
104 }
Encode(VariantTensorData * data) const105 void Encode(VariantTensorData* data) const {
106 LOG(ERROR) << "The Encode() method is not implemented for "
107 "DatasetVariantWrapper objects.";
108 }
Decode(const VariantTensorData & data)109 bool Decode(const VariantTensorData& data) {
110 LOG(ERROR) << "The Decode() method is not implemented for "
111 "DatasetVariantWrapper objects.";
112 return false;
113 }
114
115 private:
116 DatasetBase* dataset_; // Owns one reference.
117 };
118
119 const char kWrappedDatasetVariantTypeName[] =
120 "tensorflow::data::WrappedDatasetVariant";
121
122 class WrappedDatasetVariantWrapper {
123 public:
WrappedDatasetVariantWrapper()124 WrappedDatasetVariantWrapper() {}
125
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)126 explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
127 : ds_tensor_(ds_tensor) {}
128
get() const129 Tensor get() const { return ds_tensor_; }
130
TypeName() const131 string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
132
DebugString() const133 string DebugString() const {
134 return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
135 }
136
Encode(VariantTensorData * data) const137 void Encode(VariantTensorData* data) const {
138 *(data->add_tensors()) = ds_tensor_;
139 }
140
Decode(const VariantTensorData & data)141 bool Decode(const VariantTensorData& data) {
142 ds_tensor_ = data.tensors(0);
143 return true;
144 }
145
146 private:
147 Tensor ds_tensor_;
148 };
149
150 class WrapDatasetVariantOp : public OpKernel {
151 public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)152 explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
153
Compute(OpKernelContext * ctx)154 void Compute(OpKernelContext* ctx) override {
155 const Tensor& tensor = ctx->input(0);
156 OP_REQUIRES(ctx,
157 tensor.dtype() == DT_VARIANT &&
158 TensorShapeUtils::IsScalar(tensor.shape()),
159 errors::InvalidArgument(
160 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
161 DatasetBase* unused;
162 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
163 Tensor* output = nullptr;
164 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
165 output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
166 }
167 };
168
169 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
170 WrapDatasetVariantOp);
171 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
172 .HostMemory("input_handle")
173 .HostMemory("output_handle")
174 .Device(DEVICE_GPU),
175 WrapDatasetVariantOp);
176
177 class UnwrapDatasetVariantOp : public OpKernel {
178 public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)179 explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
180
Compute(OpKernelContext * ctx)181 void Compute(OpKernelContext* ctx) override {
182 const Tensor& tensor = ctx->input(0);
183 OP_REQUIRES(ctx,
184 tensor.dtype() == DT_VARIANT &&
185 TensorShapeUtils::IsScalar(tensor.shape()),
186 errors::InvalidArgument(
187 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
188 Variant variant = tensor.scalar<Variant>()();
189 const WrappedDatasetVariantWrapper* wrapper =
190 variant.get<WrappedDatasetVariantWrapper>();
191 OP_REQUIRES(ctx, wrapper != nullptr,
192 errors::InvalidArgument(
193 "Tensor must be a WrappedDataset variant object."));
194 Tensor ds_tensor = wrapper->get();
195 OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
196 }
197 };
198
199 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
200 UnwrapDatasetVariantOp);
201 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
202 .HostMemory("input_handle")
203 .HostMemory("output_handle")
204 .Device(DEVICE_GPU),
205 UnwrapDatasetVariantOp);
206
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)207 static Status WrappedDatasetVariantDeviceCopy(
208 const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
209 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
210 *to = WrappedDatasetVariantWrapper(from);
211 return OkStatus();
212 }
213
214 #define REGISTER_OPTIONAL_COPY(DIRECTION) \
215 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
216 WrappedDatasetVariantWrapper, DIRECTION, \
217 WrappedDatasetVariantDeviceCopy)
218
219 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
220 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
221 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
222
223 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
224 kWrappedDatasetVariantTypeName);
225
226 } // namespace
227
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)228 Status GraphDefBuilderWrapper::AddDataset(const DatasetBase* dataset,
229 const std::vector<Node*>& inputs,
230 Node** output) {
231 return AddDataset(dataset, inputs, {}, output);
232 }
233
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)234 Status GraphDefBuilderWrapper::AddDataset(
235 const DatasetBase* dataset, const std::vector<Node*>& inputs,
236 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
237 Node** output) {
238 std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
239 for (size_t i = 0; i < inputs.size(); i++) {
240 enumerated_inputs[i] = std::make_pair(i, inputs[i]);
241 }
242 return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
243 }
244
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)245 Status GraphDefBuilderWrapper::AddDataset(
246 const DatasetBase* dataset,
247 const std::vector<std::pair<size_t, Node*>>& inputs,
248 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
249 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
250 Node** output) {
251 return AddDataset(dataset, inputs, list_inputs, attrs,
252 /*use_dataset_name=*/false, output);
253 }
254
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,bool use_dataset_name,Node ** output)255 Status GraphDefBuilderWrapper::AddDataset(
256 const DatasetBase* dataset,
257 const std::vector<std::pair<size_t, Node*>>& inputs,
258 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
259 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
260 bool use_dataset_name, Node** output) {
261 auto& type_string = dataset->type_string();
262 auto opts = absl::make_unique<GraphDefBuilder::Options>(b_->opts());
263 // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
264 // attributes defined. It will be nice to have a consistent pattern.
265 bool has_output_types_attr = HasAttr(type_string, "output_types");
266 bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
267 if (has_output_shapes_attr) {
268 opts = absl::make_unique<GraphDefBuilder::Options>(
269 opts->WithAttr("output_shapes", dataset->output_shapes()));
270 }
271 if (has_output_types_attr) {
272 opts = absl::make_unique<GraphDefBuilder::Options>(
273 opts->WithAttr("output_types", dataset->output_dtypes()));
274 }
275 bool has_metadata_attr = HasAttr(type_string, "metadata");
276 if (has_metadata_attr) {
277 std::string serialized_metadata;
278 dataset->metadata().SerializeToString(&serialized_metadata);
279 opts = absl::make_unique<GraphDefBuilder::Options>(
280 opts->WithAttr("metadata", serialized_metadata));
281 }
282 for (const auto& attr : attrs) {
283 opts = absl::make_unique<GraphDefBuilder::Options>(
284 opts->WithAttr(attr.first, attr.second));
285 }
286 if (opts->HaveError()) {
287 return errors::Internal("AddDataset: Failed to build Options with error ",
288 opts->StatusToString());
289 }
290 NodeBuilder node_builder(
291 use_dataset_name ? dataset->node_name() : opts->GetNameForOp(type_string),
292 type_string, opts->op_registry());
293 {
294 size_t total_size = inputs.size() + list_inputs.size();
295 auto inputs_iter = inputs.begin();
296 auto list_inputs_iter = list_inputs.begin();
297 for (int i = 0; i < total_size; i++) {
298 if (inputs_iter != inputs.end() && inputs_iter->first == i) {
299 node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
300 inputs_iter++;
301 } else if (list_inputs_iter != list_inputs.end() &&
302 list_inputs_iter->first == i) {
303 std::vector<NodeBuilder::NodeOut> nodeout_inputs;
304 nodeout_inputs.reserve(list_inputs_iter->second.size());
305 for (Node* n : list_inputs_iter->second) {
306 nodeout_inputs.emplace_back(n);
307 }
308 node_builder.Input(nodeout_inputs);
309 list_inputs_iter++;
310 } else {
311 return errors::InvalidArgument("No input found for index ", i);
312 }
313 }
314 }
315 *output = opts->FinalizeBuilder(&node_builder);
316 if (*output == nullptr) {
317 return errors::Internal("AddDataset: Failed to build ", type_string,
318 " op with error ", opts->StatusToString());
319 }
320 return OkStatus();
321 }
322
AddFunction(SerializationContext * ctx,const string & function_name,const FunctionLibraryDefinition & lib_def)323 Status GraphDefBuilderWrapper::AddFunction(
324 SerializationContext* ctx, const string& function_name,
325 const FunctionLibraryDefinition& lib_def) {
326 if (b_->HasFunction(function_name)) {
327 VLOG(1) << "Function with name " << function_name << "already exists in"
328 << " the graph. It will not be added again.";
329 return OkStatus();
330 }
331 const FunctionDef* f_def = lib_def.Find(function_name);
332 if (f_def == nullptr) {
333 return errors::InvalidArgument("Unable to find FunctionDef for ",
334 function_name, " in the registry.");
335 }
336 FunctionDefLibrary def;
337 *def.add_function() = *f_def;
338 const string gradient_func = lib_def.FindGradient(function_name);
339 if (!gradient_func.empty()) {
340 GradientDef* g_def = def.add_gradient();
341 g_def->set_function_name(function_name);
342 g_def->set_gradient_func(gradient_func);
343 }
344 TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
345
346 // Recursively add functions in inputs of function_name.
347 for (const NodeDef& node_def : f_def->node_def()) {
348 const OpRegistrationData* op_reg_data = nullptr;
349 TF_RETURN_IF_ERROR(lib_def.LookUp(node_def.op(), &op_reg_data));
350 if (op_reg_data->is_function_op) {
351 TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name(), lib_def));
352 }
353 // Recursively add functions in attrs of this NodeDef.
354 for (const auto& pair : node_def.attr()) {
355 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second, lib_def));
356 }
357 }
358
359 // Recursively add functions in attrs of function_name.
360 for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
361 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def));
362 }
363 return OkStatus();
364 }
365
AddPlaceholderInternal(const Tensor & val,Node ** output)366 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
367 Node** output) {
368 *output = ops::SourceOp(
369 "Placeholder",
370 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
371 }
372
AddTensorInternal(const Tensor & val,Node ** output)373 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
374 Node** output) {
375 *output = ops::SourceOp(
376 "Const",
377 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
378 }
379
HasAttr(const string & name,const string & attr_name) const380 bool GraphDefBuilderWrapper::HasAttr(const string& name,
381 const string& attr_name) const {
382 const OpDef* op_def = nullptr;
383 Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
384 if (!s.ok() || op_def == nullptr) {
385 return false;
386 }
387 return HasAttr(op_def, attr_name);
388 }
389
GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext * ctx)390 int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx) {
391 thread::ThreadPool* thread_pool =
392 ctx->device()->tensorflow_device_thread_pool();
393 if (thread_pool) {
394 return thread_pool->NumThreads();
395 } else {
396 static const int32_t kDefaultRunnerThreadpoolSize = port::MaxParallelism();
397 return kDefaultRunnerThreadpoolSize;
398 }
399 }
400
InitializeBase(IteratorContext * ctx,const IteratorBase * parent)401 Status IteratorBase::InitializeBase(IteratorContext* ctx,
402 const IteratorBase* parent) {
403 parent_ = parent;
404 id_ =
405 Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
406 if (parent_) {
407 parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()),
408 reinterpret_cast<uint64>(parent_));
409 }
410 if (const auto& model = ctx->model()) {
411 auto factory = [ctx, this](model::Node::Args args) {
412 return CreateNode(ctx, std::move(args));
413 };
414 model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
415 cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
416 }
417 return OkStatus();
418 }
419
GetAllocatedBytes(const std::vector<Tensor> & element)420 int64_t GetAllocatedBytes(const std::vector<Tensor>& element) {
421 int64_t allocated_bytes = 0;
422 DatasetBase* dataset;
423 for (auto& tensor : element) {
424 if (tensor.dtype() == DT_VARIANT &&
425 GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
426 allocated_bytes += dataset->AllocatedBytes();
427 } else {
428 allocated_bytes += tensor.AllocatedBytes();
429 }
430 }
431 return allocated_bytes;
432 }
433
GetTotalBytes(const std::vector<Tensor> & element)434 int64_t GetTotalBytes(const std::vector<Tensor>& element) {
435 int64_t total_bytes = 0;
436 DatasetBase* dataset;
437 for (auto& tensor : element) {
438 if (tensor.dtype() == DT_VARIANT &&
439 GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
440 total_bytes += dataset->TotalBytes();
441 } else {
442 total_bytes += tensor.TotalBytes();
443 }
444 }
445 return total_bytes;
446 }
447
FullName(const std::string & prefix,const std::string & name)448 std::string FullName(const std::string& prefix, const std::string& name) {
449 if (str_util::StrContains(name, kColon)) {
450 LOG(ERROR) << name << " should not contain " << kColon;
451 }
452
453 return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
454 }
455
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)456 Status GetDatasetFromVariantTensor(const Tensor& tensor,
457 DatasetBase** out_dataset) {
458 if (!(tensor.dtype() == DT_VARIANT &&
459 TensorShapeUtils::IsScalar(tensor.shape()))) {
460 return errors::InvalidArgument(
461 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
462 }
463 const Variant& variant = tensor.scalar<Variant>()();
464 const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
465 if (wrapper == nullptr) {
466 return errors::InvalidArgument("Tensor must be a Dataset object.");
467 }
468 *out_dataset = wrapper->get();
469 if (*out_dataset == nullptr) {
470 return errors::Internal("Read uninitialized Dataset variant.");
471 }
472 return OkStatus();
473 }
474
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)475 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
476 if (!(tensor->dtype() == DT_VARIANT &&
477 TensorShapeUtils::IsScalar(tensor->shape()))) {
478 return errors::InvalidArgument(
479 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
480 }
481 tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
482 return OkStatus();
483 }
484
485 namespace internal {
486
487 #define WARN_PROTO_FIELD_CONFLICT(reflection, field, field_type, src, dst) \
488 { \
489 auto source_value = reflection->Get##field_type(src, field); \
490 auto destination_value = reflection->Get##field_type(*dst, field); \
491 if (source_value != destination_value) { \
492 LOG(WARNING) << "Changing the value of option field " << field->name() \
493 << " from " << destination_value << " to " << source_value; \
494 } \
495 }
496
497 #define WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst) \
498 { \
499 auto source_value = reflection->GetEnum(src, field); \
500 auto destination_value = reflection->GetEnum(*dst, field); \
501 if (source_value != destination_value) { \
502 LOG(WARNING) << "Changing the value of option enum field " \
503 << field->name() << " from " \
504 << destination_value->full_name() << " to " \
505 << source_value->full_name(); \
506 } \
507 }
508
WarnProtoConflicts(const protobuf::Message & src,protobuf::Message * dst)509 void WarnProtoConflicts(const protobuf::Message& src, protobuf::Message* dst) {
510 std::vector<const protobuf::FieldDescriptor*> set_src;
511 std::vector<const protobuf::FieldDescriptor*> set_dst;
512 const protobuf::Reflection* reflection = src.GetReflection();
513 reflection->ListFields(src, &set_src);
514 reflection->ListFields(*dst, &set_dst);
515 std::sort(set_src.begin(), set_src.end());
516 std::sort(set_dst.begin(), set_dst.end());
517
518 std::vector<const protobuf::FieldDescriptor*> in_both;
519 std::set_intersection(set_src.begin(), set_src.end(), set_dst.begin(),
520 set_dst.end(), std::back_inserter(in_both));
521
522 for (auto field : in_both) {
523 if (field->type() == protobuf::FieldDescriptor::TYPE_MESSAGE) {
524 WarnProtoConflicts(reflection->GetMessage(src, field),
525 reflection->MutableMessage(dst, field));
526 } else {
527 switch (field->cpp_type()) {
528 case protobuf::FieldDescriptor::CPPTYPE_INT32:
529 WARN_PROTO_FIELD_CONFLICT(reflection, field, Int32, src, dst);
530 break;
531 case protobuf::FieldDescriptor::CPPTYPE_INT64:
532 WARN_PROTO_FIELD_CONFLICT(reflection, field, Int64, src, dst);
533 break;
534 case protobuf::FieldDescriptor::CPPTYPE_UINT32:
535 WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt32, src, dst);
536 break;
537 case protobuf::FieldDescriptor::CPPTYPE_UINT64:
538 WARN_PROTO_FIELD_CONFLICT(reflection, field, UInt64, src, dst);
539 break;
540 case protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
541 WARN_PROTO_FIELD_CONFLICT(reflection, field, Double, src, dst);
542 break;
543 case protobuf::FieldDescriptor::CPPTYPE_FLOAT:
544 WARN_PROTO_FIELD_CONFLICT(reflection, field, Float, src, dst);
545 break;
546 case protobuf::FieldDescriptor::CPPTYPE_BOOL:
547 WARN_PROTO_FIELD_CONFLICT(reflection, field, Bool, src, dst);
548 break;
549 case protobuf::FieldDescriptor::CPPTYPE_ENUM:
550 WARN_PROTO_ENUM_FIELD_CONFLICT(reflection, field, src, dst);
551 break;
552 default: {
553 LOG(ERROR) << "Unrecognized proto type for field "
554 << field->full_name();
555 }
556 }
557 }
558 }
559 }
560
561 #undef WARN_PROTO_ENUM_FIELD_CONFLICT
562 #undef WARN_PROTO_FIELD_CONFLICT
563
MergeOptions(const protobuf::Message & source,protobuf::Message * destination)564 void MergeOptions(const protobuf::Message& source,
565 protobuf::Message* destination) {
566 WarnProtoConflicts(source, destination);
567 destination->MergeFrom(source);
568 }
569
MergeOptions(const protobuf::MessageLite & source,protobuf::MessageLite * destination)570 void MergeOptions(const protobuf::MessageLite& source,
571 protobuf::MessageLite* destination) {
572 destination->CheckTypeAndMergeFrom(source);
573 }
574
575 } // namespace internal
576
Initialize(const Metadata & metadata)577 void DatasetBase::Initialize(const Metadata& metadata) {
578 Status s = ComputeNumSources();
579 if (!s.ok()) {
580 LOG(ERROR) << s;
581 }
582 s = MergeOptionsFromInputs();
583 if (!s.ok()) {
584 LOG(ERROR) << s;
585 }
586 metadata_ = metadata;
587 if (metadata_.name() == "") {
588 static std::atomic<int64_t> id_counter(0);
589 *metadata_.mutable_name() =
590 strings::StrCat(type_string(), ":", id_counter.fetch_add(1));
591 }
592 }
593
ComputeNumSources()594 Status DatasetBase::ComputeNumSources() {
595 std::vector<const DatasetBase*> inputs;
596 Status s = InputDatasets(&inputs);
597 if (errors::IsUnimplemented(s)) {
598 return errors::Unimplemented(
599 "Cannot compute input sources for dataset of type ", type_string(),
600 ", because the dataset does not implement `InputDatasets`.");
601 }
602 if (num_sources_ >= 0) {
603 // Already computed.
604 return OkStatus();
605 }
606 num_sources_ = 0;
607 if (inputs.empty()) {
608 num_sources_ = 1;
609 return OkStatus();
610 }
611 for (const auto& input : inputs) {
612 if (input->num_sources() < 0) {
613 return errors::FailedPrecondition(
614 "Cannot compute input sources for dataset of type ", type_string(),
615 ", because sources could not be computed for input dataset of type ",
616 input->type_string());
617 }
618 num_sources_ += input->num_sources();
619 }
620 return OkStatus();
621 }
622
CheckRandomAccessCompatible(const int64 index) const623 Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const {
624 CardinalityOptions options;
625 options.set_compute_level(CardinalityOptions::CARDINALITY_COMPUTE_MODERATE);
626 int64 cardinality = Cardinality(options);
627 if (cardinality == kInfiniteCardinality ||
628 cardinality == kUnknownCardinality) {
629 return tensorflow::errors::FailedPrecondition(
630 "Dataset of type ", this->DebugString(), " has ",
631 cardinality == kInfiniteCardinality ? "infinite" : "unknown",
632 " cardinality, which does not support random access.");
633 }
634 if (index < 0 || index >= cardinality) {
635 return errors::OutOfRange("Index out of range [0, ", cardinality,
636 "):", index);
637 }
638 return OkStatus();
639 }
640
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const641 Status DatasetBase::Get(OpKernelContext* ctx, int64 index,
642 std::vector<Tensor>* out_tensors) const {
643 return errors::Unimplemented(
644 "Random access is not implemented for this dataset.");
645 }
646
Finalize(OpKernelContext * ctx,std::function<StatusOr<core::RefCountPtr<DatasetBase>> ()> make_finalized_dataset) const647 StatusOr<DatasetBase*> DatasetBase::Finalize(
648 OpKernelContext* ctx,
649 std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
650 make_finalized_dataset) const {
651 mutex_lock l(mu_);
652 if (!finalized_dataset_) {
653 TF_ASSIGN_OR_RETURN(finalized_dataset_, make_finalized_dataset());
654 }
655 return finalized_dataset_.get();
656 }
657
MergeOptionsFromInputs()658 Status DatasetBase::MergeOptionsFromInputs() {
659 std::vector<const DatasetBase*> inputs;
660 Status s = InputDatasets(&inputs);
661 if (errors::IsUnimplemented(s)) {
662 return errors::Unimplemented(
663 "Cannot merge options for dataset of type ", type_string(),
664 ", because the dataset does not implement `InputDatasets`.");
665 }
666 if (inputs.empty()) {
667 return OkStatus();
668 }
669 // Merge options from inputs sequentially before merging options from dataset.
670 // Since the last options merged takes precedence, the options that may be set
671 // for the current dataset through OptionsDataset takes precedence over those
672 // set on the input datasets.
673 Options merged_options = inputs[0]->options_;
674 for (int i = 1; i < inputs.size(); ++i) {
675 internal::MergeOptions(inputs[i]->options_, &merged_options);
676 }
677 internal::MergeOptions(options_, &merged_options);
678 options_ = merged_options;
679 return OkStatus();
680 }
681
MakeIterator(IteratorContext * ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator) const682 Status DatasetBase::MakeIterator(
683 IteratorContext* ctx, const IteratorBase* parent,
684 const string& output_prefix,
685 std::unique_ptr<IteratorBase>* iterator) const {
686 if (type_string() == "OptionsDataset" || type_string() == "FinalizeDataset") {
687 std::vector<const DatasetBase*> inputs;
688 Status s = InputDatasets(&inputs);
689 return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator);
690 }
691 profiler::TraceMe traceme(
692 [&] {
693 return profiler::TraceMeEncode(
694 strings::StrCat("MakeIterator::", type_string()), {});
695 },
696 profiler::TraceMeLevel::kInfo);
697 *iterator = MakeIteratorInternal(output_prefix);
698 Status s = (*iterator)->InitializeBase(ctx, parent);
699 if (s.ok()) {
700 s.Update((*iterator)->Initialize(ctx));
701 }
702 if (!s.ok()) {
703 // Reset the iterator to avoid returning an uninitialized iterator.
704 iterator->reset();
705 }
706 return s;
707 }
708
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const709 Status DatasetBase::MakeSplitProviders(
710 std::vector<std::unique_ptr<SplitProvider>>* split_providers) const {
711 std::vector<const DatasetBase*> inputs;
712 Status s = InputDatasets(&inputs);
713 if (errors::IsUnimplemented(s)) {
714 return errors::Unimplemented(
715 "Cannot create split providers for dataset of type ", type_string(),
716 ", because the dataset implements neither `InputDatasets` nor "
717 "`MakeSplitProvider`.");
718 }
719 if (inputs.size() != 1) {
720 return errors::Unimplemented(
721 "Cannot create split providers for dataset of type ", type_string(),
722 ", because the dataset is not unary (instead having arity ",
723 inputs.size(),
724 "), and no custom implementation of `MakeSplitProvider` is defined.");
725 }
726 return inputs[0]->MakeSplitProviders(split_providers);
727 }
728
Cardinality() const729 int64_t DatasetBase::Cardinality() const {
730 mutex_lock l(cardinality_mu_);
731 if (cardinality_ == kUnknownCardinality) {
732 cardinality_ = CardinalityInternal();
733 }
734 return cardinality_;
735 }
736
Cardinality(CardinalityOptions options) const737 int64_t DatasetBase::Cardinality(CardinalityOptions options) const {
738 mutex_lock l(cardinality_mu_);
739 if (cardinality_ == kUnknownCardinality) {
740 cardinality_ = CardinalityInternal(options);
741 }
742 return cardinality_;
743 }
744
InputDatasets(std::vector<const DatasetBase * > * inputs) const745 Status DatasetBase::InputDatasets(
746 std::vector<const DatasetBase*>* inputs) const {
747 return errors::Unimplemented("InputDatasets not implemented for ",
748 type_string());
749 }
750
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)751 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
752 SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
753 Status status = dataset->AsGraphDefInternal(ctx, this, output);
754 if (ctx->is_graph_rewrite()) {
755 if (status.ok()) {
756 // Record cardinality in an unregistered attributes so that rewrites have
757 // this information.
758 (*output)->AddAttr(kCardinalityAttrForRewrite, dataset->Cardinality());
759 } else if (errors::IsUnimplemented(status)) {
760 Tensor t(DT_VARIANT, TensorShape({}));
761 // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
762 // increment the refcount of `dataset` here to retain ownership.
763 dataset->Ref();
764 TF_RETURN_IF_ERROR(
765 StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
766 TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
767 DCHECK_NE(ctx->input_list(), nullptr);
768 ctx->input_list()->emplace_back((*output)->name(), std::move(t));
769 LOG_EVERY_N_SEC(WARNING, 30)
770 << "Input of " << dataset->DebugString()
771 << " will not be optimized because the dataset does not implement "
772 "the "
773 "AsGraphDefInternal() method needed to apply optimizations.";
774 return OkStatus();
775 }
776 }
777 return status;
778 }
779
AddDatasetOrTensor(SerializationContext * ctx,const Tensor & t,Node ** output)780 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor(
781 SerializationContext* ctx, const Tensor& t, Node** output) {
782 if (t.dtype() == DT_VARIANT) {
783 // If the input tensor is a variant, it may represent a multi-dimensional
784 // array of datasets. We attempt to decode each dataset so that we can use
785 // their custom serialization logic and combine the result of their
786 // individual serializations using the `Pack` operation.
787 //
788 // If this fails, we fallback to using its Variant::Encode() based
789 // serialization.
790 Status s = AddDatasetOrTensorHelper(ctx, t, output);
791 if (s.ok()) {
792 return s;
793 }
794 }
795 if (t.dtype() == DT_RESOURCE && !ctx->is_graph_rewrite()) {
796 Status s = AddResourceHelper(ctx, t, output);
797 if (!errors::IsUnimplemented(s)) {
798 // Fall through to AddTensor if AsGraphDef is not implemented for this
799 // resource.
800 return s;
801 }
802 }
803 return AddTensor(t, output);
804 }
805
AddIdentity(SerializationContext * ctx,const std::string & name_prefix,Node ** input,Node ** output)806 Status DatasetBase::DatasetGraphDefBuilder::AddIdentity(
807 SerializationContext* ctx, const std::string& name_prefix, Node** input,
808 Node** output) {
809 *output =
810 ops::UnaryOp("Identity", *input,
811 builder()->opts().WithName(UniqueNodeName(name_prefix)));
812 return OkStatus();
813 }
814
AddDatasetOrTensorHelper(SerializationContext * ctx,const Tensor & t,Node ** output)815 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper(
816 SerializationContext* ctx, const Tensor& t, Node** output) {
817 if (t.dims() == 0) {
818 DatasetBase* dataset;
819 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(t, &dataset));
820 return AddInputDataset(ctx, dataset, output);
821 }
822 std::vector<NodeBuilder::NodeOut> nodes;
823 for (int i = 0; i < t.dim_size(0); ++i) {
824 Node* node;
825 TF_RETURN_IF_ERROR(AddDatasetOrTensorHelper(ctx, t.SubSlice(i), &node));
826 nodes.emplace_back(node);
827 }
828 auto op_name = "Pack";
829 auto opts = builder()->opts();
830 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
831 opts.op_registry());
832 node_builder.Input(std::move(nodes));
833 *output = opts.FinalizeBuilder(&node_builder);
834 return OkStatus();
835 }
836
AddResourceHelper(SerializationContext * ctx,const Tensor & t,Node ** output)837 Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper(
838 SerializationContext* ctx, const Tensor& t, Node** output) {
839 const ResourceHandle& handle = t.flat<ResourceHandle>()(0);
840 if (ctx->device_name() != handle.device()) {
841 return errors::InvalidArgument("Trying to access resource ", handle.name(),
842 " located in device ", handle.device(),
843 " from device ", ctx->device_name());
844 }
845 ResourceBase* resource;
846 TF_RETURN_IF_ERROR(ctx->resource_mgr()->Lookup(handle, &resource));
847 core::ScopedUnref unref(resource);
848 return resource->AsGraphDef(builder(), output);
849 }
850
DatasetBaseIterator(const BaseParams & params)851 DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params)
852 : params_(params) {
853 params_.dataset->Ref();
854 VLOG(2) << prefix() << " constructor";
855 strings::StrAppend(&traceme_metadata_, "name=", dataset()->metadata().name());
856 strings::StrAppend(&traceme_metadata_, ",shapes=");
857 auto& shapes = output_shapes();
858 for (int i = 0; i < shapes.size(); ++i) {
859 if (i > 0) {
860 strings::StrAppend(&traceme_metadata_, " ");
861 }
862 strings::StrAppend(&traceme_metadata_, shapes.at(i).DebugString());
863 }
864 strings::StrAppend(&traceme_metadata_, ",types=");
865 auto& types = output_dtypes();
866 for (int i = 0; i < types.size(); ++i) {
867 if (i > 0) {
868 strings::StrAppend(&traceme_metadata_, " ");
869 }
870 strings::StrAppend(&traceme_metadata_, DataTypeString(types.at(i)));
871 }
872 }
873
~DatasetBaseIterator()874 DatasetBaseIterator::~DatasetBaseIterator() {
875 VLOG(2) << prefix() << " destructor";
876 params_.dataset->Unref();
877 }
878
BuildTraceMeName()879 string DatasetBaseIterator::BuildTraceMeName() {
880 string result =
881 strings::StrCat(params_.prefix, "#", traceme_metadata_, ",id=", id_);
882 if (parent_) {
883 strings::StrAppend(&result, ",parent_id=", parent_id_);
884 }
885 TraceMeMetadata metadata = GetTraceMeMetadata();
886 for (const auto& pair : metadata) {
887 strings::StrAppend(&result, ",", pair.first, "=", pair.second);
888 }
889 strings::StrAppend(&result, "#");
890 return result;
891 }
892
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)893 Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
894 std::vector<Tensor>* out_tensors,
895 bool* end_of_sequence) {
896 profiler::TraceMe activity([&] { return BuildTraceMeName(); },
897 profiler::TraceMeLevel::kInfo);
898 DVLOG(3) << prefix() << " GetNext enter";
899 auto model = ctx->model();
900 if (collect_resource_usage(ctx)) {
901 int64_t now_nanos = EnvTime::NowNanos();
902 auto output = node_->output();
903 if (output) {
904 output->record_stop(now_nanos);
905 }
906 node_->record_start(now_nanos);
907 }
908 out_tensors->clear();
909 Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
910 if (TF_PREDICT_TRUE(s.ok())) {
911 if (TF_PREDICT_TRUE(!*end_of_sequence)) {
912 DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size());
913 RecordElement(ctx, out_tensors);
914 } else {
915 out_tensors->clear();
916 }
917 }
918 if (collect_resource_usage(ctx)) {
919 int64_t now_nanos = EnvTime::NowNanos();
920 node_->record_stop(now_nanos);
921 auto output = node_->output();
922 if (output) {
923 output->record_start(now_nanos);
924 }
925 }
926 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
927 s = errors::Internal("Iterator \"", params_.prefix,
928 "\" returned `OutOfRange`. This indicates an "
929 "implementation error as `OutOfRange` errors are not "
930 "expected to be returned here. Original message: ",
931 s.error_message());
932 LOG(ERROR) << s;
933 }
934 DVLOG(3) << prefix() << " GetNext exit";
935 return s;
936 }
937
Skip(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)938 Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip,
939 bool* end_of_sequence, int* num_skipped) {
940 profiler::TraceMe activity([&] { return BuildTraceMeName(); },
941 profiler::TraceMeLevel::kInfo);
942 DVLOG(3) << prefix() << " Skip enter";
943 auto model = ctx->model();
944 if (collect_resource_usage(ctx)) {
945 int64_t now_nanos = EnvTime::NowNanos();
946 auto output = node_->output();
947 if (output) {
948 output->record_stop(now_nanos);
949 }
950 node_->record_start(now_nanos);
951 }
952 Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped);
953 if (collect_resource_usage(ctx)) {
954 int64_t now_nanos = EnvTime::NowNanos();
955 node_->record_stop(now_nanos);
956 auto output = node_->output();
957 if (output) {
958 output->record_start(now_nanos);
959 }
960 }
961 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
962 s = errors::Internal("Iterator \"", params_.prefix,
963 "\" returned `OutOfRange`. This indicates an "
964 "implementation error as `OutOfRange` errors are not "
965 "expected to be returned here. Original message: ",
966 s.error_message());
967 LOG(ERROR) << s;
968 }
969 DVLOG(3) << prefix() << " Skip exit";
970 return s;
971 }
972
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)973 Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip,
974 bool* end_of_sequence,
975 int* num_skipped) {
976 *num_skipped = 0;
977 for (int i = 0; i < num_to_skip; ++i) {
978 std::vector<Tensor> out_tensors;
979 TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence));
980 if (*end_of_sequence) {
981 return OkStatus();
982 }
983 // RecordElement is used to count the number of element computed and
984 // help calculate the CPU time spent on a given iterator to do the
985 // autotuning.
986 // Here we only call RecordElement in the default implementation of
987 // SkipInternal (which trivially calls GetNextInternal) and assume
988 // that the overridden SkipInternal in the derived class will have
989 // negligible cost compare to its GetNextInternal.
990 RecordElement(ctx, &out_tensors);
991 (*num_skipped)++;
992 }
993 return OkStatus();
994 }
995
Compute(OpKernelContext * ctx)996 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
997 DatasetBase* dataset = nullptr;
998 MakeDataset(ctx, &dataset);
999 if (ctx->status().ok()) {
1000 Tensor* output = nullptr;
1001 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
1002 OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
1003 dataset->Initialize(metadata_);
1004 }
1005 }
1006
TraceString(const OpKernelContext & ctx,bool verbose) const1007 string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
1008 bool verbose) const {
1009 return profiler::TraceMeOp(name_view(), type_string_view());
1010 }
1011
1012 // static
IsDatasetOp(const OpDef & op_def)1013 bool DatasetOpKernel::IsDatasetOp(const OpDef& op_def) {
1014 if (op_def.output_arg_size() != 1) return false;
1015 if (op_def.output_arg(0).type() != DT_VARIANT) return false;
1016 absl::string_view op_name = op_def.name();
1017 if (op_name == "DatasetFromGraph") return true;
1018 if (absl::EndsWith(op_name, "Dataset")) return true;
1019 // Check if the suffix matches "DatasetV[0-9]+".
1020 size_t index = op_name.length() - 1;
1021 while (index >= 0 && isdigit(op_name[index])) {
1022 index--;
1023 }
1024 constexpr absl::string_view kDatasetPrefix = "DatasetV";
1025 constexpr absl::string_view::size_type kPrefixLength = kDatasetPrefix.size();
1026 if (index < kPrefixLength - 1 || index == op_name.length() - 1) return false;
1027 return op_name.substr(index - kPrefixLength + 1, kPrefixLength) ==
1028 kDatasetPrefix;
1029 }
1030
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)1031 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
1032 DatasetBase** output) {
1033 DatasetBase* input;
1034 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
1035 MakeDataset(ctx, input, output);
1036 }
1037
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)1038 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
1039 DatasetBase** output) {
1040 DatasetBase* input;
1041 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
1042 DatasetBase* another_input;
1043 OP_REQUIRES_OK(ctx,
1044 GetDatasetFromVariantTensor(ctx->input(1), &another_input));
1045 MakeDataset(ctx, input, another_input, output);
1046 }
1047
1048 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
1049 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
1050 "_DATASET_GRAPH_OUTPUT_NODE";
1051
BackgroundWorker(Env * env,const char * name)1052 BackgroundWorker::BackgroundWorker(Env* env, const char* name)
1053 : env_(env), name_(name) {}
1054
~BackgroundWorker()1055 BackgroundWorker::~BackgroundWorker() {
1056 {
1057 mutex_lock l(mu_);
1058 cancelled_ = true;
1059 }
1060 cond_var_.notify_one();
1061 // Block until the background thread has terminated.
1062 //
1063 // NOTE(mrry): We explicitly free and join the thread here because
1064 // `WorkerLoop()` uses other members of this object, and so we must join
1065 // the thread before destroying them.
1066 thread_.reset();
1067 }
1068
Schedule(std::function<void ()> work_item)1069 void BackgroundWorker::Schedule(std::function<void()> work_item) {
1070 {
1071 mutex_lock l(mu_);
1072 if (!thread_) {
1073 thread_ = absl::WrapUnique(env_->StartThread(
1074 {} /* thread_options */, name_, [this]() { WorkerLoop(); }));
1075 }
1076 work_queue_.push_back(std::move(work_item));
1077 }
1078 cond_var_.notify_one();
1079 }
1080
WorkerLoop()1081 void BackgroundWorker::WorkerLoop() {
1082 tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
1083 while (true) {
1084 std::function<void()> work_item = nullptr;
1085 {
1086 mutex_lock l(mu_);
1087 while (!cancelled_ && work_queue_.empty()) {
1088 cond_var_.wait(l);
1089 }
1090 if (cancelled_) {
1091 return;
1092 }
1093 DCHECK(!work_queue_.empty());
1094 work_item = std::move(work_queue_.front());
1095 work_queue_.pop_front();
1096 }
1097 DCHECK(work_item != nullptr);
1098 work_item();
1099 }
1100 }
1101
1102 namespace {
1103 class RunnerImpl : public Runner {
1104 public:
Run(const std::function<void ()> & f)1105 void Run(const std::function<void()>& f) override {
1106 tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
1107 f();
1108
1109 // NOTE: We invoke a virtual function to prevent `f` being tail-called, and
1110 // thus ensure that this function remains on the stack until after `f`
1111 // returns.
1112 PreventTailCall();
1113 }
1114
1115 private:
PreventTailCall()1116 virtual void PreventTailCall() {}
1117 };
1118 } // namespace
1119
1120 /* static */
get()1121 Runner* Runner::get() {
1122 static Runner* singleton = new RunnerImpl;
1123 return singleton;
1124 }
1125
1126 } // namespace data
1127 } // namespace tensorflow
1128