1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/collective.h"
27 #include "tensorflow/core/framework/dataset_metadata.pb.h"
28 #include "tensorflow/core/framework/dataset_options.pb.h"
29 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/function_handle_cache.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/model.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/register_types.h"
37 #include "tensorflow/core/framework/thread_factory.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/framework/variant_encode_decode.h"
40 #include "tensorflow/core/framework/variant_tensor_data.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/threadpool.h"
43 #include "tensorflow/core/lib/core/threadpool_interface.h"
44 #include "tensorflow/core/lib/strings/str_util.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/cpu_info.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/refcount.h"
49 #include "tensorflow/core/platform/tracing.h"
50
51 // Polymorphic datasets should support all primitive TensorFlow
52 // types. Use this macro to expand `m(T)` once for each primitive type
53 // `T`, e.g. to build a `switch` statement.
54 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
55
56 namespace tensorflow {
57
58 // Forward declarations to avoid introducing a dependency on headers in
59 // "tensorflow/core/graph/...".
60 class GraphDefBuilder;
61 class Node;
62
63 namespace data {
64
65 namespace internal {
66 // Merges Options from source to destination. If there is a conflict on a field,
67 // the field value from the source takes precedence.
68 void MergeOptions(const protobuf::Message& source,
69 protobuf::Message* destination);
70 void MergeOptions(const protobuf::MessageLite& source,
71 protobuf::MessageLite* destination);
72 } // namespace internal
73
74 using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
75
76 constexpr char kTFDataFunction[] = "_tf_data_function";
77
78 constexpr int kInfiniteCardinality = -1;
79 constexpr int kUnknownCardinality = -2;
80
81 // This constant is a magic number that is used (as a prefix) to identify keys
82 // used for serialization of iterator state.
83 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
84 constexpr char kPipe[] = "|";
85 constexpr char kColon[] = ":";
86
87 constexpr char kTFDataResourceTag[] = "tfdata";
88 constexpr char kTraceInfoUnavailable[] = "unavailable";
89 constexpr char kMetadata[] = "metadata";
90
91 constexpr char kCardinalityAttrForRewrite[] = "_cardinality";
92
93 class DatasetBase;
94 class SerializationContext;
95
IsTFDataFunction(const FunctionDef & func)96 inline bool IsTFDataFunction(const FunctionDef& func) {
97 auto iter = func.attr().find(data::kTFDataFunction);
98 return (iter != func.attr().end() && iter->second.b());
99 }
100
101 // Interface for reading values from a key-value store.
102 // Used for restoring iterator state. This class is thread safe.
103 // Please see comment on IteratorStateWriter for guidance around using the
104 // Read*(key, val) vs Read*(name, key, val).
105 class IteratorStateReader {
106 public:
107 // Determines whether the iterator state contains the given key.
108 virtual bool Contains(StringPiece key) const = 0;
109 virtual bool Contains(StringPiece name, StringPiece key) const = 0;
110
111 // Reads an integer for the given key.
112 virtual Status ReadScalar(StringPiece key, int64_t* val) const = 0;
113 virtual Status ReadScalar(StringPiece name, StringPiece key,
114 int64_t* val) const = 0;
115
116 // Reads a string for the given key.
117 virtual Status ReadScalar(StringPiece key, tstring* val) const = 0;
118 virtual Status ReadScalar(StringPiece name, StringPiece key,
119 tstring* val) const = 0;
120
121 // Reads a tensor for the given key.
122 // TODO(jsimsa): Remove non-FLR overrides once all callers are updated.
123 virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0;
124 virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key,
125 Tensor* val) const = 0;
126 virtual Status ReadTensor(StringPiece name, StringPiece key,
127 Tensor* val) const = 0;
128 virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name,
129 StringPiece key, Tensor* val) const = 0;
130
~IteratorStateReader()131 virtual ~IteratorStateReader() {}
132 };
133
134 // Interface for writing values to a key-value store.
135 // Used for saving iterator state. Not thread safe.
136 // The IteratorStateWriter creates a tensor for each unique iterator name it
137 // sees. For the Write*(key, val) API's the key is expected to encode this
138 // name as keys are required to be produced using the full_name() method.
139 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
140 // might exceed the 2 GB limit, you can pass an explicit name in via the
141 // Write*(name, key, val) APIs allowing you to further split up the state
142 // into more manageable chunks.
143 class IteratorStateWriter {
144 public:
145 // Writes an integer for the given key.
146 virtual Status WriteScalar(StringPiece key, const int64_t val) = 0;
147 virtual Status WriteScalar(StringPiece name, StringPiece key,
148 const int64_t val) = 0;
149
150 // Writes a string for the given key.
151 virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
152 virtual Status WriteScalar(StringPiece name, StringPiece key,
153 const tstring& val) = 0;
154
155 // Writes a tensor for the given key.
156 virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
157 virtual Status WriteTensor(StringPiece name, StringPiece key,
158 const Tensor& val) = 0;
159
~IteratorStateWriter()160 virtual ~IteratorStateWriter() {}
161 };
162
163 // Generates a full name key for iterator checkpointing. All keys generated for
164 // iterator checkpoints should go through this function.
165 std::string FullName(const std::string& prefix, const std::string& name);
166
167 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
168 class GraphDefBuilderWrapper {
169 public:
GraphDefBuilderWrapper(GraphDefBuilder * b)170 explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
171
172 // Adds a Const node with scalar value to the Graph.
173 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
174 // non-null if the method returns with an OK status.
175 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
176 template <typename T>
AddScalar(const T & val,Node ** output)177 Status AddScalar(const T& val, Node** output) {
178 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
179 val_t.scalar<T>()() = val;
180 AddTensorInternal(val_t, output);
181 if (*output == nullptr) {
182 return errors::Internal("AddScalar: Failed to build Const op.");
183 }
184 return OkStatus();
185 }
186
187 // Adds a Const node with vector value to the Graph.
188 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
189 // non-null if the method returns with an OK status.
190 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
191 // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
192 template <typename T>
AddVector(const std::vector<T> & val,Node ** output)193 Status AddVector(const std::vector<T>& val, Node** output) {
194 Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
195 TensorShape({static_cast<int64_t>(val.size())}));
196 for (size_t i = 0; i < val.size(); i++) {
197 val_t.flat<T>()(i) = val[i];
198 }
199 AddTensorInternal(val_t, output);
200 if (*output == nullptr) {
201 return errors::Internal("AddVector: Failed to build Const op.");
202 }
203 return OkStatus();
204 }
205
AddVector(const std::vector<string> & val,Node ** output)206 Status AddVector(const std::vector<string>& val, Node** output) {
207 Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
208 TensorShape({static_cast<int64_t>(val.size())}));
209 for (size_t i = 0; i < val.size(); i++) {
210 val_t.flat<tstring>()(i) = val[i];
211 }
212 AddTensorInternal(val_t, output);
213 if (*output == nullptr) {
214 return errors::Internal("AddVector: Failed to build Const op.");
215 }
216 return OkStatus();
217 }
218
219 // Adds a `Const` node for the given tensor value to the graph.
220 //
221 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
222 // non-null if the method returns with an OK status. The returned `Node`
223 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)224 Status AddTensor(const Tensor& val, Node** output) {
225 AddTensorInternal(val, output);
226 if (*output == nullptr) {
227 return errors::Internal("AddTensor: Failed to build Const op.");
228 }
229 return OkStatus();
230 }
231
232 // Adds a `Placeholder` node for the given tensor value to the graph.
233 //
234 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
235 // non-null if the method returns with an OK status. The returned `Node`
236 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)237 Status AddPlaceholder(const Tensor& val, Node** output) {
238 AddPlaceholderInternal(val, output);
239 if (*output == nullptr) {
240 return errors::Internal(
241 "AddPlaceholder: Failed to build Placeholder op.");
242 }
243 return OkStatus();
244 }
245
246 // Adds a node for the given dataset to the `Graph`. The value of
247 // `DatasetBase::type_string()` is used as the op type for the node. Values
248 // for the `output_types` and `output_shapes` node attributes are also written
249 // if those attributes are defined in the `OpDef`.
250 //
251 // If `use_dataset_name` is set, the value of `DatasetBase::node_name()` is
252 // used as the op name for the node. This argument should only be set when
253 // serializing `DatasetBase` instances which might not have been created
254 // through op kernel execution to make sure the dataset op name is preserved
255 // across serialization boundaries, which is in turn needed to make sure
256 // iterator checkpoints are valid across serialization boundaries. When
257 // `use_dataset_name` is set, the caller is responsible for making sure that
258 // the op name is unique across the graph.
259 //
260 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
261 // non-null if the method returns with an OK status. The returned `Node`
262 // pointer is owned by the backing `Graph` of `GraphDefBuilder`.
263 Status AddDataset(const DatasetBase* dataset,
264 const std::vector<Node*>& inputs, Node** output);
265 Status AddDataset(const DatasetBase* dataset,
266 const std::vector<Node*>& inputs,
267 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
268 Node** output);
269 Status AddDataset(
270 const DatasetBase* dataset,
271 const std::vector<std::pair<size_t, Node*>>& inputs,
272 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
273 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
274 Node** output);
275 Status AddDataset(
276 const DatasetBase* dataset,
277 const std::vector<std::pair<size_t, Node*>>& inputs,
278 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
279 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
280 bool use_dataset_name, Node** output);
281
282 // Adds a user-defined function with name `function_name` to the graph and
283 // recursively adds all functions it references. If a function with a matching
284 // name has already been added, returns with OK status. If a user-defined with
285 // name `function_name` is not found in the context's function library,
286 // returns an InvalidArgumentError. If the function with name `function_name`
287 // or any of its dependent functions are stateful, and the context does not
288 // explicitly permit stateful functions, returns an InvalidArgument error.
289 Status AddFunction(SerializationContext* ctx, const string& function_name,
290 const FunctionLibraryDefinition& lib_def);
291
292 template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)293 void BuildAttrValue(const T& value, AttrValue* attr) {
294 SetAttrValue(value, attr);
295 }
296
297 template <typename T>
BuildAttrValue(const T & value)298 AttrValue BuildAttrValue(const T& value) {
299 AttrValue attr;
300 SetAttrValue(value, &attr);
301 return attr;
302 }
303
304 protected:
builder()305 GraphDefBuilder* builder() { return b_; }
306
307 private:
308 void AddPlaceholderInternal(const Tensor& val, Node** output);
309 void AddTensorInternal(const Tensor& val, Node** output);
310 bool HasAttr(const string& op_type_name, const string& attr_name) const;
311
HasAttr(const OpDef * op_def,const string & attr_name)312 bool HasAttr(const OpDef* op_def, const string& attr_name) const {
313 for (const auto& attr : op_def->attr()) {
314 if (attr.name() == attr_name) {
315 return true;
316 }
317 }
318 return false;
319 }
320
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)321 Status AddAttrFunctions(SerializationContext* ctx,
322 const AttrValue& attr_value,
323 const FunctionLibraryDefinition& lib_def) {
324 if (attr_value.has_func()) {
325 TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
326 } else if (attr_value.has_list()) {
327 for (const NameAttrList& name_attr_list : attr_value.list().func()) {
328 TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
329 }
330 }
331 return OkStatus();
332 }
333
334 GraphDefBuilder* b_;
335 };
336
337 class StatsAggregator;
338
339 // A utility class for running a function and ensuring that there is always a
340 // `tensorflow::data` symbol on the stack.
341 class Runner {
342 public:
~Runner()343 virtual ~Runner() {}
344
345 // Runs the given function.
346 virtual void Run(const std::function<void()>& f) = 0;
347
348 // Returns a global singleton Runner.
349 static Runner* get();
350 };
351
352 // A class which provides a sequence of splits. Splits represent subdivisions of
353 // a dataset, e.g. filenames or ranges within files. We use splitting to
354 // partition input data into smaller pieces for distributed processing (see
355 // go/tf-data-splitting-design).
356 //
357 // Datasets provide a `MakeSplitProvider` method to expose a listing of their
358 // splits.
359 //
360 // Iterators created with a split provider will only iterate over the splits
361 // provided by the split provider.
362 class SplitProvider {
363 public:
~SplitProvider()364 virtual ~SplitProvider() {}
365 // Stores the next split in `*split`, setting `*end_of_splits` to indicate
366 // whether there were any splits left.
367 virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
368 // Resets the split provider to its beginning.
369 virtual Status Reset() = 0;
370 // Saves the state of this split provider.
371 virtual Status Save(std::function<std::string(std::string)> full_name,
372 IteratorStateWriter* writer) = 0;
373 // Restores the state of this split provider.
374 virtual Status Restore(std::function<std::string(std::string)> full_name,
375 IteratorStateReader* reader) = 0;
376 };
377
378 // Returns the runner threadpool size from an OpKernelContext.
379 int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx);
380
381 // A cut-down version of `OpKernelContext` for running computations in
382 // iterators. Note that we cannot simply use `OpKernelContext` here because we
383 // might run computation in an iterator whose lifetime is not nested within the
384 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
385 //
386 // TODO(mrry): We're making some daring assumptions about the lifetime of the
387 // runner passed in here. A runner will be deleted when the original step ends,
388 // but all existing runners only close over session-lifetime (or longer-lived)
389 // state, so we can make a copy of the function. There's nothing in the
390 // definition of the API from which we took the runner to guarantee that what we
391 // are doing is safe. We should formalize the properties here.
392 class IteratorContext {
393 public:
394 struct Params {
ParamsParams395 explicit Params(IteratorContext* ctx)
396 : allocator_getter(ctx->allocator_getter()),
397 cancellation_manager(ctx->cancellation_manager()),
398 collective_executor(ctx->collective_executor()),
399 env(ctx->env()),
400 flr(ctx->flr()),
401 function_handle_cache(ctx->function_handle_cache()),
402 interleave_depth(ctx->interleave_depth()),
403 is_restoring(ctx->is_restoring()),
404 model(ctx->model()),
405 options(ctx->options()),
406 resource_mgr(ctx->resource_mgr()),
407 runner(*(ctx->runner())),
408 runner_threadpool_size(ctx->runner_threadpool_size()),
409 split_providers(ctx->split_providers()),
410 stats_aggregator(ctx->stats_aggregator()),
411 thread_factory(ctx->thread_factory()),
412 thread_pool(ctx->thread_pool()) {}
413
ParamsParams414 explicit Params(OpKernelContext* ctx)
415 : collective_executor(ctx->collective_executor()),
416 env(ctx->env()),
417 flr(ctx->function_library()) {
418 // NOTE: need reinterpret_cast because function.h forward-declares Device.
419 DeviceBase* device =
420 reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
421 allocator_getter = [device](AllocatorAttributes attrs) {
422 return device->GetAllocator(attrs);
423 };
424
425 runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
426
427 // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
428 // that a symbol in the tensorflow::data namespace is always on the stack
429 // when executing a function inside a Dataset.
430 runner = std::bind(
431 [](
432 // Note: `runner` is a const reference to avoid copying it.
433 const std::function<void(std::function<void()>)>& ctx_runner,
434 std::function<void()> fn) {
435 std::function<void()> wrapped_fn = std::bind(
436 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
437 std::move(fn));
438 ctx_runner(std::move(wrapped_fn));
439 },
440 *ctx->runner(), std::placeholders::_1);
441 }
442
443 // The Allocator to be used to allocate the output of an iterator.
444 std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
445
446 // The CancellationManager to be used to cancel execution of ops.
447 CancellationManager* cancellation_manager;
448
449 // Collective support.
450 CollectiveExecutor* collective_executor = nullptr;
451
452 // Interface to operating system functionality.
453 Env* env = nullptr;
454
455 // The FunctionLibraryRuntime object to be used to make function calls.
456 FunctionLibraryRuntime* flr = nullptr;
457
458 // A FunctionHandleCache that owns all the function handles. Not owned.
459 FunctionHandleCache* function_handle_cache = nullptr;
460
461 // Records the number of ParallelInterleave operations in the path from the
462 // root node to this node (not including this node) in the input pipeline
463 // tree.
464 int64 interleave_depth = 0;
465
466 // Marks whether the iterator is restored from a checkpoint.
467 bool is_restoring = false;
468
469 // If non-null, identifies the object used for performance modeling.
470 std::shared_ptr<model::Model> model = nullptr;
471
472 // The input pipeline options.
473 const Options* options = nullptr;
474
475 // A resource manager for storing dataset-related state, e.g. random
476 // seeds or cached tensors. Not owned.
477 ResourceMgr* resource_mgr = nullptr;
478
479 // Function call support.
480 std::function<void(std::function<void()>)> runner = nullptr;
481
482 // Number of threads used for executing user-defined functions.
483 int32 runner_threadpool_size = 0;
484
485 // Split providers indicating which splits to process. May be empty,
486 // indicating that the iterator should process all splits.
487 std::vector<std::shared_ptr<SplitProvider>> split_providers;
488
489 // The `StatsAggregator` object to record statistics about the iterator.
490 //
491 // TODO(b/147325552): Remove this API and any of its uses after we switch to
492 // using C++ based implementation for tf.data options (on 4/12/2021).
493 std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
494
495 // A factory for creating threads to perform blocking work.
496 std::shared_ptr<ThreadFactory> thread_factory = nullptr;
497
498 // A shared thread pool to schedule computation into.
499 thread::ThreadPoolInterface* thread_pool = nullptr;
500 };
501
IteratorContext(IteratorContext * ctx)502 explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
503
IteratorContext(OpKernelContext * ctx)504 explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
505
IteratorContext(Params params)506 explicit IteratorContext(Params params) : params_(std::move(params)) {}
507
allocator(AllocatorAttributes attrs)508 Allocator* allocator(AllocatorAttributes attrs) {
509 return params_.allocator_getter(attrs);
510 }
511
allocator_getter()512 std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
513 return params_.allocator_getter;
514 }
515
cancellation_manager()516 CancellationManager* cancellation_manager() {
517 return params_.cancellation_manager;
518 }
519
collective_executor()520 CollectiveExecutor* collective_executor() {
521 return params_.collective_executor;
522 }
523
env()524 Env* env() const { return params_.env; }
525
flr()526 FunctionLibraryRuntime* flr() { return params_.flr; }
527
function_handle_cache()528 FunctionHandleCache* function_handle_cache() {
529 return params_.function_handle_cache;
530 }
531
interleave_depth()532 int64 interleave_depth() { return params_.interleave_depth; }
533
is_restoring()534 bool is_restoring() { return params_.is_restoring; }
535
model()536 const std::shared_ptr<model::Model>& model() { return params_.model; }
537
options()538 const Options* options() { return params_.options; }
539
resource_mgr()540 ResourceMgr* resource_mgr() { return params_.resource_mgr; }
541
runner()542 std::function<void(std::function<void()>)>* runner() {
543 return ¶ms_.runner;
544 }
545
runner_threadpool_size()546 int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
547
split_providers()548 std::vector<std::shared_ptr<SplitProvider>> split_providers() {
549 return params_.split_providers;
550 }
551
stats_aggregator()552 std::shared_ptr<StatsAggregator> stats_aggregator() {
553 return params_.stats_aggregator;
554 }
555
thread_factory()556 const std::shared_ptr<ThreadFactory>& thread_factory() {
557 return params_.thread_factory;
558 }
559
thread_pool()560 thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
561
CreateThreadPool(const string & name,int num_threads)562 std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
563 int num_threads) {
564 if (params_.thread_pool) {
565 // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
566 // is an instance of `thread::ThreadPoolInterface`). Notably, the
567 // ownership of `params_.thread_pool` is *not* transferred onto the newly
568 // created `ThreadPool` instance.
569 return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
570 } else {
571 return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
572 name, num_threads,
573 /*low_latency_hint=*/false);
574 }
575 }
576
StartThread(const string & name,std::function<void ()> fn)577 std::unique_ptr<Thread> StartThread(const string& name,
578 std::function<void()> fn) {
579 if (params_.thread_factory) {
580 return params_.thread_factory->StartThread(name, std::move(fn));
581 } else {
582 return absl::WrapUnique(
583 Env::Default()->StartThread({}, name, std::move(fn)));
584 }
585 }
586
587 private:
588 Params params_;
589 };
590
591 // Aggregates runtime support needed for dataset and iterator serialization.
592 class SerializationContext {
593 public:
594 // Enum describing what to do during serialization when external state is
595 // encountered.
596 enum class ExternalStatePolicy : int64 {
597 // Proceed with serialization, but log a warning about what state will be
598 // lost.
599 kWarn = 0,
600 // Proceed with serialization without logging any warning.
601 kIgnore = 1,
602 // Fail the serialization with an error.
603 kFail = 2,
604 };
605
606 // Handles the CheckExternalState status according to the external state
607 // policy.
HandleCheckExternalStateStatus(Status s)608 Status HandleCheckExternalStateStatus(Status s) {
609 if (s.ok()) {
610 return s;
611 }
612 switch (params_.external_state_policy) {
613 case ExternalStatePolicy::kWarn:
614 LOG(WARNING) << s.ToString();
615 return OkStatus();
616 case ExternalStatePolicy::kIgnore:
617 VLOG(2) << "Ignoring error status: " << s.ToString();
618 return OkStatus();
619 case ExternalStatePolicy::kFail:
620 return s;
621 }
622 LOG(FATAL) << "Control should never reach here";
623 }
624
625 struct Params {
ParamsParams626 explicit Params() {}
627
ParamsParams628 explicit Params(OpKernelContext* ctx)
629 : resource_mgr(ctx->resource_manager()),
630 device_name(ctx->device()->attributes().name()) {}
631
632 std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
633
634 // Indicates what to do if the dataset depends on external state.
635 ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
636
637 // Indicates whether the serialization is for rewrites.
638 //
639 // If true:
640 // * A dataset that doesn't implement serialization is replaced with a
641 // placeholder returned in `input_list`.
642 // * Data tensors are replaced with a placeholder returned in
643 // `input_list`.
644 // * Datasets that use random seeds should not serialize the random seeds.
645 // This doesn't affect datasets that use fixed seeds; fixed seeds will
646 // always be preserved.
647 // * Cardinality is serialized as an unregistered attribute
648 // `_cardinality`.
649 // If false:
650 // * A dataset that doesn't implement serialization should result in an
651 // error.
652 // * Data tensors (potentially large) should be serialized.
653 // * Datasets that use random seeds should serialize the random seeds.
654 bool is_graph_rewrite = false;
655
656 // A resource manager for looking up resources during serialization.
657 ResourceMgr* resource_mgr;
658
659 // The name of the device doing the serialization.
660 std::string device_name;
661 };
662
SerializationContext(Params params)663 explicit SerializationContext(Params params) : params_(params) {}
664
input_list()665 std::vector<std::pair<string, Tensor>>* input_list() {
666 return params_.input_list;
667 }
668
external_state_policy()669 ExternalStatePolicy external_state_policy() const {
670 return params_.external_state_policy;
671 }
672
is_graph_rewrite()673 bool is_graph_rewrite() const { return params_.is_graph_rewrite; }
674
resource_mgr()675 const ResourceMgr* resource_mgr() const { return params_.resource_mgr; }
676
device_name()677 const std::string& device_name() const { return params_.device_name; }
678
679 private:
680 Params params_;
681
682 TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
683 };
684
685 // Represents the current position in a range of outputs, where the
686 // range of outputs is typically represented by an `DatasetBase`,
687 // defined below.
688 class IteratorBase {
689 public:
~IteratorBase()690 virtual ~IteratorBase() {
691 for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
692 (*rit)();
693 }
694 }
695
696 // Gets the next output from the range that this iterator is traversing.
697 //
698 // If at least one output remains in this iterator's range, that
699 // output will be stored in `*out_tensors` and `false` will be
700 // stored in `*end_of_sequence`.
701 //
702 // If no more outputs remain in this iterator's range, `true` will be stored
703 // in `*end_of_sequence`, and `*out_tensors` will be empty.
704 //
705 // Implementations should never return `OutOfRange` error. If at end of
706 // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
707 // Internally raised `OutOfRange` errors that do not imply end of sequence
708 // should be converted to a different error type before being propagated to
709 // the caller.
710 //
711 // Implementations must explicitly set `*end_of_sequence = false` if an
712 // `Status::OK()` status is returned and the iterator is not at the end of the
713 // sequence.
714 //
715 // This method is thread-safe.
716 //
717 // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
718 // potentially remove this method.
719 virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
720 bool* end_of_sequence) = 0;
721
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)722 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
723 bool* end_of_sequence) {
724 return GetNext(&ctx, out_tensors, end_of_sequence);
725 }
726
727 // Skips the next `num_to_skip` outputs from the range that this iterator
728 // is traversing.
729 //
730 // If there are not enough outputs to skip, it will set
731 // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will
732 // store the number of outputs that are skipped. When `*end_of_sequence` is
733 // `false`, `*num_skipped` should equal to `num_to_skip`.
734 virtual Status Skip(IteratorContext* ctx, int num_to_skip,
735 bool* end_of_sequence, int* num_skipped) = 0;
736
Skip(IteratorContext && ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)737 virtual Status Skip(IteratorContext&& ctx, int num_to_skip,
738 bool* end_of_sequence, int* num_skipped) {
739 return Skip(&ctx, num_to_skip, end_of_sequence, num_skipped);
740 }
741
742 // Returns a vector of DataType values, representing the respective
743 // element types of each tuple component in the outputs of this
744 // iterator.
745 virtual const DataTypeVector& output_dtypes() const = 0;
746
747 // Returns a vector of tensor shapes, representing the respective
748 // (and possibly partially defined) shapes of each tuple component
749 // in the outputs of this iterator.
750 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
751
752 // Returns a string that identifies the sequence of iterators leading up to
753 // this iterator.
754 virtual const string& prefix() const = 0;
755
756 // Performs initialization that needs to happen outside of a constructor to
757 // properly propagate errors.
Initialize(IteratorContext * ctx)758 virtual Status Initialize(IteratorContext* ctx) { return OkStatus(); }
759
760 // Performs initialization of the base iterator.
761 Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
762
763 // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)764 virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
765 int64_t start_us = EnvTime::NowMicros();
766 TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
767 VLOG(1) << "Saved " << prefix() << " in "
768 << (EnvTime::NowMicros() - start_us) << "us";
769 return OkStatus();
770 }
771
772 protected:
773 // Returns a node that models this iterator.
774 virtual std::shared_ptr<model::Node> CreateNode(
775 IteratorContext* ctx, model::Node::Args args) const = 0;
776
777 // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)778 virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
779 int64_t start_us = EnvTime::NowMicros();
780 TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
781 VLOG(1) << "Restored " << prefix() << " in "
782 << (EnvTime::NowMicros() - start_us) << "us";
783 return OkStatus();
784 }
785
786 // This is needed so that sub-classes of IteratorBase can call
787 // `SaveInternal` on their input iterators.
SaveInput(SerializationContext * ctx,IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)788 Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
789 const std::unique_ptr<IteratorBase>& input) {
790 return input->Save(ctx, writer);
791 }
792
793 // This is needed so that sub-classes of IteratorBase can call
794 // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)795 Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
796 const std::unique_ptr<IteratorBase>& input) {
797 return input->Restore(ctx, reader);
798 }
799
RestoreInput(IteratorContext && ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)800 Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader,
801 const std::unique_ptr<IteratorBase>& input) {
802 return RestoreInput(&ctx, reader, input);
803 }
804
805 // Saves the state of this iterator.
806 //
807 // This method is used to store the state of the iterator in a checkpoint.
808 // implementations have an override.
809 virtual Status SaveInternal(SerializationContext* ctx,
810 IteratorStateWriter* writer) = 0;
811
812 // Restores the state of this iterator.
813 //
814 // This method is used to restore the state of the iterator from a checkpoint.
815 //
816 // Implementations may assume that the iterator is in a clean state. That is,
817 // its `Initialize` method has been called, but its `GetNext` method has
818 // never been called.
819 // implementations have an override.
820 virtual Status RestoreInternal(IteratorContext* ctx,
821 IteratorStateReader* reader) = 0;
822
823 // Returns a pointer to the node representing this iterator in the performance
824 // model. It may be null, if performance modeling is not enabled for this
825 // iterator.
model_node()826 std::shared_ptr<model::Node> model_node() const { return node_; }
827
828 // Returns the number of elements produced by this iterator.
num_elements()829 int64_t num_elements() const {
830 if (node_) return node_->num_elements();
831 return 0;
832 }
833
834 private:
835 // For access to `AddCleanupFunction` and `Restore`.
836 friend class DatasetBase;
837 friend class DatasetBaseIterator; // for access to `node_`
838
839 std::vector<std::function<void()>> cleanup_fns_;
840 std::shared_ptr<model::Node> node_ = nullptr;
841 const IteratorBase* parent_ = nullptr; // Not owned.
842 int64_t id_ = 0;
843 int64_t parent_id_ = 0;
844 };
845
846 // Represents runtime information needed to construct a dataset.
847 class DatasetContext {
848 public:
849 struct Params {
850 string type_string; // op type name of this dataset.
851 string node_name; // graph node name of this dataset op, uniquely
852 // identifying the dataset in the graph.
853 };
854
DatasetContext(Params params)855 explicit DatasetContext(Params params) : params_(std::move(params)) {}
856
DatasetContext(OpKernelContext * ctx)857 explicit DatasetContext(OpKernelContext* ctx) {
858 params_.type_string = ctx->op_kernel().type_string();
859 params_.node_name = ctx->op_kernel().name();
860 }
861
type_string()862 const string& type_string() const { return params_.type_string; }
node_name()863 const string& node_name() const { return params_.node_name; }
864
865 private:
866 Params params_;
867 };
868
869 // Returns the number of bytes allocated for the given tensor.
870 int64_t GetAllocatedBytes(const std::vector<Tensor>& element);
871
872 // Returns the estimated memory usage in bytes of the given tensor.
873 int64_t GetTotalBytes(const std::vector<Tensor>& element);
874
875 // Validates and extracts a `DatasetBase` object from `tensor`.
876 //
877 // `tensor` must have been written by a call to SetVariantTensorToDataset().
878 //
879 // The retrieved pointer is a borrowed reference to the dataset, which is owned
880 // by the tensor. The consumer must either acquire its own reference to the
881 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
882 // destroyed or mutated while the retrieved pointer is in use.
883 Status GetDatasetFromVariantTensor(const Tensor& tensor,
884 DatasetBase** out_dataset);
885
886 // Stores a `DatasetBase` object in `tensor`.
887 //
888 // The ownership of `dataset` is transferred to `tensor`.
889 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
890
891 // Represents a (potentially infinite) range of outputs, where each
892 // output is a tuple of tensors.
893 class DatasetBase : public core::RefCounted {
894 public:
895 // Key for storing the Dataset graph in the serialized format.
896 TF_EXPORT static const char kDatasetGraphKey[];
897
898 // Key for storing the output node of the Dataset graph in the serialized
899 // format.
900 TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
901
DatasetBase(DatasetContext && ctx)902 explicit DatasetBase(DatasetContext&& ctx)
903 : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
904
905 // Op type name of this dataset.
type_string()906 const string& type_string() const { return type_string_; }
907
908 // Graph node name of this dataset op, uniquely identifying the dataset in
909 // the graph.
node_name()910 const string& node_name() const { return node_name_; }
911
912 // Initializes the dataset.
913 void Initialize(const Metadata& metadata);
914
metadata()915 const Metadata& metadata() const { return metadata_; }
916
options()917 const Options& options() const { return options_; }
918
num_sources()919 int64_t num_sources() const { return num_sources_; }
920
921 // Returns a new iterator for iterating over the range of elements in
922 // this dataset.
923 //
924 // This method may be called multiple times on the same instance,
925 // and the resulting iterators will have distinct state. Each
926 // iterator will traverse all elements in this dataset from the
927 // start.
928 //
929 // The prefix identifies the sequence of iterators leading up to the newly
930 // created iterator.
931 Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent,
932 const string& output_prefix,
933 std::unique_ptr<IteratorBase>* iterator) const;
934
MakeIterator(IteratorContext && ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)935 Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent,
936 const string& output_prefix,
937 std::unique_ptr<IteratorBase>* iterator) const {
938 return MakeIterator(&ctx, parent, output_prefix, iterator);
939 }
940
941 // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)942 Status MakeIteratorFromCheckpoint(
943 IteratorContext* ctx, const string& output_prefix,
944 IteratorStateReader* reader,
945 std::unique_ptr<IteratorBase>* iterator) const {
946 std::unique_ptr<IteratorBase> it;
947 IteratorContext::Params params(ctx);
948 params.is_restoring = true;
949 IteratorContext restore_ctx(std::move(params));
950 TF_RETURN_IF_ERROR(MakeIterator(&restore_ctx,
951 /*parent=*/nullptr, output_prefix, &it));
952 TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader));
953 *iterator = std::move(it);
954 return OkStatus();
955 }
956
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)957 Status MakeIteratorFromCheckpoint(
958 IteratorContext&& ctx, const string& output_prefix,
959 IteratorStateReader* reader,
960 std::unique_ptr<IteratorBase>* iterator) const {
961 return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
962 }
963
964 // Returns a split provider which partitions the dataset's data into splits
965 // and provides them in a sequence. The split provider is stored in
966 // `*split_provider`.
967 virtual Status MakeSplitProviders(
968 std::vector<std::unique_ptr<SplitProvider>>* split_providers) const;
969
970 // Returns a vector of DataType values, representing the respective
971 // element types of each tuple component in the outputs of this
972 // dataset.
973 virtual const DataTypeVector& output_dtypes() const = 0;
974
975 // Returns a vector of tensor shapes, representing the respective
976 // (and possibly partially defined) shapes of each tuple component
977 // in the outputs of this dataset.
978 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
979
980 // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()981 virtual int64_t AllocatedBytes() const { return 0; }
982
983 // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()984 virtual int64_t TotalBytes() const { return 0; }
985
986 // Returns the cardinality of this dataset.
987 // TODO(shilpakrish): Remove this overload once all callers are migrated
988 // to the API which passes in the options parameter.
989 ABSL_DEPRECATED("Use the overload that passes in the options parameter.")
990 int64_t Cardinality() const;
991
992 // Returns the cardinality of this dataset based on the options.
993 int64_t Cardinality(CardinalityOptions options) const;
994
995 // Internal implementation of cardinality for a dataset.
996 // TODO(shilpakrish): Remove this overload once all callers are migrated
997 // to the API which passes in the options parameter.
998 ABSL_DEPRECATED("Use the overload that passes in the options parameter.")
CardinalityInternal()999 virtual int64_t CardinalityInternal() const { return kUnknownCardinality; }
1000
1001 // Internal implementation of cardinality for a dataset based on the options.
CardinalityInternal(CardinalityOptions options)1002 virtual int64_t CardinalityInternal(CardinalityOptions options) const {
1003 return kUnknownCardinality;
1004 }
1005
1006 // A human-readable debug string for this dataset.
1007 virtual string DebugString() const = 0;
1008
1009 // Stores the dataset's input datasets in `*inputs`. The pointers stored in
1010 // `*inputs` are borrowed. The only valid non-ok return status is
1011 // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
1012 // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
1013 // default implementation of `MakeSplitProvider` when there is a single input
1014 // dataset.
1015 virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
1016
1017 // Indicates whether the dataset depends on any external state which would
1018 // prevent it from being serializable. If so, the method returns
1019 // `errors::FailedPrecondition` with a message that identifies the external
1020 // state. Otherwise, the method returns `Status::OK()`.
1021 virtual Status CheckExternalState() const = 0;
1022
1023 // Indicates whether the dataset is compatible with random access.
1024 Status CheckRandomAccessCompatible(const int64 index) const;
1025
1026 // Return the element at a particular index for a randomly accessible dataset.
1027 virtual Status Get(OpKernelContext* ctx, int64 index,
1028 std::vector<Tensor>* out_tensors) const;
1029
1030 // Return a finalized version of the dataset. The returned DatasetBase is
1031 // unowned and lives for as long as this dataset.
1032 virtual StatusOr<DatasetBase*> Finalize(
1033 OpKernelContext* ctx,
1034 std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
1035 make_finalized_dataset) const;
1036
1037 // Wrapper around a GraphDefBuilder which provides support for serializing
1038 // Datasets as GraphDefs.
1039 class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
1040 public:
DatasetGraphDefBuilder(GraphDefBuilder * b)1041 explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
1042 : GraphDefBuilderWrapper(b) {}
1043 Status AddInputDataset(SerializationContext* ctx,
1044 const DatasetBase* dataset, Node** output);
1045 Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val,
1046 Node** output);
1047 Status AddIdentity(SerializationContext* ctx,
1048 const std::string& name_prefix, Node** input,
1049 Node** output);
1050
1051 private:
1052 Status AddDatasetOrTensorHelper(SerializationContext* ctx,
1053 const Tensor& val, Node** output);
1054 Status AddResourceHelper(SerializationContext* ctx, const Tensor& val,
1055 Node** output);
1056 };
1057
1058 protected:
1059 friend class CapturedFunction;
1060
1061 // Serializes the dataset into a `GraphDef`, which has two uses:
1062 //
1063 // 1) To perform static input pipeline optimizations, tf.data serializes the
1064 // dataset graph, applies graph rewrites, and then deserializes the graph.
1065 // If a subclass of `DatasetBase` does not implement this method, then it will
1066 // be excluded from static optimizations (and so will any upstream datasets).
1067 //
1068 // 2) To save the dataset so that it can restore at a later point (possibly in
1069 // different environment). If a subclass of `DatasetBase` does not implement
1070 // this method, then this migration will not be possible.
1071 virtual Status AsGraphDefInternal(SerializationContext* ctx,
1072 DatasetGraphDefBuilder* b,
1073 Node** node) const = 0;
1074
1075 virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
1076 const string& prefix) const = 0;
1077
set_options(const Options & options)1078 void set_options(const Options& options) { options_ = options; }
1079
1080 private:
1081 // Computes and stores the cardinality of a given dataset.
1082 Status ComputeCardinality();
1083
1084 // Computes the number of source datasets feeding into this dataset. A source
1085 // dataset is a leaf in the subtree of dataset inputs.
1086 Status ComputeNumSources();
1087
1088 // Merges options from inputs to this dataset. If there is a conflict in a
1089 // field value, the options set on this dataset takes precedence over those in
1090 // the inputs. The order of precedence on the inputs is in the same order as
1091 // how they appear for this dataset.
1092 Status MergeOptionsFromInputs();
1093
1094 const string type_string_;
1095 const string node_name_;
1096 Metadata metadata_;
1097 Options options_;
1098 mutable mutex mu_;
1099 mutable mutex cardinality_mu_;
1100 mutable core::RefCountPtr<DatasetBase> finalized_dataset_;
1101 // The number of source datasets feeding into the dataset. A source dataset
1102 // is a leaf in the subtree of dataset inputs.
1103 int64_t num_sources_ = -1;
1104 mutable int64_t cardinality_ TF_GUARDED_BY(cardinality_mu_) =
1105 kUnknownCardinality;
1106 };
1107
1108 // Represents an iterator that is associated with a particular dataset.
1109 class DatasetBaseIterator : public IteratorBase {
1110 public:
1111 struct BaseParams {
1112 // Owns one reference on the shared dataset object.
1113 const DatasetBase* dataset;
1114
1115 // Identifies the sequence of iterators leading up to this iterator.
1116 const string prefix;
1117 };
1118
1119 explicit DatasetBaseIterator(const BaseParams& params);
1120
1121 ~DatasetBaseIterator() override;
1122
dataset()1123 virtual const DatasetBase* dataset() const { return params_.dataset; }
1124
output_dtypes()1125 const DataTypeVector& output_dtypes() const override {
1126 return params_.dataset->output_dtypes();
1127 }
1128
output_shapes()1129 const std::vector<PartialTensorShape>& output_shapes() const override {
1130 return params_.dataset->output_shapes();
1131 }
1132
prefix()1133 const string& prefix() const override { return params_.prefix; }
1134
1135 // Returns a name to be used for the TraceMe event.
1136 //
1137 // NOTE: TraceMe supports passing key-value pairs of "arguments" using the
1138 // following format "name#arg_1=value_,...,arg_n=value_n".
1139 string BuildTraceMeName();
1140
1141 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
1142 bool* end_of_sequence) final;
1143
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1144 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
1145 bool* end_of_sequence) {
1146 return GetNext(&ctx, out_tensors, end_of_sequence);
1147 }
1148
1149 Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence,
1150 int* num_skipped) final;
1151
Save(SerializationContext * ctx,IteratorStateWriter * writer)1152 Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
1153 VLOG(2) << "Attempting to save checkpoints on iterator (prefix: "
1154 << prefix() << ") from " << dataset()->DebugString();
1155 return IteratorBase::Save(ctx, writer);
1156 }
1157
1158 protected:
Restore(IteratorContext * ctx,IteratorStateReader * reader)1159 Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final {
1160 VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: "
1161 << prefix() << ") from " << dataset()->DebugString();
1162 return IteratorBase::Restore(ctx, reader);
1163 }
1164
1165 // Internal implementation of GetNext that is wrapped in tracing logic.
1166 //
1167 // See the docstring of `GetNext` method regaring the contract for
1168 // `out_tensors` and `end_of_sequence`. Implementations may assume that
1169 // `*out_tensors` is empty.
1170 virtual Status GetNextInternal(IteratorContext* ctx,
1171 std::vector<Tensor>* out_tensors,
1172 bool* end_of_sequence) = 0;
1173
1174 // Internal implementation of Skip that is wrapped in tracing logic
1175 virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip,
1176 bool* end_of_sequence, int* num_skipped);
1177
full_name(const string & name)1178 string full_name(const string& name) const {
1179 return FullName(params_.prefix, name);
1180 }
1181
1182 // Returns a map of key-value pairs to included in the TraceMe string.
GetTraceMeMetadata()1183 virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; }
1184
1185 // By default we model iterators using an unknown node, which acts as
1186 // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)1187 std::shared_ptr<model::Node> CreateNode(
1188 IteratorContext* ctx, model::Node::Args args) const override {
1189 return model::MakeUnknownNode(std::move(args));
1190 }
1191
1192 // When modeling is enabled, this method disables autotuning for the given
1193 // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)1194 void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1195 if (iterator->node_) {
1196 iterator->node_->set_autotune(false);
1197 }
1198 }
1199
1200 // When modeling is enabled, this method enables autotuning for the given
1201 // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)1202 void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1203 if (iterator->node_) {
1204 iterator->node_->set_autotune(true);
1205 }
1206 }
1207
1208 // When modeling is enabled, this method records the fact that this iterator
1209 // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)1210 void RecordBufferDequeue(IteratorContext* ctx,
1211 const std::vector<Tensor>& element) {
1212 if (collect_resource_usage(ctx)) {
1213 node_->record_buffer_event(-GetAllocatedBytes(element), -1);
1214
1215 DCHECK_GE(node_->buffered_elements(), 0);
1216 }
1217 }
1218
1219 // When modeling is enabled, this method records the fact that this iterator
1220 // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)1221 void RecordBufferEnqueue(IteratorContext* ctx,
1222 const std::vector<Tensor>& element) {
1223 if (collect_resource_usage(ctx)) {
1224 node_->record_buffer_event(GetAllocatedBytes(element), 1);
1225 }
1226 }
1227
1228 // When modeling is enabled, this method records the fact that this iterator
1229 // has produced an element and its size in bytes.
RecordElement(IteratorContext * ctx,std::vector<Tensor> * out_tensors)1230 void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
1231 if (collect_resource_usage(ctx)) {
1232 int64_t num_bytes = GetAllocatedBytes(*out_tensors);
1233 node_->record_element();
1234 node_->record_bytes_produced(num_bytes);
1235 if (node_->output()) {
1236 node_->output()->record_bytes_consumed(num_bytes);
1237 }
1238 }
1239 }
1240
1241 // When modeling is enabled, this method records the fact that a thread of
1242 // this iterator has started work.
RecordStart(IteratorContext * ctx)1243 void RecordStart(IteratorContext* ctx) {
1244 if (collect_resource_usage(ctx)) {
1245 int64_t now_nanos = EnvTime::NowNanos();
1246 node_->record_start(now_nanos);
1247 }
1248 }
1249
1250 // When modeling is enabled, this method records the fact that a thread of
1251 // this iterator has stopped work.
RecordStop(IteratorContext * ctx)1252 void RecordStop(IteratorContext* ctx) {
1253 if (collect_resource_usage(ctx)) {
1254 int64_t now_nanos = EnvTime::NowNanos();
1255 node_->record_stop(now_nanos);
1256 }
1257 }
1258
1259 // Returns whether work is currently being recorded, i.e. whether we are
1260 // currently between a `RecordStart` and a `RecordStop`.
IsRecording(IteratorContext * ctx)1261 bool IsRecording(IteratorContext* ctx) {
1262 return node_ && node_->is_recording();
1263 }
1264
1265 private:
collect_resource_usage(IteratorContext * ctx)1266 bool collect_resource_usage(IteratorContext* ctx) {
1267 return ctx->model() && node_;
1268 }
1269
1270 string traceme_metadata_;
1271 BaseParams params_;
1272 };
1273
1274 // Represents an iterator that is associated with a particular dataset
1275 // with a particular type.
1276 template <class DatasetType>
1277 class DatasetIterator : public DatasetBaseIterator {
1278 public:
1279 struct Params {
1280 // Borrowed pointer to the dataset.
1281 const DatasetType* dataset;
1282
1283 // Identifies the sequence of iterators leading up to this iterator.
1284 const string prefix;
1285 };
1286
DatasetIterator(const Params & params)1287 explicit DatasetIterator(const Params& params)
1288 : DatasetBaseIterator({params.dataset, params.prefix}),
1289 typed_dataset_(params.dataset) {}
1290
1291 // The dataset from which this iterator was created.
dataset()1292 const DatasetType* dataset() const final { return typed_dataset_; }
1293
1294 private:
1295 const DatasetType* const typed_dataset_; // Not owned.
1296 };
1297
1298 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1299 Status ParseScalarArgument(OpKernelContext* ctx,
1300 const StringPiece& argument_name, T* output) {
1301 const Tensor* argument_t;
1302 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1303 if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1304 return errors::InvalidArgument(argument_name, " must be a scalar");
1305 }
1306 *output = argument_t->scalar<T>()();
1307 return OkStatus();
1308 }
1309
1310 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1311 Status ParseVectorArgument(OpKernelContext* ctx,
1312 const StringPiece& argument_name,
1313 std::vector<T>* output) {
1314 const Tensor* argument_t;
1315 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1316 if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1317 return errors::InvalidArgument(argument_name, " must be a vector");
1318 }
1319 int size = argument_t->vec<T>().size();
1320 output->reserve(size);
1321 for (int i = 0; i < size; ++i) {
1322 output->push_back(argument_t->vec<T>()(i));
1323 }
1324 return OkStatus();
1325 }
1326
1327 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1328 // graph execution engine.
1329 class DatasetOpKernel : public OpKernel {
1330 public:
DatasetOpKernel(OpKernelConstruction * ctx)1331 explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
1332 if (ctx->HasAttr(kMetadata)) {
1333 std::string serialized_metadata;
1334 OP_REQUIRES_OK(ctx, ctx->GetAttr(kMetadata, &serialized_metadata));
1335 OP_REQUIRES(ctx, metadata_.ParseFromString(serialized_metadata),
1336 errors::InvalidArgument(absl::StrCat(
1337 "Could not parse the 'metadata' attribute.")));
1338 }
1339 }
1340
1341 void Compute(OpKernelContext* ctx) final;
1342
1343 // Checks whether the given op is a tf.data operation.
1344 //
1345 // NOTE: The check uses a heuristic and can produce both false positives and
1346 // false negatives. In particular, tf.data operations are expected to use
1347 // names that end with "Dataset" or "DatasetV[0-9]+".
1348 static bool IsDatasetOp(const OpDef& op_def);
1349
1350 string TraceString(const OpKernelContext& ctx, bool verbose) const override;
1351
1352 protected:
1353 // Subclasses should implement this method. It will be called during Compute
1354 // execution.
1355 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1356
1357 private:
1358 Metadata metadata_;
1359 };
1360
1361 // Encapsulates the work required to plug unary Datasets into the core
1362 // TensorFlow graph execution engine.
1363 class UnaryDatasetOpKernel : public DatasetOpKernel {
1364 public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1365 explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1366 : DatasetOpKernel(ctx) {}
1367
1368 protected:
1369 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1370 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1371 DatasetBase** output) = 0;
1372 };
1373
1374 // Encapsulates the work required to plug binary Datasets into the core
1375 // TensorFlow graph execution engine.
1376 class BinaryDatasetOpKernel : public DatasetOpKernel {
1377 public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1378 explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1379 : DatasetOpKernel(ctx) {}
1380
1381 protected:
1382 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1383 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1384 DatasetBase* another_input,
1385 DatasetBase** output) = 0;
1386 };
1387
1388 // A simple background worker that executes closures asynchronously and without
1389 // blocking.
1390 //
1391 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1392 // to avoid blocking an executor thread that may be required by the blocking
1393 // work.
1394 //
1395 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1396 // purpose because its current implementation (in Eigen) uses a finite-length
1397 // queue and will block the caller when full. This can lead to deadlock under
1398 // heavy load. Since the number of concurrent work items in each user of a
1399 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1400 // overhead is tolerable.
1401 class BackgroundWorker {
1402 public:
1403 BackgroundWorker(Env* env, const char* name);
1404
1405 ~BackgroundWorker();
1406
1407 void Schedule(std::function<void()> work_item);
1408
1409 private:
1410 void WorkerLoop();
1411
1412 Env* const env_;
1413 const char* const name_;
1414
1415 std::unique_ptr<Thread> thread_;
1416 mutex mu_;
1417 condition_variable cond_var_;
1418 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1419 std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_);
1420 };
1421
1422 } // namespace data
1423 } // namespace tensorflow
1424
1425 #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1426