1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/c/c_api_internal.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
23 #include "tensorflow/core/common_runtime/eager/execute.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
26 #include "tensorflow/core/distributed_runtime/server_lib.h"
27 #include "tensorflow/core/distributed_runtime/session_mgr.h"
28 #include "tensorflow/core/distributed_runtime/worker_cache.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
30 #include "tensorflow/core/distributed_runtime/worker_env.h"
31 #include "tensorflow/core/framework/rendezvous.h"
32 #include "tensorflow/core/lib/core/error_codes.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/random/random.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/lib/strings/stringprintf.h"
38 #include "tensorflow/core/platform/cpu_info.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/host_info.h"
41
42 namespace tensorflow {
43 namespace eager {
44
45 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)46 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
47 const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
48 int* num_retvals) {
49 const tensorflow::OpRegistrationData* op_reg_data = nullptr;
50 auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
51 if (errors::IsNotFound(status)) {
52 status = context->FindFunctionOpData(op_name, &op_reg_data);
53 }
54 TF_RETURN_IF_ERROR(status);
55
56 const tensorflow::OpDef& op_def = op_reg_data->op_def;
57
58 for (const auto& output_arg : op_def.output_arg()) {
59 if (!output_arg.number_attr().empty()) {
60 auto iter = attrs.find(output_arg.number_attr());
61 if (iter == attrs.end()) {
62 return errors::InvalidArgument("Unable to find number_attr ",
63 output_arg.number_attr(),
64 " for Op: ", op_name);
65 }
66 *num_retvals += iter->second.i();
67 } else if (!output_arg.type_list_attr().empty()) {
68 auto iter = attrs.find(output_arg.type_list_attr());
69 if (iter == attrs.end()) {
70 return errors::InvalidArgument("Unable to find type_list_attr ",
71 output_arg.type_list_attr(),
72 " for Op: ", op_name);
73 }
74 *num_retvals += iter->second.list().type_size();
75 } else {
76 *num_retvals += 1;
77 }
78 }
79
80 return Status::OK();
81 }
82 } // namespace
83
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)84 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
85 CreateContextResponse* response) {
86 // make sure env_ , env_->rendezvous_mgr available
87 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
88 return tensorflow::errors::Internal(
89 "invalid eager env_ or env_->rendezvous_mgr.");
90 }
91 std::vector<std::unique_ptr<tensorflow::Device>> devices;
92
93 TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
94 // TODO(nareshmodi): Correctly set the SessionOptions.
95 SessionOptions(),
96 strings::Printf("/job:%s/replica:0/task:%d",
97 request->server_def().job_name().data(),
98 request->server_def().task_index()),
99 &devices));
100 response->mutable_device_attributes()->Reserve(devices.size());
101 for (const auto& d : devices) {
102 *response->add_device_attributes() = d->attributes();
103 }
104
105 std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
106 absl::make_unique<DeviceMgr>(std::move(devices));
107
108 auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
109 auto session_name = strings::StrCat("eager_", request->rendezvous_id());
110 TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
111 session_name, request->server_def(), true));
112
113 std::shared_ptr<WorkerSession> worker_session;
114 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
115 session_name, &worker_session));
116
117 // Initialize remote tensor communication based on worker session.
118 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
119
120 std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
121 SessionOptions(),
122 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
123 request->async(), std::move(device_mgr), r));
124
125 uint64 context_id;
126 {
127 mutex_lock l(contexts_mu_);
128 do {
129 context_id = random::New64();
130 } while (contexts_.find(context_id) != contexts_.end());
131 contexts_.emplace(
132 context_id,
133 new ServerContext(std::move(ctx), request->keep_alive_secs(), env_));
134 }
135 response->set_context_id(context_id);
136
137 return Status::OK();
138 }
139
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)140 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
141 const tensorflow::Tensor* t = nullptr;
142
143 // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
144 TF_RETURN_IF_ERROR(handle->Tensor(&t));
145
146 t->shape().AsProto(proto);
147
148 return Status::OK();
149 }
150
ExecuteOp(const Operation & operation,ServerContext * server_context,QueueResponse * queue_response)151 Status EagerServiceImpl::ExecuteOp(const Operation& operation,
152 ServerContext* server_context,
153 QueueResponse* queue_response) {
154 std::unique_ptr<tensorflow::EagerOperation> op;
155 const char* name = operation.name().c_str(); // Shorthand
156 const tensorflow::AttrTypeMap* types;
157 bool is_function = false;
158 TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(name, &types, &is_function));
159 if (is_function && !server_context->Context()->FindFunctionByName(name)) {
160 return errors::NotFound(
161 "'", name,
162 "' is neither a type of a primitive operation nor a name "
163 "of a function registered in binary running on ",
164 port::Hostname(),
165 ". Make sure the operation or function is "
166 "registered in the binary running in this process.");
167 }
168 op.reset(new tensorflow::EagerOperation(server_context->Context(), name,
169 is_function, types));
170
171 TF_RETURN_IF_ERROR(op->SetDevice(operation.device().c_str()));
172
173 for (const auto& remote_handle : operation.inputs()) {
174 tensorflow::TensorHandle* handle;
175 TF_RETURN_IF_ERROR(server_context->GetTensorHandle(
176 RemoteTensorHandleInternal(remote_handle), &handle));
177
178 op->AddInput(handle);
179 }
180
181 for (const auto& attr : operation.attrs()) {
182 op->MutableAttrs()->Set(attr.first, attr.second);
183 }
184
185 int num_retvals = 0;
186 // TODO(nareshmodi): Consider caching this.
187 TF_RETURN_IF_ERROR(GetNumRetvals(server_context->Context(), operation.name(),
188 operation.attrs(), &num_retvals));
189
190 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals;
191 TF_RETURN_IF_ERROR(EagerExecute(op.get(), &retvals, &num_retvals));
192
193 server_context->AddOperationOutputs(retvals, operation.id());
194
195 for (auto* handle : retvals) {
196 TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
197 }
198
199 return Status::OK();
200 }
201
Enqueue(const EnqueueRequest * request,EnqueueResponse * response)202 Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
203 EnqueueResponse* response) {
204 ServerContext* context = nullptr;
205 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
206 core::ScopedUnref context_unref(context);
207
208 for (const auto& item : request->queue()) {
209 auto* queue_response = response->add_queue_response();
210 if (item.has_operation()) {
211 TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
212 } else {
213 TF_RETURN_IF_ERROR(context->DeleteTensorHandle(
214 RemoteTensorHandleInternal(item.handle_to_decref())));
215 }
216 }
217
218 return Status::OK();
219 }
220
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)221 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
222 WaitQueueDoneResponse* response) {
223 ServerContext* context = nullptr;
224 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
225 core::ScopedUnref context_unref(context);
226
227 if (request->op_id_size() > 0) {
228 return errors::Unimplemented(
229 "EagerServiceImpl::WaitQueueDone is not "
230 "implemented for particular op IDs.");
231 }
232 return context->Context()->AsyncWait();
233 }
234
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)235 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
236 KeepAliveResponse* response) {
237 ServerContext* context = nullptr;
238 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
239 core::ScopedUnref context_unref(context);
240
241 return Status::OK();
242 }
243
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)244 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
245 CloseContextResponse* response) {
246 ServerContext* context = nullptr;
247 if (!GetServerContext(request->context_id(), &context).ok()) {
248 // Swallow the error here.
249 return Status::OK();
250 }
251
252 core::ScopedUnref context_unref(context);
253
254 mutex_lock l(contexts_mu_);
255 contexts_.erase(request->context_id());
256
257 // GetServerContext returns a newly Reffed copy of ServerContext, which is
258 // unreffed by context_unref. Additionally, we need to unref it one time since
259 // we are releasing it from the map.
260 context->Unref();
261
262 return Status::OK();
263 }
264
RegisterFunction(const RegisterFunctionRequest * request,RegisterFunctionResponse * response)265 Status EagerServiceImpl::RegisterFunction(
266 const RegisterFunctionRequest* request,
267 RegisterFunctionResponse* response) {
268 ServerContext* context = nullptr;
269 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
270 core::ScopedUnref context_unref(context);
271
272 return context->Context()->AddFunctionDef(request->function_def());
273 }
274
SendTensor(const SendTensorRequest * request,SendTensorResponse * response)275 Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
276 SendTensorResponse* response) {
277 ServerContext* context = nullptr;
278 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
279 core::ScopedUnref context_unref(context);
280
281 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
282 for (const auto& tensor_proto : request->tensors()) {
283 Tensor tensor;
284 if (!tensor.FromProto(tensor_proto)) {
285 return errors::InvalidArgument("Unable to parse tensor proto");
286 }
287
288 TensorHandle* tensor_handle =
289 new TensorHandle(tensor, nullptr, nullptr, nullptr);
290
291 TensorHandle* copied_handle = nullptr;
292 TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
293 request->device_name().c_str(),
294 &copied_handle));
295 tensors.push_back(copied_handle);
296 tensor_handle->Unref();
297 }
298
299 context->AddOperationOutputs(tensors, request->op_id());
300
301 return Status::OK();
302 }
303
GetServerContext(uint64 context_id,ServerContext ** server_context)304 tensorflow::Status EagerServiceImpl::GetServerContext(
305 uint64 context_id, ServerContext** server_context) {
306 mutex_lock l(contexts_mu_);
307 auto iter = contexts_.find(context_id);
308 if (iter == contexts_.end()) {
309 *server_context = nullptr;
310 return errors::InvalidArgument(strings::Printf(
311 "Unable to find a context_id matching the specified one "
312 "(%lld). Perhaps the worker was restarted, or the context was GC'd?",
313 context_id));
314 }
315
316 *server_context = iter->second;
317 (*server_context)->Ref();
318
319 (*server_context)->RecordAccess();
320
321 return Status::OK();
322 }
323
324 } // namespace eager
325 } // namespace tensorflow
326