1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/standalone.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23
24 #include "absl/memory/memory.h"
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/device_mgr.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/graph_runner.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/data/root_dataset.h"
33 #include "tensorflow/core/framework/dataset.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/graph/graph.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/public/version.h"
39 #include "tensorflow/core/util/ptr_util.h"
40
41 namespace tensorflow {
42 namespace data {
43 namespace standalone {
44
45 namespace {
46
CreateParams(ProcessFunctionLibraryRuntime * pflr,DeviceMgr * device_mgr,std::function<void (std::function<void ()>)> * runner)47 OpKernelContext::Params CreateParams(
48 ProcessFunctionLibraryRuntime* pflr, DeviceMgr* device_mgr,
49 std::function<void(std::function<void()>)>* runner) {
50 OpKernelContext::Params params;
51 params.function_library = pflr->GetFLR("/device:CPU:0");
52 params.device = device_mgr->ListDevices()[0];
53 params.runner = runner;
54 return params;
55 }
56
57 } // namespace
58
GetNext(std::vector<Tensor> * outputs,bool * end_of_input)59 Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) {
60 return iterator_->GetNext(ctx_.get(), outputs, end_of_input);
61 }
62
Iterator(IteratorBase * iterator,IteratorContext * ctx)63 Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx)
64 : iterator_(iterator), ctx_(ctx) {}
65
FromGraph(Params params,const GraphDef & graph_def,std::unique_ptr<Dataset> * result)66 Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
67 std::unique_ptr<Dataset>* result) {
68 Graph graph(OpRegistry::Global());
69 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
70
71 // Instantiate enough of the TF runtime to run `graph` on a single CPU device.
72 auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
73 "CPU", params.session_options, "/job:localhost/replica:0/task:0"));
74 Device* device = device_mgr->ListDevices()[0];
75 // Create a copy of the `FunctionLibraryDefinition` to extend lifetime beyond
76 // the lifetime of `graph`.
77 auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
78 OpRegistry::Global(), graph_def.library());
79 auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
80 device_mgr.get(), Env::Default(), /*config=*/nullptr,
81 TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{},
82 /*thread_pool=*/nullptr, /*parent=*/nullptr,
83 /*session_metadata=*/nullptr,
84 Rendezvous::Factory{
85 [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
86 *r = new IntraProcessRendezvous(device_mgr);
87 return Status::OK();
88 }});
89
90 string fetch_node = "";
91 for (const auto& node : graph_def.node()) {
92 if (node.op() == "_Retval") {
93 fetch_node = node.input(0);
94 }
95 }
96 if (fetch_node.empty()) {
97 return errors::NotFound("Failed to find a _Retval op in the given dataset");
98 }
99
100 // Run graph up to `output_node` and extract the `DatasetBase` stored in the
101 // DT_VARIANT output tensor.
102 std::vector<Tensor> outputs;
103 GraphRunner graph_runner(device);
104 TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"), {},
105 {fetch_node}, &outputs));
106 data::DatasetBase* dataset;
107 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
108
109 data::DatasetBase* finalized_dataset;
110 std::unique_ptr<thread::ThreadPool> pool(
111 NewThreadPoolFromSessionOptions(params.session_options));
112 std::function<void(std::function<void()>)> runner =
113 [&pool](std::function<void()> c) { pool->Schedule(std::move(c)); };
114 OpKernelContext::Params op_params =
115 CreateParams(pflr.get(), device_mgr.get(), &runner);
116 OpKernelContext ctx(&op_params, /*num_outputs=*/0);
117 TF_RETURN_IF_ERROR(data::FinalizeDataset(&ctx, dataset, &finalized_dataset));
118 core::ScopedUnref unref(finalized_dataset);
119 *result = WrapUnique(new Dataset(finalized_dataset, device_mgr.release(),
120 pflr.release(), flib_def.release(),
121 pool.release(), std::move(runner)));
122 return Status::OK();
123 } // static
124
MakeIterator(std::vector<std::unique_ptr<SplitProvider>> split_providers,std::unique_ptr<Iterator> * result)125 Status Dataset::MakeIterator(
126 std::vector<std::unique_ptr<SplitProvider>> split_providers,
127 std::unique_ptr<Iterator>* result) {
128 // Create an `IteratorContext`, which bundles together the necessary runtime
129 // support to create and get elements from an iterator.
130 std::unique_ptr<IteratorContext> ctx;
131 // NOTE(mrry): In the current API, an `IteratorContext` is always initially
132 // created from an `OpKernelContext*`, so we need to create `OpKernelContext`
133 // with a valid subset of parameters.
134 OpKernelContext::Params op_params =
135 CreateParams(pflr_.get(), device_mgr_.get(), &runner_);
136 OpKernelContext op_ctx(&op_params, /*num_outputs=*/0);
137 IteratorContext::Params params(&op_ctx);
138 params.cancellation_manager = &cancellation_manager_;
139 params.function_handle_cache = function_handle_cache_.get();
140 params.resource_mgr = &resource_mgr_;
141 std::move(split_providers.begin(), split_providers.end(),
142 std::back_inserter(params.split_providers));
143 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
144 params.thread_pool = &unbounded_thread_pool_;
145 ctx = absl::make_unique<IteratorContext>(std::move(params));
146
147 // Create the iterator from the dataset.
148 std::unique_ptr<IteratorBase> iterator;
149 TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), /*parent=*/nullptr,
150 "Iterator", &iterator));
151 *result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
152
153 return Status::OK();
154 }
155
MakeIterator(std::unique_ptr<Iterator> * result)156 Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
157 return MakeIterator(/*split_providers=*/{}, result);
158 }
159
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * result)160 Status Dataset::MakeSplitProviders(
161 std::vector<std::unique_ptr<SplitProvider>>* result) {
162 return dataset_->MakeSplitProviders(result);
163 }
164
Get() const165 const DatasetBase* Dataset::Get() const { return dataset_; }
166
Dataset(DatasetBase * dataset,DeviceMgr * device_mgr,ProcessFunctionLibraryRuntime * pflr,FunctionLibraryDefinition * flib_def,thread::ThreadPool * pool,std::function<void (std::function<void ()>)> runner)167 Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
168 ProcessFunctionLibraryRuntime* pflr,
169 FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
170 std::function<void(std::function<void()>)> runner)
171 : dataset_(dataset),
172 device_mgr_(device_mgr),
173 flib_def_(flib_def),
174 pflr_(pflr),
175 interop_threadpool_(pool),
176 runner_(std::move(runner)),
177 unbounded_thread_pool_(Env::Default(), "tf_data_standalone") {
178 dataset_->Ref();
179 function_handle_cache_ =
180 absl::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
181 }
182
~Dataset()183 Dataset::~Dataset() { dataset_->Unref(); }
184
185 } // namespace standalone
186 } // namespace data
187 } // namespace tensorflow
188