1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/model.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/thread_factory.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variant_encode_decode.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/lib/core/threadpool_interface.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/lib/strings/strcat.h"
42 #include "tensorflow/core/platform/cpu_info.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/tracing.h"
45
46 // Polymorphic datasets should support all primitive TensorFlow
47 // types. Use this macro to expand `m(T)` once for each primitive type
48 // `T`, e.g. to build a `switch` statement.
49 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
50
51 namespace tensorflow {
52
53 // Forward declarations to avoid introducing a dependency on headers in
54 // "tensorflow/core/graph/...".
55 class GraphDefBuilder;
56 class Node;
57
58 namespace data {
59
60 constexpr int kInfiniteCardinality = -1;
61 constexpr int kUnknownCardinality = -2;
62
63 // This constant is a magic number that is used (as a prefix) to identify keys
64 // used for serialization of iterator state.
65 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
66 constexpr char kPipe[] = "|";
67 constexpr char kColon[] = ":";
68
69 class DatasetBase;
70 class SerializationContext;
71
72 // Interface for reading values from a key-value store.
73 // Used for restoring iterator state. This class is thread safe.
74 // Please see comment on IteratorStateWriter for guidance around using the
75 // Read*(key, val) vs Read*(name, key, val).
76 class IteratorStateReader {
77 public:
78 virtual Status ReadScalar(StringPiece key, int64* val) = 0;
79 virtual Status ReadScalar(StringPiece key, tstring* val) = 0;
80 virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
81
82 virtual Status ReadScalar(StringPiece name, StringPiece key, int64* val) = 0;
83 virtual Status ReadScalar(StringPiece name, StringPiece key,
84 tstring* val) = 0;
85 virtual Status ReadTensor(StringPiece name, StringPiece key, Tensor* val) = 0;
86
87 virtual bool Contains(StringPiece key) = 0;
88 virtual bool Contains(StringPiece name, StringPiece key) = 0;
89
~IteratorStateReader()90 virtual ~IteratorStateReader() {}
91 };
92
93 // Interface for writing values to a key-value store.
94 // Used for saving iterator state. Not thread safe.
95 // The IteratorStateWriter creates a tensor for each unique iterator name it
96 // sees. For the Write*(key, val) API's the key is expected to encode this
97 // name as keys are required to be produced using the full_name() method.
98 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
99 // might exceed the 2 GB limit, you can pass an explicit name in via the
100 // Write*(name, key, val) APIs allowing you to further split up the state
101 // into more manageable chunks.
102 class IteratorStateWriter {
103 public:
104 virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
105 virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
106 virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
107
108 virtual Status WriteScalar(StringPiece name, StringPiece key,
109 const int64 val) = 0;
110 virtual Status WriteScalar(StringPiece name, StringPiece key,
111 const tstring& val) = 0;
112 virtual Status WriteTensor(StringPiece name, StringPiece key,
113 const Tensor& val) = 0;
114
~IteratorStateWriter()115 virtual ~IteratorStateWriter() {}
116 };
117
118 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
119 class GraphDefBuilderWrapper {
120 public:
GraphDefBuilderWrapper(GraphDefBuilder * b)121 explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
122
123 // Adds a Const node with scalar value to the Graph.
124 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
125 // non-null if the method returns with an OK status.
126 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
127 template <typename T>
AddScalar(const T & val,Node ** output)128 Status AddScalar(const T& val, Node** output) {
129 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
130 val_t.scalar<T>()() = val;
131 AddTensorInternal(val_t, output);
132 if (*output == nullptr) {
133 return errors::Internal("AddScalar: Failed to build Const op.");
134 }
135 return Status::OK();
136 }
137
138 // Adds a Const node with vector value to the Graph.
139 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
140 // non-null if the method returns with an OK status.
141 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
142 // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
143 template <typename T>
AddVector(const std::vector<T> & val,Node ** output)144 Status AddVector(const std::vector<T>& val, Node** output) {
145 Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
146 TensorShape({static_cast<int64>(val.size())}));
147 for (size_t i = 0; i < val.size(); i++) {
148 val_t.flat<T>()(i) = val[i];
149 }
150 AddTensorInternal(val_t, output);
151 if (*output == nullptr) {
152 return errors::Internal("AddVector: Failed to build Const op.");
153 }
154 return Status::OK();
155 }
156
AddVector(const std::vector<string> & val,Node ** output)157 Status AddVector(const std::vector<string>& val, Node** output) {
158 Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
159 TensorShape({static_cast<int64>(val.size())}));
160 for (size_t i = 0; i < val.size(); i++) {
161 val_t.flat<tstring>()(i) = val[i];
162 }
163 AddTensorInternal(val_t, output);
164 if (*output == nullptr) {
165 return errors::Internal("AddVector: Failed to build Const op.");
166 }
167 return Status::OK();
168 }
169
170 // Adds a `Const` node for the given tensor value to the graph.
171 //
172 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
173 // non-null if the method returns with an OK status. The returned `Node`
174 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)175 Status AddTensor(const Tensor& val, Node** output) {
176 AddTensorInternal(val, output);
177 if (*output == nullptr) {
178 return errors::Internal("AddTensor: Failed to build Const op.");
179 }
180 return Status::OK();
181 }
182
183 // Adds a `Placeholder` node for the given tensor value to the graph.
184 //
185 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
186 // non-null if the method returns with an OK status. The returned `Node`
187 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)188 Status AddPlaceholder(const Tensor& val, Node** output) {
189 AddPlaceholderInternal(val, output);
190 if (*output == nullptr) {
191 return errors::Internal(
192 "AddPlaceholder: Failed to build Placeholder op.");
193 }
194 return Status::OK();
195 }
196
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)197 Status AddDataset(const DatasetBase* dataset,
198 const std::vector<Node*>& inputs, Node** output) {
199 return AddDataset(dataset, inputs, {}, output);
200 }
201
202 // Adds a node corresponding to the `DatasetType` to the Graph.
203 // Return value of `DatasetType::type_string()` is used as the op type for the
204 // node.
205 // Values for the output_types and output_shapes node attributes are also
206 // written if those attributes are defined in the OpDef.
207 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
208 // non-null if the method returns with an OK status.
209 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)210 Status AddDataset(const DatasetBase* dataset,
211 const std::vector<Node*>& inputs,
212 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
213 Node** output) {
214 std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
215 for (size_t i = 0; i < inputs.size(); i++) {
216 enumerated_inputs[i] = std::make_pair(i, inputs[i]);
217 }
218 return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
219 }
220
221 Status AddDataset(
222 const DatasetBase* dataset,
223 const std::vector<std::pair<size_t, Node*>>& inputs,
224 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
225 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
226 Node** output);
227
228 // Adds a user-defined function with name `function_name` to the graph and
229 // recursively adds all functions it references. If a function with a matching
230 // name has already been added, returns with OK status. If a user-defined with
231 // name `function_name` is not found in the context's function library,
232 // returns an InvalidArgumentError. If the function with name `function_name`
233 // or any of its dependent functions are stateful, and the context does not
234 // explicitly permit stateful functions, returns an InvalidArgument error.
235 Status AddFunction(SerializationContext* ctx, const string& function_name,
236 const FunctionLibraryDefinition& lib_def);
237
238 template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)239 void BuildAttrValue(const T& value, AttrValue* attr) {
240 SetAttrValue(value, attr);
241 }
242
243 private:
244 void AddPlaceholderInternal(const Tensor& val, Node** output);
245 void AddTensorInternal(const Tensor& val, Node** output);
246 bool HasAttr(const string& op_type_name, const string& attr_name) const;
247
HasAttr(const OpDef * op_def,const string & attr_name)248 bool HasAttr(const OpDef* op_def, const string& attr_name) const {
249 for (auto attr : op_def->attr()) {
250 if (attr.name() == attr_name) {
251 return true;
252 }
253 }
254 return false;
255 }
256
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)257 Status AddAttrFunctions(SerializationContext* ctx,
258 const AttrValue& attr_value,
259 const FunctionLibraryDefinition& lib_def) {
260 if (attr_value.has_func()) {
261 TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
262 } else if (attr_value.has_list()) {
263 for (const NameAttrList& name_attr_list : attr_value.list().func()) {
264 TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
265 }
266 }
267 return Status::OK();
268 }
269
270 GraphDefBuilder* b_;
271 };
272
273 class StatsAggregator;
274 class FunctionHandleCache;
275
276 // A utility class for running a function and ensuring that there is always a
277 // `tensorflow::data` symbol on the stack.
278 class Runner {
279 public:
~Runner()280 virtual ~Runner() {}
281
282 // Runs the given function.
283 virtual void Run(const std::function<void()>& f) = 0;
284
285 // Returns a global singleton Runner.
286 static Runner* get();
287 };
288
289 // A cut-down version of `OpKernelContext` for running computations in
290 // iterators. Note that we cannot simply use `OpKernelContext` here because we
291 // might run computation in an iterator whose lifetime is not nested within the
292 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
293 //
294 // TODO(mrry): We're making some daring assumptions about the lifetime of the
295 // runner passed in here. A runner will be deleted when the original step ends,
296 // but all existing runners only close over session-lifetime (or longer-lived)
297 // state, so we can make a copy of the function. There's nothing in the
298 // definition of the API from which we took the runner to guarantee that what we
299 // are doing is safe. We should formalize the properties here.
300 class IteratorContext {
301 public:
302 struct Params {
ParamsParams303 explicit Params(IteratorContext* ctx)
304 : allocator_getter(ctx->allocator_getter()),
305 cancellation_manager(ctx->cancellation_manager()),
306 env(ctx->env()),
307 flr(ctx->flr()),
308 function_handle_cache(ctx->function_handle_cache()),
309 resource_mgr(ctx->resource_mgr()),
310 model(ctx->model()),
311 runner(*(ctx->runner())),
312 runner_threadpool_size(ctx->runner_threadpool_size()),
313 stats_aggregator(ctx->stats_aggregator()),
314 thread_factory(ctx->thread_factory()),
315 thread_pool(ctx->thread_pool()) {}
316
ParamsParams317 explicit Params(OpKernelContext* ctx)
318 : env(ctx->env()), flr(ctx->function_library()) {
319 // NOTE: need reinterpret_cast because function.h forward-declares Device.
320 DeviceBase* device =
321 reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
322 allocator_getter = [device](AllocatorAttributes attrs) {
323 return device->GetAllocator(attrs);
324 };
325
326 thread::ThreadPool* thread_pool =
327 ctx->device()->tensorflow_device_thread_pool();
328 if (thread_pool) {
329 runner_threadpool_size = thread_pool->NumThreads();
330 } else {
331 runner_threadpool_size = port::MaxParallelism();
332 }
333
334 // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
335 // that a symbol in the tensorflow::data namespace is always on the stack
336 // when executing a function inside a Dataset.
337 runner = std::bind(
338 [](
339 // Note: `runner` is a const reference to avoid copying it.
340 const std::function<void(std::function<void()>)>& ctx_runner,
341 std::function<void()> fn) {
342 std::function<void()> wrapped_fn = std::bind(
343 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
344 std::move(fn));
345 ctx_runner(std::move(wrapped_fn));
346 },
347 *ctx->runner(), std::placeholders::_1);
348 }
349
350 // The Allocator to be used to allocate the output of an iterator.
351 std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
352
353 // The CancellationManager to be used to cancel execution of ops.
354 CancellationManager* cancellation_manager;
355
356 // Interface to operating system functionality.
357 Env* env = nullptr;
358
359 // The FunctionLibraryRuntime object to be used to make function calls.
360 FunctionLibraryRuntime* flr = nullptr;
361
362 // A FunctionHandleCache that owns all the function handles. Not owned.
363 FunctionHandleCache* function_handle_cache = nullptr;
364
365 // A resource manager for storing dataset-related state, e.g. random
366 // seeds or cached tensors. Not owned.
367 ResourceMgr* resource_mgr = nullptr;
368
369 // If non-null, identifies the object used for performance modeling.
370 std::shared_ptr<model::Model> model = nullptr;
371
372 // Function call support.
373 std::function<void(std::function<void()>)> runner = nullptr;
374
375 // Number of threads used for executing user-defined functions.
376 int32 runner_threadpool_size = 0;
377
378 // The `StatsAggregator` object to record statistics about the iterator.
379 std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
380
381 // A factory for creating threads to perform blocking work.
382 std::shared_ptr<ThreadFactory> thread_factory = nullptr;
383
384 // A shared thread pool to schedule computation into.
385 thread::ThreadPoolInterface* thread_pool = nullptr;
386 };
387
IteratorContext(IteratorContext * ctx)388 explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
389
IteratorContext(OpKernelContext * ctx)390 explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
391
IteratorContext(Params params)392 explicit IteratorContext(Params params) : params_(std::move(params)) {}
393
allocator(AllocatorAttributes attrs)394 Allocator* allocator(AllocatorAttributes attrs) {
395 return params_.allocator_getter(attrs);
396 }
397
allocator_getter()398 std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
399 return params_.allocator_getter;
400 }
401
cancellation_manager()402 CancellationManager* cancellation_manager() {
403 return params_.cancellation_manager;
404 }
405
env()406 Env* env() const { return params_.env; }
407
flr()408 FunctionLibraryRuntime* flr() { return params_.flr; }
409
function_handle_cache()410 FunctionHandleCache* function_handle_cache() {
411 return params_.function_handle_cache;
412 }
413
resource_mgr()414 ResourceMgr* resource_mgr() { return params_.resource_mgr; }
415
model()416 const std::shared_ptr<model::Model>& model() { return params_.model; }
417
runner()418 std::function<void(std::function<void()>)>* runner() {
419 return ¶ms_.runner;
420 }
421
runner_threadpool_size()422 int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
423
stats_aggregator()424 std::shared_ptr<StatsAggregator> stats_aggregator() {
425 return params_.stats_aggregator;
426 }
427
thread_factory()428 const std::shared_ptr<ThreadFactory>& thread_factory() {
429 return params_.thread_factory;
430 }
431
thread_pool()432 thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
433
params()434 Params params() { return params_; }
435
CreateThreadPool(const string & name,int num_threads)436 std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
437 int num_threads) {
438 if (params_.thread_pool) {
439 // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
440 // is an instance of `thread::ThreadPoolInterface`). Notably, the
441 // ownership of `params_.thread_pool` is *not* transferred onto the newly
442 // created `ThreadPool` instance.
443 return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
444 } else {
445 return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
446 name, num_threads,
447 /*low_latency_hint=*/false);
448 }
449 }
450
StartThread(const string & name,std::function<void ()> fn)451 std::unique_ptr<Thread> StartThread(const string& name,
452 std::function<void()> fn) {
453 if (params_.thread_factory) {
454 return params_.thread_factory->StartThread(name, std::move(fn));
455 } else {
456 return absl::WrapUnique(
457 Env::Default()->StartThread({}, name, std::move(fn)));
458 }
459 }
460
461 private:
462 Params params_;
463 };
464
465 // Aggregates runtime support needed for dataset and iterator serialization.
466 class SerializationContext {
467 public:
468 // Enum describing what to do during serialization when external state is
469 // encountered.
470 enum class ExternalStatePolicy : int64 {
471 // Proceed with serialization, but log a warning about what state will be
472 // lost.
473 kWarn = 0,
474 // Proceed with serialization without logging any warning.
475 kIgnore = 1,
476 // Fail the serialization with an error.
477 kFail = 2,
478 };
479
480 struct Params {
481 std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
482
483 // Indicates what to do if the dataset depends on external state.
484 ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
485
486 // Indicates whether an attempt to serialize a dataset that does not
487 // implement serialization should result in an error. If set to `false`, the
488 // serialized graph will replace the dataset with a placeholder returned in
489 // `input_list`.
490 bool fail_if_unimplemented = true;
491
492 // Indicates whether (potentionally large) data tensors should be
493 // serialized, or replaced with a placeholder returned in `input_list`. The
494 // latter makes sense to do when performing data agnostic graph rewrites to
495 // reduce the memory usage.
496 bool serialize_data_tensors = true;
497
498 // Indicates whether datasets that use random seeds should have the values
499 // of random seeds serialized or not. If the values of random seeds are
500 // serialized, the deserialized dataset will have the same seeds as the
501 // original dataset. Otherwise, the deserialized dataset will use different
502 // seeds. This param does not affect datasets that use fixed seeds; fixed
503 // seeds will always be preserved.
504 bool preserve_random_seeds = true;
505 };
506
SerializationContext(Params params)507 explicit SerializationContext(Params params) : params_(params) {}
508
input_list()509 std::vector<std::pair<string, Tensor>>* input_list() {
510 return params_.input_list;
511 }
512
external_state_policy()513 ExternalStatePolicy external_state_policy() const {
514 return params_.external_state_policy;
515 }
516
fail_if_unimplemented()517 bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
518
serialize_data_tensors()519 bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
520
preserve_random_seeds()521 bool preserve_random_seeds() const { return params_.preserve_random_seeds; }
522
523 private:
524 Params params_;
525
526 TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
527 };
528
529 // Represents the current position in a range of outputs, where the
530 // range of outputs is typically represented by an `DatasetBase`,
531 // defined below.
532 class IteratorBase {
533 public:
~IteratorBase()534 virtual ~IteratorBase() {
535 for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
536 (*rit)();
537 }
538 }
539
540 // Gets the next output from the range that this iterator is traversing.
541 //
542 // If at least one output remains in this iterator's range, that
543 // output will be stored in `*out_tensors` and `false` will be
544 // stored in `*end_of_sequence`.
545 //
546 // If no more outputs remain in this iterator's range, `true` will
547 // be stored in `*end_of_sequence`, and the content of
548 // `*out_tensors` will be undefined.
549 //
550 // Implementations should never return `OutOfRange` error. If at end of
551 // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
552 // Internally raised `OutOfRange` errors that do not imply end of sequence
553 // should be converted to a different error type before being propagated to
554 // the caller.
555 //
556 // This method is thread-safe.
557 //
558 // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
559 // potentially remove this method.
560 virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
561 bool* end_of_sequence) = 0;
562
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)563 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
564 bool* end_of_sequence) {
565 return GetNext(&ctx, out_tensors, end_of_sequence);
566 }
567
568 // Returns a vector of DataType values, representing the respective
569 // element types of each tuple component in the outputs of this
570 // iterator.
571 virtual const DataTypeVector& output_dtypes() const = 0;
572
573 // Returns a vector of tensor shapes, representing the respective
574 // (and possibly partially defined) shapes of each tuple component
575 // in the outputs of this iterator.
576 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
577
578 // Returns a string that identifies the sequence of iterators leading up to
579 // this iterator.
580 virtual const string& prefix() const = 0;
581
582 // Performs initialization that needs to happen outside of a constructor to
583 // properly propagate errors.
Initialize(IteratorContext * ctx)584 virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
585
586 // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)587 virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
588 return SaveInternal(writer);
589 }
590
591 protected:
592 // Returns a node that models this iterator.
593 virtual std::shared_ptr<model::Node> CreateNode(
594 IteratorContext* ctx, model::Node::Args args) const = 0;
595
596 // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)597 Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
598 return RestoreInternal(ctx, reader);
599 }
600
601 // This is needed so that sub-classes of IteratorBase can call
602 // `SaveInternal` on their input iterators.
SaveInput(IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)603 Status SaveInput(IteratorStateWriter* writer,
604 const std::unique_ptr<IteratorBase>& input) {
605 return input->SaveInternal(writer);
606 }
607
608 // This is needed so that sub-classes of IteratorBase can call
609 // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)610 Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
611 const std::unique_ptr<IteratorBase>& input) {
612 return input->RestoreInternal(ctx, reader);
613 }
614
615 // Saves the state of this iterator.
616 //
617 // This method is used to store the state of the iterator in a checkpoint.
618 //
619 // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
620 // implementations have an override.
SaveInternal(IteratorStateWriter * writer)621 virtual Status SaveInternal(IteratorStateWriter* writer) {
622 return errors::Unimplemented("SaveInternal");
623 }
624
625 // Restores the state of this iterator.
626 //
627 // This method is used to restore the state of the iterator from a checkpoint.
628 //
629 // Implementations may assume that the iterator is in a clean state. That is,
630 // its `Initialize` method has been called, but its `GetNext` method has
631 // never been called.
632 //
633 // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
634 // implementations have an override.
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)635 virtual Status RestoreInternal(IteratorContext* ctx,
636 IteratorStateReader* reader) {
637 return errors::Unimplemented("RestoreInternal");
638 }
639
640 // Returns the number of elements produced by this itertaor.
num_elements()641 int64 num_elements() const {
642 if (node_) return node_->num_elements();
643 return 0;
644 }
645
646 private:
647 // For access to `AddCleanupFunction` and `Restore`.
648 friend class DatasetBase;
649 friend class DatasetBaseIterator; // for access to `node_`
650
651 // Registers a cleanup function to be called upon object destruction.
652 //
653 // Registered functions are invoked in the reserve order of registration.
AddCleanupFunction(std::function<void ()> && cleanup_fn)654 void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
655 cleanup_fns_.push_back(std::move(cleanup_fn));
656 }
657
658 // Associates the given performance modeling `Node` with this iterator.
SetNode(std::shared_ptr<model::Node> node)659 void SetNode(std::shared_ptr<model::Node> node) { node_ = node.get(); }
660
661 std::vector<std::function<void()>> cleanup_fns_;
662 model::Node* node_ = nullptr; // Not owned.
663 };
664
665 // Represents runtime information needed to construct a dataset.
666 class DatasetContext {
667 public:
668 struct Params {
669 string type_string; // op type name of this dataset.
670 string node_name; // graph node name of this dataset op, uniquely
671 // identifying the dataset in the graph.
672 };
673
DatasetContext(Params params)674 explicit DatasetContext(Params params) : params_(std::move(params)) {}
675
DatasetContext(OpKernelContext * ctx)676 explicit DatasetContext(OpKernelContext* ctx) {
677 params_.type_string = ctx->op_kernel().type_string();
678 params_.node_name = ctx->op_kernel().name();
679 }
680
type_string()681 const string& type_string() const { return params_.type_string; }
node_name()682 const string& node_name() const { return params_.node_name; }
683
684 private:
685 Params params_;
686 };
687
688 // Returns the number of bytes allocated for the given tensor.
689 int64 GetAllocatedBytes(const std::vector<Tensor>& element);
690
691 // Returns the estimated memory usage in bytes of the given tensor.
692 int64 GetTotalBytes(const std::vector<Tensor>& element);
693
694 // Validates and extracts a `DatasetBase` object from `tensor`.
695 //
696 // `tensor` must have been written by a call to SetVariantTensorToDataset().
697 //
698 // The retrieved pointer is a borrowed reference to the dataset, which is owned
699 // by the tensor. The consumer must either acquire its own reference to the
700 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
701 // destroyed or mutated while the retrieved pointer is in use.
702 Status GetDatasetFromVariantTensor(const Tensor& tensor,
703 DatasetBase** out_dataset);
704
705 // Stores a `DatasetBase` object in `tensor`.
706 //
707 // The ownership of `dataset` is transferred to `tensor`.
708 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
709
710 // Represents a (potentially infinite) range of outputs, where each
711 // output is a tuple of tensors.
712 class DatasetBase : public core::RefCounted {
713 public:
714 // Key for storing the Dataset graph in the serialized format.
715 TF_EXPORT static const char kDatasetGraphKey[];
716
717 // Key for storing the output node of the Dataset graph in the serialized
718 // format.
719 TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
720
DatasetBase(DatasetContext && ctx)721 explicit DatasetBase(DatasetContext&& ctx)
722 : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
723
724 // Op type name of this dataset.
type_string()725 const string& type_string() const { return type_string_; }
726
727 // Graph node name of this dataset op, uniquely identifying the dataset in
728 // the graph.
node_name()729 const string& node_name() const { return node_name_; }
730
731 // Returns a new iterator for iterating over the range of elements in
732 // this dataset.
733 //
734 // This method may be called multiple times on the same instance,
735 // and the resulting iterators will have distinct state. Each
736 // iterator will traverse all elements in this dataset from the
737 // start.
738 //
739 // The prefix identifies the sequence of iterators leading up to the newly
740 // created iterator.
MakeIterator(IteratorContext * ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)741 Status MakeIterator(IteratorContext* ctx, const string& output_prefix,
742 std::unique_ptr<IteratorBase>* iterator) const {
743 *iterator = MakeIteratorInternal(output_prefix);
744 if (const auto& model = ctx->model()) {
745 const string& prefix = (*iterator)->prefix();
746 (*iterator)->SetNode(model->AddNode(MakeNodeFactory(ctx, iterator->get()),
747 prefix, output_prefix));
748 (*iterator)->AddCleanupFunction(
749 [model, prefix]() { model->RemoveNode(prefix); });
750 }
751 Status s = (*iterator)->Initialize(ctx);
752 if (!s.ok()) {
753 // Reset the iterator to avoid returning an uninitialized iterator.
754 iterator->reset();
755 }
756 return s;
757 }
758
MakeIterator(IteratorContext && ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)759 Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,
760 std::unique_ptr<IteratorBase>* iterator) const {
761 return MakeIterator(&ctx, output_prefix, iterator);
762 }
763
764 // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)765 Status MakeIteratorFromCheckpoint(
766 IteratorContext* ctx, const string& output_prefix,
767 IteratorStateReader* reader,
768 std::unique_ptr<IteratorBase>* iterator) const {
769 std::unique_ptr<IteratorBase> it;
770 TF_RETURN_IF_ERROR(MakeIterator(ctx, output_prefix, &it));
771 TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
772 *iterator = std::move(it);
773 return Status::OK();
774 }
775
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)776 Status MakeIteratorFromCheckpoint(
777 IteratorContext&& ctx, const string& output_prefix,
778 IteratorStateReader* reader,
779 std::unique_ptr<IteratorBase>* iterator) const {
780 return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
781 }
782
783 // Returns a vector of DataType values, representing the respective
784 // element types of each tuple component in the outputs of this
785 // dataset.
786 virtual const DataTypeVector& output_dtypes() const = 0;
787
788 // Returns a vector of tensor shapes, representing the respective
789 // (and possibly partially defined) shapes of each tuple component
790 // in the outputs of this dataset.
791 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
792
793 // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()794 virtual int64 AllocatedBytes() const { return 0; }
795
796 // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()797 virtual int64 TotalBytes() const { return 0; }
798
799 // Returns the cardinality of this dataset.
Cardinality()800 virtual int64 Cardinality() const { return kUnknownCardinality; }
801
802 // A human-readable debug string for this dataset.
803 virtual string DebugString() const = 0;
804
805 // If the dataset is stateful it will not be possible to save its graph or
806 // checkpoint the state of its iterators.
807 //
808 // TODO(jsimsa): Remove this method once all `DatasetBase` implementations are
809 // migrated over to `CheckExternalState`.
IsStateful()810 virtual bool IsStateful() const { return false; }
811
812 // Indicates whether the dataset depends on any external state. If so, the
813 // method returns `errors::FailedPrecondition` with a message that identifies
814 // the external state. Otherwise, the method returns `Status::OK()`.
815 //
816 // TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
817 // implementations have an override.
CheckExternalState()818 virtual Status CheckExternalState() const {
819 if (IsStateful()) {
820 return errors::FailedPrecondition("Dataset cannot be serialized.");
821 }
822 return Status::OK();
823 }
824
825 protected:
826 friend Status AsGraphDef(
827 OpKernelContext* ctx, const DatasetBase* dataset,
828 SerializationContext&& serialization_ctx,
829 GraphDef* graph_def); // For access to graph related members.
830 friend class CapturedFunction;
831
832 class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
833 public:
DatasetGraphDefBuilder(GraphDefBuilder * b)834 explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
835 : GraphDefBuilderWrapper(b) {}
836 Status AddInputDataset(SerializationContext* ctx,
837 const DatasetBase* dataset, Node** output);
838 };
839
840 // Serializes the dataset into a `GraphDef`, which has two uses:
841 //
842 // 1) To perform static input pipeline optimizations, tf.data serializes the
843 // dataset graph, applies graph rewrites, and then deserializes the graph.
844 // If a subclass of `DatasetBase` does not implement this method, then it will
845 // be excluded from static optimizations (and so will any upstream datasets).
846 //
847 // 2) To save the dataset so that it can restore at a later point (possibly in
848 // different environment). If a subclass of `DatasetBase` does not implement
849 // this method, then this migration will not be possible.
850 virtual Status AsGraphDefInternal(SerializationContext* ctx,
851 DatasetGraphDefBuilder* b,
852 Node** node) const = 0;
853
854 virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
855 const string& prefix) const = 0;
856
857 private:
858 // Returns a factory for nodes that represent the given iterator.
MakeNodeFactory(IteratorContext * ctx,IteratorBase * iterator)859 static model::Node::Factory MakeNodeFactory(IteratorContext* ctx,
860 IteratorBase* iterator) {
861 return [ctx, iterator](model::Node::Args args) {
862 return iterator->CreateNode(ctx, std::move(args));
863 };
864 }
865
866 const string type_string_;
867 const string node_name_;
868 };
869
870 // Represents an iterator that is associated with a particular dataset.
871 class DatasetBaseIterator : public IteratorBase {
872 public:
873 struct BaseParams {
874 // Owns one reference on the shared dataset object.
875 const DatasetBase* dataset;
876
877 // Identifies the sequence of iterators leading up to this iterator.
878 const string prefix;
879 };
880
DatasetBaseIterator(const BaseParams & params)881 explicit DatasetBaseIterator(const BaseParams& params) : params_(params) {
882 params_.dataset->Ref();
883 VLOG(2) << prefix() << " constructor";
884 }
885
~DatasetBaseIterator()886 ~DatasetBaseIterator() override {
887 VLOG(2) << prefix() << " destructor";
888 params_.dataset->Unref();
889 }
890
output_dtypes()891 const DataTypeVector& output_dtypes() const override {
892 return params_.dataset->output_dtypes();
893 }
894
output_shapes()895 const std::vector<PartialTensorShape>& output_shapes() const override {
896 return params_.dataset->output_shapes();
897 }
898
dataset()899 virtual const DatasetBase* dataset() const { return params_.dataset; }
900
901 // The sequence of iterators leading up to this iterator.
prefix()902 const string& prefix() const override { return params_.prefix; }
903
904 // Returns a name to be used for the TraceMe event.
905 //
906 // NOTE: TraceMe support passing key value pairs of "arguments" using the
907 // following format "name#arg_1=value_,...,arg_n=value_n".
BuildTraceMeName()908 virtual string BuildTraceMeName() { return params_.prefix; }
909
910 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
911 bool* end_of_sequence) final;
912
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)913 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
914 bool* end_of_sequence) {
915 return GetNext(&ctx, out_tensors, end_of_sequence);
916 }
917
Save(SerializationContext * ctx,IteratorStateWriter * writer)918 Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
919 Status s = params_.dataset->CheckExternalState();
920 if (!s.ok()) {
921 if (ctx->external_state_policy() ==
922 SerializationContext::ExternalStatePolicy::kWarn) {
923 LOG(WARNING) << "Dataset contains external state: " << s.ToString();
924 }
925 if (ctx->external_state_policy() ==
926 SerializationContext::ExternalStatePolicy::kFail) {
927 return s;
928 }
929 }
930 return IteratorBase::Save(ctx, writer);
931 }
932
933 protected:
934 // Internal implementation of GetNext that is wrapped in tracing logic.
935 virtual Status GetNextInternal(IteratorContext* ctx,
936 std::vector<Tensor>* out_tensors,
937 bool* end_of_sequence) = 0;
938
full_name(const string & name)939 string full_name(const string& name) const {
940 if (str_util::StrContains(name, kColon)) {
941 LOG(ERROR) << name << " should not contain " << kColon;
942 }
943
944 return strings::StrCat(kFullNameRandomHex, kPipe, params_.prefix, kColon,
945 name);
946 }
947
948 // By default we model iterators using an unknown node, which acts as
949 // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)950 std::shared_ptr<model::Node> CreateNode(
951 IteratorContext* ctx, model::Node::Args args) const override {
952 return model::MakeUnknownNode(std::move(args));
953 }
954
955 // When modeling is enabled, this method disables autotuning for the given
956 // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)957 void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
958 if (iterator->node_) {
959 iterator->node_->set_autotune(false);
960 }
961 }
962
963 // When modeling is enabled, this method enables autotuning for the given
964 // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)965 void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
966 if (iterator->node_) {
967 iterator->node_->set_autotune(true);
968 }
969 }
970
971 // When modeling is enabled, this method records the fact that this iterator
972 // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)973 void RecordBufferDequeue(IteratorContext* ctx,
974 const std::vector<Tensor>& element) {
975 if (collect_resource_usage(ctx)) {
976 node_->record_buffer_event(-GetAllocatedBytes(element), -1);
977 }
978 }
979
980 // When modeling is enabled, this method records the fact that this iterator
981 // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)982 void RecordBufferEnqueue(IteratorContext* ctx,
983 const std::vector<Tensor>& element) {
984 if (collect_resource_usage(ctx)) {
985 node_->record_buffer_event(GetAllocatedBytes(element), 1);
986 }
987 }
988
989 // When modeling is enabled, this method records the fact that this iterator
990 // has produced an element.
RecordElement(IteratorContext * ctx)991 void RecordElement(IteratorContext* ctx) {
992 if (node_) {
993 node_->record_element();
994 }
995 }
996
997 // When modeling is enabled, this method records the fact that a thread of
998 // this iterator has started work.
999 void RecordStart(IteratorContext* ctx, bool stop_output = false) {
1000 if (collect_resource_usage(ctx)) {
1001 int64 now_nanos = EnvTime::NowNanos();
1002 if (stop_output && node_->output()) {
1003 node_->output()->record_stop(now_nanos);
1004 }
1005 node_->record_start(now_nanos);
1006 }
1007 }
1008
1009 // When modeling is enabled, this method records the fact that a thread of
1010 // this iterator has stopped work.
1011 void RecordStop(IteratorContext* ctx, bool start_output = false) {
1012 if (collect_resource_usage(ctx)) {
1013 int64 now_nanos = EnvTime::NowNanos();
1014 node_->record_stop(now_nanos);
1015 if (start_output && node_->output()) {
1016 node_->output()->record_start(now_nanos);
1017 }
1018 }
1019 }
1020
1021 private:
collect_resource_usage(IteratorContext * ctx)1022 inline bool collect_resource_usage(IteratorContext* ctx) {
1023 auto model = ctx->model();
1024 return model && model->collect_resource_usage() && node_;
1025 }
1026
1027 BaseParams params_;
1028 };
1029
1030 // Represents an iterator that is associated with a particular dataset
1031 // with a particular type.
1032 template <class DatasetType>
1033 class DatasetIterator : public DatasetBaseIterator {
1034 public:
1035 struct Params {
1036 // Borrowed pointer to the dataset.
1037 const DatasetType* dataset;
1038
1039 // Identifies the sequence of iterators leading up to this iterator.
1040 const string prefix;
1041 };
1042
DatasetIterator(const Params & params)1043 explicit DatasetIterator(const Params& params)
1044 : DatasetBaseIterator({params.dataset, params.prefix}),
1045 typed_dataset_(params.dataset) {}
1046
1047 // The dataset from which this iterator was created.
dataset()1048 const DatasetType* dataset() const final { return typed_dataset_; }
1049
1050 private:
1051 const DatasetType* const typed_dataset_; // Not owned.
1052 };
1053
1054 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1055 Status ParseScalarArgument(OpKernelContext* ctx,
1056 const StringPiece& argument_name, T* output) {
1057 const Tensor* argument_t;
1058 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1059 if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1060 return errors::InvalidArgument(argument_name, " must be a scalar");
1061 }
1062 *output = argument_t->scalar<T>()();
1063 return Status::OK();
1064 }
1065
1066 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1067 Status ParseVectorArgument(OpKernelContext* ctx,
1068 const StringPiece& argument_name,
1069 std::vector<T>* output) {
1070 const Tensor* argument_t;
1071 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1072 if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1073 return errors::InvalidArgument(argument_name, " must be a vector");
1074 }
1075 int size = argument_t->vec<T>().size();
1076 output->reserve(size);
1077 for (int i = 0; i < size; ++i) {
1078 output->push_back(argument_t->vec<T>()(i));
1079 }
1080 return Status::OK();
1081 }
1082
1083 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1084 // graph execution engine.
1085 class DatasetOpKernel : public OpKernel {
1086 public:
DatasetOpKernel(OpKernelConstruction * ctx)1087 explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1088
1089 void Compute(OpKernelContext* ctx) final;
1090
1091 // Indicates whether the given op corresponds to an op whose kernels subclass
1092 // the `DatasetOpKernel` class.
1093 static bool IsDatasetOp(const OpDef* op_def);
1094
1095 protected:
1096 // Subclasses should implement this method. It will be called during Compute
1097 // execution.
1098 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1099 };
1100
1101 // Encapsulates the work required to plug unary Datasets into the core
1102 // TensorFlow graph execution engine.
1103 class UnaryDatasetOpKernel : public DatasetOpKernel {
1104 public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1105 explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1106 : DatasetOpKernel(ctx) {}
1107
1108 protected:
1109 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1110 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1111 DatasetBase** output) = 0;
1112 };
1113
1114 // Encapsulates the work required to plug binary Datasets into the core
1115 // TensorFlow graph execution engine.
1116 class BinaryDatasetOpKernel : public DatasetOpKernel {
1117 public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1118 explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1119 : DatasetOpKernel(ctx) {}
1120
1121 protected:
1122 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1123 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1124 DatasetBase* another_input,
1125 DatasetBase** output) = 0;
1126 };
1127
1128 // A simple background worker that executes closures asynchronously and without
1129 // blocking.
1130 //
1131 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1132 // to avoid blocking an executor thread that may be required by the blocking
1133 // work.
1134 //
1135 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1136 // purpose because its current implementation (in Eigen) uses a finite-length
1137 // queue and will block the caller when full. This can lead to deadlock under
1138 // heavy load. Since the number of concurrent work items in each user of a
1139 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1140 // overhead is tolerable.
1141 class BackgroundWorker {
1142 public:
1143 BackgroundWorker(Env* env, const char* name);
1144
1145 ~BackgroundWorker();
1146
1147 void Schedule(std::function<void()> work_item);
1148
1149 private:
1150 void WorkerLoop();
1151
1152 Env* const env_;
1153 const char* const name_;
1154
1155 std::unique_ptr<Thread> thread_;
1156 mutex mu_;
1157 condition_variable cond_var_;
1158 bool cancelled_ GUARDED_BY(mu_) = false;
1159 std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
1160 };
1161
1162 // Registry of names of ops whose kernels subclass the `DatasetOpKernel` class.
1163 class DatasetOpRegistry {
1164 public:
1165 // Registers the op name.
1166 static void Register(const string& op_name);
1167
1168 // Checks whether the given op name has been registered.
1169 static bool IsRegistered(const string& op_name);
1170 };
1171
1172 // Helper class to register dataset op name.
1173 class DatasetOpRegistrar {
1174 public:
DatasetOpRegistrar(const string & op_name)1175 explicit DatasetOpRegistrar(const string& op_name) {
1176 DatasetOpRegistry::Register(op_name);
1177 }
1178 };
1179
1180 // Macro that can be used to register an op name of a dataset op.
1181 #define REGISTER_DATASET_OP_NAME(op_name) \
1182 REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, op_name)
1183
1184 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, op_name) \
1185 REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name)
1186
1187 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name) \
1188 static ::tensorflow::data::DatasetOpRegistrar \
1189 registrar__body__##ctr##__object(op_name)
1190
1191 } // namespace data
1192
1193 // TODO(b/114112161): Remove these aliases when all users have moved over to the
1194 // `tensorflow::data` namespace.
1195 using data::DatasetBase;
1196 using data::DatasetContext;
1197 using data::DatasetIterator;
1198 using data::DatasetOpKernel;
1199 using data::IteratorBase;
1200 using data::IteratorContext;
1201 using data::IteratorStateReader;
1202 using data::IteratorStateWriter;
1203 using data::SerializationContext;
1204 using data::UnaryDatasetOpKernel;
1205
1206 } // namespace tensorflow
1207
1208 #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1209