1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/rendezvous_util.h"
21 #include "tensorflow/core/lib/gtl/map_util.h"
22 #include "tensorflow/core/util/device_name_utils.h"
23
24 namespace tensorflow {
25
26 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
27
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,DistributedFunctionLibraryRuntime * parent)28 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
29 const DeviceMgr* device_mgr, Env* env, int graph_def_version,
30 const FunctionLibraryDefinition* lib_def,
31 const OptimizerOptions& optimizer_options,
32 DistributedFunctionLibraryRuntime* parent)
33 : device_mgr_(device_mgr),
34 lib_def_(lib_def),
35 next_handle_(0),
36 parent_(parent) {
37 if (device_mgr == nullptr) {
38 flr_map_[nullptr] =
39 NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
40 lib_def, optimizer_options, this);
41 return;
42 }
43 for (Device* d : device_mgr->ListDevices()) {
44 flr_map_[d] =
45 NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version,
46 lib_def, optimizer_options, this);
47 }
48 }
49
ProcessFunctionLibraryRuntime(const DeviceMgr * device_mgr,Env * env,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,DistributedFunctionLibraryRuntime * parent)50 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
51 const DeviceMgr* device_mgr, Env* env, int graph_def_version,
52 const FunctionLibraryDefinition* lib_def,
53 const OptimizerOptions& optimizer_options,
54 CustomKernelCreator custom_kernel_creator,
55 DistributedFunctionLibraryRuntime* parent)
56 : device_mgr_(device_mgr),
57 lib_def_(lib_def),
58 next_handle_(0),
59 parent_(parent) {
60 if (device_mgr == nullptr) {
61 flr_map_[nullptr] = NewFunctionLibraryRuntime(
62 nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
63 std::move(custom_kernel_creator), this);
64 return;
65 }
66 for (Device* d : device_mgr->ListDevices()) {
67 flr_map_[d] = NewFunctionLibraryRuntime(
68 device_mgr, env, d, graph_def_version, lib_def, optimizer_options,
69 custom_kernel_creator, this);
70 }
71 }
72
73 /* static */
SendTensors(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,gtl::ArraySlice<Tensor> tensors_to_send,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,Rendezvous * rendezvous)74 Status ProcessFunctionLibraryRuntime::SendTensors(
75 const string& source_device, const string& target_device,
76 const string& key_prefix, int64 src_incarnation,
77 gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
78 const std::vector<AllocatorAttributes>& alloc_attrs,
79 Rendezvous* rendezvous) {
80 std::vector<string> keys;
81 for (int i = 0; i < tensors_to_send.size(); ++i) {
82 string name = strings::StrCat(key_prefix, i);
83 string key = Rendezvous::CreateKey(source_device, src_incarnation,
84 target_device, name, FrameAndIter(0, 0));
85 keys.push_back(key);
86 }
87 TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
88 rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
89 return Status::OK();
90 }
91
92 /* static */
ReceiveTensorsAsync(const string & source_device,const string & target_device,const string & key_prefix,int64 src_incarnation,int64 num_tensors,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,Rendezvous * rendezvous,std::vector<Tensor> * received_tensors,const StatusCallback & done)93 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
94 const string& source_device, const string& target_device,
95 const string& key_prefix, int64 src_incarnation, int64 num_tensors,
96 DeviceContext* device_context,
97 const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
98 std::vector<Tensor>* received_tensors, const StatusCallback& done) {
99 std::vector<string> keys;
100 for (int64 i = 0; i < num_tensors; ++i) {
101 string name = strings::StrCat(key_prefix, i);
102 string key = Rendezvous::CreateKey(source_device, src_incarnation,
103 target_device, name, FrameAndIter(0, 0));
104 keys.push_back(key);
105 }
106 RecvOutputsFromRendezvousAsync(
107 rendezvous, device_context, alloc_attrs, keys, received_tensors,
108 [done](const Status& status) { done(status); });
109 }
110
GetDeviceIncarnation(const string & device_name,int64 * incarnation)111 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
112 const string& device_name, int64* incarnation) {
113 FunctionLibraryRuntime* flr = GetFLR(device_name);
114 if (flr == nullptr) {
115 return errors::InvalidArgument("Device name: ", device_name, " not found");
116 }
117 *incarnation = flr->device()->attributes().incarnation();
118 return Status::OK();
119 }
120
GetDeviceContext(const string & device_name,DeviceContext ** device_context)121 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
122 const string& device_name, DeviceContext** device_context) {
123 *device_context = nullptr;
124 FunctionLibraryRuntime* flr = GetFLR(device_name);
125 if (flr == nullptr) {
126 return errors::InvalidArgument("Device name: ", device_name, " not found.");
127 }
128 Device* device = flr->device();
129 string device_type = device->parsed_name().type;
130 if (device_type == "CPU") return Status::OK();
131 if (device_type == "GPU") {
132 auto* dev_info = flr->device()->tensorflow_gpu_device_info();
133 if (dev_info) {
134 *device_context = dev_info->default_context;
135 return Status::OK();
136 }
137 }
138 return errors::Internal("Device type: ", device_type,
139 " is currently unsupported for remote ",
140 "function executions");
141 }
142
GetFLR(const string & device_name) const143 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
144 const string& device_name) const {
145 Device* device = nullptr;
146 if (device_name != kDefaultFLRDevice) {
147 if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
148 LOG(ERROR) << "Could not find device: " << device_name;
149 return nullptr;
150 }
151 }
152 const auto& iter = flr_map_.find(device);
153 if (iter == flr_map_.end()) {
154 LOG(ERROR) << "Could not find device: " << device_name;
155 return nullptr;
156 }
157 return iter->second.get();
158 }
159
AddHandle(const string & function_key,const string & device_name,FunctionLibraryRuntime::LocalHandle local_handle)160 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
161 const string& function_key, const string& device_name,
162 FunctionLibraryRuntime::LocalHandle local_handle) {
163 mutex_lock l(mu_);
164 FunctionLibraryRuntime::Handle h =
165 gtl::FindWithDefault(table_, function_key, kInvalidHandle);
166 if (h != kInvalidHandle) {
167 if (function_data_.count(h) != 0) return h;
168 }
169 h = next_handle_;
170 function_data_.insert({h, FunctionData(device_name, local_handle)});
171 table_[function_key] = h;
172 next_handle_++;
173 return h;
174 }
175
GetHandle(const string & function_key) const176 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
177 const string& function_key) const {
178 mutex_lock l(mu_);
179 FunctionLibraryRuntime::Handle h =
180 gtl::FindWithDefault(table_, function_key, kInvalidHandle);
181 if (h != kInvalidHandle) {
182 if (function_data_.count(h) == 0) return kInvalidHandle;
183 }
184 return h;
185 }
186
IsInstantiatedOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle)187 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
188 const string& device_name, FunctionLibraryRuntime::Handle handle) {
189 return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
190 }
191
192 FunctionLibraryRuntime::LocalHandle
GetHandleOnDevice(const string & device_name,FunctionLibraryRuntime::Handle handle)193 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
194 const string& device_name, FunctionLibraryRuntime::Handle handle) {
195 mutex_lock l(mu_);
196 if (function_data_.count(handle) == 0) {
197 return kInvalidLocalHandle;
198 }
199 const FunctionData& function_data = function_data_[handle];
200 if (function_data.target_device != device_name) {
201 return kInvalidLocalHandle;
202 }
203 return function_data.local_handle;
204 }
205
GetDeviceName(FunctionLibraryRuntime::Handle handle)206 string ProcessFunctionLibraryRuntime::GetDeviceName(
207 FunctionLibraryRuntime::Handle handle) {
208 mutex_lock l(mu_);
209 CHECK_EQ(1, function_data_.count(handle));
210 const FunctionData& function_data = function_data_[handle];
211 return function_data.target_device;
212 }
213
Instantiate(const string & function_name,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::Handle * handle)214 Status ProcessFunctionLibraryRuntime::Instantiate(
215 const string& function_name, AttrSlice attrs,
216 const FunctionLibraryRuntime::InstantiateOptions& options,
217 FunctionLibraryRuntime::Handle* handle) {
218 *handle = kInvalidHandle;
219 FunctionLibraryRuntime* flr = GetFLR(options.target);
220 if (flr != nullptr) {
221 return flr->Instantiate(function_name, attrs, options, handle);
222 }
223 if (parent_ == nullptr) {
224 return errors::Internal(
225 "Currently don't support instantiating functions on device: ",
226 options.target);
227 }
228 FunctionLibraryRuntime::Handle cluster_handle;
229 TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
230 options, &cluster_handle));
231 string function_key = Canonicalize(function_name, attrs);
232 *handle = AddHandle(function_key, options.target, cluster_handle);
233 return Status::OK();
234 }
235
RemoveHandle(FunctionLibraryRuntime::Handle handle)236 Status ProcessFunctionLibraryRuntime::RemoveHandle(
237 FunctionLibraryRuntime::Handle handle) {
238 mutex_lock l(mu_);
239 function_data_.erase(handle);
240 return Status::OK();
241 }
242
ReleaseHandle(FunctionLibraryRuntime::Handle handle)243 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
244 FunctionLibraryRuntime::Handle handle) {
245 FunctionLibraryRuntime* flr = nullptr;
246 string target_device;
247 {
248 mutex_lock l(mu_);
249 CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
250 target_device = function_data_[handle].target_device;
251 }
252 flr = GetFLR(target_device);
253 if (flr != nullptr) {
254 return flr->ReleaseHandle(handle);
255 }
256 return errors::InvalidArgument("Handle not found: ", handle);
257 }
258
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::Handle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)259 void ProcessFunctionLibraryRuntime::Run(
260 const FunctionLibraryRuntime::Options& opts,
261 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
262 std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
263 if (!opts.remote_execution) {
264 done(errors::InvalidArgument(
265 "ProcessFunctionLibraryRuntime::Run should only be called when there ",
266 "is a remote execution."));
267 return;
268 }
269
270 FunctionLibraryRuntime* flr = nullptr;
271 string target_device;
272 FunctionLibraryRuntime::LocalHandle local_handle;
273 {
274 mutex_lock l(mu_);
275 if (function_data_.count(handle) == 0) {
276 done(errors::NotFound("Handle: ", handle, " not found."));
277 return;
278 }
279 target_device = function_data_[handle].target_device;
280 local_handle = function_data_[handle].local_handle;
281 }
282 flr = GetFLR(target_device);
283 if (flr != nullptr) {
284 auto rendezvous = opts.rendezvous;
285 string source_device = opts.source_device;
286 DeviceContext* device_context;
287 Status s = GetDeviceContext(source_device, &device_context);
288 if (!s.ok()) {
289 done(s);
290 return;
291 }
292 int64 src_incarnation, target_incarnation;
293 s = GetDeviceIncarnation(source_device, &src_incarnation);
294 s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
295 if (!s.ok()) {
296 done(s);
297 return;
298 }
299
300 // Send the args over to the target device.
301 s = SendTensors(source_device, target_device, "arg_", src_incarnation, args,
302 device_context, opts.args_alloc_attrs, rendezvous);
303 if (!s.ok()) {
304 done(s);
305 return;
306 }
307 const std::vector<AllocatorAttributes>& rets_alloc_attrs =
308 opts.rets_alloc_attrs;
309 std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
310 flr->Run(opts, handle, args, remote_rets,
311 [source_device, target_device, target_incarnation, rendezvous,
312 device_context, rets_alloc_attrs, remote_rets, rets,
313 done](const Status& status) {
314 if (!status.ok()) {
315 delete remote_rets;
316 done(status);
317 return;
318 }
319 int64 num_returns = remote_rets->size();
320 delete remote_rets;
321 // Now receive the return values from the target.
322 ReceiveTensorsAsync(target_device, source_device, "ret_",
323 target_incarnation, num_returns,
324 device_context, rets_alloc_attrs, rendezvous,
325 rets, done);
326 });
327 return;
328 }
329 if (parent_ != nullptr) {
330 parent_->Run(opts, local_handle, args, rets, done);
331 return;
332 }
333 done(errors::Internal("Could not find device"));
334 }
335
Clone(Env * env,int graph_def_version,const OptimizerOptions & optimizer_options,CustomKernelCreator custom_kernel_creator,std::unique_ptr<FunctionLibraryDefinition> * out_lib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * out_pflr)336 Status ProcessFunctionLibraryRuntime::Clone(
337 Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
338 CustomKernelCreator custom_kernel_creator,
339 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
340 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) {
341 out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
342 out_pflr->reset(new ProcessFunctionLibraryRuntime(
343 device_mgr_, env, graph_def_version, out_lib_def->get(),
344 optimizer_options, std::move(custom_kernel_creator), parent_));
345 return Status::OK();
346 }
347
348 } // namespace tensorflow
349