1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/model.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/thread_factory.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variant_encode_decode.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/lib/core/threadpool_interface.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/lib/strings/strcat.h"
42 #include "tensorflow/core/platform/cpu_info.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/tracing.h"
45
46 // Polymorphic datasets should support all primitive TensorFlow
47 // types. Use this macro to expand `m(T)` once for each primitive type
48 // `T`, e.g. to build a `switch` statement.
49 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
50
51 namespace tensorflow {
52
53 // Forward declarations to avoid introducing a dependency on headers in
54 // "tensorflow/core/graph/...".
55 class GraphDefBuilder;
56 class Node;
57
58 namespace data {
59
60 using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
61
62 constexpr char kTFDataFunction[] = "_tf_data_function";
63
64 constexpr int kInfiniteCardinality = -1;
65 constexpr int kUnknownCardinality = -2;
66
67 // This constant is a magic number that is used (as a prefix) to identify keys
68 // used for serialization of iterator state.
69 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
70 constexpr char kPipe[] = "|";
71 constexpr char kColon[] = ":";
72
73 constexpr char kTFDataResourceTag[] = "tfdata";
74
75 class DatasetBase;
76 class SerializationContext;
77
IsTFDataFunction(const FunctionDef & func)78 inline bool IsTFDataFunction(const FunctionDef& func) {
79 auto iter = func.attr().find(data::kTFDataFunction);
80 return (iter != func.attr().end() && iter->second.b());
81 }
82
83 // Interface for reading values from a key-value store.
84 // Used for restoring iterator state. This class is thread safe.
85 // Please see comment on IteratorStateWriter for guidance around using the
86 // Read*(key, val) vs Read*(name, key, val).
87 class IteratorStateReader {
88 public:
89 virtual Status ReadScalar(StringPiece key, int64* val) const = 0;
90 virtual Status ReadScalar(StringPiece key, tstring* val) const = 0;
91 virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0;
92
93 virtual Status ReadScalar(StringPiece name, StringPiece key,
94 int64* val) const = 0;
95 virtual Status ReadScalar(StringPiece name, StringPiece key,
96 tstring* val) const = 0;
97 virtual Status ReadTensor(StringPiece name, StringPiece key,
98 Tensor* val) const = 0;
99
100 virtual bool Contains(StringPiece key) const = 0;
101 virtual bool Contains(StringPiece name, StringPiece key) const = 0;
102
~IteratorStateReader()103 virtual ~IteratorStateReader() {}
104 };
105
106 // Interface for writing values to a key-value store.
107 // Used for saving iterator state. Not thread safe.
108 // The IteratorStateWriter creates a tensor for each unique iterator name it
109 // sees. For the Write*(key, val) API's the key is expected to encode this
110 // name as keys are required to be produced using the full_name() method.
111 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
112 // might exceed the 2 GB limit, you can pass an explicit name in via the
113 // Write*(name, key, val) APIs allowing you to further split up the state
114 // into more manageable chunks.
115 class IteratorStateWriter {
116 public:
117 virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
118 virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
119 virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
120
121 virtual Status WriteScalar(StringPiece name, StringPiece key,
122 const int64 val) = 0;
123 virtual Status WriteScalar(StringPiece name, StringPiece key,
124 const tstring& val) = 0;
125 virtual Status WriteTensor(StringPiece name, StringPiece key,
126 const Tensor& val) = 0;
127
~IteratorStateWriter()128 virtual ~IteratorStateWriter() {}
129 };
130
131 // Generates a full name key for iterator checkpointing. All keys generated for
132 // iterator checkpoints should go through this function.
133 std::string FullName(const std::string& prefix, const std::string& name);
134
135 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
136 class GraphDefBuilderWrapper {
137 public:
GraphDefBuilderWrapper(GraphDefBuilder * b)138 explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
139
140 // Adds a Const node with scalar value to the Graph.
141 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
142 // non-null if the method returns with an OK status.
143 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
144 template <typename T>
AddScalar(const T & val,Node ** output)145 Status AddScalar(const T& val, Node** output) {
146 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
147 val_t.scalar<T>()() = val;
148 AddTensorInternal(val_t, output);
149 if (*output == nullptr) {
150 return errors::Internal("AddScalar: Failed to build Const op.");
151 }
152 return Status::OK();
153 }
154
155 // Adds a Const node with vector value to the Graph.
156 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
157 // non-null if the method returns with an OK status.
158 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
159 // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
160 template <typename T>
AddVector(const std::vector<T> & val,Node ** output)161 Status AddVector(const std::vector<T>& val, Node** output) {
162 Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
163 TensorShape({static_cast<int64>(val.size())}));
164 for (size_t i = 0; i < val.size(); i++) {
165 val_t.flat<T>()(i) = val[i];
166 }
167 AddTensorInternal(val_t, output);
168 if (*output == nullptr) {
169 return errors::Internal("AddVector: Failed to build Const op.");
170 }
171 return Status::OK();
172 }
173
AddVector(const std::vector<string> & val,Node ** output)174 Status AddVector(const std::vector<string>& val, Node** output) {
175 Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
176 TensorShape({static_cast<int64>(val.size())}));
177 for (size_t i = 0; i < val.size(); i++) {
178 val_t.flat<tstring>()(i) = val[i];
179 }
180 AddTensorInternal(val_t, output);
181 if (*output == nullptr) {
182 return errors::Internal("AddVector: Failed to build Const op.");
183 }
184 return Status::OK();
185 }
186
187 // Adds a `Const` node for the given tensor value to the graph.
188 //
189 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
190 // non-null if the method returns with an OK status. The returned `Node`
191 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)192 Status AddTensor(const Tensor& val, Node** output) {
193 AddTensorInternal(val, output);
194 if (*output == nullptr) {
195 return errors::Internal("AddTensor: Failed to build Const op.");
196 }
197 return Status::OK();
198 }
199
200 // Adds a `Placeholder` node for the given tensor value to the graph.
201 //
202 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
203 // non-null if the method returns with an OK status. The returned `Node`
204 // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)205 Status AddPlaceholder(const Tensor& val, Node** output) {
206 AddPlaceholderInternal(val, output);
207 if (*output == nullptr) {
208 return errors::Internal(
209 "AddPlaceholder: Failed to build Placeholder op.");
210 }
211 return Status::OK();
212 }
213
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)214 Status AddDataset(const DatasetBase* dataset,
215 const std::vector<Node*>& inputs, Node** output) {
216 return AddDataset(dataset, inputs, {}, output);
217 }
218
219 // Adds a node corresponding to the `DatasetType` to the Graph.
220 // Return value of `DatasetType::type_string()` is used as the op type for the
221 // node.
222 // Values for the output_types and output_shapes node attributes are also
223 // written if those attributes are defined in the OpDef.
224 // `*output` contains a pointer to the output `Node`. It is guaranteed to be
225 // non-null if the method returns with an OK status.
226 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)227 Status AddDataset(const DatasetBase* dataset,
228 const std::vector<Node*>& inputs,
229 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
230 Node** output) {
231 std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
232 for (size_t i = 0; i < inputs.size(); i++) {
233 enumerated_inputs[i] = std::make_pair(i, inputs[i]);
234 }
235 return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
236 }
237
238 Status AddDataset(
239 const DatasetBase* dataset,
240 const std::vector<std::pair<size_t, Node*>>& inputs,
241 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
242 const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
243 Node** output);
244
245 // Adds a user-defined function with name `function_name` to the graph and
246 // recursively adds all functions it references. If a function with a matching
247 // name has already been added, returns with OK status. If a user-defined with
248 // name `function_name` is not found in the context's function library,
249 // returns an InvalidArgumentError. If the function with name `function_name`
250 // or any of its dependent functions are stateful, and the context does not
251 // explicitly permit stateful functions, returns an InvalidArgument error.
252 Status AddFunction(SerializationContext* ctx, const string& function_name,
253 const FunctionLibraryDefinition& lib_def);
254
255 template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)256 void BuildAttrValue(const T& value, AttrValue* attr) {
257 SetAttrValue(value, attr);
258 }
259
260 protected:
builder()261 GraphDefBuilder* builder() { return b_; }
262
263 private:
264 void AddPlaceholderInternal(const Tensor& val, Node** output);
265 void AddTensorInternal(const Tensor& val, Node** output);
266 bool HasAttr(const string& op_type_name, const string& attr_name) const;
267
HasAttr(const OpDef * op_def,const string & attr_name)268 bool HasAttr(const OpDef* op_def, const string& attr_name) const {
269 for (const auto& attr : op_def->attr()) {
270 if (attr.name() == attr_name) {
271 return true;
272 }
273 }
274 return false;
275 }
276
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)277 Status AddAttrFunctions(SerializationContext* ctx,
278 const AttrValue& attr_value,
279 const FunctionLibraryDefinition& lib_def) {
280 if (attr_value.has_func()) {
281 TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
282 } else if (attr_value.has_list()) {
283 for (const NameAttrList& name_attr_list : attr_value.list().func()) {
284 TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
285 }
286 }
287 return Status::OK();
288 }
289
290 GraphDefBuilder* b_;
291 };
292
293 class StatsAggregator;
294 class FunctionHandleCache;
295
296 // A utility class for running a function and ensuring that there is always a
297 // `tensorflow::data` symbol on the stack.
298 class Runner {
299 public:
~Runner()300 virtual ~Runner() {}
301
302 // Runs the given function.
303 virtual void Run(const std::function<void()>& f) = 0;
304
305 // Returns a global singleton Runner.
306 static Runner* get();
307 };
308
309 // A class which provides a sequence of splits. Iterators created with a split
310 // provider will iterate over only the splits provided by the split provider.
311 class SplitProvider {
312 public:
~SplitProvider()313 virtual ~SplitProvider() {}
314 // Stores the next split in `*split`, setting `*end_of_splits` to indicate
315 // whether there were any splits left.
316 virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
317 // Resets the split provider to its beginning.
318 virtual Status Reset() = 0;
319 // Saves the state of this split provider.
320 virtual Status Save(std::function<std::string(std::string)> full_name,
321 IteratorStateWriter* writer) = 0;
322 // Restores the state of this split provider.
323 virtual Status Restore(std::function<std::string(std::string)> full_name,
324 IteratorStateReader* reader) = 0;
325 };
326
327 // A cut-down version of `OpKernelContext` for running computations in
328 // iterators. Note that we cannot simply use `OpKernelContext` here because we
329 // might run computation in an iterator whose lifetime is not nested within the
330 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
331 //
332 // TODO(mrry): We're making some daring assumptions about the lifetime of the
333 // runner passed in here. A runner will be deleted when the original step ends,
334 // but all existing runners only close over session-lifetime (or longer-lived)
335 // state, so we can make a copy of the function. There's nothing in the
336 // definition of the API from which we took the runner to guarantee that what we
337 // are doing is safe. We should formalize the properties here.
338 class IteratorContext {
339 public:
340 struct Params {
ParamsParams341 explicit Params(IteratorContext* ctx)
342 : allocator_getter(ctx->allocator_getter()),
343 cancellation_manager(ctx->cancellation_manager()),
344 env(ctx->env()),
345 flr(ctx->flr()),
346 function_handle_cache(ctx->function_handle_cache()),
347 resource_mgr(ctx->resource_mgr()),
348 model(ctx->model()),
349 runner(*(ctx->runner())),
350 runner_threadpool_size(ctx->runner_threadpool_size()),
351 split_provider(ctx->split_provider()),
352 stats_aggregator(ctx->stats_aggregator()),
353 thread_factory(ctx->thread_factory()),
354 thread_pool(ctx->thread_pool()) {}
355
ParamsParams356 explicit Params(OpKernelContext* ctx)
357 : env(ctx->env()), flr(ctx->function_library()) {
358 // NOTE: need reinterpret_cast because function.h forward-declares Device.
359 DeviceBase* device =
360 reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
361 allocator_getter = [device](AllocatorAttributes attrs) {
362 return device->GetAllocator(attrs);
363 };
364
365 thread::ThreadPool* thread_pool =
366 ctx->device()->tensorflow_device_thread_pool();
367 if (thread_pool) {
368 runner_threadpool_size = thread_pool->NumThreads();
369 } else {
370 static const int32 kDefaultRunnerThreadpoolSize =
371 port::MaxParallelism();
372 runner_threadpool_size = kDefaultRunnerThreadpoolSize;
373 }
374
375 // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
376 // that a symbol in the tensorflow::data namespace is always on the stack
377 // when executing a function inside a Dataset.
378 runner = std::bind(
379 [](
380 // Note: `runner` is a const reference to avoid copying it.
381 const std::function<void(std::function<void()>)>& ctx_runner,
382 std::function<void()> fn) {
383 std::function<void()> wrapped_fn = std::bind(
384 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
385 std::move(fn));
386 ctx_runner(std::move(wrapped_fn));
387 },
388 *ctx->runner(), std::placeholders::_1);
389 }
390
391 // The Allocator to be used to allocate the output of an iterator.
392 std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
393
394 // The CancellationManager to be used to cancel execution of ops.
395 CancellationManager* cancellation_manager;
396
397 // Interface to operating system functionality.
398 Env* env = nullptr;
399
400 // The FunctionLibraryRuntime object to be used to make function calls.
401 FunctionLibraryRuntime* flr = nullptr;
402
403 // A FunctionHandleCache that owns all the function handles. Not owned.
404 FunctionHandleCache* function_handle_cache = nullptr;
405
406 // A resource manager for storing dataset-related state, e.g. random
407 // seeds or cached tensors. Not owned.
408 ResourceMgr* resource_mgr = nullptr;
409
410 // If non-null, identifies the object used for performance modeling.
411 std::shared_ptr<model::Model> model = nullptr;
412
413 // Function call support.
414 std::function<void(std::function<void()>)> runner = nullptr;
415
416 // Number of threads used for executing user-defined functions.
417 int32 runner_threadpool_size = 0;
418
419 // An optional split provider indicating which splits to process.
420 std::shared_ptr<SplitProvider> split_provider = nullptr;
421
422 // The `StatsAggregator` object to record statistics about the iterator.
423 std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
424
425 // A factory for creating threads to perform blocking work.
426 std::shared_ptr<ThreadFactory> thread_factory = nullptr;
427
428 // A shared thread pool to schedule computation into.
429 thread::ThreadPoolInterface* thread_pool = nullptr;
430 };
431
IteratorContext(IteratorContext * ctx)432 explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
433
IteratorContext(OpKernelContext * ctx)434 explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
435
IteratorContext(Params params)436 explicit IteratorContext(Params params) : params_(std::move(params)) {}
437
allocator(AllocatorAttributes attrs)438 Allocator* allocator(AllocatorAttributes attrs) {
439 return params_.allocator_getter(attrs);
440 }
441
allocator_getter()442 std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
443 return params_.allocator_getter;
444 }
445
cancellation_manager()446 CancellationManager* cancellation_manager() {
447 return params_.cancellation_manager;
448 }
449
env()450 Env* env() const { return params_.env; }
451
flr()452 FunctionLibraryRuntime* flr() { return params_.flr; }
453
function_handle_cache()454 FunctionHandleCache* function_handle_cache() {
455 return params_.function_handle_cache;
456 }
457
resource_mgr()458 ResourceMgr* resource_mgr() { return params_.resource_mgr; }
459
model()460 const std::shared_ptr<model::Model>& model() { return params_.model; }
461
runner()462 std::function<void(std::function<void()>)>* runner() {
463 return ¶ms_.runner;
464 }
465
runner_threadpool_size()466 int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
467
split_provider()468 std::shared_ptr<SplitProvider> split_provider() {
469 return params_.split_provider;
470 }
471
stats_aggregator()472 std::shared_ptr<StatsAggregator> stats_aggregator() {
473 return params_.stats_aggregator;
474 }
475
thread_factory()476 const std::shared_ptr<ThreadFactory>& thread_factory() {
477 return params_.thread_factory;
478 }
479
thread_pool()480 thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
481
params()482 Params params() { return params_; }
483
CreateThreadPool(const string & name,int num_threads)484 std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
485 int num_threads) {
486 if (params_.thread_pool) {
487 // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
488 // is an instance of `thread::ThreadPoolInterface`). Notably, the
489 // ownership of `params_.thread_pool` is *not* transferred onto the newly
490 // created `ThreadPool` instance.
491 return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
492 } else {
493 return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
494 name, num_threads,
495 /*low_latency_hint=*/false);
496 }
497 }
498
StartThread(const string & name,std::function<void ()> fn)499 std::unique_ptr<Thread> StartThread(const string& name,
500 std::function<void()> fn) {
501 if (params_.thread_factory) {
502 return params_.thread_factory->StartThread(name, std::move(fn));
503 } else {
504 return absl::WrapUnique(
505 Env::Default()->StartThread({}, name, std::move(fn)));
506 }
507 }
508
509 private:
510 Params params_;
511 };
512
513 // Aggregates runtime support needed for dataset and iterator serialization.
514 class SerializationContext {
515 public:
516 // Enum describing what to do during serialization when external state is
517 // encountered.
518 enum class ExternalStatePolicy : int64 {
519 // Proceed with serialization, but log a warning about what state will be
520 // lost.
521 kWarn = 0,
522 // Proceed with serialization without logging any warning.
523 kIgnore = 1,
524 // Fail the serialization with an error.
525 kFail = 2,
526 };
527
528 // Handles the CheckExternalState status according to the external state
529 // policy.
HandleCheckExternalStateStatus(Status s)530 Status HandleCheckExternalStateStatus(Status s) {
531 if (s.ok()) {
532 return s;
533 }
534 switch (params_.external_state_policy) {
535 case ExternalStatePolicy::kWarn:
536 LOG(WARNING) << s.ToString();
537 return Status::OK();
538 case ExternalStatePolicy::kIgnore:
539 VLOG(2) << "Ignoring error status: " << s.ToString();
540 return Status::OK();
541 case ExternalStatePolicy::kFail:
542 return s;
543 }
544 LOG(FATAL) << "Control should never reach here";
545 }
546
547 struct Params {
548 std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
549
550 // Indicates what to do if the dataset depends on external state.
551 ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
552
553 // Indicates whether an attempt to serialize a dataset that does not
554 // implement serialization should result in an error. If set to `false`, the
555 // serialized graph will replace the dataset with a placeholder returned in
556 // `input_list`.
557 bool fail_if_unimplemented = true;
558
559 // Indicates whether (potentially large) data tensors should be
560 // serialized, or replaced with a placeholder returned in `input_list`. The
561 // latter makes sense to do when performing data agnostic graph rewrites to
562 // reduce the memory usage.
563 bool serialize_data_tensors = true;
564
565 // Indicates whether datasets that use random seeds should have the values
566 // of random seeds serialized or not. If the values of random seeds are
567 // serialized, the deserialized dataset will have the same seeds as the
568 // original dataset. Otherwise, the deserialized dataset will use different
569 // seeds. This param does not affect datasets that use fixed seeds; fixed
570 // seeds will always be preserved.
571 bool preserve_random_seeds = true;
572 };
573
SerializationContext(Params params)574 explicit SerializationContext(Params params) : params_(params) {}
575
input_list()576 std::vector<std::pair<string, Tensor>>* input_list() {
577 return params_.input_list;
578 }
579
external_state_policy()580 ExternalStatePolicy external_state_policy() const {
581 return params_.external_state_policy;
582 }
583
fail_if_unimplemented()584 bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
585
serialize_data_tensors()586 bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
587
preserve_random_seeds()588 bool preserve_random_seeds() const { return params_.preserve_random_seeds; }
589
590 private:
591 Params params_;
592
593 TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
594 };
595
596 // Represents the current position in a range of outputs, where the
597 // range of outputs is typically represented by an `DatasetBase`,
598 // defined below.
599 class IteratorBase {
600 public:
~IteratorBase()601 virtual ~IteratorBase() {
602 for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
603 (*rit)();
604 }
605 }
606
607 // Gets the next output from the range that this iterator is traversing.
608 //
609 // If at least one output remains in this iterator's range, that
610 // output will be stored in `*out_tensors` and `false` will be
611 // stored in `*end_of_sequence`.
612 //
613 // If no more outputs remain in this iterator's range, `true` will
614 // be stored in `*end_of_sequence`, and the content of
615 // `*out_tensors` will be undefined.
616 //
617 // Implementations should never return `OutOfRange` error. If at end of
618 // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
619 // Internally raised `OutOfRange` errors that do not imply end of sequence
620 // should be converted to a different error type before being propagated to
621 // the caller.
622 //
623 // This method is thread-safe.
624 //
625 // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
626 // potentially remove this method.
627 virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
628 bool* end_of_sequence) = 0;
629
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)630 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
631 bool* end_of_sequence) {
632 return GetNext(&ctx, out_tensors, end_of_sequence);
633 }
634
635 // Skips the next `num_to_skip` outputs from the range that this iterator
636 // is traversing.
637 //
638 // If there are not enough outputs to skip, it will set
639 // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will
640 // store the number of outputs that are skipped. When `*end_of_sequence` is
641 // `false`, `*num_skipped` should equal to `num_to_skip`.
642 virtual Status Skip(IteratorContext* ctx, int num_to_skip,
643 bool* end_of_sequence, int* num_skipped) = 0;
644
645 // Returns a vector of DataType values, representing the respective
646 // element types of each tuple component in the outputs of this
647 // iterator.
648 virtual const DataTypeVector& output_dtypes() const = 0;
649
650 // Returns a vector of tensor shapes, representing the respective
651 // (and possibly partially defined) shapes of each tuple component
652 // in the outputs of this iterator.
653 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
654
655 // Returns a string that identifies the sequence of iterators leading up to
656 // this iterator.
657 virtual const string& prefix() const = 0;
658
659 // Performs initialization that needs to happen outside of a constructor to
660 // properly propagate errors.
Initialize(IteratorContext * ctx)661 virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
662
663 // Performs initialization of the base iterator.
664 Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
665
666 // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)667 virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
668 int64 start_us = EnvTime::NowMicros();
669 TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
670 VLOG(1) << "Saved " << prefix() << " in "
671 << (EnvTime::NowMicros() - start_us) << "us";
672 return Status::OK();
673 }
674
675 protected:
676 // Returns a node that models this iterator.
677 virtual std::shared_ptr<model::Node> CreateNode(
678 IteratorContext* ctx, model::Node::Args args) const = 0;
679
680 // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)681 Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
682 int64 start_us = EnvTime::NowMicros();
683 TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
684 VLOG(1) << "Restored " << prefix() << " in "
685 << (EnvTime::NowMicros() - start_us) << "us";
686 return Status::OK();
687 }
688
689 // This is needed so that sub-classes of IteratorBase can call
690 // `SaveInternal` on their input iterators.
SaveInput(SerializationContext * ctx,IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)691 Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
692 const std::unique_ptr<IteratorBase>& input) {
693 int64 start_us = EnvTime::NowMicros();
694 TF_RETURN_IF_ERROR(input->SaveInternal(ctx, writer));
695 VLOG(2) << "Saved " << input->prefix() << " in "
696 << (EnvTime::NowMicros() - start_us) << "us";
697 return Status::OK();
698 }
699
700 // This is needed so that sub-classes of IteratorBase can call
701 // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)702 Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
703 const std::unique_ptr<IteratorBase>& input) {
704 int64 start_us = EnvTime::NowMicros();
705 TF_RETURN_IF_ERROR(input->RestoreInternal(ctx, reader));
706 VLOG(2) << "Restored " << input->prefix() << " in "
707 << (EnvTime::NowMicros() - start_us) << "us";
708 return Status::OK();
709 }
710
RestoreInput(IteratorContext && ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)711 Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader,
712 const std::unique_ptr<IteratorBase>& input) {
713 return RestoreInput(&ctx, reader, input);
714 }
715
716 // Saves the state of this iterator.
717 //
718 // This method is used to store the state of the iterator in a checkpoint.
719 // implementations have an override.
720 virtual Status SaveInternal(SerializationContext* ctx,
721 IteratorStateWriter* writer) = 0;
722
723 // Restores the state of this iterator.
724 //
725 // This method is used to restore the state of the iterator from a checkpoint.
726 //
727 // Implementations may assume that the iterator is in a clean state. That is,
728 // its `Initialize` method has been called, but its `GetNext` method has
729 // never been called.
730 // implementations have an override.
731 virtual Status RestoreInternal(IteratorContext* ctx,
732 IteratorStateReader* reader) = 0;
733
734 // Returns a pointer to the node representing this iterator in the performance
735 // model. It may be null, if performance modeling is not enabled for this
736 // iterator.
model_node()737 std::shared_ptr<model::Node> model_node() const { return node_; }
738
739 // Returns the number of elements produced by this iterator.
num_elements()740 int64 num_elements() const {
741 if (node_) return node_->num_elements();
742 return 0;
743 }
744
745 private:
746 // For access to `AddCleanupFunction` and `Restore`.
747 friend class DatasetBase;
748 friend class DatasetBaseIterator; // for access to `node_`
749
750 std::vector<std::function<void()>> cleanup_fns_;
751 std::shared_ptr<model::Node> node_ = nullptr;
752 const IteratorBase* parent_ = nullptr; // Not owned.
753 int64 id_ = 0;
754 int64 parent_id_ = 0;
755 };
756
757 // Represents runtime information needed to construct a dataset.
758 class DatasetContext {
759 public:
760 struct Params {
761 string type_string; // op type name of this dataset.
762 string node_name; // graph node name of this dataset op, uniquely
763 // identifying the dataset in the graph.
764 };
765
DatasetContext(Params params)766 explicit DatasetContext(Params params) : params_(std::move(params)) {}
767
DatasetContext(OpKernelContext * ctx)768 explicit DatasetContext(OpKernelContext* ctx) {
769 params_.type_string = ctx->op_kernel().type_string();
770 params_.node_name = ctx->op_kernel().name();
771 }
772
type_string()773 const string& type_string() const { return params_.type_string; }
node_name()774 const string& node_name() const { return params_.node_name; }
775
776 private:
777 Params params_;
778 };
779
780 // Returns the number of bytes allocated for the given tensor.
781 int64 GetAllocatedBytes(const std::vector<Tensor>& element);
782
783 // Returns the estimated memory usage in bytes of the given tensor.
784 int64 GetTotalBytes(const std::vector<Tensor>& element);
785
786 // Validates and extracts a `DatasetBase` object from `tensor`.
787 //
788 // `tensor` must have been written by a call to SetVariantTensorToDataset().
789 //
790 // The retrieved pointer is a borrowed reference to the dataset, which is owned
791 // by the tensor. The consumer must either acquire its own reference to the
792 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
793 // destroyed or mutated while the retrieved pointer is in use.
794 Status GetDatasetFromVariantTensor(const Tensor& tensor,
795 DatasetBase** out_dataset);
796
797 // Stores a `DatasetBase` object in `tensor`.
798 //
799 // The ownership of `dataset` is transferred to `tensor`.
800 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
801
802 // Represents a (potentially infinite) range of outputs, where each
803 // output is a tuple of tensors.
804 class DatasetBase : public core::RefCounted {
805 public:
806 // Key for storing the Dataset graph in the serialized format.
807 TF_EXPORT static const char kDatasetGraphKey[];
808
809 // Key for storing the output node of the Dataset graph in the serialized
810 // format.
811 TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
812
DatasetBase(DatasetContext && ctx)813 explicit DatasetBase(DatasetContext&& ctx)
814 : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
815
816 // Op type name of this dataset.
type_string()817 const string& type_string() const { return type_string_; }
818
819 // Graph node name of this dataset op, uniquely identifying the dataset in
820 // the graph.
node_name()821 const string& node_name() const { return node_name_; }
822
823 // Returns a new iterator for iterating over the range of elements in
824 // this dataset.
825 //
826 // This method may be called multiple times on the same instance,
827 // and the resulting iterators will have distinct state. Each
828 // iterator will traverse all elements in this dataset from the
829 // start.
830 //
831 // The prefix identifies the sequence of iterators leading up to the newly
832 // created iterator.
833 Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent,
834 const string& output_prefix,
835 std::unique_ptr<IteratorBase>* iterator) const;
836
MakeIterator(IteratorContext && ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)837 Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent,
838 const string& output_prefix,
839 std::unique_ptr<IteratorBase>* iterator) const {
840 return MakeIterator(&ctx, parent, output_prefix, iterator);
841 }
842
843 // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)844 Status MakeIteratorFromCheckpoint(
845 IteratorContext* ctx, const string& output_prefix,
846 IteratorStateReader* reader,
847 std::unique_ptr<IteratorBase>* iterator) const {
848 std::unique_ptr<IteratorBase> it;
849 TF_RETURN_IF_ERROR(
850 MakeIterator(ctx, /*parent=*/nullptr, output_prefix, &it));
851 TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
852 *iterator = std::move(it);
853 return Status::OK();
854 }
855
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)856 Status MakeIteratorFromCheckpoint(
857 IteratorContext&& ctx, const string& output_prefix,
858 IteratorStateReader* reader,
859 std::unique_ptr<IteratorBase>* iterator) const {
860 return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
861 }
862
863 // Returns a split provider which partitions the dataset's data into splits
864 // and provides them in a sequence. The split provider is stored in
865 // `*split_provider`.
866 virtual Status MakeSplitProvider(
867 std::unique_ptr<SplitProvider>* split_provider) const;
868
869 // Returns a vector of DataType values, representing the respective
870 // element types of each tuple component in the outputs of this
871 // dataset.
872 virtual const DataTypeVector& output_dtypes() const = 0;
873
874 // Returns a vector of tensor shapes, representing the respective
875 // (and possibly partially defined) shapes of each tuple component
876 // in the outputs of this dataset.
877 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
878
879 // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()880 virtual int64 AllocatedBytes() const { return 0; }
881
882 // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()883 virtual int64 TotalBytes() const { return 0; }
884
885 // Returns the cardinality of this dataset.
Cardinality()886 virtual int64 Cardinality() const { return kUnknownCardinality; }
887
888 // A human-readable debug string for this dataset.
889 virtual string DebugString() const = 0;
890
891 // Stores the dataset's input datasets in `*inputs`. The pointers stored in
892 // `*inputs` are borrowed. The only valid non-ok return status is
893 // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
894 // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
895 // default implementation of `MakeSplitProvider` when there is a single input
896 // dataset.
897 virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
898
899 // Indicates whether the dataset depends on any external state which would
900 // prevent it from being serializable. If so, the method returns
901 // `errors::FailedPrecondition` with a message that identifies the external
902 // state. Otherwise, the method returns `Status::OK()`.
903 virtual Status CheckExternalState() const = 0;
904
905 protected:
906 friend Status AsGraphDef(
907 OpKernelContext* ctx, const DatasetBase* dataset,
908 SerializationContext&& serialization_ctx,
909 GraphDef* graph_def); // For access to graph related members.
910 friend class CapturedFunction;
911
912 class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
913 public:
DatasetGraphDefBuilder(GraphDefBuilder * b)914 explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
915 : GraphDefBuilderWrapper(b) {}
916 Status AddInputDataset(SerializationContext* ctx,
917 const DatasetBase* dataset, Node** output);
918 Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val,
919 Node** output);
920
921 private:
922 Status AddDatasetOrTensorHelper(SerializationContext* ctx,
923 const Tensor& val, Node** output);
924 };
925
926 // Serializes the dataset into a `GraphDef`, which has two uses:
927 //
928 // 1) To perform static input pipeline optimizations, tf.data serializes the
929 // dataset graph, applies graph rewrites, and then deserializes the graph.
930 // If a subclass of `DatasetBase` does not implement this method, then it will
931 // be excluded from static optimizations (and so will any upstream datasets).
932 //
933 // 2) To save the dataset so that it can restore at a later point (possibly in
934 // different environment). If a subclass of `DatasetBase` does not implement
935 // this method, then this migration will not be possible.
936 virtual Status AsGraphDefInternal(SerializationContext* ctx,
937 DatasetGraphDefBuilder* b,
938 Node** node) const = 0;
939
940 virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
941 const string& prefix) const = 0;
942
943 private:
944 const string type_string_;
945 const string node_name_;
946 };
947
948 // Represents an iterator that is associated with a particular dataset.
949 class DatasetBaseIterator : public IteratorBase {
950 public:
951 struct BaseParams {
952 // Owns one reference on the shared dataset object.
953 const DatasetBase* dataset;
954
955 // Identifies the sequence of iterators leading up to this iterator.
956 const string prefix;
957 };
958
959 explicit DatasetBaseIterator(const BaseParams& params);
960
961 ~DatasetBaseIterator() override;
962
dataset()963 virtual const DatasetBase* dataset() const { return params_.dataset; }
964
output_dtypes()965 const DataTypeVector& output_dtypes() const override {
966 return params_.dataset->output_dtypes();
967 }
968
output_shapes()969 const std::vector<PartialTensorShape>& output_shapes() const override {
970 return params_.dataset->output_shapes();
971 }
972
prefix()973 const string& prefix() const override { return params_.prefix; }
974
975 // Returns a name to be used for the TraceMe event.
976 //
977 // NOTE: TraceMe supports passing key-value pairs of "arguments" using the
978 // following format "name#arg_1=value_,...,arg_n=value_n".
979 string BuildTraceMeName();
980
981 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
982 bool* end_of_sequence) final;
983
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)984 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
985 bool* end_of_sequence) {
986 return GetNext(&ctx, out_tensors, end_of_sequence);
987 }
988
989 Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence,
990 int* num_skipped) final;
991
Save(SerializationContext * ctx,IteratorStateWriter * writer)992 Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
993 return IteratorBase::Save(ctx, writer);
994 }
995
996 protected:
997 // Internal implementation of GetNext that is wrapped in tracing logic.
998 virtual Status GetNextInternal(IteratorContext* ctx,
999 std::vector<Tensor>* out_tensors,
1000 bool* end_of_sequence) = 0;
1001
1002 // Internal implementation of Skip that is wrapped in tracing logic
1003 virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip,
1004 bool* end_of_sequence, int* num_skipped);
1005
full_name(const string & name)1006 string full_name(const string& name) const {
1007 return FullName(params_.prefix, name);
1008 }
1009
1010 // Returns a map of key-value pairs to included in the TraceMe string.
GetTraceMeMetadata()1011 virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; }
1012
1013 // By default we model iterators using an unknown node, which acts as
1014 // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)1015 std::shared_ptr<model::Node> CreateNode(
1016 IteratorContext* ctx, model::Node::Args args) const override {
1017 return model::MakeUnknownNode(std::move(args));
1018 }
1019
1020 // When modeling is enabled, this method disables autotuning for the given
1021 // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)1022 void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1023 if (iterator->node_) {
1024 iterator->node_->set_autotune(false);
1025 }
1026 }
1027
1028 // When modeling is enabled, this method enables autotuning for the given
1029 // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)1030 void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1031 if (iterator->node_) {
1032 iterator->node_->set_autotune(true);
1033 }
1034 }
1035
1036 // When modeling is enabled, this method records the fact that this iterator
1037 // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)1038 void RecordBufferDequeue(IteratorContext* ctx,
1039 const std::vector<Tensor>& element) {
1040 if (collect_resource_usage(ctx)) {
1041 node_->record_buffer_event(-GetAllocatedBytes(element), -1);
1042 }
1043 }
1044
1045 // When modeling is enabled, this method records the fact that this iterator
1046 // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)1047 void RecordBufferEnqueue(IteratorContext* ctx,
1048 const std::vector<Tensor>& element) {
1049 if (collect_resource_usage(ctx)) {
1050 node_->record_buffer_event(GetAllocatedBytes(element), 1);
1051 }
1052 }
1053
1054 // When modeling is enabled, this method records the fact that this iterator
1055 // has produced an element and its size in bytes.
RecordElement(IteratorContext * ctx,std::vector<Tensor> * out_tensors)1056 void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
1057 if (node_) {
1058 int64 num_bytes = GetAllocatedBytes(*out_tensors);
1059 node_->record_element();
1060 node_->record_bytes_produced(num_bytes);
1061 if (node_->output()) {
1062 node_->output()->record_bytes_consumed(num_bytes);
1063 }
1064 }
1065 }
1066
1067 // When modeling is enabled, this method records the fact that a thread of
1068 // this iterator has started work.
RecordStart(IteratorContext * ctx)1069 void RecordStart(IteratorContext* ctx) {
1070 if (collect_resource_usage(ctx)) {
1071 int64 now_nanos = EnvTime::NowNanos();
1072 node_->record_start(now_nanos);
1073 }
1074 }
1075
1076 // When modeling is enabled, this method records the fact that a thread of
1077 // this iterator has stopped work.
RecordStop(IteratorContext * ctx)1078 void RecordStop(IteratorContext* ctx) {
1079 if (collect_resource_usage(ctx)) {
1080 int64 now_nanos = EnvTime::NowNanos();
1081 node_->record_stop(now_nanos);
1082 }
1083 }
1084
1085 // Returns whether work is currently being recorded, i.e. whether we are
1086 // currently between a `RecordStart` and a `RecordStop`.
IsRecording(IteratorContext * ctx)1087 bool IsRecording(IteratorContext* ctx) {
1088 return collect_resource_usage(ctx) && node_->is_recording();
1089 }
1090
1091 private:
collect_resource_usage(IteratorContext * ctx)1092 bool collect_resource_usage(IteratorContext* ctx) {
1093 auto model = ctx->model();
1094 return model && model->collect_resource_usage() && node_;
1095 }
1096
1097 BaseParams params_;
1098 };
1099
1100 // Represents an iterator that is associated with a particular dataset
1101 // with a particular type.
1102 template <class DatasetType>
1103 class DatasetIterator : public DatasetBaseIterator {
1104 public:
1105 struct Params {
1106 // Borrowed pointer to the dataset.
1107 const DatasetType* dataset;
1108
1109 // Identifies the sequence of iterators leading up to this iterator.
1110 const string prefix;
1111 };
1112
DatasetIterator(const Params & params)1113 explicit DatasetIterator(const Params& params)
1114 : DatasetBaseIterator({params.dataset, params.prefix}),
1115 typed_dataset_(params.dataset) {}
1116
1117 // The dataset from which this iterator was created.
dataset()1118 const DatasetType* dataset() const final { return typed_dataset_; }
1119
1120 private:
1121 const DatasetType* const typed_dataset_; // Not owned.
1122 };
1123
1124 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1125 Status ParseScalarArgument(OpKernelContext* ctx,
1126 const StringPiece& argument_name, T* output) {
1127 const Tensor* argument_t;
1128 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1129 if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1130 return errors::InvalidArgument(argument_name, " must be a scalar");
1131 }
1132 *output = argument_t->scalar<T>()();
1133 return Status::OK();
1134 }
1135
1136 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1137 Status ParseVectorArgument(OpKernelContext* ctx,
1138 const StringPiece& argument_name,
1139 std::vector<T>* output) {
1140 const Tensor* argument_t;
1141 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1142 if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1143 return errors::InvalidArgument(argument_name, " must be a vector");
1144 }
1145 int size = argument_t->vec<T>().size();
1146 output->reserve(size);
1147 for (int i = 0; i < size; ++i) {
1148 output->push_back(argument_t->vec<T>()(i));
1149 }
1150 return Status::OK();
1151 }
1152
1153 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1154 // graph execution engine.
1155 class DatasetOpKernel : public OpKernel {
1156 public:
DatasetOpKernel(OpKernelConstruction * ctx)1157 explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1158
1159 void Compute(OpKernelContext* ctx) final;
1160
1161 // Indicates whether the given op corresponds to an op whose kernels subclass
1162 // the `DatasetOpKernel` class.
1163 static bool IsDatasetOp(const OpDef* op_def);
1164
1165 string TraceString(const OpKernelContext& ctx, bool verbose) const override;
1166
1167 protected:
1168 // Subclasses should implement this method. It will be called during Compute
1169 // execution.
1170 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1171 };
1172
1173 // Encapsulates the work required to plug unary Datasets into the core
1174 // TensorFlow graph execution engine.
1175 class UnaryDatasetOpKernel : public DatasetOpKernel {
1176 public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1177 explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1178 : DatasetOpKernel(ctx) {}
1179
1180 protected:
1181 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1182 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1183 DatasetBase** output) = 0;
1184 };
1185
1186 // Encapsulates the work required to plug binary Datasets into the core
1187 // TensorFlow graph execution engine.
1188 class BinaryDatasetOpKernel : public DatasetOpKernel {
1189 public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1190 explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1191 : DatasetOpKernel(ctx) {}
1192
1193 protected:
1194 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1195 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1196 DatasetBase* another_input,
1197 DatasetBase** output) = 0;
1198 };
1199
1200 // A simple background worker that executes closures asynchronously and without
1201 // blocking.
1202 //
1203 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1204 // to avoid blocking an executor thread that may be required by the blocking
1205 // work.
1206 //
1207 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1208 // purpose because its current implementation (in Eigen) uses a finite-length
1209 // queue and will block the caller when full. This can lead to deadlock under
1210 // heavy load. Since the number of concurrent work items in each user of a
1211 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1212 // overhead is tolerable.
1213 class BackgroundWorker {
1214 public:
1215 BackgroundWorker(Env* env, const char* name);
1216
1217 ~BackgroundWorker();
1218
1219 void Schedule(std::function<void()> work_item);
1220
1221 private:
1222 void WorkerLoop();
1223
1224 Env* const env_;
1225 const char* const name_;
1226
1227 std::unique_ptr<Thread> thread_;
1228 mutex mu_;
1229 condition_variable cond_var_;
1230 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1231 std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_);
1232 };
1233
1234 // Registry of names of ops whose kernels subclass the `DatasetOpKernel` class.
1235 class DatasetOpRegistry {
1236 public:
1237 // Registers the op name.
1238 static void Register(const string& op_name);
1239
1240 // Checks whether the given op name has been registered.
1241 static bool IsRegistered(const string& op_name);
1242 };
1243
1244 // Helper class to register dataset op name.
1245 class DatasetOpRegistrar {
1246 public:
DatasetOpRegistrar(const string & op_name)1247 explicit DatasetOpRegistrar(const string& op_name) {
1248 DatasetOpRegistry::Register(op_name);
1249 }
1250 };
1251
1252 // Macro that can be used to register an op name of a dataset op.
1253 #define REGISTER_DATASET_OP_NAME(op_name) \
1254 REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, op_name)
1255
1256 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, op_name) \
1257 REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name)
1258
1259 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name) \
1260 static ::tensorflow::data::DatasetOpRegistrar \
1261 registrar__body__##ctr##__object(op_name)
1262
1263 } // namespace data
1264 } // namespace tensorflow
1265
1266 #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1267