• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
16 
17 #include <map>
18 
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
21 #include "tensorflow/core/distributed_runtime/worker_session.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/graph_def_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/protobuf/named_tensor.pb.h"
31 #include "tensorflow/core/protobuf/worker.pb.h"
32 
33 namespace tensorflow {
34 
35 /* static */
ConstructFunctionGraph(const OpDef & sig,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,const FunctionLibraryDefinition & flib_def,GraphDef * gdef,std::vector<string> * send_keys,std::vector<string> * recv_keys)36 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
37     const OpDef& sig, AttrSlice attrs,
38     const FunctionLibraryRuntime::InstantiateOptions& options,
39     const FunctionLibraryDefinition& flib_def, GraphDef* gdef,
40     std::vector<string>* send_keys, std::vector<string>* recv_keys) {
41   const string& target = options.target;
42   const string& func_name = sig.name();
43   const FunctionDef* func_def = flib_def.Find(sig.name());
44   if (func_def == nullptr) {
45     return errors::InvalidArgument("Function ", func_name,
46                                    " not found in flib_def.");
47   }
48 
49   // Build a smaller flib_def containing only the functions used by the given
50   // function, plus that function itself.
51   FunctionLibraryDefinition pruned_flib_def =
52       flib_def.ReachableDefinitions(*func_def);
53   TF_RETURN_IF_ERROR(pruned_flib_def.CopyFunctionDefFrom(func_name, flib_def));
54 
55   Graph g(pruned_flib_def);
56 
57   std::vector<Node*> input_nodes;
58   input_nodes.reserve(sig.input_arg_size());
59 
60   // Construct recv nodes for each input argument.
61   int i = 0;
62   for (const auto& in : sig.input_arg()) {
63     // Resolve the input type.
64     bool is_type_list;
65     DataTypeVector dtypes;
66     TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes));
67     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
68     if (is_type_list || dtypes.size() > 1) {
69       return errors::Unimplemented("Input arg: ", in.name(),
70                                    " has a list type or variadic number of "
71                                    "attrs. Currently unsupported.");
72     }
73 
74     auto input_node_builder =
75         NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv")
76             .Attr("tensor_type", dtypes[0])
77             .Attr("tensor_name", in.name())
78             .Attr("send_device", target)
79             .Attr("recv_device", target)
80             .Attr("send_device_incarnation", 1)
81             .Attr("client_terminated", true)
82             .Device(target);
83 
84     Node* input_node;
85     TF_RETURN_IF_ERROR(
86         NodeBuilder(input_node_builder).Finalize(&g, &input_node));
87     input_nodes.push_back(input_node);
88 
89     // src_incarnation = 1 works because the transfer is across the same device.
90     // TODO(rohanj): Find the src_incarnation for the remote device and set it.
91     const string& key = Rendezvous::CreateKey(
92         target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0));
93     send_keys->push_back(key);
94     ++i;
95   }
96 
97   NodeDef function_node_def;
98   function_node_def.set_name(func_name);
99   function_node_def.set_op(func_name);
100   i = 0;
101   function_node_def.set_device(target);
102   for (const auto& p : attrs) {
103     (*function_node_def.mutable_attr())[p.first] = p.second;
104   }
105   Status status;
106   Node* function_node = g.AddNode(std::move(function_node_def), &status);
107   TF_RETURN_IF_ERROR(status);
108   for (size_t i = 0; i < input_nodes.size(); ++i) {
109     g.AddEdge(input_nodes[i], 0, function_node, i);
110   }
111 
112   // Construct output nodes for each output.
113   i = 0;
114   for (const auto& out : sig.output_arg()) {
115     // Resolve the output type.
116     bool is_type_list;
117     DataTypeVector dtypes;
118     TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes));
119     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
120     if (is_type_list || dtypes.size() > 1) {
121       return errors::Unimplemented("Output arg: ", out.name(),
122                                    " has a list type or variadic number of "
123                                    "attrs. Currently unsupported.");
124     }
125 
126     auto output_node_builder =
127         NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
128             .Input(func_name, i, dtypes[0])
129             .Attr("tensor_name", out.name())
130             .Attr("send_device", target)
131             .Attr("recv_device", target)
132             .Attr("send_device_incarnation", 1)
133             .Attr("client_terminated", true)
134             .Device(target);
135 
136     Node* output_node;
137     TF_RETURN_IF_ERROR(
138         NodeBuilder(output_node_builder).Finalize(&g, &output_node));
139 
140     g.AddEdge(function_node, i, output_node, 0);
141 
142     const string& key =
143         Rendezvous::CreateKey(target, 1 /* src_incarnation */, target,
144                               out.name(), FrameAndIter(0, 0));
145     recv_keys->push_back(key);
146     ++i;
147   }
148 
149   // Inline function node into the graph.
150   InlineFunctionBodyOptions inline_options;
151   inline_options.inlined_function_body_placer =
152       InlinedFunctionBodyPlacer::SingleDevice();
153   // When the remote call is a partition of a multi-device function, and the
154   // Send/Recv nodes depend on the frame names in the original graph, we must
155   // retain the original frame names. Since the graph contains a single function
156   // call, we do not need to add a unique prefix to frame names inside the
157   // inlined graph.
158   inline_options.uniquify_frame_names = false;
159   std::unique_ptr<FunctionBody> function_body;
160   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*func_def, attrs, &pruned_flib_def,
161                                              &function_body));
162   TF_RETURN_IF_ERROR(InlineFunctionBody(pruned_flib_def, &g, function_node,
163                                         function_body.get(), inline_options));
164 
165   g.ToGraphDef(gdef);
166 
167   // Since we have inlined `function_node`, we can prune its function definition
168   // from the library.
169   *(gdef->mutable_library()) = flib_def.ReachableDefinitions(*gdef).ToProto();
170 
171   return Status::OK();
172 }
173 
~ClusterFunctionLibraryRuntime()174 ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
175   for (auto& function_data : function_data_) {
176     worker_session_->worker_cache()->ReleaseWorker(function_data.target,
177                                                    function_data.wi);
178   }
179 }
180 
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)181 void ClusterFunctionLibraryRuntime::Instantiate(
182     const string& function_name, const FunctionLibraryDefinition& lib_def,
183     AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
184     FunctionLibraryRuntime::LocalHandle* handle,
185     FunctionLibraryRuntime::DoneCallback done) {
186   auto target = options.target;
187   VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
188           << " (this: " << this << ")";
189   WorkerInterface* wi =
190       worker_session_->worker_cache()->GetOrCreateWorker(target);
191 
192   if (wi == nullptr) {
193     std::vector<string> workers;
194     worker_session_->worker_cache()->ListWorkers(&workers);
195     done(errors::InvalidArgument(
196         "Could not find worker with target: ", target,
197         " Available workers: ", absl::StrJoin(workers, ", ")));
198     return;
199   }
200 
201   // Make RPC and obtain a graph handle.
202   GraphDef gdef;
203   auto* send_keys = new std::vector<string>;
204   auto* recv_keys = new std::vector<string>;
205   auto construct_graph_fn = [&](const FunctionLibraryDefinition* lib_def) {
206     const FunctionDef* fdef = lib_def->Find(function_name);
207     const OpDef& sig = fdef->signature();
208     TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, *lib_def,
209                                               &gdef, send_keys, recv_keys));
210     return Status::OK();
211   };
212   Status s;
213   if (options.lib_def) {
214     s = construct_graph_fn(options.lib_def);
215   } else {
216     s = construct_graph_fn(&lib_def);
217   }
218   if (!s.ok()) {
219     done(s);
220     return;
221   }
222 
223   auto* req = new RegisterGraphRequest;
224   req->set_session_handle(worker_session_->session_name());
225   req->set_create_worker_session_called(create_worker_session_called_);
226   *req->mutable_graph_def() = std::move(gdef);
227   StripDefaultAttributes(*OpRegistry::Global(),
228                          req->mutable_graph_def()->mutable_node());
229   req->mutable_graph_options()
230       ->mutable_optimizer_options()
231       ->set_do_function_inlining(true);
232   auto* resp = new RegisterGraphResponse;
233 
234   wi->RegisterGraphAsync(
235       req, resp,
236       [this, handle, req, resp, wi, function_name, target, send_keys, recv_keys,
237        done](const Status& status) {
238         if (status.ok()) {
239           mutex_lock l(mu_);
240           *handle = function_data_.size();
241           function_data_.push_back(FunctionData(resp->graph_handle(), target,
242                                                 wi, *send_keys, *recv_keys));
243           VLOG(1) << "CFLR::Instantiate: [Success] " << function_name << " on "
244                   << target << " (this: " << this << ")"
245                   << " with handle: " << *handle;
246         }
247         done(status);
248         delete recv_keys;
249         delete send_keys;
250         delete req;
251         delete resp;
252       });
253 }
254 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)255 void ClusterFunctionLibraryRuntime::Run(
256     const FunctionLibraryRuntime::Options& opts,
257     FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
258     std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
259   FunctionData* function_data = nullptr;
260   {
261     mutex_lock l(mu_);
262     CHECK_LE(handle, function_data_.size());
263     function_data = &function_data_[handle];
264   }
265 
266   WorkerInterface* wi = function_data->wi;
267 
268   if (wi == nullptr) {
269     done(errors::Internal("Could not find worker"));
270     return;
271   }
272 
273   RunGraphRequest* req = new RunGraphRequest;
274   req->set_session_handle(worker_session_->session_name());
275   req->set_create_worker_session_called(create_worker_session_called_);
276   req->set_graph_handle(function_data->graph_handle);
277   req->set_step_id(opts.step_id);
278   int i = 0;
279   for (const auto& send_key : function_data->send_keys) {
280     NamedTensorProto* send = req->add_send();
281     send->set_name(send_key);
282     args[i].AsProtoTensorContent(send->mutable_tensor());
283     i++;
284   }
285   const std::vector<string>& recv_keys = function_data->recv_keys;
286   for (const auto& recv_key : recv_keys) {
287     req->add_recv_key(recv_key);
288   }
289 
290   RunGraphResponse* resp = new RunGraphResponse();
291   CallOptions* call_options = new CallOptions();
292   wi->RunGraphAsync(
293       call_options, req, resp,
294       [call_options, req, resp, rets, recv_keys, done](const Status& status) {
295         Status* local_status = new Status(status);
296         auto cleanup =
297             gtl::MakeCleanup([call_options, req, resp, local_status, done] {
298               done(*local_status);
299               delete call_options;
300               delete req;
301               delete resp;
302               delete local_status;
303             });
304         if (!local_status->ok()) {
305           return;
306         }
307         std::map<string, TensorProto*> mapped_recvs;
308         for (auto& recv : *resp->mutable_recv()) {
309           mapped_recvs[recv.name()] = recv.mutable_tensor();
310         }
311 
312         for (const auto& recv_key : recv_keys) {
313           TensorProto* tp = mapped_recvs[recv_key];
314           if (tp == nullptr) {
315             local_status->Update(
316                 errors::Internal("Could not find key: ", recv_key));
317             return;
318           }
319           Tensor t;
320           if (t.FromProto(*tp)) {
321             rets->push_back(t);
322           } else {
323             local_status->Update(errors::Internal(
324                 "Could not convert tensor proto: ", tp->DebugString()));
325             return;
326           }
327         }
328       });
329 }
330 
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)331 void ClusterFunctionLibraryRuntime::CleanUp(
332     uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
333     FunctionLibraryRuntime::DoneCallback done) {
334   FunctionData* function_data = nullptr;
335   {
336     mutex_lock l(mu_);
337     DCHECK_LE(handle, function_data_.size());
338     function_data = &function_data_[handle];
339   }
340 
341   WorkerInterface* wi = function_data->wi;
342 
343   if (wi == nullptr) {
344     done(errors::Internal("Could not find worker"));
345     return;
346   }
347   CleanupGraphRequest* cleanup_req = new CleanupGraphRequest;
348   cleanup_req->set_step_id(step_id);
349   CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse;
350   wi->CleanupGraphAsync(
351       cleanup_req, cleanup_resp,
352       [cleanup_req, cleanup_resp, done](const Status& cleanup_status) {
353         done(cleanup_status);
354         delete cleanup_req;
355         delete cleanup_resp;
356       });
357 }
358 
359 }  // namespace tensorflow
360