1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16
17 #include <unordered_map>
18
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/variant_encode_decode.h"
22 #include "tensorflow/core/framework/variant_op_registry.h"
23 #include "tensorflow/core/graph/graph_def_builder.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/resource.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30
31 namespace tensorflow {
32 namespace data {
33 namespace {
34
get_dataset_op_registry_lock()35 static mutex* get_dataset_op_registry_lock() {
36 static mutex dataset_op_registry_lock(LINKER_INITIALIZED);
37 return &dataset_op_registry_lock;
38 }
39
get_dataset_op_registry()40 static std::unordered_set<string>* get_dataset_op_registry() {
41 static std::unordered_set<string>* names = new std::unordered_set<string>;
42 return names;
43 }
44
45 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
46 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
47 // and the wrapper's copy constructor and destructor take care of managing the
48 // reference count.
49 //
50 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
51 // specification. In particular, we cannot currently serialize an arbitrary
52 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
53 // implemented.
54 class DatasetVariantWrapper {
55 public:
DatasetVariantWrapper()56 DatasetVariantWrapper() : dataset_(nullptr) {}
57
58 // Transfers ownership of `dataset` to `*this`.
DatasetVariantWrapper(DatasetBase * dataset)59 explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
60
DatasetVariantWrapper(const DatasetVariantWrapper & other)61 DatasetVariantWrapper(const DatasetVariantWrapper& other)
62 : dataset_(other.dataset_) {
63 if (dataset_) dataset_->Ref();
64 }
65
operator =(DatasetVariantWrapper && other)66 DatasetVariantWrapper& operator=(DatasetVariantWrapper&& other) {
67 if (&other == this) return *this;
68 std::swap(dataset_, other.dataset_);
69 return *this;
70 }
71
72 DatasetVariantWrapper& operator=(const DatasetVariantWrapper& other) = delete;
73
~DatasetVariantWrapper()74 ~DatasetVariantWrapper() {
75 if (dataset_) dataset_->Unref();
76 }
77
get() const78 DatasetBase* get() const { return dataset_; }
79
TypeName() const80 string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
DebugString() const81 string DebugString() const {
82 if (dataset_) {
83 return dataset_->DebugString();
84 } else {
85 return "<Uninitialized DatasetVariantWrapper>";
86 }
87 }
Encode(VariantTensorData * data) const88 void Encode(VariantTensorData* data) const {
89 LOG(ERROR) << "The Encode() method is not implemented for "
90 "DatasetVariantWrapper objects.";
91 }
Decode(const VariantTensorData & data)92 bool Decode(const VariantTensorData& data) {
93 LOG(ERROR) << "The Decode() method is not implemented for "
94 "DatasetVariantWrapper objects.";
95 return false;
96 }
97
98 private:
99 DatasetBase* dataset_; // Owns one reference.
100 };
101
102 const char kWrappedDatasetVariantTypeName[] =
103 "tensorflow::data::WrappedDatasetVariant";
104
105 class WrappedDatasetVariantWrapper {
106 public:
WrappedDatasetVariantWrapper()107 WrappedDatasetVariantWrapper() {}
108
WrappedDatasetVariantWrapper(const Tensor & ds_tensor)109 explicit WrappedDatasetVariantWrapper(const Tensor& ds_tensor)
110 : ds_tensor_(ds_tensor) {}
111
get() const112 Tensor get() const { return ds_tensor_; }
113
TypeName() const114 string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; }
115
DebugString() const116 string DebugString() const {
117 return "tensorflow::WrappedDatasetVariantWrapper::DebugString";
118 }
119
Encode(VariantTensorData * data) const120 void Encode(VariantTensorData* data) const {
121 *(data->add_tensors()) = ds_tensor_;
122 }
123
Decode(const VariantTensorData & data)124 bool Decode(const VariantTensorData& data) {
125 ds_tensor_ = data.tensors(0);
126 return true;
127 }
128
129 private:
130 Tensor ds_tensor_;
131 };
132
133 class WrapDatasetVariantOp : public OpKernel {
134 public:
WrapDatasetVariantOp(OpKernelConstruction * ctx)135 explicit WrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
136
Compute(OpKernelContext * ctx)137 void Compute(OpKernelContext* ctx) override {
138 const Tensor& tensor = ctx->input(0);
139 OP_REQUIRES(ctx,
140 tensor.dtype() == DT_VARIANT &&
141 TensorShapeUtils::IsScalar(tensor.shape()),
142 errors::InvalidArgument(
143 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
144 DatasetBase* unused;
145 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(tensor, &unused));
146 Tensor* output = nullptr;
147 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
148 output->scalar<Variant>()() = WrappedDatasetVariantWrapper(tensor);
149 }
150 };
151
152 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant").Device(DEVICE_CPU),
153 WrapDatasetVariantOp);
154 REGISTER_KERNEL_BUILDER(Name("WrapDatasetVariant")
155 .HostMemory("input_handle")
156 .HostMemory("output_handle")
157 .Device(DEVICE_GPU),
158 WrapDatasetVariantOp);
159
160 class UnwrapDatasetVariantOp : public OpKernel {
161 public:
UnwrapDatasetVariantOp(OpKernelConstruction * ctx)162 explicit UnwrapDatasetVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
163
Compute(OpKernelContext * ctx)164 void Compute(OpKernelContext* ctx) override {
165 const Tensor& tensor = ctx->input(0);
166 OP_REQUIRES(ctx,
167 tensor.dtype() == DT_VARIANT &&
168 TensorShapeUtils::IsScalar(tensor.shape()),
169 errors::InvalidArgument(
170 "Dataset tensor must be a scalar of dtype DT_VARIANT."));
171 Variant variant = tensor.scalar<Variant>()();
172 const WrappedDatasetVariantWrapper* wrapper =
173 variant.get<WrappedDatasetVariantWrapper>();
174 OP_REQUIRES(ctx, wrapper != nullptr,
175 errors::InvalidArgument(
176 "Tensor must be a WrappedDataset variant object."));
177 Tensor ds_tensor = wrapper->get();
178 OP_REQUIRES_OK(ctx, ctx->set_output("output_handle", ds_tensor));
179 }
180 };
181
182 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant").Device(DEVICE_CPU),
183 UnwrapDatasetVariantOp);
184 REGISTER_KERNEL_BUILDER(Name("UnwrapDatasetVariant")
185 .HostMemory("input_handle")
186 .HostMemory("output_handle")
187 .Device(DEVICE_GPU),
188 UnwrapDatasetVariantOp);
189
WrappedDatasetVariantDeviceCopy(const WrappedDatasetVariantWrapper & from,WrappedDatasetVariantWrapper * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)190 static Status WrappedDatasetVariantDeviceCopy(
191 const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to,
192 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
193 *to = WrappedDatasetVariantWrapper(from);
194 return Status::OK();
195 }
196
197 #define REGISTER_OPTIONAL_COPY(DIRECTION) \
198 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
199 WrappedDatasetVariantWrapper, DIRECTION, \
200 WrappedDatasetVariantDeviceCopy)
201
202 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
203 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
204 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
205
206 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(WrappedDatasetVariantWrapper,
207 kWrappedDatasetVariantTypeName);
208
209 } // namespace
210
AddDataset(const DatasetBase * dataset,const std::vector<std::pair<size_t,Node * >> & inputs,const std::vector<std::pair<size_t,gtl::ArraySlice<Node * >>> & list_inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)211 Status GraphDefBuilderWrapper::AddDataset(
212 const DatasetBase* dataset,
213 const std::vector<std::pair<size_t, Node*>>& inputs,
214 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
215 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
216 Node** output) {
217 const string& type_string = dataset->type_string();
218 std::unique_ptr<const GraphDefBuilder::Options> opts(
219 new GraphDefBuilder::Options(b_->opts()));
220 // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
221 // attributes defined. It will be nice to have a consistent pattern.
222 bool has_output_types_attr = HasAttr(type_string, "output_types");
223 bool has_output_shapes_attr = HasAttr(type_string, "output_shapes");
224 if (has_output_shapes_attr) {
225 opts.reset(new GraphDefBuilder::Options(
226 opts->WithAttr("output_shapes", dataset->output_shapes())));
227 }
228 if (has_output_types_attr) {
229 opts.reset(new GraphDefBuilder::Options(
230 opts->WithAttr("output_types", dataset->output_dtypes())));
231 }
232 for (const auto& attr : attrs) {
233 opts.reset(
234 new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
235 }
236 if (opts->HaveError()) {
237 return errors::Internal("AddDataset: Failed to build Options with error ",
238 opts->StatusToString());
239 }
240 NodeBuilder node_builder(opts->GetNameForOp(type_string), type_string,
241 opts->op_registry());
242 {
243 size_t total_size = inputs.size() + list_inputs.size();
244 auto inputs_iter = inputs.begin();
245 auto list_inputs_iter = list_inputs.begin();
246 for (int i = 0; i < total_size; i++) {
247 if (inputs_iter != inputs.end() && inputs_iter->first == i) {
248 node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
249 inputs_iter++;
250 } else if (list_inputs_iter != list_inputs.end() &&
251 list_inputs_iter->first == i) {
252 std::vector<NodeBuilder::NodeOut> nodeout_inputs;
253 nodeout_inputs.reserve(list_inputs_iter->second.size());
254 for (Node* n : list_inputs_iter->second) {
255 nodeout_inputs.emplace_back(n);
256 }
257 node_builder.Input(nodeout_inputs);
258 list_inputs_iter++;
259 } else {
260 return errors::InvalidArgument("No input found for index ", i);
261 }
262 }
263 }
264 *output = opts->FinalizeBuilder(&node_builder);
265 if (*output == nullptr) {
266 return errors::Internal("AddDataset: Failed to build ", type_string,
267 " op with error ", opts->StatusToString());
268 }
269 return Status::OK();
270 }
271
AddFunction(SerializationContext * ctx,const string & function_name,const FunctionLibraryDefinition & lib_def)272 Status GraphDefBuilderWrapper::AddFunction(
273 SerializationContext* ctx, const string& function_name,
274 const FunctionLibraryDefinition& lib_def) {
275 if (b_->HasFunction(function_name)) {
276 VLOG(1) << "Function with name " << function_name << "already exists in"
277 << " the graph. It will not be added again.";
278 return Status::OK();
279 }
280 const FunctionDef* f_def = lib_def.Find(function_name);
281 if (f_def == nullptr) {
282 return errors::InvalidArgument("Unable to find FunctionDef for ",
283 function_name, " in the registry.");
284 }
285 FunctionDefLibrary def;
286 *def.add_function() = *f_def;
287 const string gradient_func = lib_def.FindGradient(function_name);
288 if (!gradient_func.empty()) {
289 GradientDef* g_def = def.add_gradient();
290 g_def->set_function_name(function_name);
291 g_def->set_gradient_func(gradient_func);
292 }
293 TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
294
295 // Recursively add functions in inputs of function_name.
296 for (const NodeDef& node_def : f_def->node_def()) {
297 const OpRegistrationData* op_reg_data = nullptr;
298 TF_RETURN_IF_ERROR(lib_def.LookUp(node_def.op(), &op_reg_data));
299 if (op_reg_data->is_function_op) {
300 TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name(), lib_def));
301 }
302 // Recursively add functions in attrs of this NodeDef.
303 for (const auto& pair : node_def.attr()) {
304 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second, lib_def));
305 }
306 }
307
308 // Recursively add functions in attrs of function_name.
309 for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
310 TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def));
311 }
312 return Status::OK();
313 }
314
AddPlaceholderInternal(const Tensor & val,Node ** output)315 void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
316 Node** output) {
317 *output = ops::SourceOp(
318 "Placeholder",
319 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
320 }
321
AddTensorInternal(const Tensor & val,Node ** output)322 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
323 Node** output) {
324 *output = ops::SourceOp(
325 "Const",
326 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
327 }
328
HasAttr(const string & name,const string & attr_name) const329 bool GraphDefBuilderWrapper::HasAttr(const string& name,
330 const string& attr_name) const {
331 const OpDef* op_def = nullptr;
332 Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
333 if (!s.ok() || op_def == nullptr) {
334 return false;
335 }
336 return HasAttr(op_def, attr_name);
337 }
338
InitializeBase(IteratorContext * ctx,const IteratorBase * parent)339 Status IteratorBase::InitializeBase(IteratorContext* ctx,
340 const IteratorBase* parent) {
341 parent_ = parent;
342 id_ =
343 Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
344 if (parent_) {
345 parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()),
346 reinterpret_cast<uint64>(parent_));
347 }
348 if (const auto& model = ctx->model()) {
349 auto factory = [ctx, this](model::Node::Args args) {
350 return CreateNode(ctx, std::move(args));
351 };
352 model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
353 cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
354 }
355 return Status::OK();
356 }
357
GetAllocatedBytes(const std::vector<Tensor> & element)358 int64 GetAllocatedBytes(const std::vector<Tensor>& element) {
359 int64 allocated_bytes = 0;
360 DatasetBase* dataset;
361 for (auto& tensor : element) {
362 if (tensor.dtype() == DT_VARIANT &&
363 GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
364 allocated_bytes += dataset->AllocatedBytes();
365 } else {
366 allocated_bytes += tensor.AllocatedBytes();
367 }
368 }
369 return allocated_bytes;
370 }
371
GetTotalBytes(const std::vector<Tensor> & element)372 int64 GetTotalBytes(const std::vector<Tensor>& element) {
373 int64 total_bytes = 0;
374 DatasetBase* dataset;
375 for (auto& tensor : element) {
376 if (tensor.dtype() == DT_VARIANT &&
377 GetDatasetFromVariantTensor(tensor, &dataset).ok()) {
378 total_bytes += dataset->TotalBytes();
379 } else {
380 total_bytes += tensor.TotalBytes();
381 }
382 }
383 return total_bytes;
384 }
385
FullName(const std::string & prefix,const std::string & name)386 std::string FullName(const std::string& prefix, const std::string& name) {
387 if (str_util::StrContains(name, kColon)) {
388 LOG(ERROR) << name << " should not contain " << kColon;
389 }
390
391 return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name);
392 }
393
GetDatasetFromVariantTensor(const Tensor & tensor,DatasetBase ** out_dataset)394 Status GetDatasetFromVariantTensor(const Tensor& tensor,
395 DatasetBase** out_dataset) {
396 if (!(tensor.dtype() == DT_VARIANT &&
397 TensorShapeUtils::IsScalar(tensor.shape()))) {
398 return errors::InvalidArgument(
399 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
400 }
401 const Variant& variant = tensor.scalar<Variant>()();
402 const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
403 if (wrapper == nullptr) {
404 return errors::InvalidArgument("Tensor must be a Dataset object.");
405 }
406 *out_dataset = wrapper->get();
407 if (*out_dataset == nullptr) {
408 return errors::Internal("Read uninitialized Dataset variant.");
409 }
410 return Status::OK();
411 }
412
StoreDatasetInVariantTensor(DatasetBase * dataset,Tensor * tensor)413 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
414 if (!(tensor->dtype() == DT_VARIANT &&
415 TensorShapeUtils::IsScalar(tensor->shape()))) {
416 return errors::InvalidArgument(
417 "Dataset tensor must be a scalar of dtype DT_VARIANT.");
418 }
419 tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
420 return Status::OK();
421 }
422
MakeIterator(IteratorContext * ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator) const423 Status DatasetBase::MakeIterator(
424 IteratorContext* ctx, const IteratorBase* parent,
425 const string& output_prefix,
426 std::unique_ptr<IteratorBase>* iterator) const {
427 *iterator = MakeIteratorInternal(output_prefix);
428 Status s = (*iterator)->InitializeBase(ctx, parent);
429 if (s.ok()) {
430 s.Update((*iterator)->Initialize(ctx));
431 }
432 if (!s.ok()) {
433 // Reset the iterator to avoid returning an uninitialized iterator.
434 iterator->reset();
435 }
436 return s;
437 }
438
MakeSplitProvider(std::unique_ptr<SplitProvider> * split_provider) const439 Status DatasetBase::MakeSplitProvider(
440 std::unique_ptr<SplitProvider>* split_provider) const {
441 std::vector<const DatasetBase*> inputs;
442 Status s = InputDatasets(&inputs);
443 if (errors::IsUnimplemented(s)) {
444 return errors::Unimplemented(
445 "Cannot create a split provider for dataset of type ", type_string(),
446 ", because the dataset implements neither `InputDatasets` nor "
447 "`MakeSplitProvider`.");
448 }
449 if (inputs.size() != 1) {
450 return errors::Unimplemented(
451 "Cannot create a split provider for dataset of type ", type_string(),
452 ", because the dataset is not unary (having arity ", inputs.size(),
453 "), and no custom implementation of `MakeSplitProvider` is defined.");
454 }
455 return inputs[0]->MakeSplitProvider(split_provider);
456 }
457
InputDatasets(std::vector<const DatasetBase * > * inputs) const458 Status DatasetBase::InputDatasets(
459 std::vector<const DatasetBase*>* inputs) const {
460 return errors::Unimplemented("InputDatasets not implemented for ",
461 type_string());
462 }
463
AddInputDataset(SerializationContext * ctx,const DatasetBase * dataset,Node ** output)464 Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
465 SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
466 Status status = dataset->AsGraphDefInternal(ctx, this, output);
467 if (errors::IsUnimplemented(status) && !ctx->fail_if_unimplemented()) {
468 Tensor t(DT_VARIANT, TensorShape({}));
469 // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
470 // increment the refcount of `dataset` here to retain ownership.
471 dataset->Ref();
472 TF_RETURN_IF_ERROR(
473 StoreDatasetInVariantTensor(const_cast<DatasetBase*>(dataset), &t));
474 TF_RETURN_IF_ERROR(AddPlaceholder(t, output));
475 DCHECK_NE(ctx->input_list(), nullptr);
476 ctx->input_list()->emplace_back((*output)->name(), std::move(t));
477 LOG_EVERY_N_SEC(WARNING, 30)
478 << "Input of " << dataset->DebugString()
479 << " will not be optimized because the dataset does not implement the "
480 "AsGraphDefInternal() method needed to apply optimizations.";
481 return Status::OK();
482 }
483 return status;
484 }
485
AddDatasetOrTensor(SerializationContext * ctx,const Tensor & t,Node ** output)486 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensor(
487 SerializationContext* ctx, const Tensor& t, Node** output) {
488 if (t.dtype() == DT_VARIANT) {
489 // If the input tensor is a variant, it may represent a multi-dimensional
490 // array of datasets. We attempt to decode each dataset so that we can use
491 // their custom serialization logic and combine the result of their
492 // individual serializations using the `Pack` operation.
493 //
494 // If this fails, we fallback to using its Variant::Encode() based
495 // serialization.
496 Status s = AddDatasetOrTensorHelper(ctx, t, output);
497 if (s.ok()) {
498 return s;
499 }
500 }
501 return AddTensor(t, output);
502 }
503
AddDatasetOrTensorHelper(SerializationContext * ctx,const Tensor & t,Node ** output)504 Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper(
505 SerializationContext* ctx, const Tensor& t, Node** output) {
506 if (t.dims() == 0) {
507 DatasetBase* dataset;
508 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(t, &dataset));
509 return AddInputDataset(ctx, dataset, output);
510 }
511 std::vector<NodeBuilder::NodeOut> nodes;
512 for (int i = 0; i < t.dim_size(0); ++i) {
513 Node* node;
514 TF_RETURN_IF_ERROR(AddDatasetOrTensorHelper(ctx, t.SubSlice(i), &node));
515 nodes.emplace_back(node);
516 }
517 auto op_name = "Pack";
518 auto opts = builder()->opts();
519 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
520 opts.op_registry());
521 node_builder.Input(std::move(nodes));
522 *output = opts.FinalizeBuilder(&node_builder);
523 return Status::OK();
524 }
525
DatasetBaseIterator(const BaseParams & params)526 DatasetBaseIterator::DatasetBaseIterator(const BaseParams& params)
527 : params_(params) {
528 params_.dataset->Ref();
529 VLOG(2) << prefix() << " constructor";
530 }
531
~DatasetBaseIterator()532 DatasetBaseIterator::~DatasetBaseIterator() {
533 VLOG(2) << prefix() << " destructor";
534 params_.dataset->Unref();
535 }
536
BuildTraceMeName()537 string DatasetBaseIterator::BuildTraceMeName() {
538 string result = strings::StrCat(params_.prefix, "#id=", id_);
539 if (parent_) {
540 strings::StrAppend(&result, ",parent_id=", parent_id_);
541 }
542
543 TraceMeMetadata metadata = GetTraceMeMetadata();
544 for (const auto& pair : metadata) {
545 strings::StrAppend(&result, ",", pair.first, "=", pair.second);
546 }
547 strings::StrAppend(&result, "#");
548 return result;
549 }
550
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)551 Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
552 std::vector<Tensor>* out_tensors,
553 bool* end_of_sequence) {
554 profiler::TraceMe activity([&] { return BuildTraceMeName(); },
555 profiler::TraceMeLevel::kInfo);
556 DVLOG(3) << prefix() << " GetNext enter";
557 auto model = ctx->model();
558 if (model && model->collect_resource_usage() && node_) {
559 int64 now_nanos = EnvTime::NowNanos();
560 auto output = node_->output();
561 if (output) {
562 output->record_stop(now_nanos);
563 }
564 node_->record_start(now_nanos);
565 }
566 Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
567 if (TF_PREDICT_TRUE(s.ok() && !*end_of_sequence)) {
568 DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size());
569 RecordElement(ctx, out_tensors);
570 }
571 if (model && model->collect_resource_usage() && node_) {
572 int64 now_nanos = EnvTime::NowNanos();
573 node_->record_stop(now_nanos);
574 auto output = node_->output();
575 if (output) {
576 output->record_start(now_nanos);
577 }
578 }
579 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
580 s = errors::Internal("Iterator \"", params_.prefix,
581 "\" returned `OutOfRange`. This indicates an "
582 "implementation error as `OutOfRange` errors are not "
583 "expected to be returned here. Original message: ",
584 s.error_message());
585 LOG(ERROR) << s;
586 }
587 DVLOG(3) << prefix() << " GetNext exit";
588 return s;
589 }
590
Skip(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)591 Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip,
592 bool* end_of_sequence, int* num_skipped) {
593 profiler::TraceMe activity([&] { return BuildTraceMeName(); },
594 profiler::TraceMeLevel::kInfo);
595 DVLOG(3) << prefix() << " Skip enter";
596 auto model = ctx->model();
597 if (model && model->collect_resource_usage() && node_) {
598 int64 now_nanos = EnvTime::NowNanos();
599 auto output = node_->output();
600 if (output) {
601 output->record_stop(now_nanos);
602 }
603 node_->record_start(now_nanos);
604 }
605 Status s = SkipInternal(ctx, num_to_skip, end_of_sequence, num_skipped);
606 if (model && model->collect_resource_usage() && node_) {
607 int64 now_nanos = EnvTime::NowNanos();
608 node_->record_stop(now_nanos);
609 auto output = node_->output();
610 if (output) {
611 output->record_start(now_nanos);
612 }
613 }
614 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) {
615 s = errors::Internal("Iterator \"", params_.prefix,
616 "\" returned `OutOfRange`. This indicates an "
617 "implementation error as `OutOfRange` errors are not "
618 "expected to be returned here. Original message: ",
619 s.error_message());
620 LOG(ERROR) << s;
621 }
622 DVLOG(3) << prefix() << " Skip exit";
623 return s;
624 }
625
SkipInternal(IteratorContext * ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)626 Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip,
627 bool* end_of_sequence,
628 int* num_skipped) {
629 *num_skipped = 0;
630 for (int i = 0; i < num_to_skip; ++i) {
631 std::vector<Tensor> out_tensors;
632 TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence));
633 if (*end_of_sequence) {
634 return Status::OK();
635 }
636 // RecordElement is used to count the number of element computed and
637 // help calculate the CPU time spent on a given iterator to do the
638 // autotuning.
639 // Here we only call RecordElement in the default implementation of
640 // SkipInternal (which trivially calls GetNextInternal) and assume
641 // that the overriden SkipInternal in the derived class will have
642 // negligible cost compare to its GetNextInternal.
643 RecordElement(ctx, &out_tensors);
644 (*num_skipped)++;
645 }
646 return Status::OK();
647 }
648
Compute(OpKernelContext * ctx)649 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
650 DatasetBase* dataset = nullptr;
651 MakeDataset(ctx, &dataset);
652 if (ctx->status().ok()) {
653 Tensor* output = nullptr;
654 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
655 OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
656 }
657 }
658
TraceString(const OpKernelContext & ctx,bool verbose) const659 string DatasetOpKernel::TraceString(const OpKernelContext& ctx,
660 bool verbose) const {
661 return profiler::TraceMeOp(name_view(), type_string_view());
662 }
663
664 // static
IsDatasetOp(const OpDef * op_def)665 bool DatasetOpKernel::IsDatasetOp(const OpDef* op_def) {
666 if (DatasetOpRegistry::IsRegistered(op_def->name())) {
667 return true;
668 }
669
670 return (op_def->output_arg_size() == 1 &&
671 op_def->output_arg(0).type() == DT_VARIANT &&
672 (absl::EndsWith(op_def->name(), "Dataset") ||
673 absl::EndsWith(op_def->name(), "DatasetV2")));
674 }
675
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)676 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
677 DatasetBase** output) {
678 DatasetBase* input;
679 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
680 MakeDataset(ctx, input, output);
681 }
682
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)683 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
684 DatasetBase** output) {
685 DatasetBase* input;
686 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
687 DatasetBase* another_input;
688 OP_REQUIRES_OK(ctx,
689 GetDatasetFromVariantTensor(ctx->input(1), &another_input));
690 MakeDataset(ctx, input, another_input, output);
691 }
692
693 const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
694 const char DatasetBase::kDatasetGraphOutputNodeKey[] =
695 "_DATASET_GRAPH_OUTPUT_NODE";
696
BackgroundWorker(Env * env,const char * name)697 BackgroundWorker::BackgroundWorker(Env* env, const char* name)
698 : env_(env), name_(name) {}
699
~BackgroundWorker()700 BackgroundWorker::~BackgroundWorker() {
701 {
702 mutex_lock l(mu_);
703 cancelled_ = true;
704 }
705 cond_var_.notify_one();
706 // Block until the background thread has terminated.
707 //
708 // NOTE(mrry): We explicitly free and join the thread here because
709 // `WorkerLoop()` uses other members of this object, and so we must join
710 // the thread before destroying them.
711 thread_.reset();
712 }
713
Schedule(std::function<void ()> work_item)714 void BackgroundWorker::Schedule(std::function<void()> work_item) {
715 {
716 mutex_lock l(mu_);
717 if (!thread_) {
718 thread_ = absl::WrapUnique(env_->StartThread(
719 {} /* thread_options */, name_, [this]() { WorkerLoop(); }));
720 }
721 work_queue_.push_back(std::move(work_item));
722 }
723 cond_var_.notify_one();
724 }
725
WorkerLoop()726 void BackgroundWorker::WorkerLoop() {
727 tensorflow::ResourceTagger tag(kTFDataResourceTag, "Background");
728 while (true) {
729 std::function<void()> work_item = nullptr;
730 {
731 mutex_lock l(mu_);
732 while (!cancelled_ && work_queue_.empty()) {
733 cond_var_.wait(l);
734 }
735 if (cancelled_) {
736 return;
737 }
738 DCHECK(!work_queue_.empty());
739 work_item = std::move(work_queue_.front());
740 work_queue_.pop_front();
741 }
742 DCHECK(work_item != nullptr);
743 work_item();
744 }
745 }
746
747 // static
Register(const string & op_name)748 void DatasetOpRegistry::Register(const string& op_name) {
749 mutex_lock l(*get_dataset_op_registry_lock());
750 get_dataset_op_registry()->insert(op_name);
751 }
752
753 // static
IsRegistered(const string & op_name)754 bool DatasetOpRegistry::IsRegistered(const string& op_name) {
755 mutex_lock l(*get_dataset_op_registry_lock());
756 std::unordered_set<string>* op_names = get_dataset_op_registry();
757 return op_names->find(op_name) != op_names->end();
758 }
759
760 namespace {
761 class RunnerImpl : public Runner {
762 public:
Run(const std::function<void ()> & f)763 void Run(const std::function<void()>& f) override {
764 tensorflow::ResourceTagger tag(kTFDataResourceTag, "Runner");
765 f();
766
767 // NOTE: We invoke a virtual function to prevent `f` being tail-called, and
768 // thus ensure that this function remains on the stack until after `f`
769 // returns.
770 PreventTailCall();
771 }
772
773 private:
PreventTailCall()774 virtual void PreventTailCall() {}
775 };
776 } // namespace
777
778 /* static */
get()779 Runner* Runner::get() {
780 static Runner* singleton = new RunnerImpl;
781 return singleton;
782 }
783
784 } // namespace data
785 } // namespace tensorflow
786