1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
17
18 #include <memory>
19
20 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23
24 namespace tensorflow {
25 namespace eager {
26
AddOperationOutputs(const gtl::ArraySlice<tensorflow::TensorHandle * > handles,int64 operation_id)27 void RemoteMgr::AddOperationOutputs(
28 const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
29 int64 operation_id) {
30 mutex_lock l(remote_tensor_handle_mu_);
31 for (int i = 0, end = handles.size(); i < end; i++) {
32 // TODO(nareshmodi): Correctly handle operation_id not being unique.
33 remote_tensor_handle_map_.emplace(
34 RemoteTensorHandleInternal(operation_id, i), handles[i]);
35 }
36 }
37
AddOperationOutput(tensorflow::TensorHandle * handle,int64 operation_id,int32 output_num)38 void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle,
39 int64 operation_id, int32 output_num) {
40 mutex_lock l(remote_tensor_handle_mu_);
41 remote_tensor_handle_map_.emplace(
42 RemoteTensorHandleInternal(operation_id, output_num), handle);
43 }
44
GetTensorHandleImpl(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)45 Status RemoteMgr::GetTensorHandleImpl(
46 const RemoteTensorHandleInternal& remote_handle,
47 tensorflow::TensorHandle** handle) {
48 auto iter = remote_tensor_handle_map_.find(remote_handle);
49 if (iter == remote_tensor_handle_map_.end()) {
50 return errors::InvalidArgument(
51 "Unable to find the relevant tensor remote_handle: Op ID: ",
52 remote_handle.op_id, ", Output num: ", remote_handle.output_num);
53 }
54
55 *handle = iter->second;
56
57 return Status::OK();
58 }
59
GetTensorHandle(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)60 Status RemoteMgr::GetTensorHandle(
61 const RemoteTensorHandleInternal& remote_handle,
62 tensorflow::TensorHandle** handle) {
63 tf_shared_lock l(remote_tensor_handle_mu_);
64 return GetTensorHandleImpl(remote_handle, handle);
65 }
66
GetMirroredResourceShape(const RemoteTensorHandleInternal & remote_handle,std::vector<DtypeAndPartialTensorShape> * handle)67 Status RemoteMgr::GetMirroredResourceShape(
68 const RemoteTensorHandleInternal& remote_handle,
69 std::vector<DtypeAndPartialTensorShape>* handle) {
70 tf_shared_lock l(mirrored_resource_shape_mu_);
71 auto iter = mirrored_resource_shape_map_.find(remote_handle);
72 if (iter == mirrored_resource_shape_map_.end()) {
73 return errors::InvalidArgument(
74 "Unable to find the relevant mirrored resource shape: Op ID: ",
75 remote_handle.op_id, ", Output num: ", remote_handle.output_num);
76 }
77
78 *handle = iter->second;
79
80 return Status::OK();
81 }
82
GetRemoteTensorHandle(const tensorflow::TensorHandle * handle,const bool wait_until_ready,int64 * op_id,int32 * output_num)83 Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
84 const bool wait_until_ready,
85 int64* op_id, int32* output_num) {
86 TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready,
87 op_id, output_num));
88 tensorflow::TensorHandle* h;
89 TF_RETURN_IF_ERROR(
90 GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
91 if (handle != h) {
92 return errors::Internal(
93 "Found two different tensor handles with the same op_id:", *op_id,
94 " and output_num:", *output_num);
95 }
96 return Status::OK();
97 }
98
DeleteTensorHandle(const RemoteTensorHandleInternal & remote_handle)99 Status RemoteMgr::DeleteTensorHandle(
100 const RemoteTensorHandleInternal& remote_handle) {
101 {
102 mutex_lock l(remote_tensor_handle_mu_);
103 auto iter = remote_tensor_handle_map_.find(remote_handle);
104 if (iter != remote_tensor_handle_map_.end()) {
105 iter->second->Unref();
106 remote_tensor_handle_map_.erase(iter);
107 return Status::OK();
108 }
109 }
110 {
111 mutex_lock l(mirrored_resource_shape_mu_);
112 auto iter = mirrored_resource_shape_map_.find(remote_handle);
113 if (iter != mirrored_resource_shape_map_.end()) {
114 mirrored_resource_shape_map_.erase(iter);
115 return Status::OK();
116 }
117 }
118 return errors::InvalidArgument(
119 "Unable to find the relevant tensor remote_handle: Op ID: ",
120 remote_handle.op_id, ", Output num: ", remote_handle.output_num);
121 }
122
SerializeRemoteTensorHandle(TensorHandle * in,const bool wait_until_ready,RemoteTensorHandle * out,Device * device,const string & device_name,const bool serialize_resource_dtype_and_shape)123 Status RemoteMgr::SerializeRemoteTensorHandle(
124 TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
125 Device* device, const string& device_name,
126 const bool serialize_resource_dtype_and_shape) {
127 int64 op_id;
128 int32 output_num;
129 if (!in->RemoteAddress(device, wait_until_ready, &op_id, &output_num).ok()) {
130 tf_shared_lock l(remote_tensor_handle_mu_);
131 TF_RETURN_IF_ERROR(
132 GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num));
133 }
134 out->Clear();
135 out->set_op_id(op_id);
136 out->set_output_num(output_num);
137 out->set_op_device(in->op_device() ? in->op_device()->name() : "");
138 out->set_device(device_name);
139 out->set_dtype(in->dtype);
140 if (serialize_resource_dtype_and_shape) {
141 std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
142 TF_RETURN_IF_ERROR(
143 in->GetResourceHandleDtypesAndShapes(&resource_dtypes_and_shapes));
144 for (const auto& dtype_and_shape : resource_dtypes_and_shapes) {
145 ResourceDtypeAndShape* dtype_and_shape_proto =
146 out->add_resource_dtypes_and_shapes();
147 dtype_and_shape_proto->set_dtype(dtype_and_shape.dtype);
148 dtype_and_shape.shape.AsProto(dtype_and_shape_proto->mutable_shape());
149 }
150 }
151 return Status::OK();
152 }
153
DeserializeRemoteTensorHandle(const RemoteTensorHandle & in,TensorHandle ** out)154 Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
155 TensorHandle** out) {
156 Device* device;
157 if (parent_->local_device_mgr()->LookupDevice(in.op_device(), &device).ok() ||
158 parent_->local_device_mgr()->LookupDevice(in.device(), &device).ok()) {
159 TF_RETURN_IF_ERROR(GetTensorHandle(RemoteTensorHandleInternal(in), out));
160 (*out)->Ref();
161 } else {
162 // Create a remote TensorHandle for remote tensors which have not been
163 // copied to the local worker yet (e.g. remote function inputs).
164 const string& device_name =
165 in.op_device().empty() ? in.device() : in.op_device();
166 TF_RETURN_IF_ERROR(
167 parent_->FindDeviceFromName(device_name.c_str(), &device));
168 *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
169 in.dtype(), device,
170 /*is_ready=*/true, parent_);
171 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
172 if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
173 &dtypes_and_shapes)
174 .ok()) {
175 for (const auto& dtype_and_shape_proto :
176 in.resource_dtypes_and_shapes()) {
177 dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
178 dtype_and_shape_proto.dtype(),
179 TensorShape(dtype_and_shape_proto.shape())});
180 }
181 mutex_lock l(mirrored_resource_shape_mu_);
182 mirrored_resource_shape_map_.emplace(
183 RemoteTensorHandleInternal(in.op_id(), in.output_num()),
184 dtypes_and_shapes);
185 }
186 (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
187 }
188
189 return Status::OK();
190 }
191
GetOrCreateExecutorForStream(uint64 stream_id)192 EagerExecutor& RemoteMgr::GetOrCreateExecutorForStream(uint64 stream_id) {
193 mutex_lock l(executor_map_mu_);
194 auto it = executor_map_.find(stream_id);
195 if (it == executor_map_.end()) {
196 auto it_and_bool = executor_map_.emplace(
197 std::piecewise_construct, std::forward_as_tuple(stream_id),
198 std::forward_as_tuple(/*async=*/true));
199 DCHECK(it_and_bool.second);
200 it = it_and_bool.first;
201 }
202 return it->second;
203 }
204
DeleteExecutorForStream(uint64 stream_id)205 void RemoteMgr::DeleteExecutorForStream(uint64 stream_id) {
206 mutex_lock l(executor_map_mu_);
207 auto it = executor_map_.find(stream_id);
208 if (it == executor_map_.end()) {
209 return;
210 }
211 Status s = it->second.ShutDown();
212 if (!s.ok()) {
213 LOG(ERROR) << "EagerExecutor shutdown with error " << s.error_message();
214 }
215 executor_map_.erase(it);
216 }
217
218 } // namespace eager
219 } // namespace tensorflow
220