/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_DATA_STANDALONE_H_ #define TENSORFLOW_CORE_DATA_STANDALONE_H_ #include #include #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/data/unbounded_thread_pool.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function_handle_cache.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace data { namespace standalone { // The purpose of the API in this file is to facilitate standalone execution of // a tf.data input pipeline graph. // // The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which // encapsulate TensorFlow runtime. // // The `Dataset` abstraction represents an input pipeline as a collection // of data sources and a logical plan of transformations that operate over the // data. // // The `Iterator` abstraction represents an execution of an input pipeline that // can be used to enumerate its elements. // // Example usage: // // // Create a `Dataset` by running the `graph_def` graph. // tensorflow::data:standalone::Dataset::Params params; // std::unique_ptr dataset; // Status s = tensorflow::data::standalone::Dataset::FromGraph( // params, graph_def, &dataset); // if (!s.ok()) { /* error handling */ } // // std::unique_ptr iterator; // s = dataset->MakeIterator(&iterator); // if (!s.ok()) { /* error handling */ } // // bool end_of_input = false; // while (!end_of_input) { // std::vector outputs; // s = iterator->GetNext(&outputs, &end_of_input); // if (!s.ok()) { /* error handling */ } // if (!end_of_input) { /* output handling */ } // } class Dataset; // Represents an execution of an input pipeline that can be used to enumerate // its elements. class Iterator { public: // Returns the next element of the input pipeline (if there is one) and an // indication of whether the end of the input pipeline has been reached. Status GetNext(std::vector* outputs, bool* end_of_input); private: friend class Dataset; Iterator(IteratorBase* iterator, IteratorContext* ctx); std::unique_ptr iterator_; std::unique_ptr ctx_; }; // Represents an input pipeline as a collection of data sources and a logical // plan of transformations that operate over the data. class Dataset { public: // Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration). struct Params { SessionOptions session_options; }; // Creates a new `Dataset` instance by running the given dataset graph. static Status FromGraph(Params params, const GraphDef& graph_def, std::unique_ptr* result); ~Dataset(); // Creates an iterator for this dataset. Status MakeIterator(std::unique_ptr* result); // Creates an iterator, optionally with a split provider. Status MakeIterator( std::vector> split_providers, std::unique_ptr* result); // Creates split providers for this dataset. Status MakeSplitProviders( std::vector>* result); // Returns a pointer to the underlying dataset. const DatasetBase* Get() const; private: Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset, DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr, FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool, std::function)> runner); DatasetBase* finalized_dataset_; // owned DatasetBase* original_dataset_; // owned std::unique_ptr device_mgr_; std::unique_ptr flib_def_; std::unique_ptr pflr_; std::unique_ptr interop_threadpool_; std::unique_ptr function_handle_cache_; std::function)> runner_; ResourceMgr resource_mgr_; CancellationManager cancellation_manager_; UnboundedThreadPool unbounded_thread_pool_; }; } // namespace standalone } // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_DATA_STANDALONE_H_