1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // TODO(skyewm): this is necessary to make the single_threaded_cpu_device.h
17 // include work. Some other include must be including eigen without defining
18 // this. Consider defining in this in a BUILD rule.
19 #define EIGEN_USE_THREADS
20
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/device_factory.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/memory_types.h"
28 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
29 #include "tensorflow/core/common_runtime/single_threaded_cpu_device.h"
30 #include "tensorflow/core/framework/log_memory.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/graph/algorithm.h"
35 #include "tensorflow/core/graph/graph.h"
36 #include "tensorflow/core/graph/node_builder.h"
37 #include "tensorflow/core/graph/subgraph.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/public/session_options.h"
42
43 namespace tensorflow {
44
45 namespace {
46
47 // A simple rendezvous class.
48 // Assumes a single sender and a single receiver, no duplicate sends, and no
49 // sends of dead tensors.
50 class SimpleRendezvous : public RendezvousInterface {
51 public:
SimpleRendezvous()52 explicit SimpleRendezvous() {}
53
Send(const ParsedKey & parsed,const Args & send_args,const Tensor & val,const bool is_dead)54 Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val,
55 const bool is_dead) override {
56 if (is_dead) {
57 return errors::Internal("Send of a dead tensor");
58 }
59
60 mutex_lock l(mu_);
61 string edge_name(parsed.edge_name);
62 if (table_.count(edge_name) > 0) {
63 return errors::Internal("Send of an already sent tensor");
64 }
65 table_[edge_name] = val;
66 return Status::OK();
67 }
68
RecvAsync(const ParsedKey & parsed,const Args & recv_args,DoneCallback done)69 void RecvAsync(const ParsedKey& parsed, const Args& recv_args,
70 DoneCallback done) override {
71 Tensor tensor;
72 Status status = Status::OK();
73 {
74 string key(parsed.edge_name);
75 mutex_lock l(mu_);
76 if (table_.count(key) <= 0) {
77 status = errors::Internal("Did not find key ", key);
78 } else {
79 tensor = table_[key];
80 }
81 }
82 done(status, Args{}, recv_args, tensor, false);
83 }
84
StartAbort(const Status & status)85 void StartAbort(const Status& status) override {}
86
87 private:
88 typedef std::unordered_map<string, Tensor> Table;
89
90 mutex mu_;
91 Table table_ TF_GUARDED_BY(mu_);
92 };
93
94 } // namespace
95
GraphRunner(Env * env)96 GraphRunner::GraphRunner(Env* env)
97 : device_deleter_(NewSingleThreadedCpuDevice(env)),
98 device_(device_deleter_.get()) {}
GraphRunner(Device * device)99 GraphRunner::GraphRunner(Device* device) : device_(device) {}
100
~GraphRunner()101 GraphRunner::~GraphRunner() {}
102
Run(Graph * graph,FunctionLibraryRuntime * function_library,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)103 Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
104 const NamedTensorList& inputs,
105 const std::vector<string>& output_names,
106 std::vector<Tensor>* outputs) {
107 if (device_ == nullptr) {
108 return errors::NotFound("Cannot find a device for GraphRunner.");
109 }
110
111 if (function_library && function_library->device() &&
112 function_library->device()->device_type() != device_->device_type()) {
113 // Mismatch between function_library's device_type and device_'s
114 // device_type.
115 // TODO(matthewmurray) Can we create a new FunctionLibraryRuntime that is
116 // identical to function_library except that it uses the given 'device_'?
117 VLOG(1) << "Cannot run on: " << device_->device_type()
118 << " with a function library for a "
119 << function_library->device()->device_type() << " device.";
120 function_library = nullptr;
121 }
122
123 // TODO(vrv): Instead of copying the entire graph, consider modifying
124 // the existing graph, and then removing those removed edges.
125 // prior to returning.
126 std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
127 CopyGraph(*graph, graph_to_run.get());
128
129 SimpleRendezvous rendez;
130
131 // Extract the input names and keys, and feed in the inputs.
132 std::vector<string> input_names;
133 for (const auto& in : inputs) {
134 const string& tensor_name = in.first;
135 input_names.emplace_back(tensor_name);
136 string full_key = Rendezvous::CreateKey("/device:CPU:0", 1, "/device:CPU:1",
137 tensor_name, FrameAndIter(0, 0));
138 Rendezvous::ParsedKey parsed;
139 TF_RETURN_IF_ERROR(Rendezvous::ParseKey(full_key, &parsed));
140 TF_RETURN_IF_ERROR(rendez.Send(parsed, Rendezvous::Args(), in.second,
141 false /* is_dead */));
142 }
143
144 // Call RewriteGraphForExecution
145 subgraph::RewriteGraphMetadata metadata;
146 TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
147 graph_to_run.get(), input_names, output_names, {} /* target nodes */,
148 device_->attributes(), false /* use_function_convention */, &metadata));
149
150 // Create the local executor and the Rendezvous for fetching back the
151 // constants.
152
153 // Run operators on the local thread. We should not need concurrency here; we
154 // should not be running expensive operators.
155 auto runner = [](Executor::Args::Closure c) { c(); };
156
157 LocalExecutorParams params;
158 // The ownership of the output tensors are bound to this device's lifetime.
159 params.device = device_;
160 params.function_library = function_library;
161 const int producer = graph_to_run->versions().producer();
162 params.create_kernel = [this, function_library, producer](
163 const std::shared_ptr<const NodeProperties>& props,
164 OpKernel** kernel) {
165 return CreateNonCachedKernel(device_, function_library, props, producer,
166 kernel);
167 };
168 params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
169
170 Executor* executor;
171 TF_RETURN_IF_ERROR(NewLocalExecutor(params, *graph_to_run, &executor));
172 std::unique_ptr<Executor> executor_unref(executor);
173
174 Executor::Args args;
175 // NOTE: we could take a step id as an argument, but currently
176 // there is no need since we never trace the running of a graph
177 // called via this method.
178 args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID;
179 args.runner = runner;
180 args.rendezvous = &rendez;
181 // NOTE: Use of graph runner is limited to single-device executions
182 // so a CollectiveExecutor should never be required.
183 args.collective_executor = nullptr;
184
185 CancellationManager cancellation_manager;
186 args.cancellation_manager = &cancellation_manager;
187
188 // Run the graph.
189 TF_RETURN_IF_ERROR(executor->Run(args));
190
191 outputs->resize(output_names.size());
192 for (size_t i = 0; i < output_names.size(); ++i) {
193 const string& output_key =
194 Rendezvous::CreateKey("/device:CPU:0", 1, "/device:CPU:1",
195 output_names[i], FrameAndIter(0, 0));
196 Rendezvous::ParsedKey parsed;
197 TF_RETURN_IF_ERROR(Rendezvous::ParseKey(output_key, &parsed));
198 bool is_dead;
199 Tensor output_tensor;
200 TF_RETURN_IF_ERROR(
201 rendez.Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
202 // Does a deep copy so that ownership of the tensor isn't tied to the
203 // allocator of the cpu device we created above. The allocator could be
204 // deleted along with the device.
205 (*outputs)[i] = tensor::DeepCopy(output_tensor);
206 }
207
208 return Status::OK();
209 }
210
211 } // namespace tensorflow
212