1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/collective.h"
27 #include "tensorflow/core/framework/dataset_options.pb.h"
28 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function_handle_cache.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/model.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/register_types.h"
36 #include "tensorflow/core/framework/thread_factory.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/framework/variant_encode_decode.h"
39 #include "tensorflow/core/framework/variant_tensor_data.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/threadpool.h"
42 #include "tensorflow/core/lib/core/threadpool_interface.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/cpu_info.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/tracing.h"
48
49 // Polymorphic datasets should support all primitive TensorFlow
50 // types. Use this macro to expand `m(T)` once for each primitive type
51 // `T`, e.g. to build a `switch` statement.
52 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
53
54 namespace tensorflow {
55
56 // Forward declarations to avoid introducing a dependency on headers in
57 // "tensorflow/core/graph/...".
58 class GraphDefBuilder;
59 class Node;
60
61 namespace data {
62
63 namespace internal {
64 // Merges Options from source to destination. If there is a conflict on a field,
65 // the field value from the source takes precedence.
66 void MergeOptions(const protobuf::Message& source,
67 protobuf::Message* destination);
68 void MergeOptions(const protobuf::MessageLite& source,
69 protobuf::MessageLite* destination);
70 } // namespace internal
71
72 using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
73
74 constexpr char kTFDataFunction[] = "_tf_data_function";
75
76 constexpr int kInfiniteCardinality = -1;
77 constexpr int kUnknownCardinality = -2;
78
79 // This constant is a magic number that is used (as a prefix) to identify keys
80 // used for serialization of iterator state.
81 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
82 constexpr char kPipe[] = "|";
83 constexpr char kColon[] = ":";
84
85 constexpr char kTFDataResourceTag[] = "tfdata";
86 constexpr char kTraceInfoUnavailable[] = "unavailable";
87
88 class DatasetBase;
89 class SerializationContext;
90
IsTFDataFunction(const FunctionDef & func)91 inline bool IsTFDataFunction(const FunctionDef& func) {
92 auto iter = func.attr().find(data::kTFDataFunction);
93 return (iter != func.attr().end() && iter->second.b());
94 }
95
96 // Interface for reading values from a key-value store.
97 // Used for restoring iterator state. This class is thread safe.
98 // Please see comment on IteratorStateWriter for guidance around using the
99 // Read*(key, val) vs Read*(name, key, val).
100 class IteratorStateReader {
101 public:
102 // Determines whether the iterator state contains the given key.
103 virtual bool Contains(StringPiece key) const = 0;
104 virtual bool Contains(StringPiece name, StringPiece key) const = 0;
105
106 // Reads an integer for the given key.
107 virtual Status ReadScalar(StringPiece key, int64* val) const = 0;
108 virtual Status ReadScalar(StringPiece name, StringPiece key,
109 int64* val) const = 0;
110
111 // Reads a string for the given key.
112 virtual Status ReadScalar(StringPiece key, tstring* val) const = 0;
113 virtual Status ReadScalar(StringPiece name, StringPiece key,
114 tstring* val) const = 0;
115
116 // Reads a tensor for the given key.
117 // TODO(jsimsa): Remove non-FLR overrides once all callers are updated.
118 virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0;
119 virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key,
120 Tensor* val) const = 0;
121 virtual Status ReadTensor(StringPiece name, StringPiece key,
122 Tensor* val) const = 0;
123 virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name,
124 StringPiece key, Tensor* val) const = 0;
125
~IteratorStateReader()126 virtual ~IteratorStateReader() {}
127 };
128
129 // Interface for writing values to a key-value store.
130 // Used for saving iterator state. Not thread safe.
131 // The IteratorStateWriter creates a tensor for each unique iterator name it
132 // sees. For the Write*(key, val) API's the key is expected to encode this
133 // name as keys are required to be produced using the full_name() method.
134 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
135 // might exceed the 2 GB limit, you can pass an explicit name in via the
136 // Write*(name, key, val) APIs allowing you to further split up the state
137 // into more manageable chunks.
138 class IteratorStateWriter {
139 public:
140 // Writes an integer for the given key.
141 virtual Status WriteScalar(StringPiece key, const int64_t val) = 0;
142 virtual Status WriteScalar(StringPiece name, StringPiece key,
143 const int64_t val) = 0;
144
145 // Writes a string for the given key.
146 virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
147 virtual Status WriteScalar(StringPiece name, StringPiece key,
148 const tstring& val) = 0;
149
150 // Writes a tensor for the given key.
151 virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
152 virtual Status WriteTensor(StringPiece name, StringPiece key,
153 const Tensor& val) = 0;
154
~IteratorStateWriter()155 virtual ~IteratorStateWriter() {}
156 };
157
158 // Generates a full name key for iterator checkpointing. All keys generated for
159 // iterator checkpoints should go through this function.
160 std::string FullName(const std::string& prefix, const std::string& name);
161
162 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
163 class GraphDefBuilderWrapper {
164 public:
GraphDefBuilderWrapper(GraphDefBuilder * b)165 explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
166
167 // Adds a Const node with scalar value to the Graph.
168 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
169 // non-null if the method returns with an OK status.
170 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
171 template <typename T>
AddScalar(const T & val,Node ** output)172 Status AddScalar(const T& val, Node** output) {
173 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
174 val_t.scalar<T>()() = val;
175 AddTensorInternal(val_t, output);
176 if (*output == nullptr) {
177 return errors::Internal("AddScalar: Failed to build Const op.");
178 }
179 return Status::OK();
180 }
181
182 // Adds a Const node with vector value to the Graph.
183 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
184 // non-null if the method returns with an OK status.
185 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
186 // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
187 template <typename T>
AddVector(const std::vector<T> & val,Node ** output)188 Status AddVector(const std::vector<T>& val, Node** output) {
189 Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
190 TensorShape({static_cast<int64>(val.size())}));
191 for (size_t i = 0; i < val.size(); i++) {
192 val_t.flat<T>()(i) = val[i];
193 }
194 AddTensorInternal(val_t, output);
195 if (*output == nullptr) {
196 return errors::Internal("AddVector: Failed to build Const op.");
197 }
198 return Status::OK();
199 }
200
AddVector(const std::vector<string> & val,Node ** output)201 Status AddVector(const std::vector<string>& val, Node** output) {
202 Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
203 TensorShape({static_cast<int64>(val.size())}));
204 for (size_t i = 0; i < val.size(); i++) {
205 val_t.flat<tstring>()(i) = val[i];
206 }
207 AddTensorInternal(val_t, output);
208 if (*output == nullptr) {
209 return errors::Internal("AddVector: Failed to build Const op.");
210 }
211 return Status::OK();
212 }
213
214 // Adds a `Const` node for the given tensor value to the graph.
215 //
216 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
217 // non-null if the method returns with an OK status. The returned `Node`
218 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)219 Status AddTensor(const Tensor& val, Node** output) {
220 AddTensorInternal(val, output);
221 if (*output == nullptr) {
222 return errors::Internal("AddTensor: Failed to build Const op.");
223 }
224 return Status::OK();
225 }
226
227 // Adds a `Placeholder` node for the given tensor value to the graph.
228 //
229 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
230 // non-null if the method returns with an OK status. The returned `Node`
231 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)232 Status AddPlaceholder(const Tensor& val, Node** output) {
233 AddPlaceholderInternal(val, output);
234 if (*output == nullptr) {
235 return errors::Internal(
236 "AddPlaceholder: Failed to build Placeholder op.");
237 }
238 return Status::OK();
239 }
240
241 // Adds a node for the given dataset to the `Graph`. The value of
242 // `DatasetBase::type_string()` is used as the op type for the node. Values
243 // for the `output_types` and `output_shapes` node attributes are also written
244 // if those attributes are defined in the `OpDef`.
245 //
246 // If `use_dataset_name` is set, the value of `DatasetBase::node_name()` is
247 // used as the op name for the node. This argument should only be set when
248 // serializing `DatasetBase` instances which might not have been created
249 // through op kernel execution to make sure the dataset op name is preserved
250 // across serialization boundaries, which is in turn needed to make sure
251 // iterator checkpoints are valid across serialization boundaries. When
252 // `use_dataset_name` is set, the caller is responsible for making sure that
253 // the op name is unique across the graph.
254 //
255 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
256 // non-null if the method returns with an OK status. The returned `Node`
257 // pointer is owned by the backing `Graph` of `GraphDefBuilder`.
258 Status AddDataset(const DatasetBase* dataset,
259 const std::vector<Node*>& inputs, Node** output);
260 Status AddDataset(const DatasetBase* dataset,
261 const std::vector<Node*>& inputs,
262 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
263 Node** output);
264 Status AddDataset(
265 const DatasetBase* dataset,
266 const std::vector<std::pair<size_t, Node*>>& inputs,
267 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
268 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
269 Node** output);
270 Status AddDataset(
271 const DatasetBase* dataset,
272 const std::vector<std::pair<size_t, Node*>>& inputs,
273 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
274 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
275 bool use_dataset_name, Node** output);
276
277 // Adds a user-defined function with name `function_name` to the graph and
278 // recursively adds all functions it references. If a function with a matching
279 // name has already been added, returns with OK status. If a user-defined with
280 // name `function_name` is not found in the context's function library,
281 // returns an InvalidArgumentError. If the function with name `function_name`
282 // or any of its dependent functions are stateful, and the context does not
283 // explicitly permit stateful functions, returns an InvalidArgument error.
284 Status AddFunction(SerializationContext* ctx, const string& function_name,
285 const FunctionLibraryDefinition& lib_def);
286
287 template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)288 void BuildAttrValue(const T& value, AttrValue* attr) {
289 SetAttrValue(value, attr);
290 }
291
292 protected:
builder()293 GraphDefBuilder* builder() { return b_; }
294
295 private:
296 void AddPlaceholderInternal(const Tensor& val, Node** output);
297 void AddTensorInternal(const Tensor& val, Node** output);
298 bool HasAttr(const string& op_type_name, const string& attr_name) const;
299
HasAttr(const OpDef * op_def,const string & attr_name)300 bool HasAttr(const OpDef* op_def, const string& attr_name) const {
301 for (const auto& attr : op_def->attr()) {
302 if (attr.name() == attr_name) {
303 return true;
304 }
305 }
306 return false;
307 }
308
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)309 Status AddAttrFunctions(SerializationContext* ctx,
310 const AttrValue& attr_value,
311 const FunctionLibraryDefinition& lib_def) {
312 if (attr_value.has_func()) {
313 TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
314 } else if (attr_value.has_list()) {
315 for (const NameAttrList& name_attr_list : attr_value.list().func()) {
316 TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
317 }
318 }
319 return Status::OK();
320 }
321
322 GraphDefBuilder* b_;
323 };
324
325 class StatsAggregator;
326
327 // A utility class for running a function and ensuring that there is always a
328 // `tensorflow::data` symbol on the stack.
329 class Runner {
330 public:
~Runner()331 virtual ~Runner() {}
332
333 // Runs the given function.
334 virtual void Run(const std::function<void()>& f) = 0;
335
336 // Returns a global singleton Runner.
337 static Runner* get();
338 };
339
340 // A class which provides a sequence of splits. Splits represent subdivisions of
341 // a dataset, e.g. filenames or ranges within files. We use splitting to
342 // partition input data into smaller pieces for distributed processing (see
343 // go/tf-data-splitting-design).
344 //
345 // Datasets provide a `MakeSplitProvider` method to expose a listing of their
346 // splits.
347 //
348 // Iterators created with a split provider will only iterate over the splits
349 // provided by the split provider.
350 class SplitProvider {
351 public:
~SplitProvider()352 virtual ~SplitProvider() {}
353 // Stores the next split in `*split`, setting `*end_of_splits` to indicate
354 // whether there were any splits left.
355 virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
356 // Resets the split provider to its beginning.
357 virtual Status Reset() = 0;
358 // Saves the state of this split provider.
359 virtual Status Save(std::function<std::string(std::string)> full_name,
360 IteratorStateWriter* writer) = 0;
361 // Restores the state of this split provider.
362 virtual Status Restore(std::function<std::string(std::string)> full_name,
363 IteratorStateReader* reader) = 0;
364 };
365
366 // A cut-down version of `OpKernelContext` for running computations in
367 // iterators. Note that we cannot simply use `OpKernelContext` here because we
368 // might run computation in an iterator whose lifetime is not nested within the
369 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
370 //
371 // TODO(mrry): We're making some daring assumptions about the lifetime of the
372 // runner passed in here. A runner will be deleted when the original step ends,
373 // but all existing runners only close over session-lifetime (or longer-lived)
374 // state, so we can make a copy of the function. There's nothing in the
375 // definition of the API from which we took the runner to guarantee that what we
376 // are doing is safe. We should formalize the properties here.
377 class IteratorContext {
378 public:
379 struct Params {
ParamsParams380 explicit Params(IteratorContext* ctx)
381 : allocator_getter(ctx->allocator_getter()),
382 cancellation_manager(ctx->cancellation_manager()),
383 collective_executor(ctx->collective_executor()),
384 env(ctx->env()),
385 flr(ctx->flr()),
386 function_handle_cache(ctx->function_handle_cache()),
387 is_restoring(ctx->is_restoring()),
388 resource_mgr(ctx->resource_mgr()),
389 model(ctx->model()),
390 runner(*(ctx->runner())),
391 runner_threadpool_size(ctx->runner_threadpool_size()),
392 split_providers(ctx->split_providers()),
393 stats_aggregator(ctx->stats_aggregator()),
394 thread_factory(ctx->thread_factory()),
395 thread_pool(ctx->thread_pool()) {}
396
ParamsParams397 explicit Params(OpKernelContext* ctx)
398 : collective_executor(ctx->collective_executor()),
399 env(ctx->env()),
400 flr(ctx->function_library()) {
401 // NOTE: need reinterpret_cast because function.h forward-declares Device.
402 DeviceBase* device =
403 reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
404 allocator_getter = [device](AllocatorAttributes attrs) {
405 return device->GetAllocator(attrs);
406 };
407
408 thread::ThreadPool* thread_pool =
409 ctx->device()->tensorflow_device_thread_pool();
410 if (thread_pool) {
411 runner_threadpool_size = thread_pool->NumThreads();
412 } else {
413 static const int32_t kDefaultRunnerThreadpoolSize =
414 port::MaxParallelism();
415 runner_threadpool_size = kDefaultRunnerThreadpoolSize;
416 }
417
418 // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
419 // that a symbol in the tensorflow::data namespace is always on the stack
420 // when executing a function inside a Dataset.
421 runner = std::bind(
422 [](
423 // Note: `runner` is a const reference to avoid copying it.
424 const std::function<void(std::function<void()>)>& ctx_runner,
425 std::function<void()> fn) {
426 std::function<void()> wrapped_fn = std::bind(
427 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
428 std::move(fn));
429 ctx_runner(std::move(wrapped_fn));
430 },
431 *ctx->runner(), std::placeholders::_1);
432 }
433
434 // The Allocator to be used to allocate the output of an iterator.
435 std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
436
437 // The CancellationManager to be used to cancel execution of ops.
438 CancellationManager* cancellation_manager;
439
440 // Collective support.
441 CollectiveExecutor* collective_executor = nullptr;
442
443 // Interface to operating system functionality.
444 Env* env = nullptr;
445
446 // The FunctionLibraryRuntime object to be used to make function calls.
447 FunctionLibraryRuntime* flr = nullptr;
448
449 // A FunctionHandleCache that owns all the function handles. Not owned.
450 FunctionHandleCache* function_handle_cache = nullptr;
451
452 // Marks whether the iterator is restored from a checkpoint.
453 bool is_restoring = false;
454
455 // A resource manager for storing dataset-related state, e.g. random
456 // seeds or cached tensors. Not owned.
457 ResourceMgr* resource_mgr = nullptr;
458
459 // If non-null, identifies the object used for performance modeling.
460 std::shared_ptr<model::Model> model = nullptr;
461
462 // Function call support.
463 std::function<void(std::function<void()>)> runner = nullptr;
464
465 // Number of threads used for executing user-defined functions.
466 int32 runner_threadpool_size = 0;
467
468 // Split providers indicating which splits to process. May be empty,
469 // indicating that the iterator should process all splits.
470 std::vector<std::shared_ptr<SplitProvider>> split_providers;
471
472 // The `StatsAggregator` object to record statistics about the iterator.
473 //
474 // TODO(b/147325552): Remove this API and any of its uses after we switch to
475 // using C++ based implementation for tf.data options (on 4/12/2021).
476 std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
477
478 // A factory for creating threads to perform blocking work.
479 std::shared_ptr<ThreadFactory> thread_factory = nullptr;
480
481 // A shared thread pool to schedule computation into.
482 thread::ThreadPoolInterface* thread_pool = nullptr;
483 };
484
IteratorContext(IteratorContext * ctx)485 explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
486
IteratorContext(OpKernelContext * ctx)487 explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
488
IteratorContext(Params params)489 explicit IteratorContext(Params params) : params_(std::move(params)) {}
490
allocator(AllocatorAttributes attrs)491 Allocator* allocator(AllocatorAttributes attrs) {
492 return params_.allocator_getter(attrs);
493 }
494
allocator_getter()495 std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
496 return params_.allocator_getter;
497 }
498
cancellation_manager()499 CancellationManager* cancellation_manager() {
500 return params_.cancellation_manager;
501 }
502
collective_executor()503 CollectiveExecutor* collective_executor() {
504 return params_.collective_executor;
505 }
506
env()507 Env* env() const { return params_.env; }
508
flr()509 FunctionLibraryRuntime* flr() { return params_.flr; }
510
function_handle_cache()511 FunctionHandleCache* function_handle_cache() {
512 return params_.function_handle_cache;
513 }
514
is_restoring()515 bool is_restoring() { return params_.is_restoring; }
516
resource_mgr()517 ResourceMgr* resource_mgr() { return params_.resource_mgr; }
518
model()519 const std::shared_ptr<model::Model>& model() { return params_.model; }
520
runner()521 std::function<void(std::function<void()>)>* runner() {
522 return ¶ms_.runner;
523 }
524
runner_threadpool_size()525 int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
526
split_providers()527 std::vector<std::shared_ptr<SplitProvider>> split_providers() {
528 return params_.split_providers;
529 }
530
stats_aggregator()531 std::shared_ptr<StatsAggregator> stats_aggregator() {
532 return params_.stats_aggregator;
533 }
534
thread_factory()535 const std::shared_ptr<ThreadFactory>& thread_factory() {
536 return params_.thread_factory;
537 }
538
thread_pool()539 thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
540
CreateThreadPool(const string & name,int num_threads)541 std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
542 int num_threads) {
543 if (params_.thread_pool) {
544 // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
545 // is an instance of `thread::ThreadPoolInterface`). Notably, the
546 // ownership of `params_.thread_pool` is *not* transferred onto the newly
547 // created `ThreadPool` instance.
548 return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
549 } else {
550 return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
551 name, num_threads,
552 /*low_latency_hint=*/false);
553 }
554 }
555
StartThread(const string & name,std::function<void ()> fn)556 std::unique_ptr<Thread> StartThread(const string& name,
557 std::function<void()> fn) {
558 if (params_.thread_factory) {
559 return params_.thread_factory->StartThread(name, std::move(fn));
560 } else {
561 return absl::WrapUnique(
562 Env::Default()->StartThread({}, name, std::move(fn)));
563 }
564 }
565
566 private:
567 Params params_;
568 };
569
570 // Aggregates runtime support needed for dataset and iterator serialization.
571 class SerializationContext {
572 public:
573 // Enum describing what to do during serialization when external state is
574 // encountered.
575 enum class ExternalStatePolicy : int64 {
576 // Proceed with serialization, but log a warning about what state will be
577 // lost.
578 kWarn = 0,
579 // Proceed with serialization without logging any warning.
580 kIgnore = 1,
581 // Fail the serialization with an error.
582 kFail = 2,
583 };
584
585 // Handles the CheckExternalState status according to the external state
586 // policy.
HandleCheckExternalStateStatus(Status s)587 Status HandleCheckExternalStateStatus(Status s) {
588 if (s.ok()) {
589 return s;
590 }
591 switch (params_.external_state_policy) {
592 case ExternalStatePolicy::kWarn:
593 LOG(WARNING) << s.ToString();
594 return Status::OK();
595 case ExternalStatePolicy::kIgnore:
596 VLOG(2) << "Ignoring error status: " << s.ToString();
597 return Status::OK();
598 case ExternalStatePolicy::kFail:
599 return s;
600 }
601 LOG(FATAL) << "Control should never reach here";
602 }
603
604 struct Params {
ParamsParams605 explicit Params() {}
606
ParamsParams607 explicit Params(OpKernelContext* ctx)
608 : resource_mgr(ctx->resource_manager()),
609 device_name(ctx->device()->attributes().name()) {}
610
611 std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
612
613 // Indicates what to do if the dataset depends on external state.
614 ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
615
616 // Indicates whether an attempt to serialize a dataset that does not
617 // implement serialization should result in an error. If set to `false`, the
618 // serialized graph will replace the dataset with a placeholder returned in
619 // `input_list`.
620 bool fail_if_unimplemented = true;
621
622 // Indicates whether (potentially large) data tensors should be
623 // serialized, or replaced with a placeholder returned in `input_list`. The
624 // latter makes sense to do when performing data agnostic graph rewrites to
625 // reduce the memory usage.
626 bool serialize_data_tensors = true;
627
628 // Indicates whether datasets that use random seeds should have the values
629 // of random seeds serialized or not. If the values of random seeds are
630 // serialized, the deserialized dataset will have the same seeds as the
631 // original dataset. Otherwise, the deserialized dataset will use different
632 // seeds. This param does not affect datasets that use fixed seeds; fixed
633 // seeds will always be preserved.
634 bool preserve_random_seeds = true;
635
636 // A resource manager for looking up resources during serialization.
637 ResourceMgr* resource_mgr;
638
639 // The name of the device doing the serialization.
640 std::string device_name;
641 };
642
SerializationContext(Params params)643 explicit SerializationContext(Params params) : params_(params) {}
644
input_list()645 std::vector<std::pair<string, Tensor>>* input_list() {
646 return params_.input_list;
647 }
648
external_state_policy()649 ExternalStatePolicy external_state_policy() const {
650 return params_.external_state_policy;
651 }
652
fail_if_unimplemented()653 bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
654
serialize_data_tensors()655 bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
656
preserve_random_seeds()657 bool preserve_random_seeds() const { return params_.preserve_random_seeds; }
658
resource_mgr()659 const ResourceMgr* resource_mgr() const { return params_.resource_mgr; }
660
device_name()661 const std::string& device_name() const { return params_.device_name; }
662
663 private:
664 Params params_;
665
666 TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
667 };
668
669 // Represents the current position in a range of outputs, where the
670 // range of outputs is typically represented by an `DatasetBase`,
671 // defined below.
672 class IteratorBase {
673 public:
~IteratorBase()674 virtual ~IteratorBase() {
675 for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
676 (*rit)();
677 }
678 }
679
680 // Gets the next output from the range that this iterator is traversing.
681 //
682 // If at least one output remains in this iterator's range, that
683 // output will be stored in `*out_tensors` and `false` will be
684 // stored in `*end_of_sequence`.
685 //
686 // If no more outputs remain in this iterator's range, `true` will be stored
687 // in `*end_of_sequence`, and `*out_tensors` will be empty.
688 //
689 // Implementations should never return `OutOfRange` error. If at end of
690 // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
691 // Internally raised `OutOfRange` errors that do not imply end of sequence
692 // should be converted to a different error type before being propagated to
693 // the caller.
694 //
695 // Implementations must explicitly set `*end_of_sequence = false` if an
696 // `Status::OK()` status is returned and the iterator is not at the end of the
697 // sequence.
698 //
699 // This method is thread-safe.
700 //
701 // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
702 // potentially remove this method.
703 virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
704 bool* end_of_sequence) = 0;
705
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)706 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
707 bool* end_of_sequence) {
708 return GetNext(&ctx, out_tensors, end_of_sequence);
709 }
710
711 // Skips the next `num_to_skip` outputs from the range that this iterator
712 // is traversing.
713 //
714 // If there are not enough outputs to skip, it will set
715 // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will
716 // store the number of outputs that are skipped. When `*end_of_sequence` is
717 // `false`, `*num_skipped` should equal to `num_to_skip`.
718 virtual Status Skip(IteratorContext* ctx, int num_to_skip,
719 bool* end_of_sequence, int* num_skipped) = 0;
720
721 // Returns a vector of DataType values, representing the respective
722 // element types of each tuple component in the outputs of this
723 // iterator.
724 virtual const DataTypeVector& output_dtypes() const = 0;
725
726 // Returns a vector of tensor shapes, representing the respective
727 // (and possibly partially defined) shapes of each tuple component
728 // in the outputs of this iterator.
729 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
730
731 // Returns a string that identifies the sequence of iterators leading up to
732 // this iterator.
733 virtual const string& prefix() const = 0;
734
735 // Performs initialization that needs to happen outside of a constructor to
736 // properly propagate errors.
Initialize(IteratorContext * ctx)737 virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
738
739 // Performs initialization of the base iterator.
740 Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
741
742 // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)743 virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
744 int64_t start_us = EnvTime::NowMicros();
745 TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
746 VLOG(1) << "Saved " << prefix() << " in "
747 << (EnvTime::NowMicros() - start_us) << "us";
748 return Status::OK();
749 }
750
751 protected:
752 // Returns a node that models this iterator.
753 virtual std::shared_ptr<model::Node> CreateNode(
754 IteratorContext* ctx, model::Node::Args args) const = 0;
755
756 // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)757 virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
758 int64_t start_us = EnvTime::NowMicros();
759 TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
760 VLOG(1) << "Restored " << prefix() << " in "
761 << (EnvTime::NowMicros() - start_us) << "us";
762 return Status::OK();
763 }
764
765 // This is needed so that sub-classes of IteratorBase can call
766 // `SaveInternal` on their input iterators.
SaveInput(SerializationContext * ctx,IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)767 Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
768 const std::unique_ptr<IteratorBase>& input) {
769 return input->Save(ctx, writer);
770 }
771
772 // This is needed so that sub-classes of IteratorBase can call
773 // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)774 Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
775 const std::unique_ptr<IteratorBase>& input) {
776 return input->Restore(ctx, reader);
777 }
778
RestoreInput(IteratorContext && ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)779 Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader,
780 const std::unique_ptr<IteratorBase>& input) {
781 return RestoreInput(&ctx, reader, input);
782 }
783
784 // Saves the state of this iterator.
785 //
786 // This method is used to store the state of the iterator in a checkpoint.
787 // implementations have an override.
788 virtual Status SaveInternal(SerializationContext* ctx,
789 IteratorStateWriter* writer) = 0;
790
791 // Restores the state of this iterator.
792 //
793 // This method is used to restore the state of the iterator from a checkpoint.
794 //
795 // Implementations may assume that the iterator is in a clean state. That is,
796 // its `Initialize` method has been called, but its `GetNext` method has
797 // never been called.
798 // implementations have an override.
799 virtual Status RestoreInternal(IteratorContext* ctx,
800 IteratorStateReader* reader) = 0;
801
802 // Returns a pointer to the node representing this iterator in the performance
803 // model. It may be null, if performance modeling is not enabled for this
804 // iterator.
model_node()805 std::shared_ptr<model::Node> model_node() const { return node_; }
806
807 // Returns the number of elements produced by this iterator.
num_elements()808 int64 num_elements() const {
809 if (node_) return node_->num_elements();
810 return 0;
811 }
812
813 private:
814 // For access to `AddCleanupFunction` and `Restore`.
815 friend class DatasetBase;
816 friend class DatasetBaseIterator; // for access to `node_`
817
818 std::vector<std::function<void()>> cleanup_fns_;
819 std::shared_ptr<model::Node> node_ = nullptr;
820 const IteratorBase* parent_ = nullptr; // Not owned.
821 int64 id_ = 0;
822 int64 parent_id_ = 0;
823 };
824
825 // Represents runtime information needed to construct a dataset.
826 class DatasetContext {
827 public:
828 struct Params {
829 string type_string; // op type name of this dataset.
830 string node_name; // graph node name of this dataset op, uniquely
831 // identifying the dataset in the graph.
832 };
833
DatasetContext(Params params)834 explicit DatasetContext(Params params) : params_(std::move(params)) {}
835
DatasetContext(OpKernelContext * ctx)836 explicit DatasetContext(OpKernelContext* ctx) {
837 params_.type_string = ctx->op_kernel().type_string();
838 params_.node_name = ctx->op_kernel().name();
839 }
840
type_string()841 const string& type_string() const { return params_.type_string; }
node_name()842 const string& node_name() const { return params_.node_name; }
843
844 private:
845 Params params_;
846 };
847
848 // Returns the number of bytes allocated for the given tensor.
849 int64 GetAllocatedBytes(const std::vector<Tensor>& element);
850
851 // Returns the estimated memory usage in bytes of the given tensor.
852 int64 GetTotalBytes(const std::vector<Tensor>& element);
853
854 // Validates and extracts a `DatasetBase` object from `tensor`.
855 //
856 // `tensor` must have been written by a call to SetVariantTensorToDataset().
857 //
858 // The retrieved pointer is a borrowed reference to the dataset, which is owned
859 // by the tensor. The consumer must either acquire its own reference to the
860 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
861 // destroyed or mutated while the retrieved pointer is in use.
862 Status GetDatasetFromVariantTensor(const Tensor& tensor,
863 DatasetBase** out_dataset);
864
865 // Stores a `DatasetBase` object in `tensor`.
866 //
867 // The ownership of `dataset` is transferred to `tensor`.
868 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
869
870 // Represents a (potentially infinite) range of outputs, where each
871 // output is a tuple of tensors.
872 class DatasetBase : public core::RefCounted {
873 public:
874 // Key for storing the Dataset graph in the serialized format.
875 TF_EXPORT static const char kDatasetGraphKey[];
876
877 // Key for storing the output node of the Dataset graph in the serialized
878 // format.
879 TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
880
DatasetBase(DatasetContext && ctx)881 explicit DatasetBase(DatasetContext&& ctx)
882 : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
883
884 // Op type name of this dataset.
type_string()885 const string& type_string() const { return type_string_; }
886
887 // Graph node name of this dataset op, uniquely identifying the dataset in
888 // the graph.
node_name()889 const string& node_name() const { return node_name_; }
890
891 // Initializes the dataset.
892 void Initialize();
893
options()894 const Options& options() const { return options_; }
895
num_sources()896 int64 num_sources() const { return num_sources_; }
897
898 // Returns a new iterator for iterating over the range of elements in
899 // this dataset.
900 //
901 // This method may be called multiple times on the same instance,
902 // and the resulting iterators will have distinct state. Each
903 // iterator will traverse all elements in this dataset from the
904 // start.
905 //
906 // The prefix identifies the sequence of iterators leading up to the newly
907 // created iterator.
908 Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent,
909 const string& output_prefix,
910 std::unique_ptr<IteratorBase>* iterator) const;
911
MakeIterator(IteratorContext && ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)912 Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent,
913 const string& output_prefix,
914 std::unique_ptr<IteratorBase>* iterator) const {
915 return MakeIterator(&ctx, parent, output_prefix, iterator);
916 }
917
918 // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)919 Status MakeIteratorFromCheckpoint(
920 IteratorContext* ctx, const string& output_prefix,
921 IteratorStateReader* reader,
922 std::unique_ptr<IteratorBase>* iterator) const {
923 std::unique_ptr<IteratorBase> it;
924 IteratorContext::Params params(ctx);
925 params.is_restoring = true;
926 TF_RETURN_IF_ERROR(MakeIterator(IteratorContext(std::move(params)),
927 /*parent=*/nullptr, output_prefix, &it));
928 TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
929 *iterator = std::move(it);
930 return Status::OK();
931 }
932
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)933 Status MakeIteratorFromCheckpoint(
934 IteratorContext&& ctx, const string& output_prefix,
935 IteratorStateReader* reader,
936 std::unique_ptr<IteratorBase>* iterator) const {
937 return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
938 }
939
940 // Returns a split provider which partitions the dataset's data into splits
941 // and provides them in a sequence. The split provider is stored in
942 // `*split_provider`.
943 virtual Status MakeSplitProviders(
944 std::vector<std::unique_ptr<SplitProvider>>* split_providers) const;
945
946 // Returns a vector of DataType values, representing the respective
947 // element types of each tuple component in the outputs of this
948 // dataset.
949 virtual const DataTypeVector& output_dtypes() const = 0;
950
951 // Returns a vector of tensor shapes, representing the respective
952 // (and possibly partially defined) shapes of each tuple component
953 // in the outputs of this dataset.
954 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
955
956 // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()957 virtual int64 AllocatedBytes() const { return 0; }
958
959 // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()960 virtual int64 TotalBytes() const { return 0; }
961
962 // Returns the cardinality of this dataset.
Cardinality()963 virtual int64 Cardinality() const { return kUnknownCardinality; }
964
965 // A human-readable debug string for this dataset.
966 virtual string DebugString() const = 0;
967
968 // Stores the dataset's input datasets in `*inputs`. The pointers stored in
969 // `*inputs` are borrowed. The only valid non-ok return status is
970 // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
971 // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
972 // default implementation of `MakeSplitProvider` when there is a single input
973 // dataset.
974 virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
975
976 // Indicates whether the dataset depends on any external state which would
977 // prevent it from being serializable. If so, the method returns
978 // `errors::FailedPrecondition` with a message that identifies the external
979 // state. Otherwise, the method returns `Status::OK()`.
980 virtual Status CheckExternalState() const = 0;
981
982 // Return the element at a particular index for a randomly accessible dataset.
983 virtual Status Get(OpKernelContext* ctx, int64 index,
984 std::vector<Tensor>* out_tensors);
985
986 // Wrapper around a GraphDefBuilder which provides support for serializing
987 // Datasets as GraphDefs.
988 class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
989 public:
DatasetGraphDefBuilder(GraphDefBuilder * b)990 explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
991 : GraphDefBuilderWrapper(b) {}
992 Status AddInputDataset(SerializationContext* ctx,
993 const DatasetBase* dataset, Node** output);
994 Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val,
995 Node** output);
996
997 private:
998 Status AddDatasetOrTensorHelper(SerializationContext* ctx,
999 const Tensor& val, Node** output);
1000 Status AddResourceHelper(SerializationContext* ctx, const Tensor& val,
1001 Node** output);
1002 };
1003
1004 protected:
1005 friend Status AsGraphDef(
1006 OpKernelContext* ctx, const DatasetBase* dataset,
1007 SerializationContext&& serialization_ctx,
1008 GraphDef* graph_def); // For access to graph related members.
1009 friend class CapturedFunction;
1010
1011 // Serializes the dataset into a `GraphDef`, which has two uses:
1012 //
1013 // 1) To perform static input pipeline optimizations, tf.data serializes the
1014 // dataset graph, applies graph rewrites, and then deserializes the graph.
1015 // If a subclass of `DatasetBase` does not implement this method, then it will
1016 // be excluded from static optimizations (and so will any upstream datasets).
1017 //
1018 // 2) To save the dataset so that it can restore at a later point (possibly in
1019 // different environment). If a subclass of `DatasetBase` does not implement
1020 // this method, then this migration will not be possible.
1021 virtual Status AsGraphDefInternal(SerializationContext* ctx,
1022 DatasetGraphDefBuilder* b,
1023 Node** node) const = 0;
1024
1025 virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
1026 const string& prefix) const = 0;
1027
set_options(const Options & options)1028 void set_options(const Options& options) { options_ = options; }
1029
1030 private:
1031 // Computes the number of source datasets feeding into this dataset. A source
1032 // dataset is a leaf in the subtree of dataset inputs.
1033 Status ComputeNumSources();
1034
1035 // Merges options from inputs to this dataset. If there is a conflict in a
1036 // field value, the options set on this dataset takes precedence over those in
1037 // the inputs. The order of precedence on the inputs is in the same order as
1038 // how they appear for this dataset.
1039 Status MergeOptionsFromInputs();
1040
1041 const string type_string_;
1042 const string node_name_;
1043 Options options_;
1044 // The number of source datasets feeding into the dataset. A source dataset is
1045 // a leaf in the subtree of dataset inputs.
1046 int64 num_sources_ = -1;
1047 };
1048
1049 // Represents an iterator that is associated with a particular dataset.
1050 class DatasetBaseIterator : public IteratorBase {
1051 public:
1052 struct BaseParams {
1053 // Owns one reference on the shared dataset object.
1054 const DatasetBase* dataset;
1055
1056 // Identifies the sequence of iterators leading up to this iterator.
1057 const string prefix;
1058 };
1059
1060 explicit DatasetBaseIterator(const BaseParams& params);
1061
1062 ~DatasetBaseIterator() override;
1063
dataset()1064 virtual const DatasetBase* dataset() const { return params_.dataset; }
1065
output_dtypes()1066 const DataTypeVector& output_dtypes() const override {
1067 return params_.dataset->output_dtypes();
1068 }
1069
output_shapes()1070 const std::vector<PartialTensorShape>& output_shapes() const override {
1071 return params_.dataset->output_shapes();
1072 }
1073
prefix()1074 const string& prefix() const override { return params_.prefix; }
1075
1076 // Returns a name to be used for the TraceMe event.
1077 //
1078 // NOTE: TraceMe supports passing key-value pairs of "arguments" using the
1079 // following format "name#arg_1=value_,...,arg_n=value_n".
1080 string BuildTraceMeName();
1081
1082 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
1083 bool* end_of_sequence) final;
1084
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1085 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
1086 bool* end_of_sequence) {
1087 return GetNext(&ctx, out_tensors, end_of_sequence);
1088 }
1089
1090 Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence,
1091 int* num_skipped) final;
1092
Save(SerializationContext * ctx,IteratorStateWriter * writer)1093 Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
1094 VLOG(2) << "Attempting to save checkpoints on iterator (prefix: "
1095 << prefix() << ") from " << dataset()->DebugString();
1096 return IteratorBase::Save(ctx, writer);
1097 }
1098
1099 protected:
Restore(IteratorContext * ctx,IteratorStateReader * reader)1100 Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final {
1101 VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: "
1102 << prefix() << ") from " << dataset()->DebugString();
1103 return IteratorBase::Restore(ctx, reader);
1104 }
1105
1106 // Internal implementation of GetNext that is wrapped in tracing logic.
1107 //
1108 // See the docstring of `GetNext` method regaring the contract for
1109 // `out_tensors` and `end_of_sequence`. Implementations may assume that
1110 // `*out_tensors` is empty.
1111 virtual Status GetNextInternal(IteratorContext* ctx,
1112 std::vector<Tensor>* out_tensors,
1113 bool* end_of_sequence) = 0;
1114
1115 // Internal implementation of Skip that is wrapped in tracing logic
1116 virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip,
1117 bool* end_of_sequence, int* num_skipped);
1118
full_name(const string & name)1119 string full_name(const string& name) const {
1120 return FullName(params_.prefix, name);
1121 }
1122
1123 // Returns a map of key-value pairs to included in the TraceMe string.
GetTraceMeMetadata()1124 virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; }
1125
1126 // By default we model iterators using an unknown node, which acts as
1127 // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)1128 std::shared_ptr<model::Node> CreateNode(
1129 IteratorContext* ctx, model::Node::Args args) const override {
1130 return model::MakeUnknownNode(std::move(args));
1131 }
1132
1133 // When modeling is enabled, this method disables autotuning for the given
1134 // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)1135 void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1136 if (iterator->node_) {
1137 iterator->node_->set_autotune(false);
1138 }
1139 }
1140
1141 // When modeling is enabled, this method enables autotuning for the given
1142 // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)1143 void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1144 if (iterator->node_) {
1145 iterator->node_->set_autotune(true);
1146 }
1147 }
1148
1149 // When modeling is enabled, this method records the fact that this iterator
1150 // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)1151 void RecordBufferDequeue(IteratorContext* ctx,
1152 const std::vector<Tensor>& element) {
1153 if (collect_resource_usage(ctx)) {
1154 node_->record_buffer_event(-GetAllocatedBytes(element), -1);
1155
1156 DCHECK_GE(node_->buffered_elements(), 0);
1157 }
1158 }
1159
1160 // When modeling is enabled, this method records the fact that this iterator
1161 // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)1162 void RecordBufferEnqueue(IteratorContext* ctx,
1163 const std::vector<Tensor>& element) {
1164 if (collect_resource_usage(ctx)) {
1165 node_->record_buffer_event(GetAllocatedBytes(element), 1);
1166 }
1167 }
1168
1169 // When modeling is enabled, this method records the fact that this iterator
1170 // has produced an element and its size in bytes.
RecordElement(IteratorContext * ctx,std::vector<Tensor> * out_tensors)1171 void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
1172 if (node_) {
1173 int64_t num_bytes = GetAllocatedBytes(*out_tensors);
1174 node_->record_element();
1175 node_->record_bytes_produced(num_bytes);
1176 if (node_->output()) {
1177 node_->output()->record_bytes_consumed(num_bytes);
1178 }
1179 }
1180 }
1181
1182 // When modeling is enabled, this method records the fact that a thread of
1183 // this iterator has started work.
RecordStart(IteratorContext * ctx)1184 void RecordStart(IteratorContext* ctx) {
1185 if (collect_resource_usage(ctx)) {
1186 int64_t now_nanos = EnvTime::NowNanos();
1187 node_->record_start(now_nanos);
1188 }
1189 }
1190
1191 // When modeling is enabled, this method records the fact that a thread of
1192 // this iterator has stopped work.
RecordStop(IteratorContext * ctx)1193 void RecordStop(IteratorContext* ctx) {
1194 if (collect_resource_usage(ctx)) {
1195 int64_t now_nanos = EnvTime::NowNanos();
1196 node_->record_stop(now_nanos);
1197 }
1198 }
1199
1200 // Returns whether work is currently being recorded, i.e. whether we are
1201 // currently between a `RecordStart` and a `RecordStop`.
IsRecording(IteratorContext * ctx)1202 bool IsRecording(IteratorContext* ctx) {
1203 return collect_resource_usage(ctx) && node_->is_recording();
1204 }
1205
1206 private:
collect_resource_usage(IteratorContext * ctx)1207 bool collect_resource_usage(IteratorContext* ctx) {
1208 auto model = ctx->model();
1209 return model && model->collect_resource_usage() && node_;
1210 }
1211
1212 string traceme_metadata_;
1213 BaseParams params_;
1214 };
1215
1216 // Represents an iterator that is associated with a particular dataset
1217 // with a particular type.
1218 template <class DatasetType>
1219 class DatasetIterator : public DatasetBaseIterator {
1220 public:
1221 struct Params {
1222 // Borrowed pointer to the dataset.
1223 const DatasetType* dataset;
1224
1225 // Identifies the sequence of iterators leading up to this iterator.
1226 const string prefix;
1227 };
1228
DatasetIterator(const Params & params)1229 explicit DatasetIterator(const Params& params)
1230 : DatasetBaseIterator({params.dataset, params.prefix}),
1231 typed_dataset_(params.dataset) {}
1232
1233 // The dataset from which this iterator was created.
dataset()1234 const DatasetType* dataset() const final { return typed_dataset_; }
1235
1236 private:
1237 const DatasetType* const typed_dataset_; // Not owned.
1238 };
1239
1240 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1241 Status ParseScalarArgument(OpKernelContext* ctx,
1242 const StringPiece& argument_name, T* output) {
1243 const Tensor* argument_t;
1244 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1245 if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1246 return errors::InvalidArgument(argument_name, " must be a scalar");
1247 }
1248 *output = argument_t->scalar<T>()();
1249 return Status::OK();
1250 }
1251
1252 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1253 Status ParseVectorArgument(OpKernelContext* ctx,
1254 const StringPiece& argument_name,
1255 std::vector<T>* output) {
1256 const Tensor* argument_t;
1257 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1258 if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1259 return errors::InvalidArgument(argument_name, " must be a vector");
1260 }
1261 int size = argument_t->vec<T>().size();
1262 output->reserve(size);
1263 for (int i = 0; i < size; ++i) {
1264 output->push_back(argument_t->vec<T>()(i));
1265 }
1266 return Status::OK();
1267 }
1268
1269 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1270 // graph execution engine.
1271 class DatasetOpKernel : public OpKernel {
1272 public:
DatasetOpKernel(OpKernelConstruction * ctx)1273 explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1274
1275 void Compute(OpKernelContext* ctx) final;
1276
1277 // Indicates whether the given op corresponds to an op whose kernels subclass
1278 // the `DatasetOpKernel` class.
1279 static bool IsDatasetOp(const OpDef* op_def);
1280
1281 string TraceString(const OpKernelContext& ctx, bool verbose) const override;
1282
1283 protected:
1284 // Subclasses should implement this method. It will be called during Compute
1285 // execution.
1286 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1287 };
1288
1289 // Encapsulates the work required to plug unary Datasets into the core
1290 // TensorFlow graph execution engine.
1291 class UnaryDatasetOpKernel : public DatasetOpKernel {
1292 public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1293 explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1294 : DatasetOpKernel(ctx) {}
1295
1296 protected:
1297 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1298 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1299 DatasetBase** output) = 0;
1300 };
1301
1302 // Encapsulates the work required to plug binary Datasets into the core
1303 // TensorFlow graph execution engine.
1304 class BinaryDatasetOpKernel : public DatasetOpKernel {
1305 public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1306 explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1307 : DatasetOpKernel(ctx) {}
1308
1309 protected:
1310 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1311 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1312 DatasetBase* another_input,
1313 DatasetBase** output) = 0;
1314 };
1315
1316 // A simple background worker that executes closures asynchronously and without
1317 // blocking.
1318 //
1319 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1320 // to avoid blocking an executor thread that may be required by the blocking
1321 // work.
1322 //
1323 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1324 // purpose because its current implementation (in Eigen) uses a finite-length
1325 // queue and will block the caller when full. This can lead to deadlock under
1326 // heavy load. Since the number of concurrent work items in each user of a
1327 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1328 // overhead is tolerable.
1329 class BackgroundWorker {
1330 public:
1331 BackgroundWorker(Env* env, const char* name);
1332
1333 ~BackgroundWorker();
1334
1335 void Schedule(std::function<void()> work_item);
1336
1337 private:
1338 void WorkerLoop();
1339
1340 Env* const env_;
1341 const char* const name_;
1342
1343 std::unique_ptr<Thread> thread_;
1344 mutex mu_;
1345 condition_variable cond_var_;
1346 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1347 std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_);
1348 };
1349
1350 } // namespace data
1351 } // namespace tensorflow
1352
1353 #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1354