• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 
24 namespace tensorflow {
25 namespace eager {
26 
AddOperationOutputs(const gtl::ArraySlice<tensorflow::TensorHandle * > handles,int64 operation_id)27 void RemoteMgr::AddOperationOutputs(
28     const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
29     int64 operation_id) {
30   mutex_lock l(remote_tensor_handle_mu_);
31   for (int i = 0, end = handles.size(); i < end; i++) {
32     // TODO(nareshmodi): Correctly handle operation_id not being unique.
33     remote_tensor_handle_map_.emplace(
34         RemoteTensorHandleInternal(operation_id, i), handles[i]);
35   }
36 }
37 
AddOperationOutput(tensorflow::TensorHandle * handle,int64 operation_id,int32 output_num)38 void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle,
39                                    int64 operation_id, int32 output_num) {
40   mutex_lock l(remote_tensor_handle_mu_);
41   remote_tensor_handle_map_.emplace(
42       RemoteTensorHandleInternal(operation_id, output_num), handle);
43 }
44 
GetTensorHandleImpl(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)45 Status RemoteMgr::GetTensorHandleImpl(
46     const RemoteTensorHandleInternal& remote_handle,
47     tensorflow::TensorHandle** handle) {
48   auto iter = remote_tensor_handle_map_.find(remote_handle);
49   if (iter == remote_tensor_handle_map_.end()) {
50     return errors::InvalidArgument(
51         "Unable to find the relevant tensor remote_handle: Op ID: ",
52         remote_handle.op_id, ", Output num: ", remote_handle.output_num);
53   }
54 
55   *handle = iter->second;
56 
57   return Status::OK();
58 }
59 
GetTensorHandle(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)60 Status RemoteMgr::GetTensorHandle(
61     const RemoteTensorHandleInternal& remote_handle,
62     tensorflow::TensorHandle** handle) {
63   tf_shared_lock l(remote_tensor_handle_mu_);
64   return GetTensorHandleImpl(remote_handle, handle);
65 }
66 
GetMirroredResourceShape(const RemoteTensorHandleInternal & remote_handle,std::vector<DtypeAndPartialTensorShape> * handle)67 Status RemoteMgr::GetMirroredResourceShape(
68     const RemoteTensorHandleInternal& remote_handle,
69     std::vector<DtypeAndPartialTensorShape>* handle) {
70   tf_shared_lock l(mirrored_resource_shape_mu_);
71   auto iter = mirrored_resource_shape_map_.find(remote_handle);
72   if (iter == mirrored_resource_shape_map_.end()) {
73     return errors::InvalidArgument(
74         "Unable to find the relevant mirrored resource shape: Op ID: ",
75         remote_handle.op_id, ", Output num: ", remote_handle.output_num);
76   }
77 
78   *handle = iter->second;
79 
80   return Status::OK();
81 }
82 
GetRemoteTensorHandle(const tensorflow::TensorHandle * handle,const bool wait_until_ready,int64 * op_id,int32 * output_num)83 Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
84                                         const bool wait_until_ready,
85                                         int64* op_id, int32* output_num) {
86   TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready,
87                                            op_id, output_num));
88   tensorflow::TensorHandle* h;
89   TF_RETURN_IF_ERROR(
90       GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
91   if (handle != h) {
92     return errors::Internal(
93         "Found two different tensor handles with the same op_id:", *op_id,
94         " and output_num:", *output_num);
95   }
96   return Status::OK();
97 }
98 
DeleteTensorHandle(const RemoteTensorHandleInternal & remote_handle)99 Status RemoteMgr::DeleteTensorHandle(
100     const RemoteTensorHandleInternal& remote_handle) {
101   {
102     mutex_lock l(remote_tensor_handle_mu_);
103     auto iter = remote_tensor_handle_map_.find(remote_handle);
104     if (iter != remote_tensor_handle_map_.end()) {
105       iter->second->Unref();
106       remote_tensor_handle_map_.erase(iter);
107       return Status::OK();
108     }
109   }
110   {
111     mutex_lock l(mirrored_resource_shape_mu_);
112     auto iter = mirrored_resource_shape_map_.find(remote_handle);
113     if (iter != mirrored_resource_shape_map_.end()) {
114       mirrored_resource_shape_map_.erase(iter);
115       return Status::OK();
116     }
117   }
118   return errors::InvalidArgument(
119       "Unable to find the relevant tensor remote_handle: Op ID: ",
120       remote_handle.op_id, ", Output num: ", remote_handle.output_num);
121 }
122 
SerializeRemoteTensorHandle(TensorHandle * in,const bool wait_until_ready,RemoteTensorHandle * out,Device * device,const string & device_name,const bool serialize_resource_dtype_and_shape)123 Status RemoteMgr::SerializeRemoteTensorHandle(
124     TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
125     Device* device, const string& device_name,
126     const bool serialize_resource_dtype_and_shape) {
127   int64 op_id;
128   int32 output_num;
129   if (!in->RemoteAddress(device, wait_until_ready, &op_id, &output_num).ok()) {
130     tf_shared_lock l(remote_tensor_handle_mu_);
131     TF_RETURN_IF_ERROR(
132         GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num));
133   }
134   out->Clear();
135   out->set_op_id(op_id);
136   out->set_output_num(output_num);
137   out->set_op_device(in->op_device() ? in->op_device()->name() : "");
138   out->set_device(device_name);
139   out->set_dtype(in->dtype);
140   if (serialize_resource_dtype_and_shape) {
141     std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
142     TF_RETURN_IF_ERROR(
143         in->GetResourceHandleDtypesAndShapes(&resource_dtypes_and_shapes));
144     for (const auto& dtype_and_shape : resource_dtypes_and_shapes) {
145       ResourceDtypeAndShape* dtype_and_shape_proto =
146           out->add_resource_dtypes_and_shapes();
147       dtype_and_shape_proto->set_dtype(dtype_and_shape.dtype);
148       dtype_and_shape.shape.AsProto(dtype_and_shape_proto->mutable_shape());
149     }
150   }
151   return Status::OK();
152 }
153 
DeserializeRemoteTensorHandle(const RemoteTensorHandle & in,TensorHandle ** out)154 Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
155                                                 TensorHandle** out) {
156   Device* device;
157   if (parent_->local_device_mgr()->LookupDevice(in.op_device(), &device).ok() ||
158       parent_->local_device_mgr()->LookupDevice(in.device(), &device).ok()) {
159     TF_RETURN_IF_ERROR(GetTensorHandle(RemoteTensorHandleInternal(in), out));
160     (*out)->Ref();
161   } else {
162     // Create a remote TensorHandle for remote tensors which have not been
163     // copied to the local worker yet (e.g. remote function inputs).
164     const string& device_name =
165         in.op_device().empty() ? in.device() : in.op_device();
166     TF_RETURN_IF_ERROR(
167         parent_->FindDeviceFromName(device_name.c_str(), &device));
168     *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
169                                                 in.dtype(), device,
170                                                 /*is_ready=*/true, parent_);
171     std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
172     if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
173                                   &dtypes_and_shapes)
174              .ok()) {
175       for (const auto& dtype_and_shape_proto :
176            in.resource_dtypes_and_shapes()) {
177         dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
178             dtype_and_shape_proto.dtype(),
179             TensorShape(dtype_and_shape_proto.shape())});
180       }
181       mutex_lock l(mirrored_resource_shape_mu_);
182       mirrored_resource_shape_map_.emplace(
183           RemoteTensorHandleInternal(in.op_id(), in.output_num()),
184           dtypes_and_shapes);
185     }
186     (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
187   }
188 
189   return Status::OK();
190 }
191 
GetOrCreateExecutorForStream(uint64 stream_id)192 EagerExecutor& RemoteMgr::GetOrCreateExecutorForStream(uint64 stream_id) {
193   mutex_lock l(executor_map_mu_);
194   auto it = executor_map_.find(stream_id);
195   if (it == executor_map_.end()) {
196     auto it_and_bool = executor_map_.emplace(
197         std::piecewise_construct, std::forward_as_tuple(stream_id),
198         std::forward_as_tuple(/*async=*/true));
199     DCHECK(it_and_bool.second);
200     it = it_and_bool.first;
201   }
202   return it->second;
203 }
204 
DeleteExecutorForStream(uint64 stream_id)205 void RemoteMgr::DeleteExecutorForStream(uint64 stream_id) {
206   mutex_lock l(executor_map_mu_);
207   auto it = executor_map_.find(stream_id);
208   if (it == executor_map_.end()) {
209     return;
210   }
211   Status s = it->second.ShutDown();
212   if (!s.ok()) {
213     LOG(ERROR) << "EagerExecutor shutdown with error " << s.error_message();
214   }
215   executor_map_.erase(it);
216 }
217 
218 }  // namespace eager
219 }  // namespace tensorflow
220