• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/c/c_api_internal.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
23 #include "tensorflow/core/common_runtime/eager/execute.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
26 #include "tensorflow/core/distributed_runtime/server_lib.h"
27 #include "tensorflow/core/distributed_runtime/session_mgr.h"
28 #include "tensorflow/core/distributed_runtime/worker_cache.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
30 #include "tensorflow/core/distributed_runtime/worker_env.h"
31 #include "tensorflow/core/framework/rendezvous.h"
32 #include "tensorflow/core/lib/core/error_codes.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/random/random.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/lib/strings/stringprintf.h"
38 #include "tensorflow/core/platform/cpu_info.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/host_info.h"
41 
42 namespace tensorflow {
43 namespace eager {
44 
45 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)46 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
47                      const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
48                      int* num_retvals) {
49   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
50   auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
51   if (errors::IsNotFound(status)) {
52     status = context->FindFunctionOpData(op_name, &op_reg_data);
53   }
54   TF_RETURN_IF_ERROR(status);
55 
56   const tensorflow::OpDef& op_def = op_reg_data->op_def;
57 
58   for (const auto& output_arg : op_def.output_arg()) {
59     if (!output_arg.number_attr().empty()) {
60       auto iter = attrs.find(output_arg.number_attr());
61       if (iter == attrs.end()) {
62         return errors::InvalidArgument("Unable to find number_attr ",
63                                        output_arg.number_attr(),
64                                        " for Op: ", op_name);
65       }
66       *num_retvals += iter->second.i();
67     } else if (!output_arg.type_list_attr().empty()) {
68       auto iter = attrs.find(output_arg.type_list_attr());
69       if (iter == attrs.end()) {
70         return errors::InvalidArgument("Unable to find type_list_attr ",
71                                        output_arg.type_list_attr(),
72                                        " for Op: ", op_name);
73       }
74       *num_retvals += iter->second.list().type_size();
75     } else {
76       *num_retvals += 1;
77     }
78   }
79 
80   return Status::OK();
81 }
82 }  // namespace
83 
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)84 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
85                                        CreateContextResponse* response) {
86   // make sure env_ , env_->rendezvous_mgr available
87   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
88     return tensorflow::errors::Internal(
89         "invalid eager env_ or env_->rendezvous_mgr.");
90   }
91   std::vector<std::unique_ptr<tensorflow::Device>> devices;
92 
93   TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
94       // TODO(nareshmodi): Correctly set the SessionOptions.
95       SessionOptions(),
96       strings::Printf("/job:%s/replica:0/task:%d",
97                       request->server_def().job_name().data(),
98                       request->server_def().task_index()),
99       &devices));
100   response->mutable_device_attributes()->Reserve(devices.size());
101   for (const auto& d : devices) {
102     *response->add_device_attributes() = d->attributes();
103   }
104 
105   std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
106       absl::make_unique<DeviceMgr>(std::move(devices));
107 
108   auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
109   auto session_name = strings::StrCat("eager_", request->rendezvous_id());
110   TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
111       session_name, request->server_def(), true));
112 
113   std::shared_ptr<WorkerSession> worker_session;
114   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
115       session_name, &worker_session));
116 
117   // Initialize remote tensor communication based on worker session.
118   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
119 
120   std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
121       SessionOptions(),
122       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
123       request->async(), std::move(device_mgr), r));
124 
125   uint64 context_id;
126   {
127     mutex_lock l(contexts_mu_);
128     do {
129       context_id = random::New64();
130     } while (contexts_.find(context_id) != contexts_.end());
131     contexts_.emplace(
132         context_id,
133         new ServerContext(std::move(ctx), request->keep_alive_secs(), env_));
134   }
135   response->set_context_id(context_id);
136 
137   return Status::OK();
138 }
139 
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)140 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
141   const tensorflow::Tensor* t = nullptr;
142 
143   // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
144   TF_RETURN_IF_ERROR(handle->Tensor(&t));
145 
146   t->shape().AsProto(proto);
147 
148   return Status::OK();
149 }
150 
ExecuteOp(const Operation & operation,ServerContext * server_context,QueueResponse * queue_response)151 Status EagerServiceImpl::ExecuteOp(const Operation& operation,
152                                    ServerContext* server_context,
153                                    QueueResponse* queue_response) {
154   std::unique_ptr<tensorflow::EagerOperation> op;
155   const char* name = operation.name().c_str();  // Shorthand
156   const tensorflow::AttrTypeMap* types;
157   bool is_function = false;
158   TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(name, &types, &is_function));
159   if (is_function && !server_context->Context()->FindFunctionByName(name)) {
160     return errors::NotFound(
161         "'", name,
162         "' is neither a type of a primitive operation nor a name "
163         "of a function registered in binary running on ",
164         port::Hostname(),
165         ". Make sure the operation or function is "
166         "registered in the binary running in this process.");
167   }
168   op.reset(new tensorflow::EagerOperation(server_context->Context(), name,
169                                           is_function, types));
170 
171   TF_RETURN_IF_ERROR(op->SetDevice(operation.device().c_str()));
172 
173   for (const auto& remote_handle : operation.inputs()) {
174     tensorflow::TensorHandle* handle;
175     TF_RETURN_IF_ERROR(server_context->GetTensorHandle(
176         RemoteTensorHandleInternal(remote_handle), &handle));
177 
178     op->AddInput(handle);
179   }
180 
181   for (const auto& attr : operation.attrs()) {
182     op->MutableAttrs()->Set(attr.first, attr.second);
183   }
184 
185   int num_retvals = 0;
186   // TODO(nareshmodi): Consider caching this.
187   TF_RETURN_IF_ERROR(GetNumRetvals(server_context->Context(), operation.name(),
188                                    operation.attrs(), &num_retvals));
189 
190   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals;
191   TF_RETURN_IF_ERROR(EagerExecute(op.get(), &retvals, &num_retvals));
192 
193   server_context->AddOperationOutputs(retvals, operation.id());
194 
195   for (auto* handle : retvals) {
196     TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
197   }
198 
199   return Status::OK();
200 }
201 
Enqueue(const EnqueueRequest * request,EnqueueResponse * response)202 Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
203                                  EnqueueResponse* response) {
204   ServerContext* context = nullptr;
205   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
206   core::ScopedUnref context_unref(context);
207 
208   for (const auto& item : request->queue()) {
209     auto* queue_response = response->add_queue_response();
210     if (item.has_operation()) {
211       TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
212     } else {
213       TF_RETURN_IF_ERROR(context->DeleteTensorHandle(
214           RemoteTensorHandleInternal(item.handle_to_decref())));
215     }
216   }
217 
218   return Status::OK();
219 }
220 
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)221 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
222                                        WaitQueueDoneResponse* response) {
223   ServerContext* context = nullptr;
224   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
225   core::ScopedUnref context_unref(context);
226 
227   if (request->op_id_size() > 0) {
228     return errors::Unimplemented(
229         "EagerServiceImpl::WaitQueueDone is not "
230         "implemented for particular op IDs.");
231   }
232   return context->Context()->AsyncWait();
233 }
234 
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)235 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
236                                    KeepAliveResponse* response) {
237   ServerContext* context = nullptr;
238   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
239   core::ScopedUnref context_unref(context);
240 
241   return Status::OK();
242 }
243 
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)244 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
245                                       CloseContextResponse* response) {
246   ServerContext* context = nullptr;
247   if (!GetServerContext(request->context_id(), &context).ok()) {
248     // Swallow the error here.
249     return Status::OK();
250   }
251 
252   core::ScopedUnref context_unref(context);
253 
254   mutex_lock l(contexts_mu_);
255   contexts_.erase(request->context_id());
256 
257   // GetServerContext returns a newly Reffed copy of ServerContext, which is
258   // unreffed by context_unref. Additionally, we need to unref it one time since
259   // we are releasing it from the map.
260   context->Unref();
261 
262   return Status::OK();
263 }
264 
RegisterFunction(const RegisterFunctionRequest * request,RegisterFunctionResponse * response)265 Status EagerServiceImpl::RegisterFunction(
266     const RegisterFunctionRequest* request,
267     RegisterFunctionResponse* response) {
268   ServerContext* context = nullptr;
269   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
270   core::ScopedUnref context_unref(context);
271 
272   return context->Context()->AddFunctionDef(request->function_def());
273 }
274 
SendTensor(const SendTensorRequest * request,SendTensorResponse * response)275 Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
276                                     SendTensorResponse* response) {
277   ServerContext* context = nullptr;
278   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
279   core::ScopedUnref context_unref(context);
280 
281   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
282   for (const auto& tensor_proto : request->tensors()) {
283     Tensor tensor;
284     if (!tensor.FromProto(tensor_proto)) {
285       return errors::InvalidArgument("Unable to parse tensor proto");
286     }
287 
288     TensorHandle* tensor_handle =
289         new TensorHandle(tensor, nullptr, nullptr, nullptr);
290 
291     TensorHandle* copied_handle = nullptr;
292     TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
293                                          request->device_name().c_str(),
294                                          &copied_handle));
295     tensors.push_back(copied_handle);
296     tensor_handle->Unref();
297   }
298 
299   context->AddOperationOutputs(tensors, request->op_id());
300 
301   return Status::OK();
302 }
303 
GetServerContext(uint64 context_id,ServerContext ** server_context)304 tensorflow::Status EagerServiceImpl::GetServerContext(
305     uint64 context_id, ServerContext** server_context) {
306   mutex_lock l(contexts_mu_);
307   auto iter = contexts_.find(context_id);
308   if (iter == contexts_.end()) {
309     *server_context = nullptr;
310     return errors::InvalidArgument(strings::Printf(
311         "Unable to find a context_id matching the specified one "
312         "(%lld). Perhaps the worker was restarted, or the context was GC'd?",
313         context_id));
314   }
315 
316   *server_context = iter->second;
317   (*server_context)->Ref();
318 
319   (*server_context)->RecordAccess();
320 
321   return Status::OK();
322 }
323 
324 }  // namespace eager
325 }  // namespace tensorflow
326