1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ 17 18 #include <deque> 19 #include <memory> 20 #include <unordered_map> 21 22 #include "absl/memory/memory.h" 23 #include "tensorflow/core/framework/attr_value.pb.h" 24 #include "tensorflow/core/framework/attr_value_util.h" 25 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/framework/graph.pb.h" 28 #include "tensorflow/core/framework/model.h" 29 #include "tensorflow/core/framework/node_def.pb.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/framework/thread_factory.h" 33 #include "tensorflow/core/framework/types.pb.h" 34 #include "tensorflow/core/framework/variant_encode_decode.h" 35 #include "tensorflow/core/framework/variant_tensor_data.h" 36 #include "tensorflow/core/lib/core/threadpool.h" 37 #include "tensorflow/core/lib/strings/str_util.h" 38 #include "tensorflow/core/lib/strings/strcat.h" 39 #include "tensorflow/core/platform/cpu_info.h" 40 #include "tensorflow/core/platform/tracing.h" 41 42 // Polymorphic datasets should support all primitive TensorFlow 43 // types. Use this macro to expand `m(T)` once for each primitive type 44 // `T`, e.g. to build a `switch` statement. 45 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) 46 47 namespace tensorflow { 48 49 // Forward declarations to avoid introducing a dependency on headers in 50 // "tensorflow/core/graph/...". 51 class GraphDefBuilder; 52 class Node; 53 54 namespace data { 55 56 constexpr int kInfiniteCardinality = -1; 57 constexpr int kUnknownCardinality = -2; 58 59 class DatasetBase; 60 class SerializationContext; 61 62 // Interface for reading values from a key-value store. 63 // Used for restoring iterator state. 64 class IteratorStateReader { 65 public: 66 virtual Status ReadScalar(StringPiece key, int64* val) = 0; 67 virtual Status ReadScalar(StringPiece key, string* val) = 0; 68 virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; 69 virtual bool Contains(StringPiece key) = 0; 70 ~IteratorStateReader()71 virtual ~IteratorStateReader() {} 72 }; 73 74 // Interface for writing values to a key-value store. 75 // Used for saving iterator state. 76 class IteratorStateWriter { 77 public: 78 virtual Status WriteScalar(StringPiece key, const int64 val) = 0; 79 virtual Status WriteScalar(StringPiece key, const string& val) = 0; 80 virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; 81 ~IteratorStateWriter()82 virtual ~IteratorStateWriter() {} 83 }; 84 85 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. 86 class GraphDefBuilderWrapper { 87 public: GraphDefBuilderWrapper(GraphDefBuilder * b)88 explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} 89 90 // Adds a Const node with scalar value to the Graph. 91 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 92 // non-null if the method returns with an OK status. 93 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. 94 template <typename T> AddScalar(const T & val,Node ** output)95 Status AddScalar(const T& val, Node** output) { 96 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({})); 97 val_t.scalar<T>()() = val; 98 AddTensorInternal(val_t, output); 99 if (*output == nullptr) { 100 return errors::Internal("AddScalar: Failed to build Const op."); 101 } 102 return Status::OK(); 103 } 104 105 // Adds a Const node with vector value to the Graph. 106 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 107 // non-null if the method returns with an OK status. 108 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. 109 // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? 110 template <typename T> AddVector(const std::vector<T> & val,Node ** output)111 Status AddVector(const std::vector<T>& val, Node** output) { 112 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), 113 TensorShape({static_cast<int64>(val.size())})); 114 for (int i = 0; i < val.size(); i++) { 115 val_t.flat<T>()(i) = val[i]; 116 } 117 AddTensorInternal(val_t, output); 118 if (*output == nullptr) { 119 return errors::Internal("AddVector: Failed to build Const op."); 120 } 121 return Status::OK(); 122 } 123 124 // Adds a `Const` node for the given tensor value to the graph. 125 // 126 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 127 // non-null if the method returns with an OK status. The returned `Node` 128 // pointer is owned by the backing graph of `GraphDefBuilder`. AddTensor(const Tensor & val,Node ** output)129 Status AddTensor(const Tensor& val, Node** output) { 130 AddTensorInternal(val, output); 131 if (*output == nullptr) { 132 return errors::Internal("AddTensor: Failed to build Const op."); 133 } 134 return Status::OK(); 135 } 136 137 // Adds a `Placeholder` node for the given tensor value to the graph. 138 // 139 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 140 // non-null if the method returns with an OK status. The returned `Node` 141 // pointer is owned by the backing graph of `GraphDefBuilder`. AddPlaceholder(const Tensor & val,Node ** output)142 Status AddPlaceholder(const Tensor& val, Node** output) { 143 AddPlaceholderInternal(val, output); 144 if (*output == nullptr) { 145 return errors::Internal( 146 "AddPlaceholder: Failed to build Placeholder op."); 147 } 148 return Status::OK(); 149 } 150 AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)151 Status AddDataset(const DatasetBase* dataset, 152 const std::vector<Node*>& inputs, Node** output) { 153 return AddDataset(dataset, inputs, {}, output); 154 } 155 156 // Adds a node corresponding to the `DatasetType` to the Graph. 157 // Return value of `DatasetType::op_name()` is used as the op type for the 158 // node. 159 // Values for the output_types and output_shapes node attributes are also 160 // written if those attributes are defined in the OpDef. 161 // `*output` contains a pointer to the output `Node`. It is guaranteed to be 162 // non-null if the method returns with an OK status. 163 // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)164 Status AddDataset(const DatasetBase* dataset, 165 const std::vector<Node*>& inputs, 166 const std::vector<std::pair<StringPiece, AttrValue>>& attrs, 167 Node** output) { 168 std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size()); 169 for (size_t i = 0; i < inputs.size(); i++) { 170 enumerated_inputs[i] = std::make_pair(i, inputs[i]); 171 } 172 return AddDataset(dataset, enumerated_inputs, {}, attrs, output); 173 } 174 175 Status AddDataset( 176 const DatasetBase* dataset, 177 const std::vector<std::pair<size_t, Node*>>& inputs, 178 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, 179 const std::vector<std::pair<StringPiece, AttrValue>>& attrs, 180 Node** output); 181 182 // Adds a user-defined function with name `function_name` to the graph and 183 // recursively adds all functions it references. If a function with a matching 184 // name has already been added, returns with OK status. If a user-defined with 185 // name `function_name` is not found in the context's function library, 186 // returns an InvalidArgumentError. If the function with name `function_name` 187 // or any of its dependent functions are stateful, and the context does not 188 // explicitly permit stateful functions, returns an InvalidArgument error. 189 Status AddFunction(SerializationContext* ctx, const string& function_name); 190 191 template <typename T> BuildAttrValue(const T & value,AttrValue * attr)192 void BuildAttrValue(const T& value, AttrValue* attr) { 193 SetAttrValue(value, attr); 194 } 195 196 private: 197 void AddPlaceholderInternal(const Tensor& val, Node** output); 198 void AddTensorInternal(const Tensor& val, Node** output); 199 EnsureFunctionIsStateless(const FunctionLibraryDefinition & flib_def,const string & function_name)200 Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def, 201 const string& function_name) const { 202 const FunctionDef* function_def = flib_def.Find(function_name); 203 if (!function_def) { 204 return errors::InvalidArgument("Unable to find FunctionDef for ", 205 function_name, " in registry."); 206 } 207 for (const NodeDef& node_def : function_def->node_def()) { 208 const OpDef* op_def; 209 TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def)); 210 // TODO(b/65524810): Hack to allow functions to capture Dataset op 211 // nodes needed for FlatMap. Currently, source datasets nodes have been 212 // marked stateful to avoid constant folding since we do not have a 213 // good way of serializing them. 214 if (IsOpWhitelisted(op_def)) { 215 continue; 216 } 217 if (op_def->is_stateful()) { 218 return errors::InvalidArgument( 219 "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", 220 "in function ", function_name, " is stateful. ", 221 "Saving stateful functions is not supported yet."); 222 } 223 } 224 return Status::OK(); 225 } 226 227 // Returns whether an op has been whitelisted for use inside map_fns. 228 // Uses a heuristic to whitelist source dataset ops which have been 229 // marked stateful due to b/65524810. 230 // Also looks up the `op_def->name` in the global 231 // `WhitelistedStatefulOpRegistry`. IsOpWhitelisted(const OpDef * op_def)232 bool IsOpWhitelisted(const OpDef* op_def) const { 233 return (str_util::EndsWith(op_def->name(), "Dataset") && 234 op_def->output_arg_size() == 1 && 235 op_def->output_arg(0).type() == DT_VARIANT) || 236 WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name()); 237 } 238 239 bool HasAttr(const string& op_type_name, const string& attr_name) const; 240 HasAttr(const OpDef * op_def,const string & attr_name)241 bool HasAttr(const OpDef* op_def, const string& attr_name) const { 242 for (auto attr : op_def->attr()) { 243 if (attr.name() == attr_name) { 244 return true; 245 } 246 } 247 return false; 248 } 249 AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value)250 Status AddAttrFunctions(SerializationContext* ctx, 251 const AttrValue& attr_value) { 252 if (attr_value.has_func()) { 253 TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); 254 } else if (attr_value.has_list()) { 255 for (const NameAttrList& name_attr_list : attr_value.list().func()) { 256 TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); 257 } 258 } 259 return Status::OK(); 260 } 261 262 GraphDefBuilder* b_; 263 }; 264 265 class StatsAggregator; 266 class FunctionHandleCache; 267 268 // A cut-down version of `OpKernelContext` for running computations in 269 // iterators. Note that we cannot simply use `OpKernelContext` here because we 270 // might run computation in an iterator whose lifetime is not nested within the 271 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching). 272 // 273 // TODO(mrry): We're making some daring assumptions about the lifetime of the 274 // runner passed in here. A runner will be deleted when the original step ends, 275 // but all existing runners only close over session-lifetime (or longer-lived) 276 // state, so we can make a copy of the function. There's nothing in the 277 // definition of the API from which we took the runner to guarantee that what we 278 // are doing is safe. We should formalize the properties here. 279 class IteratorContext { 280 public: 281 struct Params { ParamsParams282 explicit Params(IteratorContext* ctx) 283 : allocator_getter(ctx->allocator_getter()), 284 env(ctx->env()), 285 function_library(ctx->function_library()), 286 lib(ctx->lib()), 287 function_handle_cache(ctx->function_handle_cache()), 288 resource_mgr(ctx->resource_mgr()), 289 model(ctx->model()), 290 runner(*(ctx->runner())), 291 runner_threadpool_size(ctx->runner_threadpool_size()), 292 stats_aggregator(ctx->stats_aggregator()), 293 thread_factory(ctx->thread_factory()) {} 294 ParamsParams295 explicit Params(OpKernelContext* ctx) 296 : env(ctx->env()), 297 lib(ctx->function_library()), 298 runner(*(ctx->runner())) { 299 // NOTE: need reinterpret_cast because function.h forward-declares Device. 300 DeviceBase* device = 301 reinterpret_cast<DeviceBase*>(ctx->function_library()->device()); 302 allocator_getter = [device](AllocatorAttributes attrs) { 303 return device->GetAllocator(attrs); 304 }; 305 thread::ThreadPool* thread_pool = 306 ctx->device()->tensorflow_device_thread_pool(); 307 if (thread_pool) { 308 runner_threadpool_size = thread_pool->NumThreads(); 309 } else { 310 runner_threadpool_size = port::NumSchedulableCPUs(); 311 } 312 } 313 314 // The Allocator to be used to allocate the output of an iterator. 315 std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; 316 317 // Interface to operating system functionality. 318 Env* env = nullptr; 319 320 // The FunctionLibraryDefinition used to look up user-defined functions. 321 std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr; 322 323 // The FunctionLibraryRuntime object to be used to make function calls. 324 FunctionLibraryRuntime* lib = nullptr; 325 326 // A FunctionHandleCache that owns all the function handles. Not owned. 327 FunctionHandleCache* function_handle_cache = nullptr; 328 329 // A resource manager for storing dataset-related state, e.g. random 330 // seeds or cached tensors. Not owned. 331 ResourceMgr* resource_mgr = nullptr; 332 333 // If non-null, identifies the object used for performance modeling. 334 std::shared_ptr<model::Model> model = nullptr; 335 336 // Function call support. 337 std::function<void(std::function<void()>)> runner = nullptr; 338 339 // Number of threads used for executing user-defined functions. 340 int32 runner_threadpool_size = 0; 341 342 // The `StatsAggregator` object to record statistics about the iterator. 343 std::shared_ptr<StatsAggregator> stats_aggregator = nullptr; 344 345 // A `ThreadFactory` for creating threads used by iterators to perform 346 // blocking work. 347 std::shared_ptr<ThreadFactory> thread_factory = nullptr; 348 }; 349 IteratorContext(IteratorContext * ctx)350 explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {} 351 IteratorContext(OpKernelContext * ctx)352 explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {} 353 IteratorContext(Params params)354 explicit IteratorContext(Params params) : params_(std::move(params)) {} 355 allocator(AllocatorAttributes attrs)356 Allocator* allocator(AllocatorAttributes attrs) { 357 return params_.allocator_getter(attrs); 358 } 359 allocator_getter()360 std::function<Allocator*(AllocatorAttributes)> allocator_getter() { 361 return params_.allocator_getter; 362 } 363 env()364 Env* env() const { return params_.env; } 365 function_library()366 std::shared_ptr<const FunctionLibraryDefinition> function_library() { 367 return params_.function_library; 368 } 369 lib()370 FunctionLibraryRuntime* lib() { return params_.lib; } 371 function_handle_cache()372 FunctionHandleCache* function_handle_cache() { 373 return params_.function_handle_cache; 374 } 375 resource_mgr()376 ResourceMgr* resource_mgr() { return params_.resource_mgr; } 377 model()378 const std::shared_ptr<model::Model>& model() { return params_.model; } 379 runner()380 std::function<void(std::function<void()>)>* runner() { 381 return ¶ms_.runner; 382 } 383 thread_factory()384 const std::shared_ptr<ThreadFactory>& thread_factory() { 385 return params_.thread_factory; 386 } 387 StartThread(const string & name,std::function<void ()> fn)388 std::unique_ptr<Thread> StartThread(const string& name, 389 std::function<void()> fn) { 390 if (params_.thread_factory) { 391 return params_.thread_factory->StartThread(name, std::move(fn)); 392 } else { 393 return absl::WrapUnique( 394 Env::Default()->StartThread({}, name, std::move(fn))); 395 } 396 } 397 runner_threadpool_size()398 int32 runner_threadpool_size() { return params_.runner_threadpool_size; } 399 stats_aggregator()400 std::shared_ptr<StatsAggregator> stats_aggregator() { 401 return params_.stats_aggregator; 402 } 403 params()404 Params params() { return params_; } 405 406 private: 407 Params params_; 408 }; 409 410 // Aggregates runtime support needed for dataset and iterator serialization. 411 class SerializationContext { 412 public: 413 struct Params { 414 const FunctionLibraryDefinition* flib_def = nullptr; // Not owned. 415 std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned. 416 bool optimization_only = false; 417 }; 418 SerializationContext(Params params)419 explicit SerializationContext(Params params) : params_(std::move(params)) {} 420 flib_def()421 const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; } 422 input_list()423 std::vector<std::pair<string, Tensor>>* input_list() { 424 return params_.input_list; 425 } 426 optimization_only()427 bool optimization_only() { return params_.optimization_only; } 428 429 private: 430 Params params_; 431 432 TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext); 433 }; 434 435 // Represents the current position in a range of outputs, where the 436 // range of outputs is typically represented by an `DatasetBase`, 437 // defined below. 438 class IteratorBase { 439 public: ~IteratorBase()440 virtual ~IteratorBase() { 441 for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { 442 (*rit)(); 443 } 444 } 445 446 // Gets the next output from the range that this iterator is traversing. 447 // 448 // If at least one output remains in this iterator's range, that 449 // output will be stored in `*out_tensors` and `false` will be 450 // stored in `*end_of_sequence`. 451 // 452 // If no more outputs remain in this iterator's range, `true` will 453 // be stored in `*end_of_sequence`, and the content of 454 // `*out_tensors` will be undefined. 455 // 456 // This method is thread-safe. 457 // 458 // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and 459 // potentially remove this method. 460 virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, 461 bool* end_of_sequence) = 0; 462 GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)463 Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors, 464 bool* end_of_sequence) { 465 return GetNext(&ctx, out_tensors, end_of_sequence); 466 } 467 468 // Returns a vector of DataType values, representing the respective 469 // element types of each tuple component in the outputs of this 470 // iterator. 471 virtual const DataTypeVector& output_dtypes() const = 0; 472 473 // Returns a vector of tensor shapes, representing the respective 474 // (and possibly partially defined) shapes of each tuple component 475 // in the outputs of this iterator. 476 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; 477 478 // Returns a string that identifies the sequence of iterators leading up to 479 // this iterator. 480 virtual const string& prefix() const = 0; 481 482 // Performs initialization that needs to happen outside of a constructor to 483 // properly propagate errors. Initialize(IteratorContext * ctx)484 virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } 485 486 // Saves the state of this iterator. Save(SerializationContext * ctx,IteratorStateWriter * writer)487 virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { 488 return SaveInternal(writer); 489 } 490 491 // Restores the state of this iterator. Restore(IteratorContext * ctx,IteratorStateReader * reader)492 virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { 493 return RestoreInternal(ctx, reader); 494 } 495 496 protected: 497 // Returns a node that models this iterator. 498 virtual std::shared_ptr<model::Node> CreateNode( 499 IteratorContext* ctx, model::Node::Args args) const = 0; 500 501 // This is needed so that sub-classes of IteratorBase can call 502 // `SaveInternal` on their input iterators. SaveInput(IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)503 Status SaveInput(IteratorStateWriter* writer, 504 const std::unique_ptr<IteratorBase>& input) { 505 return input->SaveInternal(writer); 506 } 507 508 // This is needed so that sub-classes of IteratorBase can call 509 // `RestoreInternal` on their input iterators. RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)510 Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, 511 const std::unique_ptr<IteratorBase>& input) { 512 return input->RestoreInternal(ctx, reader); 513 } 514 515 // Saves the state of this iterator recursively. SaveInternal(IteratorStateWriter * writer)516 virtual Status SaveInternal(IteratorStateWriter* writer) { 517 return errors::Unimplemented("SaveInternal"); 518 } 519 520 // Restores the state of this iterator recursively. RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)521 virtual Status RestoreInternal(IteratorContext* ctx, 522 IteratorStateReader* reader) { 523 return errors::Unimplemented("RestoreInternal"); 524 } 525 526 private: 527 friend class DatasetBase; // for access to `AddCleanupFunction` 528 friend class DatasetBaseIterator; // for access to `node_` 529 530 // Registers a cleanup function to be called upon object destruction. 531 // 532 // Registered functions are invoked in the reserve order of registration. AddCleanupFunction(std::function<void ()> && cleanup_fn)533 void AddCleanupFunction(std::function<void()>&& cleanup_fn) { 534 cleanup_fns_.push_back(std::move(cleanup_fn)); 535 } 536 537 // Associates the given performance modeling `Node` with this iterator. SetNode(std::shared_ptr<model::Node> node)538 void SetNode(std::shared_ptr<model::Node> node) { node_ = node.get(); } 539 540 std::vector<std::function<void()>> cleanup_fns_; 541 model::Node* node_ = nullptr; // Not owned. 542 }; 543 544 // Represents runtime information needed to construct a dataset. 545 class DatasetContext { 546 public: 547 struct Params { 548 string type_string; // op type name of this dataset. 549 string node_name; // graph node name of this dataset op, uniquely 550 // identifying the dataset in the graph. 551 }; 552 DatasetContext(Params params)553 explicit DatasetContext(Params params) : params_(std::move(params)) {} 554 DatasetContext(OpKernelContext * ctx)555 explicit DatasetContext(OpKernelContext* ctx) { 556 params_.type_string = ctx->op_kernel().type_string(); 557 params_.node_name = ctx->op_kernel().name(); 558 } 559 type_string()560 const string& type_string() const { return params_.type_string; } node_name()561 const string& node_name() const { return params_.node_name; } 562 563 private: 564 Params params_; 565 }; 566 567 // Returns the number of bytes allocated for the given tensor. 568 int64 GetAllocatedBytes(const std::vector<Tensor>& element); 569 570 // Validates and extracts a `DatasetBase` object from `tensor`. 571 // 572 // `tensor` must have been written by a call to SetVariantTensorToDataset(). 573 // 574 // The retrieved pointer is a borrowed reference to the dataset, which is owned 575 // by the tensor. The consumer must either acquire its own reference to the 576 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not 577 // destroyed or mutated while the retrieved pointer is in use. 578 Status GetDatasetFromVariantTensor(const Tensor& tensor, 579 DatasetBase** out_dataset); 580 581 // Stores a `DatasetBase` object in `tensor`. 582 // 583 // The ownership of `dataset` is transferred to `tensor`. 584 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); 585 586 // Represents a (potentially infinite) range of outputs, where each 587 // output is a tuple of tensors. 588 class DatasetBase : public core::RefCounted { 589 public: 590 // Key for storing the Dataset graph in the serialized format. 591 TF_EXPORT static const char kDatasetGraphKey[]; 592 593 // Key for storing the output node of the Dataset graph in the serialized 594 // format. 595 TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; 596 DatasetBase(DatasetContext && ctx)597 explicit DatasetBase(DatasetContext&& ctx) 598 : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {} 599 600 // Op type name of this dataset. type_string()601 const string& type_string() const { return type_string_; } 602 603 // Graph node name of this dataset op, uniquely identifying the dataset in 604 // the graph. node_name()605 const string& node_name() const { return node_name_; } 606 607 // Returns a new iterator for iterating over the range of elements in 608 // this dataset. 609 // 610 // This method may be called multiple times on the same instance, 611 // and the resulting iterators will have distinct state. Each 612 // iterator will traverse all elements in this dataset from the 613 // start. 614 // 615 // The prefix identifies the sequence of iterators leading up to the newly 616 // created iterator. MakeIterator(IteratorContext * ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)617 Status MakeIterator(IteratorContext* ctx, const string& output_prefix, 618 std::unique_ptr<IteratorBase>* iterator) const { 619 *iterator = MakeIteratorInternal(output_prefix); 620 if (const auto& model = ctx->model()) { 621 const string& prefix = (*iterator)->prefix(); 622 (*iterator)->SetNode(model->AddNode(MakeNodeFactory(ctx, iterator->get()), 623 prefix, output_prefix)); 624 (*iterator)->AddCleanupFunction( 625 [model, prefix]() { model->RemoveNode(prefix); }); 626 } 627 return (*iterator)->Initialize(ctx); 628 } 629 MakeIterator(IteratorContext && ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)630 Status MakeIterator(IteratorContext&& ctx, const string& output_prefix, 631 std::unique_ptr<IteratorBase>* iterator) const { 632 return MakeIterator(&ctx, output_prefix, iterator); 633 } 634 635 // Returns a vector of DataType values, representing the respective 636 // element types of each tuple component in the outputs of this 637 // dataset. 638 virtual const DataTypeVector& output_dtypes() const = 0; 639 640 // Returns a vector of tensor shapes, representing the respective 641 // (and possibly partially defined) shapes of each tuple component 642 // in the outputs of this dataset. 643 virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; 644 645 // Returns the number of bytes allocated for tensors of this dataset. AllocatedBytes()646 virtual int64 AllocatedBytes() const { return 0; } 647 648 // Returns the cardinality of this dataset. Cardinality()649 virtual int64 Cardinality() const { return kUnknownCardinality; } 650 651 // A human-readable debug string for this dataset. 652 virtual string DebugString() const = 0; 653 654 // Serializes the dataset and writes it to the `writer`. 655 virtual Status Save(SerializationContext* ctx, 656 IteratorStateWriter* writer) const; 657 658 protected: 659 friend class DatasetToGraphOp; // For access to graph related members. 660 661 class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { 662 public: DatasetGraphDefBuilder(GraphDefBuilder * b)663 DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} 664 Status AddInputDataset(SerializationContext* ctx, 665 const DatasetBase* dataset, Node** output); 666 }; 667 668 virtual Status AsGraphDefInternal(SerializationContext* ctx, 669 DatasetGraphDefBuilder* b, 670 Node** node) const = 0; 671 672 virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( 673 const string& prefix) const = 0; 674 675 private: 676 // Returns a factory for nodes that represent the given iterator. MakeNodeFactory(IteratorContext * ctx,IteratorBase * iterator)677 static model::Node::Factory MakeNodeFactory(IteratorContext* ctx, 678 IteratorBase* iterator) { 679 return [ctx, iterator](model::Node::Args args) { 680 return iterator->CreateNode(ctx, std::move(args)); 681 }; 682 } 683 684 const string type_string_; 685 const string node_name_; 686 }; 687 688 // Represents an iterator that is associated with a particular dataset. 689 class DatasetBaseIterator : public IteratorBase { 690 public: 691 struct BaseParams { 692 // Owns one reference on the shared dataset object. 693 const DatasetBase* dataset; 694 695 // Identifies the sequence of iterators leading up to this iterator. 696 const string prefix; 697 }; 698 DatasetBaseIterator(const BaseParams & params)699 explicit DatasetBaseIterator(const BaseParams& params) : params_(params) { 700 params_.dataset->Ref(); 701 } 702 ~DatasetBaseIterator()703 ~DatasetBaseIterator() override { params_.dataset->Unref(); } 704 705 // The sequence of iterators leading up to this iterator. prefix()706 const string& prefix() const override { return params_.prefix; } 707 output_dtypes()708 const DataTypeVector& output_dtypes() const override { 709 return params_.dataset->output_dtypes(); 710 } 711 output_shapes()712 const std::vector<PartialTensorShape>& output_shapes() const override { 713 return params_.dataset->output_shapes(); 714 } 715 GetNext(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)716 Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, 717 bool* end_of_sequence) final { 718 tracing::ScopedActivity activity(params_.prefix); 719 RecordStart(ctx, true /* stop_output */); 720 Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); 721 if (s.ok() && !*end_of_sequence) RecordElement(ctx); 722 RecordStop(ctx, true /* start_output */); 723 if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { 724 s = errors::Internal( 725 "Iterator \"", params_.prefix, 726 "\" returned OutOfRange without setting `*end_of_sequence`. This " 727 "indicates that an error may have occurred. Original message: ", 728 s.error_message()); 729 LOG(ERROR) << s; 730 } 731 return s; 732 } 733 Save(SerializationContext * ctx,IteratorStateWriter * writer)734 Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final { 735 TF_RETURN_IF_ERROR(params_.dataset->Save(ctx, writer)); 736 return IteratorBase::Save(ctx, writer); 737 } 738 739 protected: 740 // Internal implementation of GetNext that is wrapped in tracing logic. 741 virtual Status GetNextInternal(IteratorContext* ctx, 742 std::vector<Tensor>* out_tensors, 743 bool* end_of_sequence) = 0; 744 full_name(const string & name)745 string full_name(const string& name) const { 746 return strings::StrCat(params_.prefix, ":", name); 747 } 748 749 // By default we model iterators using an unknown node, which acts as 750 // pass-through with respect to performance modeling. CreateNode(IteratorContext * ctx,model::Node::Args args)751 std::shared_ptr<model::Node> CreateNode( 752 IteratorContext* ctx, model::Node::Args args) const override { 753 return model::MakeUnknownNode(std::move(args)); 754 } 755 756 // When modeling is enabled, this method records the fact that this iterator 757 // has dequeued an element from an internal buffer. RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)758 void RecordBufferDequeue(IteratorContext* ctx, 759 const std::vector<Tensor>& element) { 760 if (collect_resource_usage(ctx)) { 761 node_->add_buffered_bytes(-GetAllocatedBytes(element)); 762 } 763 } 764 765 // When modeling is enabled, this method records the fact that this iterator 766 // has enqueued an element in an internal buffer. RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)767 void RecordBufferEnqueue(IteratorContext* ctx, 768 const std::vector<Tensor>& element) { 769 if (collect_resource_usage(ctx)) { 770 node_->add_buffered_bytes(GetAllocatedBytes(element)); 771 } 772 } 773 774 // When modeling is enabled, this method records the fact that this iterator 775 // has produced an element. RecordElement(IteratorContext * ctx)776 void RecordElement(IteratorContext* ctx) { 777 if (node_) { 778 node_->record_element(); 779 } 780 } 781 782 // When modeling is enabled, this method records the fact that a thread of 783 // this iterator has started work. 784 void RecordStart(IteratorContext* ctx, bool stop_output = false) { 785 if (collect_resource_usage(ctx)) { 786 int64 now_nanos = Env::Default()->NowNanos(); 787 if (stop_output && node_->output()) { 788 node_->output()->record_stop(now_nanos); 789 } 790 node_->record_start(now_nanos); 791 } 792 } 793 794 // When modeling is enabled, this method records the fact that a thread of 795 // this iterator has stopped work. 796 void RecordStop(IteratorContext* ctx, bool start_output = false) { 797 if (collect_resource_usage(ctx)) { 798 int64 now_nanos = Env::Default()->NowNanos(); 799 node_->record_stop(now_nanos); 800 if (start_output && node_->output()) { 801 node_->output()->record_start(now_nanos); 802 } 803 } 804 } 805 806 private: collect_resource_usage(IteratorContext * ctx)807 inline bool collect_resource_usage(IteratorContext* ctx) { 808 auto model = ctx->model(); 809 return model && model->collect_resource_usage() && node_; 810 } 811 812 BaseParams params_; 813 }; 814 815 // Represents an iterator that is associated with a particular dataset 816 // with a particular type. 817 template <class DatasetType> 818 class DatasetIterator : public DatasetBaseIterator { 819 public: 820 struct Params { 821 // Borrowed pointer to the dataset. 822 const DatasetType* dataset; 823 824 // Identifies the sequence of iterators leading up to this iterator. 825 const string prefix; 826 }; 827 DatasetIterator(const Params & params)828 explicit DatasetIterator(const Params& params) 829 : DatasetBaseIterator({params.dataset, params.prefix}), 830 typed_dataset_(params.dataset) {} 831 832 // The dataset from which this iterator was created. dataset()833 const DatasetType* dataset() const { return typed_dataset_; } 834 835 protected: 836 virtual Status GetNextInternal(IteratorContext* ctx, 837 std::vector<Tensor>* out_tensors, 838 bool* end_of_sequence) = 0; 839 840 private: 841 const DatasetType* const typed_dataset_; // Not owned. 842 }; 843 844 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow 845 // graph execution engine. 846 class DatasetOpKernel : public OpKernel { 847 public: DatasetOpKernel(OpKernelConstruction * ctx)848 DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} 849 void Compute(OpKernelContext* ctx) final; 850 851 protected: 852 // Subclasses should implement this method. It will be called during Compute 853 // execution. 854 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; 855 856 template <typename T> ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)857 Status ParseScalarArgument(OpKernelContext* ctx, 858 const StringPiece& argument_name, T* output) { 859 const Tensor* argument_t; 860 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); 861 if (!TensorShapeUtils::IsScalar(argument_t->shape())) { 862 return errors::InvalidArgument(argument_name, " must be a scalar"); 863 } 864 *output = argument_t->scalar<T>()(); 865 return Status::OK(); 866 } 867 868 template <typename T> ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)869 Status ParseVectorArgument(OpKernelContext* ctx, 870 const StringPiece& argument_name, 871 std::vector<T>* output) { 872 const Tensor* argument_t; 873 TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); 874 if (!TensorShapeUtils::IsVector(argument_t->shape())) { 875 return errors::InvalidArgument(argument_name, " must be a vector"); 876 } 877 int size = argument_t->vec<T>().size(); 878 output->reserve(size); 879 for (int i = 0; i < size; ++i) { 880 output->push_back(argument_t->vec<T>()(i)); 881 } 882 return Status::OK(); 883 } 884 }; 885 886 // Encapsulates the work required to plug unary Datasets into the core 887 // TensorFlow graph execution engine. 888 class UnaryDatasetOpKernel : public DatasetOpKernel { 889 public: UnaryDatasetOpKernel(OpKernelConstruction * ctx)890 UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} 891 892 protected: 893 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; 894 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 895 DatasetBase** output) = 0; 896 }; 897 898 // Encapsulates the work required to plug binary Datasets into the core 899 // TensorFlow graph execution engine. 900 class BinaryDatasetOpKernel : public DatasetOpKernel { 901 public: BinaryDatasetOpKernel(OpKernelConstruction * ctx)902 BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} 903 904 protected: 905 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; 906 virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 907 DatasetBase* another_input, 908 DatasetBase** output) = 0; 909 }; 910 911 // A simple background worker that executes closures asynchronously and without 912 // blocking. 913 // 914 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel` 915 // to avoid blocking an executor thread that may be required by the blocking 916 // work. 917 // 918 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this 919 // purpose because its current implementation (in Eigen) uses a finite-length 920 // queue and will block the caller when full. This can lead to deadlock under 921 // heavy load. Since the number of concurrent work items in each user of a 922 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation 923 // overhead is tolerable. 924 class BackgroundWorker { 925 public: 926 BackgroundWorker(Env* env, const string& name); 927 928 ~BackgroundWorker(); 929 930 void Schedule(std::function<void()> work_item); 931 932 private: 933 void WorkerLoop(); 934 935 std::unique_ptr<Thread> thread_; 936 mutex mu_; 937 condition_variable cond_var_; 938 bool cancelled_ GUARDED_BY(mu_) = false; 939 std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_); 940 }; 941 942 } // namespace data 943 944 // TODO(b/114112161): Remove these aliases when all users have moved over to the 945 // `tensorflow::data` namespace. 946 using data::DatasetBase; 947 using data::DatasetContext; 948 using data::DatasetIterator; 949 using data::DatasetOpKernel; 950 using data::IteratorBase; 951 using data::IteratorContext; 952 using data::IteratorStateReader; 953 using data::IteratorStateWriter; 954 using data::SerializationContext; 955 using data::UnaryDatasetOpKernel; 956 957 } // namespace tensorflow 958 959 #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ 960