• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
16 
17 #include <map>
18 #include <memory>
19 
20 #include "tensorflow/core/common_runtime/eager/context.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/distributed_runtime/call_options.h"
24 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
25 #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
26 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/graph_def_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/cleanup.h"
32 
33 namespace tensorflow {
34 namespace eager {
35 namespace {
StripDefaultAttributesInRegisterFunctionOp(RegisterFunctionOp * register_function)36 void StripDefaultAttributesInRegisterFunctionOp(
37     RegisterFunctionOp* register_function) {
38   StripDefaultAttributes(
39       *OpRegistry::Global(),
40       register_function->mutable_function_def()->mutable_node_def());
41   for (auto& function :
42        *register_function->mutable_library()->mutable_function()) {
43     StripDefaultAttributes(*OpRegistry::Global(), function.mutable_node_def());
44   }
45 }
46 }  // namespace
47 
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)48 void EagerClusterFunctionLibraryRuntime::Instantiate(
49     const string& function_name, const FunctionLibraryDefinition& lib_def,
50     AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
51     FunctionLibraryRuntime::LocalHandle* handle,
52     FunctionLibraryRuntime::DoneCallback done) {
53   auto target = options.target;
54   auto released_op = std::make_unique<EagerOperation>(ctx_);
55   Status s =
56       released_op->Reset(function_name.c_str(), target.c_str(), true, nullptr);
57   if (!s.ok()) {
58     done(s);
59     return;
60   }
61   if (!released_op->is_function()) {
62     done(errors::Internal(function_name, " is not a function."));
63     return;
64   }
65 
66   VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
67           << " (this: " << this << ")";
68   core::RefCountPtr<eager::EagerClient> eager_client;
69   s = ctx_->GetClient(target, &eager_client);
70   if (!s.ok()) {
71     done(s);
72     return;
73   }
74 
75   if (eager_client == nullptr) {
76     done(errors::InvalidArgument("Could not find eager client for target: ",
77                                  target));
78     return;
79   }
80 
81   const FunctionLibraryDefinition& func_lib_def =
82       options.lib_def ? *options.lib_def : lib_def;
83 
84   auto request = std::make_shared<EnqueueRequest>();
85   auto response = std::make_shared<EnqueueResponse>();
86 
87   request->set_context_id(context_id_);
88 
89   RegisterFunctionOp* register_function =
90       request->add_queue()->mutable_register_function();
91   *register_function->mutable_function_def() =
92       *func_lib_def.Find(function_name);
93   register_function->set_is_component_function(true);
94   *register_function->mutable_library() =
95       func_lib_def.ReachableDefinitions(register_function->function_def())
96           .ToProto();
97   StripDefaultAttributesInRegisterFunctionOp(register_function);
98 
99   const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
100   eager_client->EnqueueAsync(
101       /*call_opts=*/nullptr, request.get(), response.get(),
102       [this, request, response, handle, released_op = released_op.release(),
103        target, ret_indices, eager_client = eager_client.get(),
104        done](const Status& s) {
105         {
106           mutex_lock l(mu_);
107           *handle = function_data_.size();
108           function_data_.emplace_back(target, ret_indices, eager_client,
109                                       absl::WrapUnique(released_op));
110         }
111         done(s);
112       });
113 }
114 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)115 void EagerClusterFunctionLibraryRuntime::Run(
116     const FunctionLibraryRuntime::Options& opts,
117     FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
118     std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
119   std::vector<FunctionArg> function_args;
120   for (const auto& tensor : args) {
121     function_args.push_back(tensor);
122   }
123   std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
124   Run(opts, handle, function_args, function_rets,
125       [rets, function_rets, done = std::move(done)](const Status& s) {
126         Status status = s;
127         if (status.ok()) {
128           for (const auto& t : *function_rets) {
129             if (t.index() == 0) {
130               rets->push_back(absl::get<Tensor>(t));
131             } else {
132               status.Update(
133                   errors::Internal("Expect a Tensor as a remote function "
134                                    "output but got a TensorShape."));
135               break;
136             }
137           }
138         }
139         delete function_rets;
140         done(status);
141       });
142 }
143 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)144 void EagerClusterFunctionLibraryRuntime::Run(
145     const FunctionLibraryRuntime::Options& opts,
146     FunctionLibraryRuntime::LocalHandle handle,
147     gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
148     FunctionLibraryRuntime::DoneCallback done) {
149   FunctionData* function_data = nullptr;
150   {
151     mutex_lock l(mu_);
152     DCHECK_LE(handle, function_data_.size());
153     function_data = &function_data_[handle];
154   }
155 
156   EagerClient* eager_client = function_data->eager_client.get();
157   if (eager_client == nullptr) {
158     done(errors::Internal("Could not find eager client"));
159     return;
160   }
161 
162   EagerOperation* op = function_data->op.get();
163   if (!op->Inputs().empty()) {
164     done(errors::Internal("Inputs should not be set during instantiation."));
165     return;
166   }
167 
168   auto request = std::make_shared<RunComponentFunctionRequest>();
169   auto response = std::make_shared<RunComponentFunctionResponse>();
170   request->set_context_id(context_id_);
171   eager::Operation* remote_op = request->mutable_operation();
172 
173   if (function_data->ret_indices.has_value()) {
174     for (const int ret_index : function_data->ret_indices.value()) {
175       request->add_output_num(ret_index);
176     }
177   }
178 
179   for (const auto& arg : args) {
180     if (arg.index() == 0) {
181       absl::get<Tensor>(arg).AsProtoTensorContent(
182           remote_op->add_op_inputs()->mutable_tensor());
183     } else {
184       remote_op->add_op_inputs()->mutable_remote_handle()->Swap(
185           absl::get<RemoteTensorHandle*>(arg));
186     }
187   }
188 
189   // The remote component function should use the same op_id as its parent
190   // multi-device function's in order to get the global unique op_id generated
191   // by the master context.
192   if (opts.op_id.has_value()) {
193     remote_op->set_id(opts.op_id.value());
194   } else {
195     remote_op->set_id(kInvalidRemoteOpId);
196   }
197   remote_op->set_is_function(true);
198   remote_op->set_is_component_function(true);
199   remote_op->set_func_step_id(opts.step_id);
200   remote_op->set_name(op->Name());
201   op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
202   remote_op->set_device(function_data->target);
203 
204   CancellationManager* cm = opts.cancellation_manager;
205   CancellationToken token = 0;
206   auto call_opts = std::make_shared<CallOptions>();
207   if (cm != nullptr) {
208     token = cm->get_cancellation_token();
209     const bool already_cancelled = !cm->RegisterCallback(
210         token,
211         [call_opts, request, response, done]() { call_opts->StartCancel(); });
212     if (already_cancelled) {
213       done(errors::Cancelled("EagerClusterFunctionLibraryRuntime::Run"));
214       return;
215     }
216   }
217 
218   // Execute component function on remote worker using RunComponentFunction RPC.
219   // Different from executing remote functions with Enqueue, this method runs
220   // a function on remote worker without tying up a thread (i.e., pure
221   // asynchronously).
222   eager_client->RunComponentFunctionAsync(
223       call_opts.get(), request.get(), response.get(),
224       [request, response, rets, call_opts, cm, token,
225        done = std::move(done)](const Status& s) {
226         if (cm != nullptr) {
227           cm->TryDeregisterCallback(token);
228         }
229         if (!s.ok()) {
230           done(s);
231           return;
232         }
233         if (!response->shape().empty() && !response->tensor().empty()) {
234           done(errors::Internal(
235               "Both shape and tensor are specified in the same response"));
236           return;
237         }
238         for (const auto& shape : response->shape()) {
239           rets->push_back(shape);
240         }
241         for (const auto& tensor_proto : response->tensor()) {
242           Tensor t;
243           if (t.FromProto(tensor_proto)) {
244             rets->push_back(std::move(t));
245           } else {
246             done(errors::Internal("Could not convert tensor proto: ",
247                                   tensor_proto.DebugString()));
248             return;
249           }
250         }
251         done(Status::OK());
252       });
253 }
254 
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)255 void EagerClusterFunctionLibraryRuntime::CleanUp(
256     uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
257     FunctionLibraryRuntime::DoneCallback done) {
258   FunctionData* function_data = nullptr;
259   {
260     mutex_lock l(mu_);
261     DCHECK_LE(handle, function_data_.size());
262     function_data = &function_data_[handle];
263   }
264 
265   EagerClient* eager_client = function_data->eager_client.get();
266   if (eager_client == nullptr) {
267     done(errors::Internal("Could not find eager client"));
268     return;
269   }
270 
271   auto request = std::make_shared<EnqueueRequest>();
272   auto response = std::make_shared<EnqueueResponse>();
273   request->set_context_id(context_id_);
274   CleanupFunctionOp* cleanup_function =
275       request->add_queue()->mutable_cleanup_function();
276   cleanup_function->set_step_id(step_id);
277   // StreamingEnqueueAsync could be blocking when streaming RPC is disabled.
278   // CleanUp() needs to be non-blocking since it would be invoked inside the
279   // enqueue done callback of Run(). So we don't use StreamingEnqueueAsync here.
280   eager_client->EnqueueAsync(
281       /*call_opts=*/nullptr, request.get(), response.get(),
282       [request, response, done](const Status& status) { done(status); });
283 }
284 
CreateClusterFLR(const uint64 context_id,EagerContext * ctx,WorkerSession * worker_session)285 DistributedFunctionLibraryRuntime* CreateClusterFLR(
286     const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
287   return new EagerClusterFunctionLibraryRuntime(
288       context_id, ctx, worker_session->remote_device_mgr());
289 }
290 
291 }  // namespace eager
292 }  // namespace tensorflow
293