1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
16
17 #include <map>
18
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
21 #include "tensorflow/core/distributed_runtime/worker_session.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/graph_def_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/protobuf/named_tensor.pb.h"
31 #include "tensorflow/core/protobuf/worker.pb.h"
32
33 namespace tensorflow {
34
35 /* static */
ConstructFunctionGraph(const OpDef & sig,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,const FunctionLibraryDefinition & flib_def,GraphDef * gdef,std::vector<string> * send_keys,std::vector<string> * recv_keys)36 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
37 const OpDef& sig, AttrSlice attrs,
38 const FunctionLibraryRuntime::InstantiateOptions& options,
39 const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
40 std::vector<string>* send_keys, std::vector<string>* recv_keys) {
41 const string& target = options.target;
42 const string& func_name = sig.name();
43 const FunctionDef* func_def = flib_def.Find(sig.name());
44 if (func_def == nullptr) {
45 return errors::InvalidArgument("Function ", func_name,
46 " not found in flib_def.");
47 }
48
49 // Build a smaller flib_def containing only the functions used by the given
50 // function, plus that function itself.
51 FunctionLibraryDefinition pruned_flib_def =
52 flib_def.ReachableDefinitions(*func_def);
53 TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
54
55 Graph g(pruned_flib_def);
56
57 std::vector<Node*> input_nodes;
58 input_nodes.reserve(sig.input_arg_size());
59
60 // Construct recv nodes for each input argument.
61 int i = 0;
62 for (const auto& in : sig.input_arg()) {
63 // Resolve the input type.
64 bool is_type_list;
65 DataTypeVector dtypes;
66 TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes));
67 // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
68 if (is_type_list || dtypes.size() > 1) {
69 return errors::Unimplemented("Input arg: ", in.name(),
70 " has a list type or variadic number of "
71 "attrs. Currently unsupported.");
72 }
73
74 auto input_node_builder =
75 NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv")
76 .Attr("tensor_type", dtypes[0])
77 .Attr("tensor_name", in.name())
78 .Attr("send_device", target)
79 .Attr("recv_device", target)
80 .Attr("send_device_incarnation", 1)
81 .Attr("client_terminated", true)
82 .Device(target);
83
84 Node* input_node;
85 TF_RETURN_IF_ERROR(
86 NodeBuilder(input_node_builder).Finalize(&g, &input_node));
87 input_nodes.push_back(input_node);
88
89 // src_incarnation = 1 works because the transfer is across the same device.
90 // TODO(rohanj): Find the src_incarnation for the remote device and set it.
91 const string& key = Rendezvous::CreateKey(
92 target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0));
93 send_keys->push_back(key);
94 ++i;
95 }
96
97 NodeDef function_node_def;
98 function_node_def.set_name(func_name);
99 function_node_def.set_op(func_name);
100 i = 0;
101 function_node_def.set_device(target);
102 for (const auto& p : attrs) {
103 (*function_node_def.mutable_attr())[p.first] = p.second;
104 }
105 Status status;
106 Node* function_node = g.AddNode(std::move(function_node_def), &status);
107 TF_RETURN_IF_ERROR(status);
108 for (size_t i = 0; i < input_nodes.size(); ++i) {
109 g.AddEdge(input_nodes[i], 0, function_node, i);
110 }
111
112 // Construct output nodes for each output.
113 i = 0;
114 for (const auto& out : sig.output_arg()) {
115 // Resolve the output type.
116 bool is_type_list;
117 DataTypeVector dtypes;
118 TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes));
119 // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
120 if (is_type_list || dtypes.size() > 1) {
121 return errors::Unimplemented("Output arg: ", out.name(),
122 " has a list type or variadic number of "
123 "attrs. Currently unsupported.");
124 }
125
126 auto output_node_builder =
127 NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
128 .Input(func_name, i, dtypes[0])
129 .Attr("tensor_name", out.name())
130 .Attr("send_device", target)
131 .Attr("recv_device", target)
132 .Attr("send_device_incarnation", 1)
133 .Attr("client_terminated", true)
134 .Device(target);
135
136 Node* output_node;
137 TF_RETURN_IF_ERROR(
138 NodeBuilder(output_node_builder).Finalize(&g, &output_node));
139
140 g.AddEdge(function_node, i, output_node, 0);
141
142 const string& key =
143 Rendezvous::CreateKey(target, 1 /* src_incarnation */, target,
144 out.name(), FrameAndIter(0, 0));
145 recv_keys->push_back(key);
146 ++i;
147 }
148
149 // Inline function node into the graph.
150 InlineFunctionBodyOptions inline_options;
151 inline_options.inlined_function_body_placer =
152 InlinedFunctionBodyPlacer::SingleDevice();
153 // When the remote call is a partition of a multi-device function, and the
154 // Send/Recv nodes depend on the frame names in the original graph, we must
155 // retain the original frame names. Since the graph contains a single function
156 // call, we do not need to add a unique prefix to frame names inside the
157 // inlined graph.
158 inline_options.uniquify_frame_names = false;
159 std::unique_ptr<FunctionBody> function_body;
160 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
161 &function_body));
162 TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
163 function_body.get(), inline_options));
164
165 g.ToGraphDef(gdef);
166
167 // Since we have inlined `function_node`, we can prune its function definition
168 // from the library.
169 *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto();
170
171 return Status::OK();
172 }
173
~ClusterFunctionLibraryRuntime()174 ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
175 for (auto& function_data : function_data_) {
176 worker_session_->worker_cache()->ReleaseWorker(function_data.target,
177 function_data.wi);
178 }
179 }
180
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)181 void ClusterFunctionLibraryRuntime::Instantiate(
182 const string& function_name, const FunctionLibraryDefinition& lib_def,
183 AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
184 FunctionLibraryRuntime::LocalHandle* handle,
185 FunctionLibraryRuntime::DoneCallback done) {
186 auto target = options.target;
187 VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
188 << " (this: " << this << ")";
189 WorkerInterface* wi =
190 worker_session_->worker_cache()->GetOrCreateWorker(target);
191
192 if (wi == nullptr) {
193 std::vector<string> workers;
194 worker_session_->worker_cache()->ListWorkers(&workers);
195 done(errors::InvalidArgument(
196 "Could not find worker with target: ", target,
197 " Available workers: ", absl::StrJoin(workers, ", ")));
198 return;
199 }
200
201 // Make RPC and obtain a graph handle.
202 GraphDef gdef;
203 auto* send_keys = new std::vector<string>;
204 auto* recv_keys = new std::vector<string>;
205 auto construct_graph_fn = [&](const FunctionLibraryDefinition* lib_def) {
206 const FunctionDef* fdef = lib_def->Find(function_name);
207 const OpDef& sig = fdef->signature();
208 TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def,
209 &gdef, send_keys, recv_keys));
210 return Status::OK();
211 };
212 Status s;
213 if (options.lib_def) {
214 s = construct_graph_fn(options.lib_def);
215 } else {
216 s = construct_graph_fn(&lib_def);
217 }
218 if (!s.ok()) {
219 done(s);
220 return;
221 }
222
223 auto* req = new RegisterGraphRequest;
224 req->set_session_handle(worker_session_->session_name());
225 req->set_create_worker_session_called(create_worker_session_called_);
226 *req->mutable_graph_def() = std::move(gdef);
227 StripDefaultAttributes(*OpRegistry::Global(),
228 req->mutable_graph_def()->mutable_node());
229 req->mutable_graph_options()
230 ->mutable_optimizer_options()
231 ->set_do_function_inlining(true);
232 auto* resp = new RegisterGraphResponse;
233
234 wi->RegisterGraphAsync(
235 req, resp,
236 [this, handle, req, resp, wi, function_name, target, send_keys, recv_keys,
237 done](const Status& status) {
238 if (status.ok()) {
239 mutex_lock l(mu_);
240 *handle = function_data_.size();
241 function_data_.push_back(FunctionData(resp->graph_handle(), target,
242 wi, *send_keys, *recv_keys));
243 VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on "
244 << target << " (this: " << this << ")"
245 << " with handle: " << *handle;
246 }
247 done(status);
248 delete recv_keys;
249 delete send_keys;
250 delete req;
251 delete resp;
252 });
253 }
254
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)255 void ClusterFunctionLibraryRuntime::Run(
256 const FunctionLibraryRuntime::Options& opts,
257 FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
258 std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
259 FunctionData* function_data = nullptr;
260 {
261 mutex_lock l(mu_);
262 CHECK_LE(handle, function_data_.size());
263 function_data = &function_data_[handle];
264 }
265
266 WorkerInterface* wi = function_data->wi;
267
268 if (wi == nullptr) {
269 done(errors::Internal("Could not find worker"));
270 return;
271 }
272
273 RunGraphRequest* req = new RunGraphRequest;
274 req->set_session_handle(worker_session_->session_name());
275 req->set_create_worker_session_called(create_worker_session_called_);
276 req->set_graph_handle(function_data->graph_handle);
277 req->set_step_id(opts.step_id);
278 int i = 0;
279 for (const auto& send_key : function_data->send_keys) {
280 NamedTensorProto* send = req->add_send();
281 send->set_name(send_key);
282 args[i].AsProtoTensorContent(send->mutable_tensor());
283 i++;
284 }
285 const std::vector<string>& recv_keys = function_data->recv_keys;
286 for (const auto& recv_key : recv_keys) {
287 req->add_recv_key(recv_key);
288 }
289
290 RunGraphResponse* resp = new RunGraphResponse();
291 CallOptions* call_options = new CallOptions();
292 wi->RunGraphAsync(
293 call_options, req, resp,
294 [call_options, req, resp, rets, recv_keys, done](const Status& status) {
295 Status* local_status = new Status(status);
296 auto cleanup =
297 gtl::MakeCleanup([call_options, req, resp, local_status, done] {
298 done(*local_status);
299 delete call_options;
300 delete req;
301 delete resp;
302 delete local_status;
303 });
304 if (!local_status->ok()) {
305 return;
306 }
307 std::map<string, TensorProto*> mapped_recvs;
308 for (auto& recv : *resp->mutable_recv()) {
309 mapped_recvs[recv.name()] = recv.mutable_tensor();
310 }
311
312 for (const auto& recv_key : recv_keys) {
313 TensorProto* tp = mapped_recvs[recv_key];
314 if (tp == nullptr) {
315 local_status->Update(
316 errors::Internal("Could not find key: ", recv_key));
317 return;
318 }
319 Tensor t;
320 if (t.FromProto(*tp)) {
321 rets->push_back(t);
322 } else {
323 local_status->Update(errors::Internal(
324 "Could not convert tensor proto: ", tp->DebugString()));
325 return;
326 }
327 }
328 });
329 }
330
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)331 void ClusterFunctionLibraryRuntime::CleanUp(
332 uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
333 FunctionLibraryRuntime::DoneCallback done) {
334 FunctionData* function_data = nullptr;
335 {
336 mutex_lock l(mu_);
337 DCHECK_LE(handle, function_data_.size());
338 function_data = &function_data_[handle];
339 }
340
341 WorkerInterface* wi = function_data->wi;
342
343 if (wi == nullptr) {
344 done(errors::Internal("Could not find worker"));
345 return;
346 }
347 CleanupGraphRequest* cleanup_req = new CleanupGraphRequest;
348 cleanup_req->set_step_id(step_id);
349 CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse;
350 wi->CleanupGraphAsync(
351 cleanup_req, cleanup_resp,
352 [cleanup_req, cleanup_resp, done](const Status& cleanup_status) {
353 done(cleanup_status);
354 delete cleanup_req;
355 delete cleanup_resp;
356 });
357 }
358
359 } // namespace tensorflow
360