1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ 17 18 #include <memory> 19 20 #include "tensorflow/core/framework/attr_value.pb.h" 21 #include "tensorflow/core/framework/attr_value_util.h" 22 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/types.pb.h" 29 #include "tensorflow/core/framework/variant_encode_decode.h" 30 #include "tensorflow/core/framework/variant_tensor_data.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 #include "tensorflow/core/lib/strings/strcat.h" 33 #include "tensorflow/core/platform/tracing.h" 34 35 // Polymorphic datasets should support all primitive TensorFlow 36 // types. Use this macro to expand `m(T)` once for each primitive type 37 // `T`, e.g. to build a `switch` statement. 38 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) 39 40 namespace tensorflow { 41 42 // Interface for reading values from a key-value store. 43 // Used for restoring iterator state. 44 class IteratorStateReader { 45 public: 46 virtual Status ReadScalar(StringPiece key, int64* val) = 0; 47 virtual Status ReadScalar(StringPiece key, string* val) = 0; 48 virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; 49 virtual bool Contains(StringPiece key) = 0; 50 ~IteratorStateReader()51 virtual ~IteratorStateReader() {} 52 }; 53 54 // Interface for writing values to a key-value store. 55 // Used for saving iterator state. 56 class IteratorStateWriter { 57 public: 58 virtual Status WriteScalar(StringPiece key, const int64 val) = 0; 59 virtual Status WriteScalar(StringPiece key, const string& val) = 0; 60 virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; 61 ~IteratorStateWriter()62 virtual ~IteratorStateWriter() {} 63 }; 64 65 // Forward declarations to avoid introducing a dependency on headers in 66 // "tensorflow/core/graph/...". 67 class GraphDefBuilder; 68 class GraphDatasetBase; 69 class Node; 70 71 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. 72 class GraphDefBuilderWrapper { 73 public: GraphDefBuilderWrapper(GraphDefBuilder * b)74 explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} 75 76 // Adds a Const node with scalar value to the Graph. 77 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 78 // non-null if the method returns with an OK status. 79 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. 80 template <typename T> AddScalar(const T & val,Node ** output)81 Status AddScalar(const T& val, Node** output) { 82 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({})); 83 val_t.scalar<T>()() = val; 84 AddTensorInternal(val_t, output); 85 if (*output == nullptr) { 86 return errors::Internal("AddScalar: Failed to build Const op."); 87 } 88 return Status::OK(); 89 } 90 91 // Adds a Const node with vector value to the Graph. 92 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 93 // non-null if the method returns with an OK status. 94 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. 95 // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? 96 template <typename T> AddVector(const std::vector<T> & val,Node ** output)97 Status AddVector(const std::vector<T>& val, Node** output) { 98 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), 99 TensorShape({static_cast<int64>(val.size())})); 100 for (int i = 0; i < val.size(); i++) { 101 val_t.flat<T>()(i) = val[i]; 102 } 103 AddTensorInternal(val_t, output); 104 if (*output == nullptr) { 105 return errors::Internal("AddVector: Failed to build Const op."); 106 } 107 return Status::OK(); 108 } 109 110 // Adds a Const node with Tensor value to the Graph. 111 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 112 // non-null if the method returns with an OK status. 113 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. AddTensor(const Tensor & val,Node ** output)114 Status AddTensor(const Tensor& val, Node** output) { 115 AddTensorInternal(val, output); 116 if (*output == nullptr) { 117 return errors::Internal("AddTensor: Failed to build Const op."); 118 } 119 return Status::OK(); 120 } 121 AddDataset(const GraphDatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)122 Status AddDataset(const GraphDatasetBase* dataset, 123 const std::vector<Node*>& inputs, Node** output) { 124 return AddDataset(dataset, inputs, {}, output); 125 } 126 127 // Adds a node corresponding to the `DatasetType` to the Graph. 128 // Return value of `DatasetType::op_name()` is used as the op type for the 129 // node. 130 // Values for the output_types and output_shapes node attributes are also 131 // written if those attributes are defined in the OpDef. 132 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 133 // non-null if the method returns with an OK status. 134 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. AddDataset(const GraphDatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)135 Status AddDataset(const GraphDatasetBase* dataset, 136 const std::vector<Node*>& inputs, 137 const std::vector<std::pair<StringPiece, AttrValue>>& attrs, 138 Node** output) { 139 std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size()); 140 for (int i = 0; i < inputs.size(); i++) { 141 enumerated_inputs[i] = std::make_pair(i, inputs[i]); 142 } 143 return AddDataset(dataset, enumerated_inputs, {}, attrs, output); 144 } 145 146 Status AddDataset( 147 const GraphDatasetBase* dataset, 148 const std::vector<std::pair<size_t, Node*>>& inputs, 149 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, 150 const std::vector<std::pair<StringPiece, AttrValue>>& attrs, 151 Node** output); 152 153 // Adds a user-defined function with name `function_name` to the graph and 154 // recursively adds all functions it references. If a function with a matching 155 // name has already been added, returns with OK status. If a user-defined with 156 // name `function_name` is not found in the FunctionLibraryDefinition, returns 157 // an InvalidArgumentError. If the function with name `function_name` or any 158 // of its dependent functions are stateful, returns an InvalidArgument error. 159 Status AddFunction(OpKernelContext* ctx, const string& function_name); 160 161 template <typename T> BuildAttrValue(const T & value,AttrValue * attr)162 void BuildAttrValue(const T& value, AttrValue* attr) { 163 SetAttrValue(value, attr); 164 } 165 166 private: 167 void AddTensorInternal(const Tensor& val, Node** output); 168 EnsureFunctionIsStateless(OpKernelContext * ctx,const string & function_name)169 Status EnsureFunctionIsStateless(OpKernelContext* ctx, 170 const string& function_name) const { 171 const FunctionLibraryDefinition* lib_def = 172 ctx->function_library()->GetFunctionLibraryDefinition(); 173 const FunctionDef* function_def = lib_def->Find(function_name); 174 if (!function_def) { 175 return errors::InvalidArgument("Unable to find FunctionDef for ", 176 function_name, " in registry."); 177 } 178 for (const NodeDef& node_def : function_def->node_def()) { 179 const OpDef* op_def; 180 TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); 181 // TODO(b/65524810): Hack to allow functions to capture Dataset op 182 // nodes needed for FlatMap. Currently, source datasets nodes have been 183 // marked stateful to avoid constant folding since we do not have a 184 // good way of serializing them. 185 if (IsOpWhitelisted(op_def)) { 186 continue; 187 } 188 if (op_def->is_stateful()) { 189 return errors::InvalidArgument( 190 "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", 191 "in function ", function_name, " is stateful. ", 192 "Saving stateful functions is not supported yet."); 193 } 194 } 195 return Status::OK(); 196 } 197 198 // Returns whether an op has been whitelisted for use inside map_fns. 199 // Uses a heuristic to whitelist source dataset ops which have been 200 // marked stateful due to b/65524810. 201 // Also looks up the `op_def->name` in the global 202 // `WhitelistedStatefulOpRegistry`. IsOpWhitelisted(const OpDef * op_def)203 bool IsOpWhitelisted(const OpDef* op_def) const { 204 return (StringPiece(op_def->name()).ends_with("Dataset") && 205 op_def->output_arg_size() == 1 && 206 op_def->output_arg(0).type() == DT_VARIANT) || 207 dataset::WhitelistedStatefulOpRegistry::Global()->Contains( 208 op_def->name()); 209 } 210 211 bool HasAttr(const string& op_type_name, const string& attr_name) const; 212 HasAttr(const OpDef * op_def,const string & attr_name)213 bool HasAttr(const OpDef* op_def, const string& attr_name) const { 214 for (auto attr : op_def->attr()) { 215 if (attr.name() == attr_name) { 216 return true; 217 } 218 } 219 return false; 220 } 221 AddAttrFunctions(const AttrValue & attr_value,OpKernelContext * ctx)222 Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { 223 if (attr_value.has_func()) { 224 TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); 225 } else if (attr_value.has_list()) { 226 for (const NameAttrList& name_attr_list : attr_value.list().func()) { 227 TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); 228 } 229 } 230 return Status::OK(); 231 } 232 233 GraphDefBuilder* b_; 234 }; 235 236 class StatsAggregator; 237 238 // A cut-down version of OpKernelContext for running computations in 239 // iterators. Note that we cannot simply use OpKernelContext here 240 // because we might run computation in an iterator whose lifetime is 241 // not nested within the lifetime of a single OpKernelContext 242 // (e.g. asynchronous prefetching). 243 // 244 // TODO(mrry): We will probably need to support more of 245 // OpKernelContext here. For example, should allocation be handled by 246 // the IteratorContext? 247 // TODO(mrry): We're making some daring assumptions about the lifetime 248 // of the runner passed in here. A runner will be deleted when the original 249 // step ends, but all existing runners only close over session-lifetime (or 250 // longer-lived) state, so we can make a copy of the function. There's nothing 251 // in the definition of the API from which we took the runner to guarantee that 252 // what we are doing is safe. We should formalize the properties here. 253 class IteratorContext { 254 public: 255 struct Params { 256 // Interface to operating system functionality. 257 Env* env; 258 259 // Function call support. 260 std::function<void(std::function<void()>)> runner = nullptr; 261 262 // A function that returns the current `StatsAggregator` instance to be 263 // used when recording statistics about the iterator. 264 // 265 // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` 266 // is a property of the `IteratorResource` (which this class does not know 267 // about), and (ii) it can change after the `IteratorContext` has been 268 // created. Better suggestions are welcome! 269 std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter = 270 nullptr; 271 272 // The FunctionLibraryRuntime object to be used to make function calls. 273 FunctionLibraryRuntime* lib = nullptr; 274 std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr; 275 276 // The Allocator to be used to allocate the output of an iterator. 277 std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; 278 }; 279 IteratorContext(Params params)280 explicit IteratorContext(Params params) : params_(std::move(params)) {} 281 env()282 Env* env() const { return params_.env; } 283 runner()284 std::function<void(std::function<void()>)>* runner() { 285 return ¶ms_.runner; 286 } 287 stats_aggregator()288 std::shared_ptr<StatsAggregator> stats_aggregator() { 289 if (params_.stats_aggregator_getter) { 290 return params_.stats_aggregator_getter(); 291 } else { 292 return nullptr; 293 } 294 } 295 function_library()296 std::shared_ptr<const FunctionLibraryDefinition> function_library() { 297 return params_.function_library; 298 } 299 lib()300 FunctionLibraryRuntime* lib() { return params_.lib; } 301 set_lib(FunctionLibraryRuntime * lib)302 void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } 303 allocator(AllocatorAttributes attrs)304 Allocator* allocator(AllocatorAttributes attrs) { 305 return params_.allocator_getter(attrs); 306 } 307 308 private: 309 Params params_; 310 }; 311 312 // Represents the current position in a range of outputs, where the 313 // range of outputs is typically represented by an `DatasetBase`, 314 // defined below. 315 class IteratorBase { 316 public: ~IteratorBase()317 virtual ~IteratorBase() {} 318 319 // Gets the next output from the range that this iterator is traversing. 320 // 321 // If at least one output remains in this iterator's range, that 322 // output will be stored in `*out_tensors` and `false` will be 323 // stored in `*end_of_sequence`. 324 // 325 // If no more outputs remain in this iterator's range, `true` will 326 // be stored in `*end_of_sequence`, and the content of 327 // `*out_tensors` will be undefined. 328 // 329 // This method is thread-safe. 330 // 331 // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and 332 // potentially remove this method. 333 virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, 334 bool* end_of_sequence) = 0; 335 336 // Returns a vector of DataType values, representing the respective 337 // element types of each tuple component in the outputs of this 338 // iterator. 339 virtual const DataTypeVector& output_dtypes() const = 0; 340 341 // Returns a vector of tensor shapes, representing the respective 342 // (and possibly partially defined) shapes of each tuple component 343 // in the outputs of this iterator. 344 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; 345 346 // Saves the state of this iterator. Save(OpKernelContext * ctx,IteratorStateWriter * writer)347 virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { 348 return SaveInternal(writer); 349 } 350 351 // Restores the state of this iterator. Restore(IteratorContext * ctx,IteratorStateReader * reader)352 virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { 353 return RestoreInternal(ctx, reader); 354 } 355 356 protected: 357 // This is needed so that sub-classes of IteratorBase can call 358 // `SaveInternal` on their parent iterators, e.g., in 359 // `RepeatDataasetOp::Dataset`. SaveParent(IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & parent)360 Status SaveParent(IteratorStateWriter* writer, 361 const std::unique_ptr<IteratorBase>& parent) { 362 return parent->SaveInternal(writer); 363 } 364 365 // This is needed so that sub-classes of IteratorBase can call 366 // `RestoreInternal` on their parent iterators, e.g., in 367 // `RepeatDataasetOp::Dataset`. RestoreParent(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & parent)368 Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader, 369 const std::unique_ptr<IteratorBase>& parent) { 370 return parent->RestoreInternal(ctx, reader); 371 } 372 373 // Saves the state of this iterator recursively. SaveInternal(IteratorStateWriter * writer)374 virtual Status SaveInternal(IteratorStateWriter* writer) { 375 return errors::Unimplemented("SaveInternal"); 376 } 377 378 // Restores the state of this iterator recursively. RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)379 virtual Status RestoreInternal(IteratorContext* ctx, 380 IteratorStateReader* reader) { 381 return errors::Unimplemented("RestoreInternal"); 382 } 383 }; 384 385 // Represents a (potentially infinite) range of outputs, where each 386 // output is a tuple of tensors. 387 class DatasetBase : public core::RefCounted { 388 public: 389 // Returns a new iterator for iterating over the range of elements in 390 // this dataset. 391 // 392 // This method may be called multiple times on the same instance, 393 // and the resulting iterators will have distinct state. Each 394 // iterator will traverse all elements in this dataset from the 395 // start. 396 // 397 // Ownership of the created iterator will be transferred to the caller. 398 // 399 // The prefix identifies the sequence of iterators leading up to the newly 400 // created iterator. 401 virtual std::unique_ptr<IteratorBase> MakeIterator( 402 const string& prefix) const = 0; 403 404 // Returns a vector of DataType values, representing the respective 405 // element types of each tuple component in the outputs of this 406 // dataset. 407 virtual const DataTypeVector& output_dtypes() const = 0; 408 409 // Returns a vector of tensor shapes, representing the respective 410 // (and possibly partially defined) shapes of each tuple component 411 // in the outputs of this dataset. 412 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; 413 414 // A human-readable debug string for this dataset. 415 virtual string DebugString() = 0; 416 417 // Serializes the dataset and writes it to the `writer`. Save(OpKernelContext * ctx,IteratorStateWriter * writer)418 virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { 419 return errors::Unimplemented("DatasetBase::Save"); 420 } 421 422 protected: 423 // TODO(srbs): Ideally all graph related logic should reside in 424 // GraphDatasetBase. However, that would require Datasets defined in all ops 425 // to derive from GraphDatasetBase. Once that is done we can move 426 // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. 427 class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { 428 public: DatasetGraphDefBuilder(GraphDefBuilder * b)429 DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} AddParentDataset(OpKernelContext * ctx,const DatasetBase * dataset,Node ** output)430 Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, 431 Node** output) { 432 return dataset->AsGraphDefInternal(ctx, this, output); 433 } 434 }; 435 AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** node)436 virtual Status AsGraphDefInternal(OpKernelContext* ctx, 437 DatasetGraphDefBuilder* b, 438 Node** node) const { 439 return AsGraphDefInternal(b, node); 440 } 441 AsGraphDefInternal(DatasetGraphDefBuilder * b,Node ** node)442 virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 443 Node** node) const { 444 return errors::Unimplemented("AsGraphDefInternal"); 445 } 446 }; 447 448 // Base-class for datasets that are built by ops. 449 class GraphDatasetBase : public DatasetBase { 450 public: GraphDatasetBase(OpKernelContext * ctx)451 GraphDatasetBase(OpKernelContext* ctx) 452 : op_name_(ctx->op_kernel().type_string()) {} 453 op_name()454 const string op_name() const { return op_name_; } 455 Save(OpKernelContext * ctx,IteratorStateWriter * writer)456 Status Save(OpKernelContext* ctx, 457 IteratorStateWriter* writer) const override { 458 string serialized_graph_def; 459 string output_node; 460 TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); 461 TF_RETURN_IF_ERROR( 462 writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); 463 TF_RETURN_IF_ERROR( 464 writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); 465 return Status::OK(); 466 } 467 468 // Key for storing the Dataset graph in the serialized format. 469 static const char kDatasetGraphKey[]; 470 471 // Key for storing the output node of the Dataset graph in the serialized 472 // format. 473 static const char kDatasetGraphOutputNodeKey[]; 474 475 private: 476 Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, 477 string* output_node) const; 478 479 const string op_name_; 480 }; 481 482 // Represents an iterator that is associated with a particular parent dataset. 483 template <class DatasetType> 484 class DatasetIterator : public IteratorBase { 485 public: 486 struct Params { 487 // Owns one reference on the shared dataset resource. 488 const DatasetType* dataset; 489 490 // Identifies the sequence of iterators leading up to this iterator. 491 const string prefix; 492 }; 493 DatasetIterator(const Params & params)494 explicit DatasetIterator(const Params& params) : params_(params) { 495 params_.dataset->Ref(); 496 } 497 ~DatasetIterator()498 ~DatasetIterator() override { params_.dataset->Unref(); } 499 500 // The dataset from which this iterator was created. dataset()501 const DatasetType* dataset() const { return params_.dataset; } 502 503 // The sequence of iterators leading up to this iterator. prefix()504 const string prefix() const { return params_.prefix; } 505 output_dtypes()506 const DataTypeVector& output_dtypes() const override { 507 return params_.dataset->output_dtypes(); 508 } 509 output_shapes()510 const std::vector<PartialTensorShape>& output_shapes() const override { 511 return params_.dataset->output_shapes(); 512 } 513 GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)514 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, 515 bool* end_of_sequence) final { 516 port::Tracing::TraceMe activity(params_.prefix); 517 Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); 518 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { 519 s = errors::Internal( 520 "Iterator \"", params_.prefix, 521 "\" returned OutOfRange without setting `*end_of_sequence`. This " 522 "indicates that an error may have occurred. Original message: ", 523 s.error_message()); 524 LOG(ERROR) << s; 525 } 526 return s; 527 } 528 Save(OpKernelContext * ctx,IteratorStateWriter * writer)529 Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { 530 TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); 531 return IteratorBase::Save(ctx, writer); 532 } 533 534 protected: 535 // Internal implementation of GetNext that is wrapped in tracing logic. 536 virtual Status GetNextInternal(IteratorContext* ctx, 537 std::vector<Tensor>* out_tensors, 538 bool* end_of_sequence) = 0; 539 full_name(const string & name)540 string full_name(const string& name) const { 541 return strings::StrCat(prefix(), ":", name); 542 } 543 544 private: 545 Params params_; 546 }; 547 548 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow 549 // graph execution engine. 550 class DatasetOpKernel : public OpKernel { 551 public: DatasetOpKernel(OpKernelConstruction * ctx)552 DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} 553 void Compute(OpKernelContext* ctx) final; 554 555 protected: 556 // Subclasses should implement this method. It will be called during Compute 557 // execution. 558 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; 559 560 template <typename T> ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)561 Status ParseScalarArgument(OpKernelContext* ctx, 562 const StringPiece& argument_name, T* output) { 563 const Tensor* argument_t; 564 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); 565 if (!TensorShapeUtils::IsScalar(argument_t->shape())) { 566 return errors::InvalidArgument(argument_name, " must be a scalar"); 567 } 568 *output = argument_t->scalar<T>()(); 569 return Status::OK(); 570 } 571 }; 572 573 // Encapsulates the work required to plug unary Datasets into the core 574 // TensorFlow graph execution engine. 575 class UnaryDatasetOpKernel : public DatasetOpKernel { 576 public: UnaryDatasetOpKernel(OpKernelConstruction * ctx)577 UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} 578 579 protected: 580 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; 581 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 582 DatasetBase** output) = 0; 583 }; 584 585 // Encapsulates the work required to plug binary Datasets into the core 586 // TensorFlow graph execution engine. 587 class BinaryDatasetOpKernel : public DatasetOpKernel { 588 public: BinaryDatasetOpKernel(OpKernelConstruction * ctx)589 BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} 590 591 protected: 592 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; 593 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 594 DatasetBase* another_input, 595 DatasetBase** output) = 0; 596 }; 597 598 // Validates and extracts a `DatasetBase` object from `tensor`. 599 // 600 // `tensor` must have been written by a call to SetVariantTensorToDataset(). 601 // 602 // The retrieved pointer is a borrowed reference to the dataset, which is owned 603 // by the tensor. The consumer must either acquire its own reference to the 604 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not 605 // destroyed or mutated while the retrieved pointer is in use. 606 Status GetDatasetFromVariantTensor(const Tensor& tensor, 607 DatasetBase** out_dataset); 608 609 // Stores a `DatasetBase` object in `tensor`. 610 // 611 // The ownership of `dataset` is transferred to `tensor`. 612 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); 613 614 } // namespace tensorflow 615 616 #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ 617