• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17 
18 #include <memory>
19 
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant_encode_decode.h"
30 #include "tensorflow/core/framework/variant_tensor_data.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/tracing.h"
34 
35 // Polymorphic datasets should support all primitive TensorFlow
36 // types. Use this macro to expand `m(T)` once for each primitive type
37 // `T`, e.g. to build a `switch` statement.
38 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
39 
40 namespace tensorflow {
41 
42 // Interface for reading values from a key-value store.
43 // Used for restoring iterator state.
44 class IteratorStateReader {
45  public:
46   virtual Status ReadScalar(StringPiece key, int64* val) = 0;
47   virtual Status ReadScalar(StringPiece key, string* val) = 0;
48   virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
49   virtual bool Contains(StringPiece key) = 0;
50 
~IteratorStateReader()51   virtual ~IteratorStateReader() {}
52 };
53 
54 // Interface for writing values to a key-value store.
55 // Used for saving iterator state.
56 class IteratorStateWriter {
57  public:
58   virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
59   virtual Status WriteScalar(StringPiece key, const string& val) = 0;
60   virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
61 
~IteratorStateWriter()62   virtual ~IteratorStateWriter() {}
63 };
64 
65 // Forward declarations to avoid introducing a dependency on headers in
66 // "tensorflow/core/graph/...".
67 class GraphDefBuilder;
68 class GraphDatasetBase;
69 class Node;
70 
71 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
72 class GraphDefBuilderWrapper {
73  public:
GraphDefBuilderWrapper(GraphDefBuilder * b)74   explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
75 
76   // Adds a Const node with scalar value to the Graph.
77   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
78   // non-null if the method returns with an OK status.
79   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
80   template <typename T>
AddScalar(const T & val,Node ** output)81   Status AddScalar(const T& val, Node** output) {
82     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
83     val_t.scalar<T>()() = val;
84     AddTensorInternal(val_t, output);
85     if (*output == nullptr) {
86       return errors::Internal("AddScalar: Failed to build Const op.");
87     }
88     return Status::OK();
89   }
90 
91   // Adds a Const node with vector value to the Graph.
92   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
93   // non-null if the method returns with an OK status.
94   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
95   // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
96   template <typename T>
AddVector(const std::vector<T> & val,Node ** output)97   Status AddVector(const std::vector<T>& val, Node** output) {
98     Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
99                           TensorShape({static_cast<int64>(val.size())}));
100     for (int i = 0; i < val.size(); i++) {
101       val_t.flat<T>()(i) = val[i];
102     }
103     AddTensorInternal(val_t, output);
104     if (*output == nullptr) {
105       return errors::Internal("AddVector: Failed to build Const op.");
106     }
107     return Status::OK();
108   }
109 
110   // Adds a Const node with Tensor value to the Graph.
111   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
112   // non-null if the method returns with an OK status.
113   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
AddTensor(const Tensor & val,Node ** output)114   Status AddTensor(const Tensor& val, Node** output) {
115     AddTensorInternal(val, output);
116     if (*output == nullptr) {
117       return errors::Internal("AddTensor: Failed to build Const op.");
118     }
119     return Status::OK();
120   }
121 
AddDataset(const GraphDatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)122   Status AddDataset(const GraphDatasetBase* dataset,
123                     const std::vector<Node*>& inputs, Node** output) {
124     return AddDataset(dataset, inputs, {}, output);
125   }
126 
127   // Adds a node corresponding to the `DatasetType` to the Graph.
128   // Return value of `DatasetType::op_name()` is used as the op type for the
129   // node.
130   // Values for the output_types and output_shapes node attributes are also
131   // written if those attributes are defined in the OpDef.
132   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
133   // non-null if the method returns with an OK status.
134   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
AddDataset(const GraphDatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)135   Status AddDataset(const GraphDatasetBase* dataset,
136                     const std::vector<Node*>& inputs,
137                     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
138                     Node** output) {
139     std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
140     for (int i = 0; i < inputs.size(); i++) {
141       enumerated_inputs[i] = std::make_pair(i, inputs[i]);
142     }
143     return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
144   }
145 
146   Status AddDataset(
147       const GraphDatasetBase* dataset,
148       const std::vector<std::pair<size_t, Node*>>& inputs,
149       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
150       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
151       Node** output);
152 
153   // Adds a user-defined function with name `function_name` to the graph and
154   // recursively adds all functions it references. If a function with a matching
155   // name has already been added, returns with OK status. If a user-defined with
156   // name `function_name` is not found in the FunctionLibraryDefinition, returns
157   // an InvalidArgumentError. If the function with name `function_name` or any
158   // of its dependent functions are stateful, returns an InvalidArgument error.
159   Status AddFunction(OpKernelContext* ctx, const string& function_name);
160 
161   template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)162   void BuildAttrValue(const T& value, AttrValue* attr) {
163     SetAttrValue(value, attr);
164   }
165 
166  private:
167   void AddTensorInternal(const Tensor& val, Node** output);
168 
EnsureFunctionIsStateless(OpKernelContext * ctx,const string & function_name)169   Status EnsureFunctionIsStateless(OpKernelContext* ctx,
170                                    const string& function_name) const {
171     const FunctionLibraryDefinition* lib_def =
172         ctx->function_library()->GetFunctionLibraryDefinition();
173     const FunctionDef* function_def = lib_def->Find(function_name);
174     if (!function_def) {
175       return errors::InvalidArgument("Unable to find FunctionDef for ",
176                                      function_name, " in registry.");
177     }
178     for (const NodeDef& node_def : function_def->node_def()) {
179       const OpDef* op_def;
180       TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def));
181       // TODO(b/65524810): Hack to allow functions to capture Dataset op
182       // nodes needed for FlatMap. Currently, source datasets nodes have been
183       // marked stateful to avoid constant folding since we do not have a
184       // good way of serializing them.
185       if (IsOpWhitelisted(op_def)) {
186         continue;
187       }
188       if (op_def->is_stateful()) {
189         return errors::InvalidArgument(
190             "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ",
191             "in function ", function_name, " is stateful. ",
192             "Saving stateful functions is not supported yet.");
193       }
194     }
195     return Status::OK();
196   }
197 
198   // Returns whether an op has been whitelisted for use inside map_fns.
199   // Uses a heuristic to whitelist source dataset ops which have been
200   // marked stateful due to b/65524810.
201   // Also looks up the `op_def->name` in the global
202   // `WhitelistedStatefulOpRegistry`.
IsOpWhitelisted(const OpDef * op_def)203   bool IsOpWhitelisted(const OpDef* op_def) const {
204     return (StringPiece(op_def->name()).ends_with("Dataset") &&
205             op_def->output_arg_size() == 1 &&
206             op_def->output_arg(0).type() == DT_VARIANT) ||
207            dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
208                op_def->name());
209   }
210 
211   bool HasAttr(const string& op_type_name, const string& attr_name) const;
212 
HasAttr(const OpDef * op_def,const string & attr_name)213   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
214     for (auto attr : op_def->attr()) {
215       if (attr.name() == attr_name) {
216         return true;
217       }
218     }
219     return false;
220   }
221 
AddAttrFunctions(const AttrValue & attr_value,OpKernelContext * ctx)222   Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) {
223     if (attr_value.has_func()) {
224       TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
225     } else if (attr_value.has_list()) {
226       for (const NameAttrList& name_attr_list : attr_value.list().func()) {
227         TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
228       }
229     }
230     return Status::OK();
231   }
232 
233   GraphDefBuilder* b_;
234 };
235 
236 class StatsAggregator;
237 
238 // A cut-down version of OpKernelContext for running computations in
239 // iterators. Note that we cannot simply use OpKernelContext here
240 // because we might run computation in an iterator whose lifetime is
241 // not nested within the lifetime of a single OpKernelContext
242 // (e.g. asynchronous prefetching).
243 //
244 // TODO(mrry): We will probably need to support more of
245 // OpKernelContext here. For example, should allocation be handled by
246 // the IteratorContext?
247 // TODO(mrry): We're making some daring assumptions about the lifetime
248 // of the runner passed in here. A runner will be deleted when the original
249 // step ends, but all existing runners only close over session-lifetime (or
250 // longer-lived) state, so we can make a copy of the function. There's nothing
251 // in the definition of the API from which we took the runner to guarantee that
252 // what we are doing is safe. We should formalize the properties here.
253 class IteratorContext {
254  public:
255   struct Params {
256     // Interface to operating system functionality.
257     Env* env;
258 
259     // Function call support.
260     std::function<void(std::function<void()>)> runner = nullptr;
261 
262     // A function that returns the current `StatsAggregator` instance to be
263     // used when recording statistics about the iterator.
264     //
265     // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator`
266     // is a property of the `IteratorResource` (which this class does not know
267     // about), and (ii) it can change after the `IteratorContext` has been
268     // created. Better suggestions are welcome!
269     std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter =
270         nullptr;
271 
272     // The FunctionLibraryRuntime object to be used to make function calls.
273     FunctionLibraryRuntime* lib = nullptr;
274     std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr;
275 
276     // The Allocator to be used to allocate the output of an iterator.
277     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
278   };
279 
IteratorContext(Params params)280   explicit IteratorContext(Params params) : params_(std::move(params)) {}
281 
env()282   Env* env() const { return params_.env; }
283 
runner()284   std::function<void(std::function<void()>)>* runner() {
285     return &params_.runner;
286   }
287 
stats_aggregator()288   std::shared_ptr<StatsAggregator> stats_aggregator() {
289     if (params_.stats_aggregator_getter) {
290       return params_.stats_aggregator_getter();
291     } else {
292       return nullptr;
293     }
294   }
295 
function_library()296   std::shared_ptr<const FunctionLibraryDefinition> function_library() {
297     return params_.function_library;
298   }
299 
lib()300   FunctionLibraryRuntime* lib() { return params_.lib; }
301 
set_lib(FunctionLibraryRuntime * lib)302   void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; }
303 
allocator(AllocatorAttributes attrs)304   Allocator* allocator(AllocatorAttributes attrs) {
305     return params_.allocator_getter(attrs);
306   }
307 
308  private:
309   Params params_;
310 };
311 
312 // Represents the current position in a range of outputs, where the
313 // range of outputs is typically represented by an `DatasetBase`,
314 // defined below.
315 class IteratorBase {
316  public:
~IteratorBase()317   virtual ~IteratorBase() {}
318 
319   // Gets the next output from the range that this iterator is traversing.
320   //
321   // If at least one output remains in this iterator's range, that
322   // output will be stored in `*out_tensors` and `false` will be
323   // stored in `*end_of_sequence`.
324   //
325   // If no more outputs remain in this iterator's range, `true` will
326   // be stored in `*end_of_sequence`, and the content of
327   // `*out_tensors` will be undefined.
328   //
329   // This method is thread-safe.
330   //
331   // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
332   // potentially remove this method.
333   virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
334                          bool* end_of_sequence) = 0;
335 
336   // Returns a vector of DataType values, representing the respective
337   // element types of each tuple component in the outputs of this
338   // iterator.
339   virtual const DataTypeVector& output_dtypes() const = 0;
340 
341   // Returns a vector of tensor shapes, representing the respective
342   // (and possibly partially defined) shapes of each tuple component
343   // in the outputs of this iterator.
344   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
345 
346   // Saves the state of this iterator.
Save(OpKernelContext * ctx,IteratorStateWriter * writer)347   virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
348     return SaveInternal(writer);
349   }
350 
351   // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)352   virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
353     return RestoreInternal(ctx, reader);
354   }
355 
356  protected:
357   // This is needed so that sub-classes of IteratorBase can call
358   // `SaveInternal` on their parent iterators, e.g., in
359   // `RepeatDataasetOp::Dataset`.
SaveParent(IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & parent)360   Status SaveParent(IteratorStateWriter* writer,
361                     const std::unique_ptr<IteratorBase>& parent) {
362     return parent->SaveInternal(writer);
363   }
364 
365   // This is needed so that sub-classes of IteratorBase can call
366   // `RestoreInternal` on their parent iterators, e.g., in
367   // `RepeatDataasetOp::Dataset`.
RestoreParent(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & parent)368   Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader,
369                        const std::unique_ptr<IteratorBase>& parent) {
370     return parent->RestoreInternal(ctx, reader);
371   }
372 
373   // Saves the state of this iterator recursively.
SaveInternal(IteratorStateWriter * writer)374   virtual Status SaveInternal(IteratorStateWriter* writer) {
375     return errors::Unimplemented("SaveInternal");
376   }
377 
378   // Restores the state of this iterator recursively.
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)379   virtual Status RestoreInternal(IteratorContext* ctx,
380                                  IteratorStateReader* reader) {
381     return errors::Unimplemented("RestoreInternal");
382   }
383 };
384 
385 // Represents a (potentially infinite) range of outputs, where each
386 // output is a tuple of tensors.
387 class DatasetBase : public core::RefCounted {
388  public:
389   // Returns a new iterator for iterating over the range of elements in
390   // this dataset.
391   //
392   // This method may be called multiple times on the same instance,
393   // and the resulting iterators will have distinct state. Each
394   // iterator will traverse all elements in this dataset from the
395   // start.
396   //
397   // Ownership of the created iterator will be transferred to the caller.
398   //
399   // The prefix identifies the sequence of iterators leading up to the newly
400   // created iterator.
401   virtual std::unique_ptr<IteratorBase> MakeIterator(
402       const string& prefix) const = 0;
403 
404   // Returns a vector of DataType values, representing the respective
405   // element types of each tuple component in the outputs of this
406   // dataset.
407   virtual const DataTypeVector& output_dtypes() const = 0;
408 
409   // Returns a vector of tensor shapes, representing the respective
410   // (and possibly partially defined) shapes of each tuple component
411   // in the outputs of this dataset.
412   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
413 
414   // A human-readable debug string for this dataset.
415   virtual string DebugString() = 0;
416 
417   // Serializes the dataset and writes it to the `writer`.
Save(OpKernelContext * ctx,IteratorStateWriter * writer)418   virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const {
419     return errors::Unimplemented("DatasetBase::Save");
420   }
421 
422  protected:
423   // TODO(srbs): Ideally all graph related logic should reside in
424   // GraphDatasetBase. However, that would require Datasets defined in all ops
425   // to derive from GraphDatasetBase. Once that is done we can move
426   // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase.
427   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
428    public:
DatasetGraphDefBuilder(GraphDefBuilder * b)429     DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
AddParentDataset(OpKernelContext * ctx,const DatasetBase * dataset,Node ** output)430     Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset,
431                             Node** output) {
432       return dataset->AsGraphDefInternal(ctx, this, output);
433     }
434   };
435 
AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** node)436   virtual Status AsGraphDefInternal(OpKernelContext* ctx,
437                                     DatasetGraphDefBuilder* b,
438                                     Node** node) const {
439     return AsGraphDefInternal(b, node);
440   }
441 
AsGraphDefInternal(DatasetGraphDefBuilder * b,Node ** node)442   virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
443                                     Node** node) const {
444     return errors::Unimplemented("AsGraphDefInternal");
445   }
446 };
447 
448 // Base-class for datasets that are built by ops.
449 class GraphDatasetBase : public DatasetBase {
450  public:
GraphDatasetBase(OpKernelContext * ctx)451   GraphDatasetBase(OpKernelContext* ctx)
452       : op_name_(ctx->op_kernel().type_string()) {}
453 
op_name()454   const string op_name() const { return op_name_; }
455 
Save(OpKernelContext * ctx,IteratorStateWriter * writer)456   Status Save(OpKernelContext* ctx,
457               IteratorStateWriter* writer) const override {
458     string serialized_graph_def;
459     string output_node;
460     TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node));
461     TF_RETURN_IF_ERROR(
462         writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
463     TF_RETURN_IF_ERROR(
464         writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
465     return Status::OK();
466   }
467 
468   // Key for storing the Dataset graph in the serialized format.
469   static const char kDatasetGraphKey[];
470 
471   // Key for storing the output node of the Dataset graph in the serialized
472   // format.
473   static const char kDatasetGraphOutputNodeKey[];
474 
475  private:
476   Status Serialize(OpKernelContext* ctx, string* serialized_graph_def,
477                    string* output_node) const;
478 
479   const string op_name_;
480 };
481 
482 // Represents an iterator that is associated with a particular parent dataset.
483 template <class DatasetType>
484 class DatasetIterator : public IteratorBase {
485  public:
486   struct Params {
487     // Owns one reference on the shared dataset resource.
488     const DatasetType* dataset;
489 
490     // Identifies the sequence of iterators leading up to this iterator.
491     const string prefix;
492   };
493 
DatasetIterator(const Params & params)494   explicit DatasetIterator(const Params& params) : params_(params) {
495     params_.dataset->Ref();
496   }
497 
~DatasetIterator()498   ~DatasetIterator() override { params_.dataset->Unref(); }
499 
500   // The dataset from which this iterator was created.
dataset()501   const DatasetType* dataset() const { return params_.dataset; }
502 
503   // The sequence of iterators leading up to this iterator.
prefix()504   const string prefix() const { return params_.prefix; }
505 
output_dtypes()506   const DataTypeVector& output_dtypes() const override {
507     return params_.dataset->output_dtypes();
508   }
509 
output_shapes()510   const std::vector<PartialTensorShape>& output_shapes() const override {
511     return params_.dataset->output_shapes();
512   }
513 
GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)514   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
515                  bool* end_of_sequence) final {
516     port::Tracing::TraceMe activity(params_.prefix);
517     Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
518     if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
519       s = errors::Internal(
520           "Iterator \"", params_.prefix,
521           "\" returned OutOfRange without setting `*end_of_sequence`. This "
522           "indicates that an error may have occurred. Original message: ",
523           s.error_message());
524       LOG(ERROR) << s;
525     }
526     return s;
527   }
528 
Save(OpKernelContext * ctx,IteratorStateWriter * writer)529   Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final {
530     TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer));
531     return IteratorBase::Save(ctx, writer);
532   }
533 
534  protected:
535   // Internal implementation of GetNext that is wrapped in tracing logic.
536   virtual Status GetNextInternal(IteratorContext* ctx,
537                                  std::vector<Tensor>* out_tensors,
538                                  bool* end_of_sequence) = 0;
539 
full_name(const string & name)540   string full_name(const string& name) const {
541     return strings::StrCat(prefix(), ":", name);
542   }
543 
544  private:
545   Params params_;
546 };
547 
548 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
549 // graph execution engine.
550 class DatasetOpKernel : public OpKernel {
551  public:
DatasetOpKernel(OpKernelConstruction * ctx)552   DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
553   void Compute(OpKernelContext* ctx) final;
554 
555  protected:
556   // Subclasses should implement this method. It will be called during Compute
557   // execution.
558   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
559 
560   template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)561   Status ParseScalarArgument(OpKernelContext* ctx,
562                              const StringPiece& argument_name, T* output) {
563     const Tensor* argument_t;
564     TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
565     if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
566       return errors::InvalidArgument(argument_name, " must be a scalar");
567     }
568     *output = argument_t->scalar<T>()();
569     return Status::OK();
570   }
571 };
572 
573 // Encapsulates the work required to plug unary Datasets into the core
574 // TensorFlow graph execution engine.
575 class UnaryDatasetOpKernel : public DatasetOpKernel {
576  public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)577   UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
578 
579  protected:
580   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
581   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
582                            DatasetBase** output) = 0;
583 };
584 
585 // Encapsulates the work required to plug binary Datasets into the core
586 // TensorFlow graph execution engine.
587 class BinaryDatasetOpKernel : public DatasetOpKernel {
588  public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)589   BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
590 
591  protected:
592   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
593   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
594                            DatasetBase* another_input,
595                            DatasetBase** output) = 0;
596 };
597 
598 // Validates and extracts a `DatasetBase` object from `tensor`.
599 //
600 // `tensor` must have been written by a call to SetVariantTensorToDataset().
601 //
602 // The retrieved pointer is a borrowed reference to the dataset, which is owned
603 // by the tensor. The consumer must either acquire its own reference to the
604 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
605 // destroyed or mutated while the retrieved pointer is in use.
606 Status GetDatasetFromVariantTensor(const Tensor& tensor,
607                                    DatasetBase** out_dataset);
608 
609 // Stores a `DatasetBase` object in `tensor`.
610 //
611 // The ownership of `dataset` is transferred to `tensor`.
612 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
613 
614 }  // namespace tensorflow
615 
616 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
617