• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/model.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/thread_factory.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/framework/variant_encode_decode.h"
35 #include "tensorflow/core/framework/variant_tensor_data.h"
36 #include "tensorflow/core/lib/core/threadpool.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/cpu_info.h"
40 #include "tensorflow/core/platform/tracing.h"
41 
42 // Polymorphic datasets should support all primitive TensorFlow
43 // types. Use this macro to expand `m(T)` once for each primitive type
44 // `T`, e.g. to build a `switch` statement.
45 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
46 
47 namespace tensorflow {
48 
49 // Forward declarations to avoid introducing a dependency on headers in
50 // "tensorflow/core/graph/...".
51 class GraphDefBuilder;
52 class Node;
53 
54 namespace data {
55 
56 constexpr int kInfiniteCardinality = -1;
57 constexpr int kUnknownCardinality = -2;
58 
59 class DatasetBase;
60 class SerializationContext;
61 
62 // Interface for reading values from a key-value store.
63 // Used for restoring iterator state.
64 class IteratorStateReader {
65  public:
66   virtual Status ReadScalar(StringPiece key, int64* val) = 0;
67   virtual Status ReadScalar(StringPiece key, string* val) = 0;
68   virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
69   virtual bool Contains(StringPiece key) = 0;
70 
~IteratorStateReader()71   virtual ~IteratorStateReader() {}
72 };
73 
74 // Interface for writing values to a key-value store.
75 // Used for saving iterator state.
76 class IteratorStateWriter {
77  public:
78   virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
79   virtual Status WriteScalar(StringPiece key, const string& val) = 0;
80   virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
81 
~IteratorStateWriter()82   virtual ~IteratorStateWriter() {}
83 };
84 
85 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
86 class GraphDefBuilderWrapper {
87  public:
GraphDefBuilderWrapper(GraphDefBuilder * b)88   explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
89 
90   // Adds a Const node with scalar value to the Graph.
91   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
92   // non-null if the method returns with an OK status.
93   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
94   template <typename T>
AddScalar(const T & val,Node ** output)95   Status AddScalar(const T& val, Node** output) {
96     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
97     val_t.scalar<T>()() = val;
98     AddTensorInternal(val_t, output);
99     if (*output == nullptr) {
100       return errors::Internal("AddScalar: Failed to build Const op.");
101     }
102     return Status::OK();
103   }
104 
105   // Adds a Const node with vector value to the Graph.
106   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
107   // non-null if the method returns with an OK status.
108   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
109   // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
110   template <typename T>
AddVector(const std::vector<T> & val,Node ** output)111   Status AddVector(const std::vector<T>& val, Node** output) {
112     Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
113                           TensorShape({static_cast<int64>(val.size())}));
114     for (int i = 0; i < val.size(); i++) {
115       val_t.flat<T>()(i) = val[i];
116     }
117     AddTensorInternal(val_t, output);
118     if (*output == nullptr) {
119       return errors::Internal("AddVector: Failed to build Const op.");
120     }
121     return Status::OK();
122   }
123 
124   // Adds a `Const` node for the given tensor value to the graph.
125   //
126   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
127   // non-null if the method returns with an OK status. The returned `Node`
128   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)129   Status AddTensor(const Tensor& val, Node** output) {
130     AddTensorInternal(val, output);
131     if (*output == nullptr) {
132       return errors::Internal("AddTensor: Failed to build Const op.");
133     }
134     return Status::OK();
135   }
136 
137   // Adds a `Placeholder` node for the given tensor value to the graph.
138   //
139   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
140   // non-null if the method returns with an OK status. The returned `Node`
141   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)142   Status AddPlaceholder(const Tensor& val, Node** output) {
143     AddPlaceholderInternal(val, output);
144     if (*output == nullptr) {
145       return errors::Internal(
146           "AddPlaceholder: Failed to build Placeholder op.");
147     }
148     return Status::OK();
149   }
150 
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)151   Status AddDataset(const DatasetBase* dataset,
152                     const std::vector<Node*>& inputs, Node** output) {
153     return AddDataset(dataset, inputs, {}, output);
154   }
155 
156   // Adds a node corresponding to the `DatasetType` to the Graph.
157   // Return value of `DatasetType::op_name()` is used as the op type for the
158   // node.
159   // Values for the output_types and output_shapes node attributes are also
160   // written if those attributes are defined in the OpDef.
161   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
162   // non-null if the method returns with an OK status.
163   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)164   Status AddDataset(const DatasetBase* dataset,
165                     const std::vector<Node*>& inputs,
166                     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
167                     Node** output) {
168     std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
169     for (size_t i = 0; i < inputs.size(); i++) {
170       enumerated_inputs[i] = std::make_pair(i, inputs[i]);
171     }
172     return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
173   }
174 
175   Status AddDataset(
176       const DatasetBase* dataset,
177       const std::vector<std::pair<size_t, Node*>>& inputs,
178       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
179       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
180       Node** output);
181 
182   // Adds a user-defined function with name `function_name` to the graph and
183   // recursively adds all functions it references. If a function with a matching
184   // name has already been added, returns with OK status. If a user-defined with
185   // name `function_name` is not found in the context's function library,
186   // returns an InvalidArgumentError. If the function with name `function_name`
187   // or any of its dependent functions are stateful, and the context does not
188   // explicitly permit stateful functions, returns an InvalidArgument error.
189   Status AddFunction(SerializationContext* ctx, const string& function_name);
190 
191   template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)192   void BuildAttrValue(const T& value, AttrValue* attr) {
193     SetAttrValue(value, attr);
194   }
195 
196  private:
197   void AddPlaceholderInternal(const Tensor& val, Node** output);
198   void AddTensorInternal(const Tensor& val, Node** output);
199 
EnsureFunctionIsStateless(const FunctionLibraryDefinition & flib_def,const string & function_name)200   Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def,
201                                    const string& function_name) const {
202     const FunctionDef* function_def = flib_def.Find(function_name);
203     if (!function_def) {
204       return errors::InvalidArgument("Unable to find FunctionDef for ",
205                                      function_name, " in registry.");
206     }
207     for (const NodeDef& node_def : function_def->node_def()) {
208       const OpDef* op_def;
209       TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def));
210       // TODO(b/65524810): Hack to allow functions to capture Dataset op
211       // nodes needed for FlatMap. Currently, source datasets nodes have been
212       // marked stateful to avoid constant folding since we do not have a
213       // good way of serializing them.
214       if (IsOpWhitelisted(op_def)) {
215         continue;
216       }
217       if (op_def->is_stateful()) {
218         return errors::InvalidArgument(
219             "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ",
220             "in function ", function_name, " is stateful. ",
221             "Saving stateful functions is not supported yet.");
222       }
223     }
224     return Status::OK();
225   }
226 
227   // Returns whether an op has been whitelisted for use inside map_fns.
228   // Uses a heuristic to whitelist source dataset ops which have been
229   // marked stateful due to b/65524810.
230   // Also looks up the `op_def->name` in the global
231   // `WhitelistedStatefulOpRegistry`.
IsOpWhitelisted(const OpDef * op_def)232   bool IsOpWhitelisted(const OpDef* op_def) const {
233     return (str_util::EndsWith(op_def->name(), "Dataset") &&
234             op_def->output_arg_size() == 1 &&
235             op_def->output_arg(0).type() == DT_VARIANT) ||
236            WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
237   }
238 
239   bool HasAttr(const string& op_type_name, const string& attr_name) const;
240 
HasAttr(const OpDef * op_def,const string & attr_name)241   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
242     for (auto attr : op_def->attr()) {
243       if (attr.name() == attr_name) {
244         return true;
245       }
246     }
247     return false;
248   }
249 
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value)250   Status AddAttrFunctions(SerializationContext* ctx,
251                           const AttrValue& attr_value) {
252     if (attr_value.has_func()) {
253       TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
254     } else if (attr_value.has_list()) {
255       for (const NameAttrList& name_attr_list : attr_value.list().func()) {
256         TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
257       }
258     }
259     return Status::OK();
260   }
261 
262   GraphDefBuilder* b_;
263 };
264 
265 class StatsAggregator;
266 class FunctionHandleCache;
267 
268 // A cut-down version of `OpKernelContext` for running computations in
269 // iterators. Note that we cannot simply use `OpKernelContext` here because we
270 // might run computation in an iterator whose lifetime is not nested within the
271 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
272 //
273 // TODO(mrry): We're making some daring assumptions about the lifetime of the
274 // runner passed in here. A runner will be deleted when the original step ends,
275 // but all existing runners only close over session-lifetime (or longer-lived)
276 // state, so we can make a copy of the function. There's nothing in the
277 // definition of the API from which we took the runner to guarantee that what we
278 // are doing is safe. We should formalize the properties here.
279 class IteratorContext {
280  public:
281   struct Params {
ParamsParams282     explicit Params(IteratorContext* ctx)
283         : allocator_getter(ctx->allocator_getter()),
284           env(ctx->env()),
285           function_library(ctx->function_library()),
286           lib(ctx->lib()),
287           function_handle_cache(ctx->function_handle_cache()),
288           resource_mgr(ctx->resource_mgr()),
289           model(ctx->model()),
290           runner(*(ctx->runner())),
291           runner_threadpool_size(ctx->runner_threadpool_size()),
292           stats_aggregator(ctx->stats_aggregator()),
293           thread_factory(ctx->thread_factory()) {}
294 
ParamsParams295     explicit Params(OpKernelContext* ctx)
296         : env(ctx->env()),
297           lib(ctx->function_library()),
298           runner(*(ctx->runner())) {
299       // NOTE: need reinterpret_cast because function.h forward-declares Device.
300       DeviceBase* device =
301           reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
302       allocator_getter = [device](AllocatorAttributes attrs) {
303         return device->GetAllocator(attrs);
304       };
305       thread::ThreadPool* thread_pool =
306           ctx->device()->tensorflow_device_thread_pool();
307       if (thread_pool) {
308         runner_threadpool_size = thread_pool->NumThreads();
309       } else {
310         runner_threadpool_size = port::NumSchedulableCPUs();
311       }
312     }
313 
314     // The Allocator to be used to allocate the output of an iterator.
315     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
316 
317     // Interface to operating system functionality.
318     Env* env = nullptr;
319 
320     // The FunctionLibraryDefinition used to look up user-defined functions.
321     std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr;
322 
323     // The FunctionLibraryRuntime object to be used to make function calls.
324     FunctionLibraryRuntime* lib = nullptr;
325 
326     // A FunctionHandleCache that owns all the function handles. Not owned.
327     FunctionHandleCache* function_handle_cache = nullptr;
328 
329     // A resource manager for storing dataset-related state, e.g. random
330     // seeds or cached tensors. Not owned.
331     ResourceMgr* resource_mgr = nullptr;
332 
333     // If non-null, identifies the object used for performance modeling.
334     std::shared_ptr<model::Model> model = nullptr;
335 
336     // Function call support.
337     std::function<void(std::function<void()>)> runner = nullptr;
338 
339     // Number of threads used for executing user-defined functions.
340     int32 runner_threadpool_size = 0;
341 
342     // The `StatsAggregator` object to record statistics about the iterator.
343     std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
344 
345     // A `ThreadFactory` for creating threads used by iterators to perform
346     // blocking work.
347     std::shared_ptr<ThreadFactory> thread_factory = nullptr;
348   };
349 
IteratorContext(IteratorContext * ctx)350   explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
351 
IteratorContext(OpKernelContext * ctx)352   explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
353 
IteratorContext(Params params)354   explicit IteratorContext(Params params) : params_(std::move(params)) {}
355 
allocator(AllocatorAttributes attrs)356   Allocator* allocator(AllocatorAttributes attrs) {
357     return params_.allocator_getter(attrs);
358   }
359 
allocator_getter()360   std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
361     return params_.allocator_getter;
362   }
363 
env()364   Env* env() const { return params_.env; }
365 
function_library()366   std::shared_ptr<const FunctionLibraryDefinition> function_library() {
367     return params_.function_library;
368   }
369 
lib()370   FunctionLibraryRuntime* lib() { return params_.lib; }
371 
function_handle_cache()372   FunctionHandleCache* function_handle_cache() {
373     return params_.function_handle_cache;
374   }
375 
resource_mgr()376   ResourceMgr* resource_mgr() { return params_.resource_mgr; }
377 
model()378   const std::shared_ptr<model::Model>& model() { return params_.model; }
379 
runner()380   std::function<void(std::function<void()>)>* runner() {
381     return &params_.runner;
382   }
383 
thread_factory()384   const std::shared_ptr<ThreadFactory>& thread_factory() {
385     return params_.thread_factory;
386   }
387 
StartThread(const string & name,std::function<void ()> fn)388   std::unique_ptr<Thread> StartThread(const string& name,
389                                       std::function<void()> fn) {
390     if (params_.thread_factory) {
391       return params_.thread_factory->StartThread(name, std::move(fn));
392     } else {
393       return absl::WrapUnique(
394           Env::Default()->StartThread({}, name, std::move(fn)));
395     }
396   }
397 
runner_threadpool_size()398   int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
399 
stats_aggregator()400   std::shared_ptr<StatsAggregator> stats_aggregator() {
401     return params_.stats_aggregator;
402   }
403 
params()404   Params params() { return params_; }
405 
406  private:
407   Params params_;
408 };
409 
410 // Aggregates runtime support needed for dataset and iterator serialization.
411 class SerializationContext {
412  public:
413   struct Params {
414     const FunctionLibraryDefinition* flib_def = nullptr;           // Not owned.
415     std::vector<std::pair<string, Tensor>>* input_list = nullptr;  // Not owned.
416     bool optimization_only = false;
417   };
418 
SerializationContext(Params params)419   explicit SerializationContext(Params params) : params_(std::move(params)) {}
420 
flib_def()421   const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
422 
input_list()423   std::vector<std::pair<string, Tensor>>* input_list() {
424     return params_.input_list;
425   }
426 
optimization_only()427   bool optimization_only() { return params_.optimization_only; }
428 
429  private:
430   Params params_;
431 
432   TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
433 };
434 
435 // Represents the current position in a range of outputs, where the
436 // range of outputs is typically represented by an `DatasetBase`,
437 // defined below.
438 class IteratorBase {
439  public:
~IteratorBase()440   virtual ~IteratorBase() {
441     for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
442       (*rit)();
443     }
444   }
445 
446   // Gets the next output from the range that this iterator is traversing.
447   //
448   // If at least one output remains in this iterator's range, that
449   // output will be stored in `*out_tensors` and `false` will be
450   // stored in `*end_of_sequence`.
451   //
452   // If no more outputs remain in this iterator's range, `true` will
453   // be stored in `*end_of_sequence`, and the content of
454   // `*out_tensors` will be undefined.
455   //
456   // This method is thread-safe.
457   //
458   // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
459   // potentially remove this method.
460   virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
461                          bool* end_of_sequence) = 0;
462 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)463   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
464                  bool* end_of_sequence) {
465     return GetNext(&ctx, out_tensors, end_of_sequence);
466   }
467 
468   // Returns a vector of DataType values, representing the respective
469   // element types of each tuple component in the outputs of this
470   // iterator.
471   virtual const DataTypeVector& output_dtypes() const = 0;
472 
473   // Returns a vector of tensor shapes, representing the respective
474   // (and possibly partially defined) shapes of each tuple component
475   // in the outputs of this iterator.
476   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
477 
478   // Returns a string that identifies the sequence of iterators leading up to
479   // this iterator.
480   virtual const string& prefix() const = 0;
481 
482   // Performs initialization that needs to happen outside of a constructor to
483   // properly propagate errors.
Initialize(IteratorContext * ctx)484   virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
485 
486   // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)487   virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
488     return SaveInternal(writer);
489   }
490 
491   // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)492   virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
493     return RestoreInternal(ctx, reader);
494   }
495 
496  protected:
497   // Returns a node that models this iterator.
498   virtual std::shared_ptr<model::Node> CreateNode(
499       IteratorContext* ctx, model::Node::Args args) const = 0;
500 
501   // This is needed so that sub-classes of IteratorBase can call
502   // `SaveInternal` on their input iterators.
SaveInput(IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)503   Status SaveInput(IteratorStateWriter* writer,
504                    const std::unique_ptr<IteratorBase>& input) {
505     return input->SaveInternal(writer);
506   }
507 
508   // This is needed so that sub-classes of IteratorBase can call
509   // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)510   Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
511                       const std::unique_ptr<IteratorBase>& input) {
512     return input->RestoreInternal(ctx, reader);
513   }
514 
515   // Saves the state of this iterator recursively.
SaveInternal(IteratorStateWriter * writer)516   virtual Status SaveInternal(IteratorStateWriter* writer) {
517     return errors::Unimplemented("SaveInternal");
518   }
519 
520   // Restores the state of this iterator recursively.
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)521   virtual Status RestoreInternal(IteratorContext* ctx,
522                                  IteratorStateReader* reader) {
523     return errors::Unimplemented("RestoreInternal");
524   }
525 
526  private:
527   friend class DatasetBase;  // for access to `AddCleanupFunction`
528   friend class DatasetBaseIterator;  // for access to `node_`
529 
530   // Registers a cleanup function to be called upon object destruction.
531   //
532   // Registered functions are invoked in the reserve order of registration.
AddCleanupFunction(std::function<void ()> && cleanup_fn)533   void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
534     cleanup_fns_.push_back(std::move(cleanup_fn));
535   }
536 
537   // Associates the given performance modeling `Node` with this iterator.
SetNode(std::shared_ptr<model::Node> node)538   void SetNode(std::shared_ptr<model::Node> node) { node_ = node.get(); }
539 
540   std::vector<std::function<void()>> cleanup_fns_;
541   model::Node* node_ = nullptr;  // Not owned.
542 };
543 
544 // Represents runtime information needed to construct a dataset.
545 class DatasetContext {
546  public:
547   struct Params {
548     string type_string;  // op type name of this dataset.
549     string node_name;    // graph node name of this dataset op, uniquely
550                          // identifying the dataset in the graph.
551   };
552 
DatasetContext(Params params)553   explicit DatasetContext(Params params) : params_(std::move(params)) {}
554 
DatasetContext(OpKernelContext * ctx)555   explicit DatasetContext(OpKernelContext* ctx) {
556     params_.type_string = ctx->op_kernel().type_string();
557     params_.node_name = ctx->op_kernel().name();
558   }
559 
type_string()560   const string& type_string() const { return params_.type_string; }
node_name()561   const string& node_name() const { return params_.node_name; }
562 
563  private:
564   Params params_;
565 };
566 
567 // Returns the number of bytes allocated for the given tensor.
568 int64 GetAllocatedBytes(const std::vector<Tensor>& element);
569 
570 // Validates and extracts a `DatasetBase` object from `tensor`.
571 //
572 // `tensor` must have been written by a call to SetVariantTensorToDataset().
573 //
574 // The retrieved pointer is a borrowed reference to the dataset, which is owned
575 // by the tensor. The consumer must either acquire its own reference to the
576 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
577 // destroyed or mutated while the retrieved pointer is in use.
578 Status GetDatasetFromVariantTensor(const Tensor& tensor,
579                                    DatasetBase** out_dataset);
580 
581 // Stores a `DatasetBase` object in `tensor`.
582 //
583 // The ownership of `dataset` is transferred to `tensor`.
584 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
585 
586 // Represents a (potentially infinite) range of outputs, where each
587 // output is a tuple of tensors.
588 class DatasetBase : public core::RefCounted {
589  public:
590   // Key for storing the Dataset graph in the serialized format.
591   TF_EXPORT static const char kDatasetGraphKey[];
592 
593   // Key for storing the output node of the Dataset graph in the serialized
594   // format.
595   TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
596 
DatasetBase(DatasetContext && ctx)597   explicit DatasetBase(DatasetContext&& ctx)
598       : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
599 
600   // Op type name of this dataset.
type_string()601   const string& type_string() const { return type_string_; }
602 
603   // Graph node name of this dataset op, uniquely identifying the dataset in
604   // the graph.
node_name()605   const string& node_name() const { return node_name_; }
606 
607   // Returns a new iterator for iterating over the range of elements in
608   // this dataset.
609   //
610   // This method may be called multiple times on the same instance,
611   // and the resulting iterators will have distinct state. Each
612   // iterator will traverse all elements in this dataset from the
613   // start.
614   //
615   // The prefix identifies the sequence of iterators leading up to the newly
616   // created iterator.
MakeIterator(IteratorContext * ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)617   Status MakeIterator(IteratorContext* ctx, const string& output_prefix,
618                       std::unique_ptr<IteratorBase>* iterator) const {
619     *iterator = MakeIteratorInternal(output_prefix);
620     if (const auto& model = ctx->model()) {
621       const string& prefix = (*iterator)->prefix();
622       (*iterator)->SetNode(model->AddNode(MakeNodeFactory(ctx, iterator->get()),
623                                           prefix, output_prefix));
624       (*iterator)->AddCleanupFunction(
625           [model, prefix]() { model->RemoveNode(prefix); });
626     }
627     return (*iterator)->Initialize(ctx);
628   }
629 
MakeIterator(IteratorContext && ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)630   Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,
631                       std::unique_ptr<IteratorBase>* iterator) const {
632     return MakeIterator(&ctx, output_prefix, iterator);
633   }
634 
635   // Returns a vector of DataType values, representing the respective
636   // element types of each tuple component in the outputs of this
637   // dataset.
638   virtual const DataTypeVector& output_dtypes() const = 0;
639 
640   // Returns a vector of tensor shapes, representing the respective
641   // (and possibly partially defined) shapes of each tuple component
642   // in the outputs of this dataset.
643   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
644 
645   // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()646   virtual int64 AllocatedBytes() const { return 0; }
647 
648   // Returns the cardinality of this dataset.
Cardinality()649   virtual int64 Cardinality() const { return kUnknownCardinality; }
650 
651   // A human-readable debug string for this dataset.
652   virtual string DebugString() const = 0;
653 
654   // Serializes the dataset and writes it to the `writer`.
655   virtual Status Save(SerializationContext* ctx,
656                       IteratorStateWriter* writer) const;
657 
658  protected:
659   friend class DatasetToGraphOp;  // For access to graph related members.
660 
661   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
662    public:
DatasetGraphDefBuilder(GraphDefBuilder * b)663     DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
664     Status AddInputDataset(SerializationContext* ctx,
665                            const DatasetBase* dataset, Node** output);
666   };
667 
668   virtual Status AsGraphDefInternal(SerializationContext* ctx,
669                                     DatasetGraphDefBuilder* b,
670                                     Node** node) const = 0;
671 
672   virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
673       const string& prefix) const = 0;
674 
675  private:
676   // Returns a factory for nodes that represent the given iterator.
MakeNodeFactory(IteratorContext * ctx,IteratorBase * iterator)677   static model::Node::Factory MakeNodeFactory(IteratorContext* ctx,
678                                               IteratorBase* iterator) {
679     return [ctx, iterator](model::Node::Args args) {
680       return iterator->CreateNode(ctx, std::move(args));
681     };
682   }
683 
684   const string type_string_;
685   const string node_name_;
686 };
687 
688 // Represents an iterator that is associated with a particular dataset.
689 class DatasetBaseIterator : public IteratorBase {
690  public:
691   struct BaseParams {
692     // Owns one reference on the shared dataset object.
693     const DatasetBase* dataset;
694 
695     // Identifies the sequence of iterators leading up to this iterator.
696     const string prefix;
697   };
698 
DatasetBaseIterator(const BaseParams & params)699   explicit DatasetBaseIterator(const BaseParams& params) : params_(params) {
700     params_.dataset->Ref();
701   }
702 
~DatasetBaseIterator()703   ~DatasetBaseIterator() override { params_.dataset->Unref(); }
704 
705   // The sequence of iterators leading up to this iterator.
prefix()706   const string& prefix() const override { return params_.prefix; }
707 
output_dtypes()708   const DataTypeVector& output_dtypes() const override {
709     return params_.dataset->output_dtypes();
710   }
711 
output_shapes()712   const std::vector<PartialTensorShape>& output_shapes() const override {
713     return params_.dataset->output_shapes();
714   }
715 
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)716   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
717                  bool* end_of_sequence) final {
718     tracing::ScopedActivity activity(params_.prefix);
719     RecordStart(ctx, true /* stop_output */);
720     Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
721     if (s.ok() && !*end_of_sequence) RecordElement(ctx);
722     RecordStop(ctx, true /* start_output */);
723     if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
724       s = errors::Internal(
725           "Iterator \"", params_.prefix,
726           "\" returned OutOfRange without setting `*end_of_sequence`. This "
727           "indicates that an error may have occurred. Original message: ",
728           s.error_message());
729       LOG(ERROR) << s;
730     }
731     return s;
732   }
733 
Save(SerializationContext * ctx,IteratorStateWriter * writer)734   Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
735     TF_RETURN_IF_ERROR(params_.dataset->Save(ctx, writer));
736     return IteratorBase::Save(ctx, writer);
737   }
738 
739  protected:
740   // Internal implementation of GetNext that is wrapped in tracing logic.
741   virtual Status GetNextInternal(IteratorContext* ctx,
742                                  std::vector<Tensor>* out_tensors,
743                                  bool* end_of_sequence) = 0;
744 
full_name(const string & name)745   string full_name(const string& name) const {
746     return strings::StrCat(params_.prefix, ":", name);
747   }
748 
749   // By default we model iterators using an unknown node, which acts as
750   // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)751   std::shared_ptr<model::Node> CreateNode(
752       IteratorContext* ctx, model::Node::Args args) const override {
753     return model::MakeUnknownNode(std::move(args));
754   }
755 
756   // When modeling is enabled, this method records the fact that this iterator
757   // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)758   void RecordBufferDequeue(IteratorContext* ctx,
759                            const std::vector<Tensor>& element) {
760     if (collect_resource_usage(ctx)) {
761       node_->add_buffered_bytes(-GetAllocatedBytes(element));
762     }
763   }
764 
765   // When modeling is enabled, this method records the fact that this iterator
766   // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)767   void RecordBufferEnqueue(IteratorContext* ctx,
768                            const std::vector<Tensor>& element) {
769     if (collect_resource_usage(ctx)) {
770       node_->add_buffered_bytes(GetAllocatedBytes(element));
771     }
772   }
773 
774   // When modeling is enabled, this method records the fact that this iterator
775   // has produced an element.
RecordElement(IteratorContext * ctx)776   void RecordElement(IteratorContext* ctx) {
777     if (node_) {
778       node_->record_element();
779     }
780   }
781 
782   // When modeling is enabled, this method records the fact that a thread of
783   // this iterator has started work.
784   void RecordStart(IteratorContext* ctx, bool stop_output = false) {
785     if (collect_resource_usage(ctx)) {
786       int64 now_nanos = Env::Default()->NowNanos();
787       if (stop_output && node_->output()) {
788         node_->output()->record_stop(now_nanos);
789       }
790       node_->record_start(now_nanos);
791     }
792   }
793 
794   // When modeling is enabled, this method records the fact that a thread of
795   // this iterator has stopped work.
796   void RecordStop(IteratorContext* ctx, bool start_output = false) {
797     if (collect_resource_usage(ctx)) {
798       int64 now_nanos = Env::Default()->NowNanos();
799       node_->record_stop(now_nanos);
800       if (start_output && node_->output()) {
801         node_->output()->record_start(now_nanos);
802       }
803     }
804   }
805 
806  private:
collect_resource_usage(IteratorContext * ctx)807   inline bool collect_resource_usage(IteratorContext* ctx) {
808     auto model = ctx->model();
809     return model && model->collect_resource_usage() && node_;
810   }
811 
812   BaseParams params_;
813 };
814 
815 // Represents an iterator that is associated with a particular dataset
816 // with a particular type.
817 template <class DatasetType>
818 class DatasetIterator : public DatasetBaseIterator {
819  public:
820   struct Params {
821     // Borrowed pointer to the dataset.
822     const DatasetType* dataset;
823 
824     // Identifies the sequence of iterators leading up to this iterator.
825     const string prefix;
826   };
827 
DatasetIterator(const Params & params)828   explicit DatasetIterator(const Params& params)
829       : DatasetBaseIterator({params.dataset, params.prefix}),
830         typed_dataset_(params.dataset) {}
831 
832   // The dataset from which this iterator was created.
dataset()833   const DatasetType* dataset() const { return typed_dataset_; }
834 
835  protected:
836   virtual Status GetNextInternal(IteratorContext* ctx,
837                                  std::vector<Tensor>* out_tensors,
838                                  bool* end_of_sequence) = 0;
839 
840  private:
841   const DatasetType* const typed_dataset_;  // Not owned.
842 };
843 
844 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
845 // graph execution engine.
846 class DatasetOpKernel : public OpKernel {
847  public:
DatasetOpKernel(OpKernelConstruction * ctx)848   DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
849   void Compute(OpKernelContext* ctx) final;
850 
851  protected:
852   // Subclasses should implement this method. It will be called during Compute
853   // execution.
854   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
855 
856   template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)857   Status ParseScalarArgument(OpKernelContext* ctx,
858                              const StringPiece& argument_name, T* output) {
859     const Tensor* argument_t;
860     TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
861     if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
862       return errors::InvalidArgument(argument_name, " must be a scalar");
863     }
864     *output = argument_t->scalar<T>()();
865     return Status::OK();
866   }
867 
868   template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)869   Status ParseVectorArgument(OpKernelContext* ctx,
870                              const StringPiece& argument_name,
871                              std::vector<T>* output) {
872     const Tensor* argument_t;
873     TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
874     if (!TensorShapeUtils::IsVector(argument_t->shape())) {
875       return errors::InvalidArgument(argument_name, " must be a vector");
876     }
877     int size = argument_t->vec<T>().size();
878     output->reserve(size);
879     for (int i = 0; i < size; ++i) {
880       output->push_back(argument_t->vec<T>()(i));
881     }
882     return Status::OK();
883   }
884 };
885 
886 // Encapsulates the work required to plug unary Datasets into the core
887 // TensorFlow graph execution engine.
888 class UnaryDatasetOpKernel : public DatasetOpKernel {
889  public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)890   UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
891 
892  protected:
893   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
894   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
895                            DatasetBase** output) = 0;
896 };
897 
898 // Encapsulates the work required to plug binary Datasets into the core
899 // TensorFlow graph execution engine.
900 class BinaryDatasetOpKernel : public DatasetOpKernel {
901  public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)902   BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
903 
904  protected:
905   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
906   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
907                            DatasetBase* another_input,
908                            DatasetBase** output) = 0;
909 };
910 
911 // A simple background worker that executes closures asynchronously and without
912 // blocking.
913 //
914 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
915 // to avoid blocking an executor thread that may be required by the blocking
916 // work.
917 //
918 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
919 // purpose because its current implementation (in Eigen) uses a finite-length
920 // queue and will block the caller when full. This can lead to deadlock under
921 // heavy load. Since the number of concurrent work items in each user of a
922 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
923 // overhead is tolerable.
924 class BackgroundWorker {
925  public:
926   BackgroundWorker(Env* env, const string& name);
927 
928   ~BackgroundWorker();
929 
930   void Schedule(std::function<void()> work_item);
931 
932  private:
933   void WorkerLoop();
934 
935   std::unique_ptr<Thread> thread_;
936   mutex mu_;
937   condition_variable cond_var_;
938   bool cancelled_ GUARDED_BY(mu_) = false;
939   std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
940 };
941 
942 }  // namespace data
943 
944 // TODO(b/114112161): Remove these aliases when all users have moved over to the
945 // `tensorflow::data` namespace.
946 using data::DatasetBase;
947 using data::DatasetContext;
948 using data::DatasetIterator;
949 using data::DatasetOpKernel;
950 using data::IteratorBase;
951 using data::IteratorContext;
952 using data::IteratorStateReader;
953 using data::IteratorStateWriter;
954 using data::SerializationContext;
955 using data::UnaryDatasetOpKernel;
956 
957 }  // namespace tensorflow
958 
959 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
960